SCoDA / tests /test_agents.py
vanishingradient's picture
Added init files
9281fab
"""
Unit tests for agent implementations.
"""
import pytest
from unittest.mock import Mock, MagicMock
import json
from coda.core.memory import SharedMemory
from coda.core.llm import LLMResponse
from coda.core.base_agent import AgentContext
from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
class MockLLM:
"""Mock LLM for testing agents."""
def __init__(self, response_content: str):
self._response = response_content
def complete(self, prompt, system_prompt=None, **kwargs):
return LLMResponse(
content=self._response,
model="mock",
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
finish_reason="stop"
)
def complete_with_image(self, prompt, image_path, **kwargs):
return self.complete(prompt, **kwargs)
class TestQueryAnalyzerAgent:
"""Tests for the Query Analyzer agent."""
@pytest.fixture
def mock_response(self):
return json.dumps({
"visualization_types": ["line chart", "bar chart"],
"key_points": ["sales trends", "monthly data"],
"todo_list": ["Load data", "Create chart", "Add labels"],
"data_requirements": ["date", "sales"],
"constraints": ["use blue colors"],
"ambiguities": []
})
def test_execute(self, mock_response):
"""Test query analysis execution."""
llm = MockLLM(mock_response)
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
context = AgentContext(query="Show sales trends over time")
result = agent.execute(context)
assert isinstance(result, QueryAnalysis)
assert "line chart" in result.visualization_types
assert "sales trends" in result.key_points
assert len(result.todo_list) == 3
def test_stores_in_memory(self, mock_response):
"""Test that results are stored in memory."""
llm = MockLLM(mock_response)
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
context = AgentContext(query="Test query")
agent.execute(context)
stored = memory.retrieve("query_analysis")
assert stored is not None
assert "visualization_types" in stored
class TestVizMappingAgent:
"""Tests for the VizMapping agent."""
@pytest.fixture
def mock_response(self):
return json.dumps({
"chart_type": "line",
"chart_subtype": None,
"x_axis": {"column": "date", "label": "Date", "type": "temporal"},
"y_axis": {"column": "sales", "label": "Sales", "type": "numerical"},
"color_encoding": None,
"size_encoding": None,
"transformations": [],
"styling_hints": {"theme": "modern"},
"visualization_goals": ["Show trends"],
"rationale": "Line chart best for trends"
})
def test_execute(self, mock_response):
"""Test visualization mapping execution."""
llm = MockLLM(mock_response)
memory = SharedMemory()
memory.store("query_analysis", {
"visualization_types": ["line chart"],
"key_points": ["trends"],
"data_requirements": ["date", "sales"]
}, "test")
agent = VizMappingAgent(llm, memory)
context = AgentContext(query="Show sales trends")
result = agent.execute(context)
assert isinstance(result, VisualMapping)
assert result.chart_type == "line"
assert result.x_axis["column"] == "date"
class TestAgentContext:
"""Tests for the AgentContext model."""
def test_basic_context(self):
"""Test creating a basic context."""
context = AgentContext(
query="Test query",
data_paths=["file1.csv", "file2.csv"]
)
assert context.query == "Test query"
assert len(context.data_paths) == 2
assert context.iteration == 0
assert context.feedback is None
def test_context_with_feedback(self):
"""Test context with feedback for refinement."""
context = AgentContext(
query="Test",
iteration=2,
feedback="Improve colors"
)
assert context.iteration == 2
assert context.feedback == "Improve colors"
class TestBaseAgentJsonExtraction:
"""Tests for JSON extraction from LLM responses."""
def test_extract_json_plain(self):
"""Test extracting plain JSON."""
llm = MockLLM("{}")
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
result = agent._extract_json('{"key": "value"}')
assert result == {"key": "value"}
def test_extract_json_markdown(self):
"""Test extracting JSON from markdown code block."""
llm = MockLLM("{}")
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
text = """Here is the response:
```json
{"key": "value"}
```
"""
result = agent._extract_json(text)
assert result == {"key": "value"}
def test_extract_json_invalid(self):
"""Test handling invalid JSON."""
llm = MockLLM("{}")
memory = SharedMemory()
agent = QueryAnalyzerAgent(llm, memory)
with pytest.raises(ValueError):
agent._extract_json("not valid json")