File size: 7,930 Bytes
8bf4d58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""Unit tests for agents."""

import pytest
from unittest.mock import Mock, patch, AsyncMock
from src.agents.local_data_agent import LocalDataAgent
from src.agents.search_agent import SearchAgent
from src.agents.cloud_agent import CloudAgent
from src.agents.aggregator_agent import AggregatorAgent


@pytest.fixture
def mock_vector_store():
    """Mock vector store."""
    mock_store = Mock()
    mock_store.search.return_value = {
        "documents": ["Test document content"],
        "ids": ["doc1"],
        "metadatas": [{"source": "test"}],
        "distances": [0.1],
    }
    return mock_store


@pytest.fixture
def mock_web_search():
    """Mock web search."""
    mock_search = AsyncMock()
    mock_search.search.return_value = {
        "success": True,
        "results": [
            {
                "title": "Test Result",
                "url": "https://example.com",
                "content": "Test content",
            }
        ],
    }
    return mock_search


class TestLocalDataAgent:
    """Tests for LocalDataAgent."""

    @pytest.mark.asyncio
    async def test_retrieve_context(self, mock_vector_store):
        """Test context retrieval."""
        with patch("src.agents.local_data_agent.get_vector_store", return_value=mock_vector_store):
            agent = LocalDataAgent()
            context = await agent.retrieve_context("test query")
            assert "Test document content" in context
            assert "test" in context

    @pytest.mark.asyncio
    async def test_process_query(self, mock_vector_store):
        """Test query processing."""
        with patch("src.agents.local_data_agent.get_vector_store", return_value=mock_vector_store):
            with patch.object(LocalDataAgent, "_process_direct", new_callable=AsyncMock) as mock_process:
                mock_process.return_value = {
                    "success": True,
                    "answer": "Test answer",
                    "agent": "local_data_agent",
                }
                agent = LocalDataAgent()
                response = await agent.process("test query")
                assert response["success"] is True
                assert "answer" in response


class TestSearchAgent:
    """Tests for SearchAgent."""

    @pytest.mark.asyncio
    async def test_retrieve_context(self, mock_web_search):
        """Test web search context retrieval."""
        with patch("src.agents.search_agent.get_web_search", return_value=mock_web_search):
            agent = SearchAgent(use_planning=False)
            agent.web_search = mock_web_search
            context = await agent.retrieve_context("test query")
            assert "Test Result" in context or "test query" in context.lower()

    @pytest.mark.asyncio
    async def test_process_query(self, mock_web_search):
        """Test query processing with web search."""
        with patch("src.agents.search_agent.get_web_search", return_value=mock_web_search):
            with patch.object(SearchAgent, "_process_direct", new_callable=AsyncMock) as mock_process:
                mock_process.return_value = {
                    "success": True,
                    "answer": "Test answer from web",
                    "agent": "search_agent",
                }
                agent = SearchAgent(use_planning=False)
                agent.web_search = mock_web_search
                response = await agent.process("test query")
                assert response["success"] is True


class TestCloudAgent:
    """Tests for CloudAgent."""

    @pytest.mark.asyncio
    async def test_retrieve_context_no_config(self):
        """Test context retrieval when cloud is not configured."""
        agent = CloudAgent()
        context = await agent.retrieve_context("test query")
        assert "not configured" in context.lower()

    @pytest.mark.asyncio
    async def test_process_query(self):
        """Test query processing."""
        with patch.object(CloudAgent, "_process_direct", new_callable=AsyncMock) as mock_process:
            mock_process.return_value = {
                "success": True,
                "answer": "Test answer",
                "agent": "cloud_agent",
            }
            agent = CloudAgent()
            response = await agent.process("test query")
            assert response["success"] is True


class TestAggregatorAgent:
    """Tests for AggregatorAgent."""

    @pytest.mark.asyncio
    async def test_select_agents(self):
        """Test agent selection logic."""
        agent = AggregatorAgent(use_planning=False)
        
        # Test local document query
        selected = agent._select_agents("What is in the document?")
        assert "local" in selected

        # Test web search query
        selected = agent._select_agents("What is the latest news?")
        assert "search" in selected

        # Test cloud query
        selected = agent._select_agents("What files are in S3?")
        assert "cloud" in selected

    @pytest.mark.asyncio
    async def test_process_query(self):
        """Test query processing with multiple agents."""
        with patch.object(LocalDataAgent, "process", new_callable=AsyncMock) as mock_local:
            mock_local.return_value = {
                "success": True,
                "answer": "Local answer",
            }
            with patch.object(SearchAgent, "process", new_callable=AsyncMock) as mock_search:
                mock_search.return_value = {
                    "success": True,
                    "answer": "Search answer",
                }
                agent = AggregatorAgent(use_planning=False)
                agent.local_agent.process = mock_local
                agent.search_agent.process = mock_search
                
                with patch.object(agent, "_synthesize_responses", new_callable=AsyncMock) as mock_synth:
                    mock_synth.return_value = {
                        "success": True,
                        "answer": "Synthesized answer",
                        "aggregated_by": "multiple_agents",
                    }
                    response = await agent.process("test query")
                    assert response["success"] is True
                    assert "aggregated_by" in response

    @pytest.mark.asyncio
    async def test_synthesize_responses(self):
        """Test response synthesis."""
        agent = AggregatorAgent(use_planning=False)
        
        agent_responses = {
            "local": {"success": True, "answer": "Answer 1"},
            "search": {"success": True, "answer": "Answer 2"},
        }

        with patch.object(agent.client.chat.completions, "create") as mock_llm:
            mock_response = Mock()
            mock_response.choices = [Mock()]
            mock_response.choices[0].message.content = "Synthesized answer"
            mock_llm.return_value = mock_response

            result = await agent._synthesize_responses(
                query="test",
                agent_responses=agent_responses,
                session_id=None,
            )

            assert result["success"] is True
            assert "Synthesized answer" in result["answer"]


class TestBaseAgent:
    """Tests for BaseAgent functionality."""

    @pytest.mark.asyncio
    async def test_agent_status(self):
        """Test agent status retrieval."""
        agent = LocalDataAgent()
        status = agent.get_status()
        assert status["name"] == "local_data_agent"
        assert "tools" in status
        assert "memory_enabled" in status

    def test_add_tool(self):
        """Test adding tools to agent."""
        agent = LocalDataAgent()
        tool_schema = {
            "name": "test_tool",
            "description": "Test tool",
        }
        tool_func = lambda x: x

        agent.add_tool(tool_schema, tool_func)
        assert "test_tool" in agent.tool_functions


if __name__ == "__main__":
    pytest.main([__file__, "-v"])