File size: 5,710 Bytes
9281fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
Unit tests for agent implementations.
"""

import pytest
from unittest.mock import Mock, MagicMock
import json

from coda.core.memory import SharedMemory
from coda.core.llm import LLMResponse
from coda.core.base_agent import AgentContext
from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
from coda.agents.viz_mapping import VizMappingAgent, VisualMapping


class MockLLM:
    """Mock LLM for testing agents."""
    
    def __init__(self, response_content: str):
        self._response = response_content
    
    def complete(self, prompt, system_prompt=None, **kwargs):
        return LLMResponse(
            content=self._response,
            model="mock",
            usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
            finish_reason="stop"
        )
    
    def complete_with_image(self, prompt, image_path, **kwargs):
        return self.complete(prompt, **kwargs)


class TestQueryAnalyzerAgent:
    """Tests for the Query Analyzer agent."""
    
    @pytest.fixture
    def mock_response(self):
        return json.dumps({
            "visualization_types": ["line chart", "bar chart"],
            "key_points": ["sales trends", "monthly data"],
            "todo_list": ["Load data", "Create chart", "Add labels"],
            "data_requirements": ["date", "sales"],
            "constraints": ["use blue colors"],
            "ambiguities": []
        })
    
    def test_execute(self, mock_response):
        """Test query analysis execution."""
        llm = MockLLM(mock_response)
        memory = SharedMemory()
        agent = QueryAnalyzerAgent(llm, memory)
        
        context = AgentContext(query="Show sales trends over time")
        result = agent.execute(context)
        
        assert isinstance(result, QueryAnalysis)
        assert "line chart" in result.visualization_types
        assert "sales trends" in result.key_points
        assert len(result.todo_list) == 3
    
    def test_stores_in_memory(self, mock_response):
        """Test that results are stored in memory."""
        llm = MockLLM(mock_response)
        memory = SharedMemory()
        agent = QueryAnalyzerAgent(llm, memory)
        
        context = AgentContext(query="Test query")
        agent.execute(context)
        
        stored = memory.retrieve("query_analysis")
        assert stored is not None
        assert "visualization_types" in stored


class TestVizMappingAgent:
    """Tests for the VizMapping agent."""
    
    @pytest.fixture
    def mock_response(self):
        return json.dumps({
            "chart_type": "line",
            "chart_subtype": None,
            "x_axis": {"column": "date", "label": "Date", "type": "temporal"},
            "y_axis": {"column": "sales", "label": "Sales", "type": "numerical"},
            "color_encoding": None,
            "size_encoding": None,
            "transformations": [],
            "styling_hints": {"theme": "modern"},
            "visualization_goals": ["Show trends"],
            "rationale": "Line chart best for trends"
        })
    
    def test_execute(self, mock_response):
        """Test visualization mapping execution."""
        llm = MockLLM(mock_response)
        memory = SharedMemory()
        
        memory.store("query_analysis", {
            "visualization_types": ["line chart"],
            "key_points": ["trends"],
            "data_requirements": ["date", "sales"]
        }, "test")
        
        agent = VizMappingAgent(llm, memory)
        context = AgentContext(query="Show sales trends")
        result = agent.execute(context)
        
        assert isinstance(result, VisualMapping)
        assert result.chart_type == "line"
        assert result.x_axis["column"] == "date"


class TestAgentContext:
    """Tests for the AgentContext model."""
    
    def test_basic_context(self):
        """Test creating a basic context."""
        context = AgentContext(
            query="Test query",
            data_paths=["file1.csv", "file2.csv"]
        )
        
        assert context.query == "Test query"
        assert len(context.data_paths) == 2
        assert context.iteration == 0
        assert context.feedback is None
    
    def test_context_with_feedback(self):
        """Test context with feedback for refinement."""
        context = AgentContext(
            query="Test",
            iteration=2,
            feedback="Improve colors"
        )
        
        assert context.iteration == 2
        assert context.feedback == "Improve colors"


class TestBaseAgentJsonExtraction:
    """Tests for JSON extraction from LLM responses."""
    
    def test_extract_json_plain(self):
        """Test extracting plain JSON."""
        llm = MockLLM("{}")
        memory = SharedMemory()
        agent = QueryAnalyzerAgent(llm, memory)
        
        result = agent._extract_json('{"key": "value"}')
        
        assert result == {"key": "value"}
    
    def test_extract_json_markdown(self):
        """Test extracting JSON from markdown code block."""
        llm = MockLLM("{}")
        memory = SharedMemory()
        agent = QueryAnalyzerAgent(llm, memory)
        
        text = """Here is the response:
```json
{"key": "value"}
```
"""
        result = agent._extract_json(text)
        
        assert result == {"key": "value"}
    
    def test_extract_json_invalid(self):
        """Test handling invalid JSON."""
        llm = MockLLM("{}")
        memory = SharedMemory()
        agent = QueryAnalyzerAgent(llm, memory)
        
        with pytest.raises(ValueError):
            agent._extract_json("not valid json")