Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for agent implementations. | |
| """ | |
| import pytest | |
| from unittest.mock import Mock, MagicMock | |
| import json | |
| from coda.core.memory import SharedMemory | |
| from coda.core.llm import LLMResponse | |
| from coda.core.base_agent import AgentContext | |
| from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis | |
| from coda.agents.data_processor import DataProcessorAgent, DataAnalysis | |
| from coda.agents.viz_mapping import VizMappingAgent, VisualMapping | |
| class MockLLM: | |
| """Mock LLM for testing agents.""" | |
| def __init__(self, response_content: str): | |
| self._response = response_content | |
| def complete(self, prompt, system_prompt=None, **kwargs): | |
| return LLMResponse( | |
| content=self._response, | |
| model="mock", | |
| usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, | |
| finish_reason="stop" | |
| ) | |
| def complete_with_image(self, prompt, image_path, **kwargs): | |
| return self.complete(prompt, **kwargs) | |
| class TestQueryAnalyzerAgent: | |
| """Tests for the Query Analyzer agent.""" | |
| def mock_response(self): | |
| return json.dumps({ | |
| "visualization_types": ["line chart", "bar chart"], | |
| "key_points": ["sales trends", "monthly data"], | |
| "todo_list": ["Load data", "Create chart", "Add labels"], | |
| "data_requirements": ["date", "sales"], | |
| "constraints": ["use blue colors"], | |
| "ambiguities": [] | |
| }) | |
| def test_execute(self, mock_response): | |
| """Test query analysis execution.""" | |
| llm = MockLLM(mock_response) | |
| memory = SharedMemory() | |
| agent = QueryAnalyzerAgent(llm, memory) | |
| context = AgentContext(query="Show sales trends over time") | |
| result = agent.execute(context) | |
| assert isinstance(result, QueryAnalysis) | |
| assert "line chart" in result.visualization_types | |
| assert "sales trends" in result.key_points | |
| assert len(result.todo_list) == 3 | |
| def test_stores_in_memory(self, mock_response): | |
| """Test that results are stored in memory.""" | |
| llm = MockLLM(mock_response) | |
| memory = SharedMemory() | |
| agent = QueryAnalyzerAgent(llm, memory) | |
| context = AgentContext(query="Test query") | |
| agent.execute(context) | |
| stored = memory.retrieve("query_analysis") | |
| assert stored is not None | |
| assert "visualization_types" in stored | |
| class TestVizMappingAgent: | |
| """Tests for the VizMapping agent.""" | |
| def mock_response(self): | |
| return json.dumps({ | |
| "chart_type": "line", | |
| "chart_subtype": None, | |
| "x_axis": {"column": "date", "label": "Date", "type": "temporal"}, | |
| "y_axis": {"column": "sales", "label": "Sales", "type": "numerical"}, | |
| "color_encoding": None, | |
| "size_encoding": None, | |
| "transformations": [], | |
| "styling_hints": {"theme": "modern"}, | |
| "visualization_goals": ["Show trends"], | |
| "rationale": "Line chart best for trends" | |
| }) | |
| def test_execute(self, mock_response): | |
| """Test visualization mapping execution.""" | |
| llm = MockLLM(mock_response) | |
| memory = SharedMemory() | |
| memory.store("query_analysis", { | |
| "visualization_types": ["line chart"], | |
| "key_points": ["trends"], | |
| "data_requirements": ["date", "sales"] | |
| }, "test") | |
| agent = VizMappingAgent(llm, memory) | |
| context = AgentContext(query="Show sales trends") | |
| result = agent.execute(context) | |
| assert isinstance(result, VisualMapping) | |
| assert result.chart_type == "line" | |
| assert result.x_axis["column"] == "date" | |
| class TestAgentContext: | |
| """Tests for the AgentContext model.""" | |
| def test_basic_context(self): | |
| """Test creating a basic context.""" | |
| context = AgentContext( | |
| query="Test query", | |
| data_paths=["file1.csv", "file2.csv"] | |
| ) | |
| assert context.query == "Test query" | |
| assert len(context.data_paths) == 2 | |
| assert context.iteration == 0 | |
| assert context.feedback is None | |
| def test_context_with_feedback(self): | |
| """Test context with feedback for refinement.""" | |
| context = AgentContext( | |
| query="Test", | |
| iteration=2, | |
| feedback="Improve colors" | |
| ) | |
| assert context.iteration == 2 | |
| assert context.feedback == "Improve colors" | |
| class TestBaseAgentJsonExtraction: | |
| """Tests for JSON extraction from LLM responses.""" | |
| def test_extract_json_plain(self): | |
| """Test extracting plain JSON.""" | |
| llm = MockLLM("{}") | |
| memory = SharedMemory() | |
| agent = QueryAnalyzerAgent(llm, memory) | |
| result = agent._extract_json('{"key": "value"}') | |
| assert result == {"key": "value"} | |
| def test_extract_json_markdown(self): | |
| """Test extracting JSON from markdown code block.""" | |
| llm = MockLLM("{}") | |
| memory = SharedMemory() | |
| agent = QueryAnalyzerAgent(llm, memory) | |
| text = """Here is the response: | |
| ```json | |
| {"key": "value"} | |
| ``` | |
| """ | |
| result = agent._extract_json(text) | |
| assert result == {"key": "value"} | |
| def test_extract_json_invalid(self): | |
| """Test handling invalid JSON.""" | |
| llm = MockLLM("{}") | |
| memory = SharedMemory() | |
| agent = QueryAnalyzerAgent(llm, memory) | |
| with pytest.raises(ValueError): | |
| agent._extract_json("not valid json") | |