Spaces:
Sleeping
Sleeping
File size: 5,710 Bytes
9281fab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
Unit tests for agent implementations.
"""
import pytest
from unittest.mock import Mock, MagicMock
import json
from coda.core.memory import SharedMemory
from coda.core.llm import LLMResponse
from coda.core.base_agent import AgentContext
from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
class MockLLM:
"""Mock LLM for testing agents."""
def __init__(self, response_content: str):
self._response = response_content
def complete(self, prompt, system_prompt=None, **kwargs):
return LLMResponse(
content=self._response,
model="mock",
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
finish_reason="stop"
)
def complete_with_image(self, prompt, image_path, **kwargs):
return self.complete(prompt, **kwargs)
class TestQueryAnalyzerAgent:
"""Tests for the Query Analyzer agent."""
@pytest.fixture
def mock_response(self):
return json.dumps({
"visualization_types": ["line chart", "bar chart"],
"key_points": ["sales trends", "monthly data"],
"todo_list": ["Load data", "Create chart", "Add labels"],
"data_requirements": ["date", "sales"],
"constraints": ["use blue colors"],
"ambiguities": []
})
def test_execute(self, mock_response):
"""Test query analysis execution."""
llm = MockLLM(mock_response)
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
context = AgentContext(query="Show sales trends over time")
result = agent.execute(context)
assert isinstance(result, QueryAnalysis)
assert "line chart" in result.visualization_types
assert "sales trends" in result.key_points
assert len(result.todo_list) == 3
def test_stores_in_memory(self, mock_response):
"""Test that results are stored in memory."""
llm = MockLLM(mock_response)
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
context = AgentContext(query="Test query")
agent.execute(context)
stored = memory.retrieve("query_analysis")
assert stored is not None
assert "visualization_types" in stored
class TestVizMappingAgent:
"""Tests for the VizMapping agent."""
@pytest.fixture
def mock_response(self):
return json.dumps({
"chart_type": "line",
"chart_subtype": None,
"x_axis": {"column": "date", "label": "Date", "type": "temporal"},
"y_axis": {"column": "sales", "label": "Sales", "type": "numerical"},
"color_encoding": None,
"size_encoding": None,
"transformations": [],
"styling_hints": {"theme": "modern"},
"visualization_goals": ["Show trends"],
"rationale": "Line chart best for trends"
})
def test_execute(self, mock_response):
"""Test visualization mapping execution."""
llm = MockLLM(mock_response)
memory = SharedMemory()
memory.store("query_analysis", {
"visualization_types": ["line chart"],
"key_points": ["trends"],
"data_requirements": ["date", "sales"]
}, "test")
agent = VizMappingAgent(llm, memory)
context = AgentContext(query="Show sales trends")
result = agent.execute(context)
assert isinstance(result, VisualMapping)
assert result.chart_type == "line"
assert result.x_axis["column"] == "date"
class TestAgentContext:
"""Tests for the AgentContext model."""
def test_basic_context(self):
"""Test creating a basic context."""
context = AgentContext(
query="Test query",
data_paths=["file1.csv", "file2.csv"]
)
assert context.query == "Test query"
assert len(context.data_paths) == 2
assert context.iteration == 0
assert context.feedback is None
def test_context_with_feedback(self):
"""Test context with feedback for refinement."""
context = AgentContext(
query="Test",
iteration=2,
feedback="Improve colors"
)
assert context.iteration == 2
assert context.feedback == "Improve colors"
class TestBaseAgentJsonExtraction:
"""Tests for JSON extraction from LLM responses."""
def test_extract_json_plain(self):
"""Test extracting plain JSON."""
llm = MockLLM("{}")
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
result = agent._extract_json('{"key": "value"}')
assert result == {"key": "value"}
def test_extract_json_markdown(self):
"""Test extracting JSON from markdown code block."""
llm = MockLLM("{}")
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
text = """Here is the response:
```json
{"key": "value"}
```
"""
result = agent._extract_json(text)
assert result == {"key": "value"}
def test_extract_json_invalid(self):
"""Test handling invalid JSON."""
llm = MockLLM("{}")
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
with pytest.raises(ValueError):
agent._extract_json("not valid json")
|