Spaces:
Sleeping
Sleeping
| """Unit tests for agents.""" | |
| import pytest | |
| from unittest.mock import Mock, patch, AsyncMock | |
| from src.agents.local_data_agent import LocalDataAgent | |
| from src.agents.search_agent import SearchAgent | |
| from src.agents.cloud_agent import CloudAgent | |
| from src.agents.aggregator_agent import AggregatorAgent | |
| def mock_vector_store(): | |
| """Mock vector store.""" | |
| mock_store = Mock() | |
| mock_store.search.return_value = { | |
| "documents": ["Test document content"], | |
| "ids": ["doc1"], | |
| "metadatas": [{"source": "test"}], | |
| "distances": [0.1], | |
| } | |
| return mock_store | |
| def mock_web_search(): | |
| """Mock web search.""" | |
| mock_search = AsyncMock() | |
| mock_search.search.return_value = { | |
| "success": True, | |
| "results": [ | |
| { | |
| "title": "Test Result", | |
| "url": "https://example.com", | |
| "content": "Test content", | |
| } | |
| ], | |
| } | |
| return mock_search | |
| class TestLocalDataAgent: | |
| """Tests for LocalDataAgent.""" | |
| async def test_retrieve_context(self, mock_vector_store): | |
| """Test context retrieval.""" | |
| with patch("src.agents.local_data_agent.get_vector_store", return_value=mock_vector_store): | |
| agent = LocalDataAgent() | |
| context = await agent.retrieve_context("test query") | |
| assert "Test document content" in context | |
| assert "test" in context | |
| async def test_process_query(self, mock_vector_store): | |
| """Test query processing.""" | |
| with patch("src.agents.local_data_agent.get_vector_store", return_value=mock_vector_store): | |
| with patch.object(LocalDataAgent, "_process_direct", new_callable=AsyncMock) as mock_process: | |
| mock_process.return_value = { | |
| "success": True, | |
| "answer": "Test answer", | |
| "agent": "local_data_agent", | |
| } | |
| agent = LocalDataAgent() | |
| response = await agent.process("test query") | |
| assert response["success"] is True | |
| assert "answer" in response | |
| class TestSearchAgent: | |
| """Tests for SearchAgent.""" | |
| async def test_retrieve_context(self, mock_web_search): | |
| """Test web search context retrieval.""" | |
| with patch("src.agents.search_agent.get_web_search", return_value=mock_web_search): | |
| agent = SearchAgent(use_planning=False) | |
| agent.web_search = mock_web_search | |
| context = await agent.retrieve_context("test query") | |
| assert "Test Result" in context or "test query" in context.lower() | |
| async def test_process_query(self, mock_web_search): | |
| """Test query processing with web search.""" | |
| with patch("src.agents.search_agent.get_web_search", return_value=mock_web_search): | |
| with patch.object(SearchAgent, "_process_direct", new_callable=AsyncMock) as mock_process: | |
| mock_process.return_value = { | |
| "success": True, | |
| "answer": "Test answer from web", | |
| "agent": "search_agent", | |
| } | |
| agent = SearchAgent(use_planning=False) | |
| agent.web_search = mock_web_search | |
| response = await agent.process("test query") | |
| assert response["success"] is True | |
| class TestCloudAgent: | |
| """Tests for CloudAgent.""" | |
| async def test_retrieve_context_no_config(self): | |
| """Test context retrieval when cloud is not configured.""" | |
| agent = CloudAgent() | |
| context = await agent.retrieve_context("test query") | |
| assert "not configured" in context.lower() | |
| async def test_process_query(self): | |
| """Test query processing.""" | |
| with patch.object(CloudAgent, "_process_direct", new_callable=AsyncMock) as mock_process: | |
| mock_process.return_value = { | |
| "success": True, | |
| "answer": "Test answer", | |
| "agent": "cloud_agent", | |
| } | |
| agent = CloudAgent() | |
| response = await agent.process("test query") | |
| assert response["success"] is True | |
| class TestAggregatorAgent: | |
| """Tests for AggregatorAgent.""" | |
| async def test_select_agents(self): | |
| """Test agent selection logic.""" | |
| agent = AggregatorAgent(use_planning=False) | |
| # Test local document query | |
| selected = agent._select_agents("What is in the document?") | |
| assert "local" in selected | |
| # Test web search query | |
| selected = agent._select_agents("What is the latest news?") | |
| assert "search" in selected | |
| # Test cloud query | |
| selected = agent._select_agents("What files are in S3?") | |
| assert "cloud" in selected | |
| async def test_process_query(self): | |
| """Test query processing with multiple agents.""" | |
| with patch.object(LocalDataAgent, "process", new_callable=AsyncMock) as mock_local: | |
| mock_local.return_value = { | |
| "success": True, | |
| "answer": "Local answer", | |
| } | |
| with patch.object(SearchAgent, "process", new_callable=AsyncMock) as mock_search: | |
| mock_search.return_value = { | |
| "success": True, | |
| "answer": "Search answer", | |
| } | |
| agent = AggregatorAgent(use_planning=False) | |
| agent.local_agent.process = mock_local | |
| agent.search_agent.process = mock_search | |
| with patch.object(agent, "_synthesize_responses", new_callable=AsyncMock) as mock_synth: | |
| mock_synth.return_value = { | |
| "success": True, | |
| "answer": "Synthesized answer", | |
| "aggregated_by": "multiple_agents", | |
| } | |
| response = await agent.process("test query") | |
| assert response["success"] is True | |
| assert "aggregated_by" in response | |
| async def test_synthesize_responses(self): | |
| """Test response synthesis.""" | |
| agent = AggregatorAgent(use_planning=False) | |
| agent_responses = { | |
| "local": {"success": True, "answer": "Answer 1"}, | |
| "search": {"success": True, "answer": "Answer 2"}, | |
| } | |
| with patch.object(agent.client.chat.completions, "create") as mock_llm: | |
| mock_response = Mock() | |
| mock_response.choices = [Mock()] | |
| mock_response.choices[0].message.content = "Synthesized answer" | |
| mock_llm.return_value = mock_response | |
| result = await agent._synthesize_responses( | |
| query="test", | |
| agent_responses=agent_responses, | |
| session_id=None, | |
| ) | |
| assert result["success"] is True | |
| assert "Synthesized answer" in result["answer"] | |
| class TestBaseAgent: | |
| """Tests for BaseAgent functionality.""" | |
| async def test_agent_status(self): | |
| """Test agent status retrieval.""" | |
| agent = LocalDataAgent() | |
| status = agent.get_status() | |
| assert status["name"] == "local_data_agent" | |
| assert "tools" in status | |
| assert "memory_enabled" in status | |
| def test_add_tool(self): | |
| """Test adding tools to agent.""" | |
| agent = LocalDataAgent() | |
| tool_schema = { | |
| "name": "test_tool", | |
| "description": "Test tool", | |
| } | |
| tool_func = lambda x: x | |
| agent.add_tool(tool_schema, tool_func) | |
| assert "test_tool" in agent.tool_functions | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |