vanishingradient commited on
Commit
9281fab
·
0 Parent(s):

Added init files

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ .env
5
+ outputs/
6
+ *.log
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 M Saqlain
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CoDA
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # CoDA: Collaborative Data Visualization Agents
13
+
14
+ A production-grade multi-agent system for automated data visualization from natural language queries.
15
+
16
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces)
17
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
18
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
19
+
20
+ ## Overview
21
+
22
+ CoDA reframes data visualization as a collaborative multi-agent problem. Instead of treating it as a monolithic task, CoDA employs specialized LLM agents that work together:
23
+
24
+ - **Query Analyzer** - Interprets natural language and extracts visualization intent
25
+ - **Data Processor** - Extracts metadata without token-heavy data loading
26
+ - **VizMapping Agent** - Maps semantics to visualization primitives
27
+ - **Search Agent** - Retrieves relevant code patterns
28
+ - **Design Explorer** - Generates aesthetic specifications
29
+ - **Code Generator** - Synthesizes executable Python code
30
+ - **Debug Agent** - Executes code and fixes errors
31
+ - **Visual Evaluator** - Assesses quality and triggers refinement
32
+
33
+ ## Quick Start
34
+
35
+ ### Installation
36
+
37
+ ```bash
38
+ # Clone the repository
39
+ git clone https://github.com/yourusername/CoDA.git
40
+ cd CoDA
41
+
42
+ # Install dependencies
43
+ pip install -r requirements.txt
44
+
45
+ # Configure API key
46
+ cp .env.example .env
47
+ # Edit .env and add your GROQ_API_KEY
48
+ ```
49
+
50
+ ### Usage
51
+
52
+ #### Web Interface (Gradio)
53
+
54
+ ```bash
55
+ python app.py
56
+ ```
57
+
58
+ Open http://localhost:7860 in your browser.
59
+
60
+ #### Command Line
61
+
62
+ ```bash
63
+ python main.py --query "Create a bar chart of sales by category" --data sales.csv
64
+ ```
65
+
66
+ Options:
67
+ - `-q, --query`: Visualization query (required)
68
+ - `-d, --data`: Data file path(s) (required)
69
+ - `-o, --output`: Output directory (default: outputs)
70
+ - `--max-iterations`: Refinement iterations (default: 3)
71
+ - `--min-score`: Quality threshold (default: 7.0)
72
+
73
+ ### Python API
74
+
75
+ ```python
76
+ from coda.orchestrator import CodaOrchestrator
77
+
78
+ orchestrator = CodaOrchestrator()
79
+ result = orchestrator.run(
80
+ query="Show sales trends over time",
81
+ data_paths=["sales_data.csv"]
82
+ )
83
+
84
+ if result.success:
85
+ print(f"Visualization saved to: {result.output_file}")
86
+ print(f"Quality Score: {result.scores['overall']}/10")
87
+ ```
88
+
89
+ ## Hugging Face Spaces Deployment
90
+
91
+ 1. Create a new Space on [Hugging Face](https://huggingface.co/new-space)
92
+ 2. Select "Gradio" as the SDK
93
+ 3. Upload all files from this repository
94
+ 4. Add `GROQ_API_KEY` as a Secret in Space Settings
95
+ 5. The Space will automatically build and deploy
96
+
97
+ ## Architecture
98
+
99
+ ```
100
+ Natural Language Query + Data Files
101
+
102
+
103
+ ┌───────────────┐
104
+ │ Query Analyzer │ ─── Extracts intent, TODO list
105
+ └───────────────┘
106
+
107
+
108
+ ┌───────────────┐
109
+ │ Data Processor │ ─── Metadata extraction (no full load)
110
+ └───────────────┘
111
+
112
+
113
+ ┌───────────────┐
114
+ │ VizMapping │ ─── Chart type, encodings
115
+ └───────────────┘
116
+
117
+
118
+ ┌───────────────┐
119
+ │ Search Agent │ ─── Code examples
120
+ └───────────────┘
121
+
122
+
123
+ ┌───────────────┐
124
+ │Design Explorer│ ─── Colors, layout, styling
125
+ └───────────────┘
126
+
127
+
128
+ ┌───────────────┐
129
+ │Code Generator │ ─── Python visualization code
130
+ └───────────────┘
131
+
132
+
133
+ ┌───────────────┐
134
+ │ Debug Agent │ ─── Execute & fix errors
135
+ └───────────────┘
136
+
137
+
138
+ ┌───────────────┐
139
+ │Visual Evaluator│ ─── Quality assessment
140
+ └───────────────┘
141
+
142
+ ───────┴───────
143
+ ↓ Feedback Loop ↓
144
+ (if quality < threshold)
145
+ ```
146
+
147
+ ## Configuration
148
+
149
+ Environment variables (in `.env`):
150
+
151
+ | Variable | Default | Description |
152
+ |----------|---------|-------------|
153
+ | `GROQ_API_KEY` | Required | Your Groq API key |
154
+ | `CODA_DEFAULT_MODEL` | llama-3.3-70b-versatile | Text model |
155
+ | `CODA_VISION_MODEL` | llama-3.2-90b-vision-preview | Vision model |
156
+ | `CODA_MIN_OVERALL_SCORE` | 7.0 | Quality threshold |
157
+ | `CODA_MAX_ITERATIONS` | 3 | Max refinement loops |
158
+
159
+ ## Supported Data Formats
160
+
161
+ - CSV (`.csv`)
162
+ - JSON (`.json`)
163
+ - Excel (`.xlsx`, `.xls`)
164
+ - Parquet (`.parquet`)
165
+
166
+ ## Requirements
167
+
168
+ - Python 3.10+
169
+ - Groq API key ([Get one free](https://console.groq.com))
170
+
171
+ ## License
172
+
173
+ MIT License - See LICENSE for details.
174
+
175
+ ## Citation
176
+
177
+ If you use CoDA in your research, please cite:
178
+
179
+ ```bibtex
180
+ @article{chen2025coda,
181
+ title={CoDA: Agentic Systems for Collaborative Data Visualization},
182
+ author={Chen, Zichen and Chen, Jiefeng and Arik, Sercan {\"O}. and Sra, Misha and Pfister, Tomas and Yoon, Jinsung},
183
+ journal={arXiv preprint arXiv:2510.03194},
184
+ year={2025},
185
+ url={https://arxiv.org/abs/2510.03194},
186
+ doi={10.48550/arXiv.2510.03194}
187
+ }
188
+ ```
189
+
190
+ **Paper**: [arXiv:2510.03194](https://arxiv.org/abs/2510.03194)
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Web Interface for CoDA.
3
+
4
+ Provides a user-friendly web UI for the CoDA visualization system,
5
+ designed for deployment on Hugging Face Spaces.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import tempfile
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import gradio as gr
15
+
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def create_coda_interface():
24
+ """Create the Gradio interface for CoDA."""
25
+
26
+ def process_visualization(
27
+ query: str,
28
+ data_file,
29
+ progress=gr.Progress()
30
+ ) -> tuple[Optional[str], str, str]:
31
+ """
32
+ Process a visualization request.
33
+
34
+ Args:
35
+ query: Natural language visualization query
36
+ data_file: Uploaded data file
37
+ progress: Gradio progress tracker
38
+
39
+ Returns:
40
+ Tuple of (image_path, status_message, details)
41
+ """
42
+ if not query.strip():
43
+ return None, "❌ Error", "Please enter a visualization query."
44
+
45
+ if data_file is None:
46
+ return None, "❌ Error", "Please upload a data file."
47
+
48
+ try:
49
+ from coda.config import Config
50
+ from coda.orchestrator import CodaOrchestrator
51
+ except ImportError as e:
52
+ return None, "❌ Import Error", f"Failed to import CoDA: {e}"
53
+
54
+ groq_api_key = os.getenv("GROQ_API_KEY", "")
55
+ if not groq_api_key:
56
+ return (
57
+ None,
58
+ "❌ Configuration Error",
59
+ "GROQ_API_KEY environment variable is not set. "
60
+ "Please add your API key in the Spaces settings."
61
+ )
62
+
63
+ with tempfile.TemporaryDirectory() as temp_dir:
64
+ data_path = Path(temp_dir) / Path(data_file.name).name
65
+
66
+ with open(data_file.name, 'rb') as src:
67
+ with open(data_path, 'wb') as dst:
68
+ dst.write(src.read())
69
+
70
+ def update_progress(status: str, pct: float):
71
+ progress(pct, desc=status)
72
+
73
+ try:
74
+ config = Config(
75
+ groq_api_key=groq_api_key,
76
+ )
77
+
78
+ orchestrator = CodaOrchestrator(
79
+ config=config,
80
+ progress_callback=update_progress,
81
+ )
82
+
83
+ result = orchestrator.run(
84
+ query=query,
85
+ data_paths=[str(data_path)],
86
+ )
87
+
88
+ if result.success and result.output_file:
89
+ scores = result.scores or {}
90
+ details = format_results(result, scores)
91
+ return result.output_file, "✅ Success", details
92
+ else:
93
+ error_msg = result.error or "Unknown error occurred"
94
+ return None, "❌ Failed", f"Visualization failed: {error_msg}"
95
+
96
+ except Exception as e:
97
+ logger.exception("Pipeline error")
98
+ return None, "❌ Error", f"An error occurred: {str(e)}"
99
+
100
+ def format_results(result, scores: dict) -> str:
101
+ """Format the results for display."""
102
+ lines = [
103
+ f"**Iterations:** {result.iterations}",
104
+ "",
105
+ "### Quality Scores",
106
+ ]
107
+
108
+ if scores:
109
+ for key, value in scores.items():
110
+ emoji = "🟢" if value >= 7 else "🟡" if value >= 5 else "🔴"
111
+ lines.append(f"- {key.title()}: {emoji} {value:.1f}/10")
112
+
113
+ if result.evaluation:
114
+ if result.evaluation.strengths:
115
+ lines.extend(["", "### Strengths"])
116
+ for s in result.evaluation.strengths[:3]:
117
+ lines.append(f"- {s}")
118
+
119
+ if result.evaluation.recommendations:
120
+ lines.extend(["", "### Recommendations"])
121
+ for r in result.evaluation.recommendations[:3]:
122
+ lines.append(f"- {r}")
123
+
124
+ return "\n".join(lines)
125
+
126
+ with gr.Blocks(
127
+ title="CoDA - Collaborative Data Visualization",
128
+ theme=gr.themes.Soft(),
129
+ css="""
130
+ .main-title {
131
+ text-align: center;
132
+ margin-bottom: 1rem;
133
+ }
134
+ .status-box {
135
+ padding: 1rem;
136
+ border-radius: 8px;
137
+ margin-top: 1rem;
138
+ }
139
+ """
140
+ ) as interface:
141
+ gr.Markdown(
142
+ """
143
+ # 🎨 CoDA: Collaborative Data Visualization Agents
144
+
145
+ Transform your data into beautiful visualizations using natural language.
146
+ Simply upload your data and describe what you want to see!
147
+ """,
148
+ elem_classes=["main-title"]
149
+ )
150
+
151
+ with gr.Row():
152
+ with gr.Column(scale=1):
153
+ query_input = gr.Textbox(
154
+ label="Visualization Query",
155
+ placeholder="e.g., 'Create a line chart showing sales trends over time'",
156
+ lines=3,
157
+ )
158
+
159
+ file_input = gr.File(
160
+ label="Upload Data File",
161
+ file_types=[".csv", ".json", ".xlsx", ".xls", ".parquet"],
162
+ )
163
+
164
+ submit_btn = gr.Button(
165
+ "🚀 Generate Visualization",
166
+ variant="primary",
167
+ size="lg",
168
+ )
169
+
170
+ gr.Markdown(
171
+ """
172
+ ### Supported Formats
173
+ - CSV, JSON, Excel (.xlsx, .xls), Parquet
174
+
175
+ ### Example Queries
176
+ - "Show me a bar chart of sales by category"
177
+ - "Create a scatter plot of price vs quantity"
178
+ - "Plot the distribution of ages as a histogram"
179
+ """
180
+ )
181
+
182
+ with gr.Column(scale=2):
183
+ output_image = gr.Image(
184
+ label="Generated Visualization",
185
+ type="filepath",
186
+ )
187
+
188
+ with gr.Row():
189
+ status_output = gr.Textbox(
190
+ label="Status",
191
+ interactive=False,
192
+ )
193
+
194
+ details_output = gr.Markdown(
195
+ label="Details",
196
+ )
197
+
198
+ gr.Examples(
199
+ examples=[
200
+ ["Create a bar chart showing the top 10 values", None],
201
+ ["Plot a line chart of trends over time", None],
202
+ ["Show a scatter plot with correlation", None],
203
+ ["Create a pie chart of category distribution", None],
204
+ ],
205
+ inputs=[query_input, file_input],
206
+ )
207
+
208
+ submit_btn.click(
209
+ fn=process_visualization,
210
+ inputs=[query_input, file_input],
211
+ outputs=[output_image, status_output, details_output],
212
+ )
213
+
214
+ return interface
215
+
216
+
217
+ app = create_coda_interface()
218
+
219
+ if __name__ == "__main__":
220
+ app.launch()
coda/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CoDA - Collaborative Data Visualization Agents
3
+
4
+ A multi-agent system for automated data visualization from natural language queries.
5
+ """
6
+
7
+ from coda.config import Config, get_config
8
+ from coda.orchestrator import CodaOrchestrator, PipelineResult
9
+ from coda.core import (
10
+ LLMProvider,
11
+ GroqLLM,
12
+ SharedMemory,
13
+ BaseAgent,
14
+ AgentContext,
15
+ AgentFactory,
16
+ )
17
+
18
+ __version__ = "1.0.0"
19
+
20
+ __all__ = [
21
+ "Config",
22
+ "get_config",
23
+ "CodaOrchestrator",
24
+ "PipelineResult",
25
+ "LLMProvider",
26
+ "GroqLLM",
27
+ "SharedMemory",
28
+ "BaseAgent",
29
+ "AgentContext",
30
+ "AgentFactory",
31
+ ]
coda/agents/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent implementations for CoDA visualization pipeline."""
2
+
3
+ from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
4
+ from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
5
+ from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
6
+ from coda.agents.search_agent import SearchAgent, SearchResult
7
+ from coda.agents.design_explorer import DesignExplorerAgent, DesignSpec
8
+ from coda.agents.code_generator import CodeGeneratorAgent, GeneratedCode
9
+ from coda.agents.debug_agent import DebugAgent, ExecutionResult
10
+ from coda.agents.visual_evaluator import VisualEvaluatorAgent, VisualEvaluation
11
+
12
+ __all__ = [
13
+ "QueryAnalyzerAgent",
14
+ "QueryAnalysis",
15
+ "DataProcessorAgent",
16
+ "DataAnalysis",
17
+ "VizMappingAgent",
18
+ "VisualMapping",
19
+ "SearchAgent",
20
+ "SearchResult",
21
+ "DesignExplorerAgent",
22
+ "DesignSpec",
23
+ "CodeGeneratorAgent",
24
+ "GeneratedCode",
25
+ "DebugAgent",
26
+ "ExecutionResult",
27
+ "VisualEvaluatorAgent",
28
+ "VisualEvaluation",
29
+ ]
coda/agents/code_generator.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code Generator Agent for CoDA.
3
+
4
+ Synthesizes executable Python visualization code by integrating
5
+ specifications from upstream agents.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+ from coda.core.base_agent import AgentContext, BaseAgent
13
+ from coda.core.llm import LLMProvider
14
+ from coda.core.memory import SharedMemory
15
+
16
+
17
+ class GeneratedCode(BaseModel):
18
+ """Structured output from the Code Generator."""
19
+
20
+ code: str = Field(default="", description="The generated Python code")
21
+ dependencies: list[str] = Field(
22
+ default_factory=lambda: ["matplotlib", "pandas"],
23
+ description="Required Python packages"
24
+ )
25
+ output_filename: str = Field(
26
+ default="output.png",
27
+ description="Name of the output visualization file"
28
+ )
29
+ documentation: str = Field(
30
+ default="Generated visualization code",
31
+ description="Brief documentation of the code"
32
+ )
33
+ quality_score: float = Field(
34
+ default=5.0,
35
+ description="Self-assessed code quality (0-10)"
36
+ )
37
+ potential_issues: list[str] = Field(
38
+ default_factory=list,
39
+ description="Potential issues or edge cases"
40
+ )
41
+
42
+
43
+ class CodeGeneratorAgent(BaseAgent[GeneratedCode]):
44
+ """
45
+ Generates executable Python visualization code.
46
+
47
+ Integrates all upstream specifications (data processing, visual mapping,
48
+ design specs) into working code.
49
+ """
50
+
51
+ MEMORY_KEY = "generated_code"
52
+
53
+ def __init__(
54
+ self,
55
+ llm: LLMProvider,
56
+ memory: SharedMemory,
57
+ name: Optional[str] = None,
58
+ ) -> None:
59
+ super().__init__(llm, memory, name or "CodeGenerator")
60
+
61
+ def _get_system_prompt(self) -> str:
62
+ return """You are an expert Python Developer specializing in data visualization.
63
+
64
+ Your expertise is in writing clean, efficient, and well-documented Python code for data visualization using matplotlib, seaborn, and pandas.
65
+
66
+ Your responsibilities:
67
+ 1. Generate complete, executable Python code
68
+ 2. Integrate all specifications from the design and mapping agents
69
+ 3. Handle data loading and transformation correctly
70
+ 4. Apply proper styling and formatting
71
+ 5. Include error handling for robustness
72
+ 6. Write clear documentation
73
+
74
+ Code requirements:
75
+ - Use matplotlib and seaborn as primary libraries
76
+ - Include all necessary imports at the top
77
+ - Load data from the specified file paths
78
+ - Apply all transformations before plotting
79
+ - Set figure size, colors, and labels as specified
80
+ - Save the output to a file (PNG format)
81
+ - Use descriptive variable names
82
+ - Add comments for complex operations
83
+
84
+ IMPORTANT styling rules:
85
+ - For seaborn barplots, ALWAYS use hue parameter: sns.barplot(..., hue='category_column', legend=False)
86
+ - Use ONLY these reliable palettes: 'viridis', 'plasma', 'inferno', 'magma', 'cividis', 'Deep', 'Muted', 'Pastel'
87
+ - DO NOT use complex or custom named palettes like 'tableau10' or 'husl' unless you are sure.
88
+ - When in doubt, omit the palette argument or use 'viridis'.
89
+ - Always use plt.tight_layout() before saving.
90
+
91
+ Always respond with a valid JSON object containing the code and metadata."""
92
+
93
+ def _build_prompt(self, context: AgentContext) -> str:
94
+ data_analysis = self._get_from_memory("data_analysis") or {}
95
+ visual_mapping = self._get_from_memory("visual_mapping") or {}
96
+ design_spec = self._get_from_memory("design_spec") or {}
97
+ search_results = self._get_from_memory("search_results") or {}
98
+
99
+ file_info = data_analysis.get("files", [])
100
+ data_paths = [f.get("file_path", "") for f in file_info] if file_info else context.data_paths
101
+
102
+ code_examples = search_results.get("examples", [])
103
+ examples_section = ""
104
+ if code_examples:
105
+ examples_section = "\nReference Code Examples:\n"
106
+ for ex in code_examples[:2]:
107
+ if isinstance(ex, dict):
108
+ examples_section += f"```python\n# {ex.get('title', 'Example')}\n{ex.get('code', '')}\n```\n"
109
+
110
+ feedback_section = ""
111
+ if context.feedback:
112
+ feedback_section = f"""
113
+ Code Feedback (iteration {context.iteration}):
114
+ {context.feedback}
115
+
116
+ Fix the issues mentioned in the feedback.
117
+ """
118
+
119
+ return f"""Generate Python visualization code based on the following specifications.
120
+
121
+ User Query: {context.query}
122
+
123
+ Data Files: {data_paths}
124
+
125
+ Visual Mapping:
126
+ - Chart Type: {visual_mapping.get('chart_type', 'line')}
127
+ - X-Axis: {visual_mapping.get('x_axis') or 'Not specified (infer from data or chart type)'}
128
+ - Y-Axis: {visual_mapping.get('y_axis') or 'Not specified (infer from data or chart type)'}
129
+ - Color Encoding: {visual_mapping.get('color_encoding')}
130
+ - Transformations: {visual_mapping.get('transformations', [])}
131
+
132
+ Design Specification:
133
+ - Colors: {design_spec.get('color_scheme', {})}
134
+ - Layout: {design_spec.get('layout', {})}
135
+ - Typography: {design_spec.get('typography', {})}
136
+ - Annotations: {design_spec.get('annotations', [])}
137
+ - Guidelines: {design_spec.get('implementation_guidelines', [])}
138
+ {examples_section}{feedback_section}
139
+
140
+ Generate a complete Python script that:
141
+ 1. Imports all necessary libraries
142
+ 2. Loads the data file(s)
143
+ 3. Applies required transformations
144
+ 4. Creates the visualization with specified styling
145
+ 5. Saves to 'output.png'
146
+
147
+ Respond with a JSON object:
148
+ - code: Complete Python code as a string
149
+ - dependencies: List of required packages
150
+ - output_filename: Output file name
151
+ - documentation: Brief description
152
+ - quality_score: Self-assessment 0-10
153
+ - potential_issues: List of potential issues
154
+
155
+ JSON Response:"""
156
+
157
+ def _parse_response(self, response: str) -> GeneratedCode:
158
+ data = self._extract_json(response)
159
+ return GeneratedCode(**data)
160
+
161
+ def _get_output_key(self) -> str:
162
+ return self.MEMORY_KEY
coda/agents/data_processor.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Processor Agent for CoDA.
3
+
4
+ Extracts metadata and insights from data files without loading full datasets,
5
+ enabling the system to work within token limits while providing rich context
6
+ for visualization decisions.
7
+ """
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Any, Optional
12
+
13
+ import pandas as pd
14
+ from pydantic import BaseModel, Field
15
+
16
+ from coda.core.base_agent import AgentContext, BaseAgent
17
+ from coda.core.llm import LLMProvider
18
+ from coda.core.memory import SharedMemory
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ColumnInfo(BaseModel):
24
+ """Information about a single column."""
25
+
26
+ name: str
27
+ dtype: str
28
+ non_null_count: int
29
+ unique_count: int
30
+ sample_values: list[Any]
31
+
32
+
33
+ class DataFileInfo(BaseModel):
34
+ """Metadata about a single data file."""
35
+
36
+ file_path: str
37
+ file_type: str
38
+ shape: tuple[int, int]
39
+ columns: list[ColumnInfo]
40
+ memory_usage_mb: float
41
+
42
+
43
+ class DataAnalysis(BaseModel):
44
+ """Structured output from the Data Processor."""
45
+
46
+ files: list[DataFileInfo] = Field(
47
+ description="Metadata for each processed data file"
48
+ )
49
+ insights: list[str] = Field(
50
+ description="Key insights about the data (patterns, outliers, etc.)"
51
+ )
52
+ processing_steps: list[str] = Field(
53
+ description="Recommended data processing steps"
54
+ )
55
+ aggregations_needed: list[str] = Field(
56
+ default_factory=list,
57
+ description="Suggested aggregations for visualization"
58
+ )
59
+ visualization_hints: list[str] = Field(
60
+ default_factory=list,
61
+ description="Hints for visualization based on data characteristics"
62
+ )
63
+ potential_issues: list[str] = Field(
64
+ default_factory=list,
65
+ description="Potential data quality issues"
66
+ )
67
+
68
+
69
+ class DataProcessorAgent(BaseAgent[DataAnalysis]):
70
+ """
71
+ Processes data files to extract metadata and insights.
72
+
73
+ Uses lightweight analysis to avoid token limits while providing
74
+ comprehensive data understanding for downstream agents.
75
+ """
76
+
77
+ MEMORY_KEY = "data_analysis"
78
+ SUPPORTED_EXTENSIONS = {".csv", ".json", ".xlsx", ".xls", ".parquet"}
79
+
80
+ def __init__(
81
+ self,
82
+ llm: LLMProvider,
83
+ memory: SharedMemory,
84
+ name: Optional[str] = None,
85
+ ) -> None:
86
+ super().__init__(llm, memory, name or "DataProcessor")
87
+
88
+ def execute(self, context: AgentContext) -> DataAnalysis:
89
+ """Override to include data extraction before LLM analysis."""
90
+ logger.info(f"[{self._name}] Processing {len(context.data_paths)} data files")
91
+
92
+ file_infos = []
93
+ for path in context.data_paths:
94
+ info = self._extract_file_metadata(path)
95
+ if info:
96
+ file_infos.append(info)
97
+
98
+ self._memory.store(
99
+ key="raw_file_info",
100
+ value=[f.model_dump() for f in file_infos],
101
+ agent_name=self._name,
102
+ )
103
+
104
+ return super().execute(context)
105
+
106
+ def _extract_file_metadata(self, file_path: str) -> Optional[DataFileInfo]:
107
+ """Extract metadata from a data file using pandas."""
108
+ path = Path(file_path)
109
+
110
+ if not path.exists():
111
+ logger.warning(f"File not found: {path}")
112
+ return None
113
+
114
+ if path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
115
+ logger.warning(f"Unsupported file type: {path.suffix}")
116
+ return None
117
+
118
+ try:
119
+ df = self._load_dataframe(path)
120
+ columns = self._analyze_columns(df)
121
+
122
+ return DataFileInfo(
123
+ file_path=str(path),
124
+ file_type=path.suffix.lower(),
125
+ shape=(len(df), len(df.columns)),
126
+ columns=columns,
127
+ memory_usage_mb=df.memory_usage(deep=True).sum() / (1024 * 1024),
128
+ )
129
+ except Exception as e:
130
+ logger.error(f"Failed to process {path}: {e}")
131
+ return None
132
+
133
+ def _load_dataframe(self, path: Path) -> pd.DataFrame:
134
+ """Load a dataframe from various file formats."""
135
+ suffix = path.suffix.lower()
136
+
137
+ if suffix == ".csv":
138
+ return pd.read_csv(path)
139
+ elif suffix == ".json":
140
+ return pd.read_json(path)
141
+ elif suffix in {".xlsx", ".xls"}:
142
+ return pd.read_excel(path)
143
+ elif suffix == ".parquet":
144
+ return pd.read_parquet(path)
145
+ else:
146
+ raise ValueError(f"Unsupported format: {suffix}")
147
+
148
+ def _analyze_columns(self, df: pd.DataFrame) -> list[ColumnInfo]:
149
+ """Analyze each column in the dataframe."""
150
+ columns = []
151
+
152
+ for col in df.columns:
153
+ series = df[col]
154
+ sample_values = series.dropna().head(5).tolist()
155
+
156
+ columns.append(ColumnInfo(
157
+ name=str(col),
158
+ dtype=str(series.dtype),
159
+ non_null_count=int(series.count()),
160
+ unique_count=int(series.nunique()),
161
+ sample_values=sample_values,
162
+ ))
163
+
164
+ return columns
165
+
166
+ def _get_system_prompt(self) -> str:
167
+ return """You are a Data Analyst specialist in a data visualization team.
168
+
169
+ Your expertise is in understanding data structures, identifying patterns, and recommending processing steps for effective visualization.
170
+
171
+ Your responsibilities:
172
+ 1. Analyze metadata to understand data characteristics
173
+ 2. Identify insights and patterns relevant to visualization
174
+ 3. Recommend data processing and aggregation steps
175
+ 4. Suggest visualization approaches based on data types
176
+ 5. Flag potential data quality issues
177
+
178
+ Always respond with a valid JSON object matching the required schema."""
179
+
180
+ def _build_prompt(self, context: AgentContext) -> str:
181
+ file_info = self._get_from_memory("raw_file_info") or []
182
+ query_analysis = self._get_from_memory("query_analysis") or {}
183
+
184
+ file_summary = self._format_file_info(file_info)
185
+
186
+ query_context = ""
187
+ if query_analysis:
188
+ query_context = f"""
189
+ Query Analysis:
190
+ - Visualization Types: {query_analysis.get('visualization_types', [])}
191
+ - Key Points: {query_analysis.get('key_points', [])}
192
+ - Data Requirements: {query_analysis.get('data_requirements', [])}
193
+ """
194
+
195
+ return f"""Analyze the following data files for visualization purposes.
196
+
197
+ User Query: {context.query}
198
+ {query_context}
199
+
200
+ Data Files:
201
+ {file_summary}
202
+
203
+ Based on this metadata, provide a JSON object with these fields.
204
+ IMPORTANT: All list fields must contain SIMPLE STRINGS, not objects.
205
+
206
+ {{
207
+ "insights": ["string1", "string2", ...], // Simple string descriptions of patterns
208
+ "processing_steps": ["step1", "step2", ...], // Simple string descriptions of steps
209
+ "aggregations_needed": ["agg1", "agg2", ...], // Simple string descriptions
210
+ "visualization_hints": ["hint1", "hint2", ...], // Simple string hints
211
+ "potential_issues": ["issue1", "issue2", ...] // Simple string issues
212
+ }}
213
+
214
+ JSON Response:"""
215
+
216
+ def _format_file_info(self, file_info: list[dict]) -> str:
217
+ """Format file information for the prompt."""
218
+ if not file_info:
219
+ return "No data files available."
220
+
221
+ lines = []
222
+ for f in file_info:
223
+ lines.append(f"\nFile: {f['file_path']}")
224
+ lines.append(f" Type: {f['file_type']}")
225
+ lines.append(f" Shape: {f['shape'][0]} rows × {f['shape'][1]} columns")
226
+ lines.append(" Columns:")
227
+
228
+ for col in f.get("columns", []):
229
+ samples = ", ".join(str(v) for v in col.get("sample_values", [])[:3])
230
+ lines.append(
231
+ f" - {col['name']} ({col['dtype']}): "
232
+ f"{col['unique_count']} unique, samples: [{samples}]"
233
+ )
234
+
235
+ return "\n".join(lines)
236
+
237
+ def _normalize_list_field(self, value: Any) -> list[str]:
238
+ """Normalize a field that should be a list of strings."""
239
+ if value is None:
240
+ return []
241
+
242
+ if isinstance(value, dict):
243
+ return [f"{k}: {v}" for k, v in value.items()]
244
+
245
+ if isinstance(value, list):
246
+ result = []
247
+ for item in value:
248
+ if isinstance(item, str):
249
+ result.append(item)
250
+ elif isinstance(item, dict):
251
+ desc_keys = ["description", "desc", "text", "value", "step", "hint", "issue"]
252
+ for key in desc_keys:
253
+ if key in item:
254
+ result.append(str(item[key]))
255
+ break
256
+ else:
257
+ result.append(str(item))
258
+ else:
259
+ result.append(str(item))
260
+ return result
261
+
262
+ return [str(value)]
263
+
264
+ def _parse_response(self, response: str) -> DataAnalysis:
265
+ data = self._extract_json(response)
266
+
267
+ data["insights"] = self._normalize_list_field(data.get("insights"))
268
+ data["processing_steps"] = self._normalize_list_field(data.get("processing_steps"))
269
+ data["aggregations_needed"] = self._normalize_list_field(data.get("aggregations_needed"))
270
+ data["visualization_hints"] = self._normalize_list_field(data.get("visualization_hints"))
271
+ data["potential_issues"] = self._normalize_list_field(data.get("potential_issues"))
272
+
273
+ file_info = self._get_from_memory("raw_file_info") or []
274
+ data["files"] = file_info
275
+
276
+ return DataAnalysis(**data)
277
+
278
+ def _get_output_key(self) -> str:
279
+ return self.MEMORY_KEY
280
+
coda/agents/debug_agent.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Debug Agent for CoDA.
3
+
4
+ Executes generated code, diagnoses errors, and applies fixes
5
+ to produce working visualizations.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import subprocess
11
+ import sys
12
+ import tempfile
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+ from pydantic import BaseModel, Field
17
+
18
+ from coda.core.base_agent import AgentContext, BaseAgent
19
+ from coda.core.llm import LLMProvider
20
+ from coda.core.memory import SharedMemory
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ExecutionResult(BaseModel):
26
+ """Structured output from the Debug Agent."""
27
+
28
+ success: bool = Field(
29
+ description="Whether execution succeeded"
30
+ )
31
+ output_file: Optional[str] = Field(
32
+ default=None,
33
+ description="Path to the generated visualization"
34
+ )
35
+ stdout: str = Field(
36
+ default="",
37
+ description="Standard output from execution"
38
+ )
39
+ stderr: str = Field(
40
+ default="",
41
+ description="Error output from execution"
42
+ )
43
+ error_diagnosis: Optional[str] = Field(
44
+ default=None,
45
+ description="Diagnosis of any errors"
46
+ )
47
+ corrected_code: Optional[str] = Field(
48
+ default=None,
49
+ description="Fixed code if errors occurred"
50
+ )
51
+ fix_applied: bool = Field(
52
+ default=False,
53
+ description="Whether a fix was applied"
54
+ )
55
+ execution_time_seconds: float = Field(
56
+ default=0.0,
57
+ description="Time taken to execute"
58
+ )
59
+
60
+
61
+ class DebugAgent(BaseAgent[ExecutionResult]):
62
+ """
63
+ Executes generated code and handles errors.
64
+
65
+ Runs the visualization code in a subprocess with timeout,
66
+ diagnoses errors, and attempts automatic fixes.
67
+ """
68
+
69
+ MEMORY_KEY = "execution_result"
70
+
71
+ def __init__(
72
+ self,
73
+ llm: LLMProvider,
74
+ memory: SharedMemory,
75
+ timeout_seconds: int = 60,
76
+ output_directory: str = "outputs",
77
+ name: Optional[str] = None,
78
+ ) -> None:
79
+ super().__init__(llm, memory, name or "DebugAgent")
80
+ self._timeout = timeout_seconds
81
+ self._output_dir = Path(output_directory)
82
+ self._output_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ def execute(self, context: AgentContext) -> ExecutionResult:
85
+ """Execute the generated code and handle errors."""
86
+ logger.info(f"[{self._name}] Starting code execution")
87
+
88
+ generated_code = self._get_from_memory("generated_code")
89
+ if not generated_code:
90
+ return ExecutionResult(
91
+ success=False,
92
+ stderr="No generated code found in memory",
93
+ )
94
+
95
+ code = generated_code.get("code", "")
96
+ output_filename = generated_code.get("output_filename", "output.png")
97
+
98
+ code = self._prepare_code(code, output_filename)
99
+
100
+ result = self._execute_code(code)
101
+
102
+ if not result.success and result.stderr:
103
+ logger.warning(f"[{self._name}] Code execution failed: {result.stderr[:500]}")
104
+ logger.info(f"[{self._name}] Attempting to fix errors")
105
+ fixed_result = self._attempt_fix(code, result.stderr, context)
106
+ if fixed_result.success:
107
+ self._store_result(fixed_result)
108
+ logger.info(f"[{self._name}] Fix successful!")
109
+ return fixed_result
110
+ logger.warning(f"[{self._name}] Fix attempt failed")
111
+ result.error_diagnosis = fixed_result.error_diagnosis
112
+ result.corrected_code = fixed_result.corrected_code
113
+
114
+ self._store_result(result)
115
+ logger.info(f"[{self._name}] Execution complete: success={result.success}")
116
+ return result
117
+
118
+ def _prepare_code(self, code: str, output_filename: str) -> str:
119
+ """Prepare code for execution by setting up paths."""
120
+ output_path = self._output_dir / output_filename
121
+
122
+ code = code.replace(
123
+ f"'{output_filename}'",
124
+ f"r'{output_path}'"
125
+ )
126
+ code = code.replace(
127
+ f'"{output_filename}"',
128
+ f"r'{output_path}'"
129
+ )
130
+
131
+ if "plt.savefig" not in code and "fig.savefig" not in code:
132
+ code += f"\nplt.savefig(r'{output_path}', dpi=150, bbox_inches='tight')\n"
133
+
134
+ return code
135
+
136
+ def _execute_code(self, code: str) -> ExecutionResult:
137
+ """Execute Python code in a subprocess."""
138
+ import time
139
+ start_time = time.time()
140
+
141
+ with tempfile.NamedTemporaryFile(
142
+ mode="w",
143
+ suffix=".py",
144
+ delete=False,
145
+ encoding="utf-8"
146
+ ) as f:
147
+ f.write(code)
148
+ temp_file = f.name
149
+
150
+ try:
151
+ result = subprocess.run(
152
+ [sys.executable, temp_file],
153
+ capture_output=True,
154
+ text=True,
155
+ timeout=self._timeout,
156
+ cwd=str(self._output_dir.parent),
157
+ )
158
+
159
+ execution_time = time.time() - start_time
160
+
161
+ output_files = list(self._output_dir.glob("*.png"))
162
+ output_file = str(output_files[-1]) if output_files else None
163
+
164
+ return ExecutionResult(
165
+ success=result.returncode == 0,
166
+ output_file=output_file,
167
+ stdout=result.stdout,
168
+ stderr=result.stderr,
169
+ execution_time_seconds=execution_time,
170
+ )
171
+
172
+ except subprocess.TimeoutExpired:
173
+ return ExecutionResult(
174
+ success=False,
175
+ stderr=f"Execution timed out after {self._timeout} seconds",
176
+ )
177
+ except Exception as e:
178
+ return ExecutionResult(
179
+ success=False,
180
+ stderr=str(e),
181
+ )
182
+ finally:
183
+ try:
184
+ os.unlink(temp_file)
185
+ except OSError:
186
+ pass
187
+
188
+ def _attempt_fix(
189
+ self,
190
+ original_code: str,
191
+ error_message: str,
192
+ context: AgentContext,
193
+ ) -> ExecutionResult:
194
+ """Attempt to fix code errors using the LLM."""
195
+ fix_prompt = f"""The following Python visualization code produced an error. Please fix it.
196
+
197
+ Original Code:
198
+ ```python
199
+ {original_code}
200
+ ```
201
+
202
+ Error Message:
203
+ {error_message}
204
+
205
+ Provide a JSON response with:
206
+ - diagnosis: What caused the error
207
+ - corrected_code: The fixed Python code
208
+
209
+ IMPORTANT: Return ONLY valid JSON. Do not include markdown formatting or explanations outside the JSON.
210
+ Safe to assume standard libraries (matplotlib, seaborn, pandas, numpy) are available.
211
+
212
+ JSON Response:"""
213
+
214
+ response = self._llm.complete(
215
+ prompt=fix_prompt,
216
+ system_prompt="You are an expert Python debugger. Fix the code error and provide corrected code.",
217
+ )
218
+
219
+ try:
220
+ data = self._extract_json(response.content)
221
+ diagnosis = data.get("diagnosis", "Unknown error")
222
+ corrected_code = data.get("corrected_code", "")
223
+
224
+ if corrected_code:
225
+ output_filename = "output.png"
226
+ corrected_code = self._prepare_code(corrected_code, output_filename)
227
+ result = self._execute_code(corrected_code)
228
+ result.error_diagnosis = diagnosis
229
+ result.corrected_code = corrected_code
230
+ result.fix_applied = result.success
231
+ return result
232
+
233
+ except Exception as e:
234
+ logger.error(f"Failed to parse fix response: {e}")
235
+
236
+ return ExecutionResult(
237
+ success=False,
238
+ stderr=error_message,
239
+ error_diagnosis="Failed to automatically fix the error",
240
+ )
241
+
242
+ def _build_prompt(self, context: AgentContext) -> str:
243
+ return ""
244
+
245
+ def _get_system_prompt(self) -> str:
246
+ return ""
247
+
248
+ def _parse_response(self, response: str) -> ExecutionResult:
249
+ return ExecutionResult(success=False)
250
+
251
+ def _get_output_key(self) -> str:
252
+ return self.MEMORY_KEY
coda/agents/design_explorer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Design Explorer Agent for CoDA.
3
+
4
+ Generates aesthetic and content specifications for visualizations,
5
+ optimizing for user experience and effective communication.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any, Optional
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+ from coda.core.base_agent import AgentContext, BaseAgent
14
+ from coda.core.llm import LLMProvider
15
+ from coda.core.memory import SharedMemory
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ColorScheme(BaseModel):
21
+ """Color scheme specification."""
22
+
23
+ primary: str = "#4A90D9"
24
+ secondary: list[str] = Field(default_factory=lambda: ["#67B7DC", "#A5D6A7"])
25
+ background: str = "#FFFFFF"
26
+ text: str = "#333333"
27
+ accent: Optional[str] = "#FF6B6B"
28
+
29
+
30
+ class LayoutSpec(BaseModel):
31
+ """Layout specification."""
32
+
33
+ figure_size: tuple[int, int] = (10, 6)
34
+ margins: dict[str, float] = Field(default_factory=lambda: {"top": 0.1, "bottom": 0.15, "left": 0.1, "right": 0.1})
35
+ title_position: str = "top"
36
+ legend_position: str = "right"
37
+ grid_visible: bool = True
38
+
39
+
40
+ class DesignSpec(BaseModel):
41
+ """Structured output from the Design Explorer."""
42
+
43
+ color_scheme: ColorScheme = Field(default_factory=ColorScheme)
44
+ layout: LayoutSpec = Field(default_factory=LayoutSpec)
45
+ typography: dict = Field(
46
+ default_factory=lambda: {"title_size": 14, "label_size": 12, "tick_size": 10, "font_family": "sans-serif"}
47
+ )
48
+ annotations: list[dict] = Field(
49
+ default_factory=list,
50
+ description="Text annotations to add"
51
+ )
52
+ implementation_guidelines: list[str] = Field(
53
+ default_factory=list,
54
+ description="Specific implementation instructions"
55
+ )
56
+ quality_metrics: dict = Field(
57
+ default_factory=lambda: {"readability": "high", "aesthetics": "clean", "clarity": "clear"}
58
+ )
59
+ alternatives: list[dict] = Field(
60
+ default_factory=list,
61
+ description="Alternative design approaches"
62
+ )
63
+ success_indicators: list[str] = Field(
64
+ default_factory=list,
65
+ description="How to know if the design is successful"
66
+ )
67
+
68
+
69
+ class DesignExplorerAgent(BaseAgent[DesignSpec]):
70
+ """
71
+ Generates aesthetic design specifications for visualizations.
72
+
73
+ Focuses on creating visually appealing and effective designs
74
+ that communicate data insights clearly.
75
+ """
76
+
77
+ MEMORY_KEY = "design_spec"
78
+
79
+ def __init__(
80
+ self,
81
+ llm: LLMProvider,
82
+ memory: SharedMemory,
83
+ name: Optional[str] = None,
84
+ ) -> None:
85
+ super().__init__(llm, memory, name or "DesignExplorer")
86
+
87
+ def _get_system_prompt(self) -> str:
88
+ return """You are a Visualization Design specialist in a data visualization team.
89
+
90
+ Your expertise is in creating aesthetically pleasing and effective data visualizations that communicate insights clearly.
91
+
92
+ Your responsibilities:
93
+ 1. Design harmonious color schemes suitable for the data
94
+ 2. Specify optimal layouts for readability
95
+ 3. Choose appropriate typography
96
+ 4. Plan meaningful annotations
97
+ 5. Define quality metrics for evaluation
98
+ 6. Consider accessibility and best practices
99
+
100
+ Design principles to follow:
101
+ - Use color purposefully, not decoratively
102
+ - Ensure sufficient contrast for readability
103
+ - Maintain consistent visual hierarchy
104
+ - Minimize chart junk and maximize data-ink ratio
105
+ - Consider colorblind-friendly palettes when appropriate
106
+
107
+ Always respond with a valid JSON object matching the required schema."""
108
+
109
+ def _build_prompt(self, context: AgentContext) -> str:
110
+ query_analysis = self._get_from_memory("query_analysis") or {}
111
+ data_analysis = self._get_from_memory("data_analysis") or {}
112
+ visual_mapping = self._get_from_memory("visual_mapping") or {}
113
+
114
+ feedback_section = ""
115
+ if context.feedback:
116
+ feedback_section = f"""
117
+ Design Feedback (iteration {context.iteration}):
118
+ {context.feedback}
119
+
120
+ Please address the feedback in your revised design.
121
+ """
122
+
123
+ return f"""Create a design specification for the following visualization.
124
+
125
+ User Query: {context.query}
126
+
127
+ Visualization Type: {visual_mapping.get('chart_type', 'Unknown')}
128
+ Visualization Goals: {visual_mapping.get('visualization_goals', [])}
129
+ Styling Hints: {visual_mapping.get('styling_hints', {})}
130
+ {feedback_section}
131
+
132
+ Provide a JSON object containing:
133
+ - color_scheme: {{
134
+ "primary": "#hex",
135
+ "secondary": ["#hex", ...],
136
+ "background": "#hex",
137
+ "text": "#hex",
138
+ "accent": "#hex" (optional)
139
+ }}
140
+ - layout: {{
141
+ "figure_size": [width, height],
142
+ "margins": {{"top": 0.1, "bottom": 0.1, "left": 0.1, "right": 0.1}},
143
+ "title_position": "top|center",
144
+ "legend_position": "right|bottom|none",
145
+ "grid_visible": true|false
146
+ }}
147
+ - typography: {{
148
+ "title_size": 16,
149
+ "label_size": 12,
150
+ "tick_size": 10,
151
+ "font_family": "sans-serif"
152
+ }}
153
+ - annotations: List of {{"text": "...", "position": "...", "style": "..."}}
154
+ - implementation_guidelines: Specific instructions for implementation
155
+ - quality_metrics: {{"readability": "...", "aesthetics": "...", "clarity": "..."}}
156
+ - alternatives: Alternative design approaches
157
+ - success_indicators: How to evaluate design success
158
+
159
+ JSON Response:"""
160
+
161
+ def _normalize_to_list(self, value: Any, as_dicts: bool = False) -> list:
162
+ """Normalize a value to a list."""
163
+ if value is None:
164
+ return []
165
+ if isinstance(value, str):
166
+ if as_dicts:
167
+ return [{"description": value}]
168
+ return [value] if value.strip() else []
169
+ if isinstance(value, list):
170
+ if as_dicts:
171
+ return [item if isinstance(item, dict) else {"description": str(item)} for item in value]
172
+ return [str(item) if not isinstance(item, str) else item for item in value]
173
+ return []
174
+
175
+ def _parse_response(self, response: str) -> DesignSpec:
176
+ data = self._extract_json(response)
177
+
178
+ # Normalize list fields
179
+ data["implementation_guidelines"] = self._normalize_to_list(data.get("implementation_guidelines"))
180
+ data["success_indicators"] = self._normalize_to_list(data.get("success_indicators"))
181
+ data["alternatives"] = self._normalize_to_list(data.get("alternatives"), as_dicts=True)
182
+ data["annotations"] = self._normalize_to_list(data.get("annotations"), as_dicts=True)
183
+
184
+ # Ensure nested models have valid data
185
+ if "color_scheme" not in data or not isinstance(data["color_scheme"], dict):
186
+ data["color_scheme"] = {}
187
+ if "layout" not in data or not isinstance(data["layout"], dict):
188
+ data["layout"] = {}
189
+ else:
190
+ # Sanitize figure_size to prevent crashes (e.g. LLM giving pixels instead of inches)
191
+ figsize = data["layout"].get("figure_size")
192
+ if isinstance(figsize, (list, tuple)) and len(figsize) == 2:
193
+ w, h = figsize
194
+ # If width > 50, assume pixels and scale down, or just clamp
195
+ if w > 50 or h > 50:
196
+ logger.warning(f"Extremely large figure size detected: {figsize}. Clamping to (12, 8).")
197
+ data["layout"]["figure_size"] = (12, 8)
198
+
199
+ if "typography" not in data or not isinstance(data["typography"], dict):
200
+ data["typography"] = {}
201
+ if "quality_metrics" not in data or not isinstance(data["quality_metrics"], dict):
202
+ data["quality_metrics"] = {}
203
+
204
+ return DesignSpec(**data)
205
+
206
+ def _get_output_key(self) -> str:
207
+ return self.MEMORY_KEY
coda/agents/query_analyzer.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Analyzer Agent for CoDA.
3
+
4
+ Interprets natural language queries to extract visualization intent,
5
+ decompose requirements into actionable items, and provide guidance
6
+ for downstream agents.
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+ from coda.core.base_agent import AgentContext, BaseAgent
14
+ from coda.core.llm import LLMProvider
15
+ from coda.core.memory import SharedMemory
16
+
17
+
18
+ class QueryAnalysis(BaseModel):
19
+ """Structured output from the Query Analyzer."""
20
+
21
+ visualization_types: list[str] = Field(
22
+ description="Suggested visualization types (e.g., line chart, bar chart)"
23
+ )
24
+ key_points: list[str] = Field(
25
+ description="Key data points or aspects to visualize"
26
+ )
27
+ todo_list: list[str] = Field(
28
+ description="Decomposed list of tasks for visualization creation"
29
+ )
30
+ data_requirements: list[str] = Field(
31
+ description="Required data columns or features"
32
+ )
33
+ constraints: list[str] = Field(
34
+ default_factory=list,
35
+ description="Any constraints mentioned in the query"
36
+ )
37
+ ambiguities: list[str] = Field(
38
+ default_factory=list,
39
+ description="Ambiguous aspects that may need clarification"
40
+ )
41
+
42
+
43
+ class QueryAnalyzerAgent(BaseAgent[QueryAnalysis]):
44
+ """
45
+ Analyzes natural language queries to extract visualization intent.
46
+
47
+ This agent is the first in the pipeline, responsible for understanding
48
+ what the user wants to visualize and breaking it down into actionable steps.
49
+ """
50
+
51
+ MEMORY_KEY = "query_analysis"
52
+
53
+ def __init__(
54
+ self,
55
+ llm: LLMProvider,
56
+ memory: SharedMemory,
57
+ name: Optional[str] = None,
58
+ ) -> None:
59
+ super().__init__(llm, memory, name or "QueryAnalyzer")
60
+
61
+ def _get_system_prompt(self) -> str:
62
+ return """You are a Query Analyzer specialist in a data visualization team.
63
+
64
+ Your expertise lies in interpreting natural language requests and extracting clear, actionable requirements for creating visualizations.
65
+
66
+ Your responsibilities:
67
+ 1. Identify the type(s) of visualizations that best suit the request
68
+ 2. Extract key data points and features to be visualized
69
+ 3. Decompose the request into a clear TODO list for the visualization pipeline
70
+ 4. Identify required data columns or features
71
+ 5. Note any constraints or preferences mentioned
72
+ 6. Flag ambiguities that might affect the visualization
73
+
74
+ Always respond with a valid JSON object matching the required schema."""
75
+
76
+ def _build_prompt(self, context: AgentContext) -> str:
77
+ metadata = self._get_from_memory("metadata_summary")
78
+ metadata_section = ""
79
+ if metadata:
80
+ metadata_section = f"""
81
+ Available Metadata:
82
+ {metadata}
83
+ """
84
+
85
+ feedback_section = ""
86
+ if context.feedback:
87
+ feedback_section = f"""
88
+ Previous Feedback (iteration {context.iteration}):
89
+ {context.feedback}
90
+ """
91
+
92
+ return f"""Analyze the following visualization query and extract structured requirements.
93
+
94
+ Query: {context.query}
95
+ {metadata_section}{feedback_section}
96
+
97
+ Respond with a JSON object containing:
98
+ - visualization_types: List of suggested chart types
99
+ - key_points: Key aspects or data points to highlight
100
+ - todo_list: Step-by-step tasks for creating the visualization
101
+ - data_requirements: Required data columns or features
102
+ - constraints: Any mentioned constraints or preferences
103
+ - ambiguities: Unclear aspects that may need clarification
104
+
105
+ JSON Response:"""
106
+
107
+ def _parse_response(self, response: str) -> QueryAnalysis:
108
+ data = self._extract_json(response)
109
+ return QueryAnalysis(**data)
110
+
111
+ def _get_output_key(self) -> str:
112
+ return self.MEMORY_KEY
coda/agents/search_agent.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Search Agent for CoDA.
3
+
4
+ Retrieves relevant code examples and patterns from a knowledge base
5
+ to assist code generation with proven implementations.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+ from coda.core.base_agent import AgentContext, BaseAgent
13
+ from coda.core.llm import LLMProvider
14
+ from coda.core.memory import SharedMemory
15
+
16
+
17
+ class CodeExample(BaseModel):
18
+ """A retrieved code example."""
19
+
20
+ title: str = ""
21
+ description: str = ""
22
+ code: str = ""
23
+ library: str = "matplotlib"
24
+ relevance_score: float = 0.5
25
+
26
+
27
+ class SearchResult(BaseModel):
28
+ """Structured output from the Search Agent."""
29
+
30
+ search_queries: list[str] = Field(
31
+ default_factory=list,
32
+ description="Queries used to find examples"
33
+ )
34
+ examples: list[CodeExample] = Field(
35
+ default_factory=list,
36
+ description="Retrieved code examples"
37
+ )
38
+ recommended_libraries: list[str] = Field(
39
+ default_factory=list,
40
+ description="Recommended visualization libraries"
41
+ )
42
+ implementation_notes: list[str] = Field(
43
+ default_factory=list,
44
+ description="Notes on implementing the visualization"
45
+ )
46
+
47
+
48
+ # Built-in code examples for common visualization patterns
49
+ CODE_EXAMPLES_DB: dict[str, list[dict]] = {
50
+ "line": [
51
+ {
52
+ "title": "Basic Line Chart",
53
+ "description": "Simple line chart with matplotlib",
54
+ "code": """import matplotlib.pyplot as plt
55
+
56
+ fig, ax = plt.subplots(figsize=(10, 6))
57
+ ax.plot(x_data, y_data, marker='o', linewidth=2)
58
+ ax.set_xlabel('X Label')
59
+ ax.set_ylabel('Y Label')
60
+ ax.set_title('Line Chart Title')
61
+ ax.grid(True, alpha=0.3)
62
+ plt.tight_layout()
63
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
64
+ "library": "matplotlib",
65
+ "relevance_score": 0.9,
66
+ },
67
+ {
68
+ "title": "Multi-line Chart with Legend",
69
+ "description": "Multiple lines with different colors and legend",
70
+ "code": """import matplotlib.pyplot as plt
71
+
72
+ fig, ax = plt.subplots(figsize=(10, 6))
73
+ for label, data in grouped_data.items():
74
+ ax.plot(data['x'], data['y'], label=label, marker='o')
75
+ ax.set_xlabel('X Label')
76
+ ax.set_ylabel('Y Label')
77
+ ax.set_title('Multi-line Chart')
78
+ ax.legend(loc='best')
79
+ ax.grid(True, alpha=0.3)
80
+ plt.tight_layout()
81
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
82
+ "library": "matplotlib",
83
+ "relevance_score": 0.85,
84
+ },
85
+ ],
86
+ "bar": [
87
+ {
88
+ "title": "Basic Bar Chart",
89
+ "description": "Vertical bar chart with matplotlib",
90
+ "code": """import matplotlib.pyplot as plt
91
+
92
+ fig, ax = plt.subplots(figsize=(10, 6))
93
+ bars = ax.bar(categories, values, color='steelblue', edgecolor='black')
94
+ ax.set_xlabel('Category')
95
+ ax.set_ylabel('Value')
96
+ ax.set_title('Bar Chart Title')
97
+ ax.bar_label(bars, fmt='%.1f')
98
+ plt.xticks(rotation=45, ha='right')
99
+ plt.tight_layout()
100
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
101
+ "library": "matplotlib",
102
+ "relevance_score": 0.9,
103
+ },
104
+ {
105
+ "title": "Grouped Bar Chart",
106
+ "description": "Side-by-side bars for comparison",
107
+ "code": """import matplotlib.pyplot as plt
108
+ import numpy as np
109
+
110
+ x = np.arange(len(categories))
111
+ width = 0.35
112
+
113
+ fig, ax = plt.subplots(figsize=(10, 6))
114
+ bars1 = ax.bar(x - width/2, values1, width, label='Group 1')
115
+ bars2 = ax.bar(x + width/2, values2, width, label='Group 2')
116
+ ax.set_xlabel('Category')
117
+ ax.set_ylabel('Value')
118
+ ax.set_title('Grouped Bar Chart')
119
+ ax.set_xticks(x)
120
+ ax.set_xticklabels(categories)
121
+ ax.legend()
122
+ plt.tight_layout()
123
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
124
+ "library": "matplotlib",
125
+ "relevance_score": 0.85,
126
+ },
127
+ ],
128
+ "scatter": [
129
+ {
130
+ "title": "Basic Scatter Plot",
131
+ "description": "Scatter plot with optional color encoding",
132
+ "code": """import matplotlib.pyplot as plt
133
+
134
+ fig, ax = plt.subplots(figsize=(10, 6))
135
+ scatter = ax.scatter(x_data, y_data, c=color_data, s=50, alpha=0.7, cmap='viridis')
136
+ ax.set_xlabel('X Label')
137
+ ax.set_ylabel('Y Label')
138
+ ax.set_title('Scatter Plot Title')
139
+ plt.colorbar(scatter, label='Color Label')
140
+ plt.tight_layout()
141
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
142
+ "library": "matplotlib",
143
+ "relevance_score": 0.9,
144
+ },
145
+ ],
146
+ "pie": [
147
+ {
148
+ "title": "Pie Chart",
149
+ "description": "Pie chart with percentages",
150
+ "code": """import matplotlib.pyplot as plt
151
+
152
+ fig, ax = plt.subplots(figsize=(10, 8))
153
+ wedges, texts, autotexts = ax.pie(
154
+ values, labels=labels, autopct='%1.1f%%',
155
+ startangle=90, colors=plt.cm.Pastel1.colors
156
+ )
157
+ ax.set_title('Pie Chart Title')
158
+ plt.tight_layout()
159
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
160
+ "library": "matplotlib",
161
+ "relevance_score": 0.9,
162
+ },
163
+ ],
164
+ "heatmap": [
165
+ {
166
+ "title": "Heatmap with Seaborn",
167
+ "description": "Correlation or matrix heatmap",
168
+ "code": """import matplotlib.pyplot as plt
169
+ import seaborn as sns
170
+
171
+ fig, ax = plt.subplots(figsize=(12, 8))
172
+ sns.heatmap(data_matrix, annot=True, fmt='.2f', cmap='coolwarm', ax=ax)
173
+ ax.set_title('Heatmap Title')
174
+ plt.tight_layout()
175
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
176
+ "library": "seaborn",
177
+ "relevance_score": 0.9,
178
+ },
179
+ ],
180
+ "histogram": [
181
+ {
182
+ "title": "Histogram",
183
+ "description": "Distribution histogram with optional KDE",
184
+ "code": """import matplotlib.pyplot as plt
185
+ import seaborn as sns
186
+
187
+ fig, ax = plt.subplots(figsize=(10, 6))
188
+ sns.histplot(data, kde=True, ax=ax, color='steelblue')
189
+ ax.set_xlabel('Value')
190
+ ax.set_ylabel('Frequency')
191
+ ax.set_title('Histogram Title')
192
+ plt.tight_layout()
193
+ plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
194
+ "library": "seaborn",
195
+ "relevance_score": 0.9,
196
+ },
197
+ ],
198
+ }
199
+
200
+
201
+ class SearchAgent(BaseAgent[SearchResult]):
202
+ """
203
+ Searches for relevant code examples to guide code generation.
204
+
205
+ Uses a built-in knowledge base of visualization patterns
206
+ and can be extended to search external resources.
207
+ """
208
+
209
+ MEMORY_KEY = "search_results"
210
+
211
+ def __init__(
212
+ self,
213
+ llm: LLMProvider,
214
+ memory: SharedMemory,
215
+ name: Optional[str] = None,
216
+ ) -> None:
217
+ super().__init__(llm, memory, name or "SearchAgent")
218
+
219
+ def _get_system_prompt(self) -> str:
220
+ return """You are a Code Search specialist in a data visualization team.
221
+
222
+ Your expertise is in finding and recommending relevant code examples and patterns for visualization implementation.
223
+
224
+ Your responsibilities:
225
+ 1. Formulate effective search queries based on requirements
226
+ 2. Select the most relevant examples from available patterns
227
+ 3. Recommend appropriate libraries for the task
228
+ 4. Provide implementation guidance
229
+
230
+ Consider the specific chart type, data characteristics, and styling requirements when selecting examples.
231
+
232
+ Always respond with a valid JSON object matching the required schema."""
233
+
234
+ def _build_prompt(self, context: AgentContext) -> str:
235
+ visual_mapping = self._get_from_memory("visual_mapping") or {}
236
+ query_analysis = self._get_from_memory("query_analysis") or {}
237
+
238
+ chart_type = visual_mapping.get("chart_type", "")
239
+ chart_subtype = visual_mapping.get("chart_subtype", "")
240
+ styling_hints = visual_mapping.get("styling_hints", {})
241
+
242
+ available_examples = list(CODE_EXAMPLES_DB.keys())
243
+
244
+ return f"""Find relevant code examples for the following visualization requirements.
245
+
246
+ User Query: {context.query}
247
+
248
+ Visualization Mapping:
249
+ - Chart Type: {chart_type}
250
+ - Chart Subtype: {chart_subtype}
251
+ - Styling: {styling_hints}
252
+ - Goals: {visual_mapping.get('visualization_goals', [])}
253
+
254
+ Available Example Categories: {available_examples}
255
+
256
+ Provide a JSON object containing:
257
+ - search_queries: List of search queries you would use
258
+ - examples: List of relevant examples (select from available categories)
259
+ - recommended_libraries: Libraries best suited for this visualization
260
+ - implementation_notes: Tips for implementing this specific visualization
261
+
262
+ For examples, include:
263
+ - title, description, code (from your knowledge of matplotlib/seaborn)
264
+ - library used
265
+ - relevance_score (0.0 to 1.0)
266
+
267
+ JSON Response:"""
268
+
269
+ def _parse_response(self, response: str) -> SearchResult:
270
+ data = self._extract_json(response)
271
+
272
+ # Normalize examples to ensure required fields have defaults
273
+ if "examples" in data and isinstance(data["examples"], list):
274
+ normalized_examples = []
275
+ for ex in data["examples"]:
276
+ if isinstance(ex, dict):
277
+ ex.setdefault("library", "matplotlib")
278
+ ex.setdefault("title", "")
279
+ ex.setdefault("description", "")
280
+ ex.setdefault("code", "")
281
+ ex.setdefault("relevance_score", 0.5)
282
+ normalized_examples.append(ex)
283
+ data["examples"] = normalized_examples
284
+
285
+ visual_mapping = self._get_from_memory("visual_mapping") or {}
286
+ chart_type = visual_mapping.get("chart_type", "line").lower()
287
+
288
+ # Use built-in examples if LLM didn't provide any
289
+ if chart_type in CODE_EXAMPLES_DB and not data.get("examples"):
290
+ data["examples"] = CODE_EXAMPLES_DB[chart_type]
291
+
292
+ return SearchResult(**data)
293
+
294
+ def _get_output_key(self) -> str:
295
+ return self.MEMORY_KEY
coda/agents/visual_evaluator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Evaluator Agent for CoDA.
3
+
4
+ Assesses generated visualizations across multiple quality dimensions
5
+ using multimodal LLM capabilities to analyze the output image.
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ from pydantic import BaseModel, Field
13
+
14
+ from coda.core.base_agent import AgentContext, BaseAgent
15
+ from coda.core.llm import LLMProvider
16
+ from coda.core.memory import SharedMemory
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class QualityScores(BaseModel):
22
+ """Quality scores for different dimensions."""
23
+
24
+ overall: float = Field(ge=0, le=10, description="Overall quality score")
25
+ readability: float = Field(ge=0, le=10, description="How easy to read and understand")
26
+ accuracy: float = Field(ge=0, le=10, description="How accurately it represents the data")
27
+ aesthetics: float = Field(ge=0, le=10, description="Visual appeal and design quality")
28
+ layout: float = Field(ge=0, le=10, description="Layout and spacing quality")
29
+ correctness: float = Field(ge=0, le=10, description="Technical correctness")
30
+
31
+
32
+ class VisualEvaluation(BaseModel):
33
+ """Structured output from the Visual Evaluator."""
34
+
35
+ scores: QualityScores = Field(default_factory=lambda: QualityScores(
36
+ overall=5.0, readability=5.0, accuracy=5.0, aesthetics=5.0, layout=5.0, correctness=5.0
37
+ ))
38
+ strengths: list[str] = Field(default_factory=list)
39
+ issues: list[str] = Field(default_factory=list)
40
+ priority_fixes: list[str] = Field(default_factory=list)
41
+ todo_completion: dict[str, bool] = Field(default_factory=dict)
42
+ recommendations: list[str] = Field(default_factory=list)
43
+ passes_threshold: bool = Field(default=False)
44
+
45
+
46
+ class VisualEvaluatorAgent(BaseAgent[VisualEvaluation]):
47
+ """
48
+ Evaluates visualization quality using multimodal analysis.
49
+
50
+ Analyzes the output image against the original requirements
51
+ and provides detailed feedback for iterative refinement.
52
+ """
53
+
54
+ MEMORY_KEY = "visual_evaluation"
55
+
56
+ def __init__(
57
+ self,
58
+ llm: LLMProvider,
59
+ memory: SharedMemory,
60
+ min_overall_score: float = 7.0,
61
+ name: Optional[str] = None,
62
+ ) -> None:
63
+ super().__init__(llm, memory, name or "VisualEvaluator")
64
+ self._min_score = min_overall_score
65
+
66
+ def execute(self, context: AgentContext) -> VisualEvaluation:
67
+ """Execute visual evaluation using the vision model."""
68
+ logger.info(f"[{self._name}] Evaluating visualization quality")
69
+
70
+ execution_result = self._get_from_memory("execution_result")
71
+ if not execution_result or not execution_result.get("success"):
72
+ return VisualEvaluation(
73
+ scores=QualityScores(
74
+ overall=0, readability=0, accuracy=0,
75
+ aesthetics=0, layout=0, correctness=0
76
+ ),
77
+ strengths=[],
78
+ issues=["Visualization generation failed"],
79
+ priority_fixes=["Fix code execution errors"],
80
+ todo_completion={},
81
+ recommendations=["Debug and fix code errors first"],
82
+ passes_threshold=False,
83
+ )
84
+
85
+ output_file = execution_result.get("output_file")
86
+ if not output_file or not Path(output_file).exists():
87
+ return VisualEvaluation(
88
+ scores=QualityScores(
89
+ overall=0, readability=0, accuracy=0,
90
+ aesthetics=0, layout=0, correctness=0
91
+ ),
92
+ strengths=[],
93
+ issues=["Output file not found"],
94
+ priority_fixes=["Ensure code saves output correctly"],
95
+ todo_completion={},
96
+ recommendations=["Check savefig call in code"],
97
+ passes_threshold=False,
98
+ )
99
+
100
+ prompt = self._build_evaluation_prompt(context)
101
+ system_prompt = self._get_system_prompt()
102
+
103
+ try:
104
+ response = self._llm.complete_with_image(
105
+ prompt=prompt,
106
+ image_path=output_file,
107
+ system_prompt=system_prompt,
108
+ )
109
+
110
+ result = self._parse_response(response.content)
111
+ self._store_result(result)
112
+
113
+ logger.info(
114
+ f"[{self._name}] Evaluation complete: "
115
+ f"overall={result.scores.overall}, passes={result.passes_threshold}"
116
+ )
117
+ return result
118
+ except Exception as e:
119
+ logger.error(f"[{self._name}] Evaluation failed: {e}")
120
+ # Return a fallback evaluation instead of crashing
121
+ fallback = VisualEvaluation(
122
+ scores=QualityScores(
123
+ overall=5.0, readability=5.0, accuracy=5.0,
124
+ aesthetics=5.0, layout=5.0, correctness=5.0
125
+ ),
126
+ strengths=["Backup evaluation (parsing failed)"],
127
+ issues=[f"Evaluation parsing error: {str(e)}"],
128
+ priority_fixes=[],
129
+ todo_completion={},
130
+ recommendations=[],
131
+ passes_threshold=False
132
+ )
133
+ self._store_result(fallback)
134
+ return fallback
135
+
136
+ def _get_system_prompt(self) -> str:
137
+ return """You are a Visualization Quality Evaluator specialist.
138
+
139
+ Your expertise is in assessing data visualizations for quality, effectiveness, and adherence to best practices.
140
+
141
+ Evaluate visualizations on these dimensions:
142
+ 1. Readability: Clear labels, appropriate font sizes, uncluttered design
143
+ 2. Accuracy: Correct representation of data, appropriate scales
144
+ 3. Aesthetics: Visual appeal, harmonious colors, professional appearance
145
+ 4. Layout: Good use of space, proper alignment, balanced composition
146
+ 5. Correctness: Technically correct chart type, proper axis handling
147
+
148
+ Be rigorous but fair in your assessment. Provide specific, actionable feedback.
149
+
150
+ Always respond with a valid JSON object matching the required schema."""
151
+
152
+ def _build_evaluation_prompt(self, context: AgentContext) -> str:
153
+ query_analysis = self._get_from_memory("query_analysis") or {}
154
+ visual_mapping = self._get_from_memory("visual_mapping") or {}
155
+ design_spec = self._get_from_memory("design_spec") or {}
156
+
157
+ todo_list = query_analysis.get("todo_list", [])
158
+
159
+ return f"""Evaluate this visualization against the original requirements.
160
+
161
+ Original Query: {context.query}
162
+
163
+ Requirements:
164
+ - Visualization Type: {visual_mapping.get('chart_type', 'Unknown')}
165
+ - Goals: {visual_mapping.get('visualization_goals', [])}
166
+ - TODO Items: {todo_list}
167
+
168
+ Design Specifications:
169
+ - Color Scheme: {design_spec.get('color_scheme', {})}
170
+ - Success Indicators: {design_spec.get('success_indicators', [])}
171
+
172
+ Evaluate the visualization image and provide a JSON response with:
173
+ - scores: {{
174
+ "overall": 0-10,
175
+ "readability": 0-10,
176
+ "accuracy": 0-10,
177
+ "aesthetics": 0-10,
178
+ "layout": 0-10,
179
+ "correctness": 0-10
180
+ }}
181
+ - strengths: List of positive aspects
182
+ - issues: List of problems found
183
+ - priority_fixes: Most important fixes (max 3)
184
+ - todo_completion: {{"todo_item": true/false}} for each TODO
185
+ - recommendations: Improvement suggestions
186
+ - passes_threshold: true if overall >= {self._min_score}
187
+
188
+ JSON Response:"""
189
+
190
+ def _build_prompt(self, context: AgentContext) -> str:
191
+ return self._build_evaluation_prompt(context)
192
+
193
+ def _parse_response(self, response: str) -> VisualEvaluation:
194
+ data = self._extract_json(response)
195
+
196
+ # Ensure scores exists and is properly formatted
197
+ scores_data = data.get("scores", {})
198
+ if isinstance(scores_data, dict):
199
+ # Ensure all required fields have defaults
200
+ scores_data.setdefault("overall", 5.0)
201
+ scores_data.setdefault("readability", 5.0)
202
+ scores_data.setdefault("accuracy", 5.0)
203
+ scores_data.setdefault("aesthetics", 5.0)
204
+ scores_data.setdefault("layout", 5.0)
205
+ scores_data.setdefault("correctness", 5.0)
206
+ data["scores"] = QualityScores(**scores_data)
207
+
208
+ # Ensure list fields are lists
209
+ for field in ["strengths", "issues", "priority_fixes", "recommendations"]:
210
+ if field not in data or not isinstance(data[field], list):
211
+ data[field] = [data[field]] if isinstance(data.get(field), str) else []
212
+
213
+ # Ensure todo_completion is a dict
214
+ if not isinstance(data.get("todo_completion"), dict):
215
+ data["todo_completion"] = {}
216
+
217
+ # Calculate passes_threshold if not provided
218
+ if "passes_threshold" not in data:
219
+ overall = data.get("scores")
220
+ if isinstance(overall, QualityScores):
221
+ data["passes_threshold"] = overall.overall >= self._min_score
222
+ else:
223
+ data["passes_threshold"] = False
224
+
225
+ return VisualEvaluation(**data)
226
+
227
+ def _get_output_key(self) -> str:
228
+ return self.MEMORY_KEY
coda/agents/viz_mapping.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VizMapping Agent for CoDA.
3
+
4
+ Maps query semantics and data characteristics to visualization primitives,
5
+ selecting appropriate chart types and defining data-to-visual bindings.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+ from coda.core.base_agent import AgentContext, BaseAgent
13
+ from coda.core.llm import LLMProvider
14
+ from coda.core.memory import SharedMemory
15
+
16
+
17
+ class DataTransformation(BaseModel):
18
+ """A data transformation step."""
19
+
20
+ operation: str
21
+ columns: list[str]
22
+ parameters: dict
23
+
24
+
25
+ class VisualMapping(BaseModel):
26
+ """Structured output from the VizMapping Agent."""
27
+
28
+ chart_type: str = Field(
29
+ description="Primary chart type (e.g., line, bar, scatter)"
30
+ )
31
+ chart_subtype: Optional[str] = Field(
32
+ default=None,
33
+ description="Chart subtype if applicable (e.g., stacked, grouped)"
34
+ )
35
+ x_axis: Optional[dict] = Field(
36
+ default=None,
37
+ description="X-axis configuration (column, label, type)"
38
+ )
39
+ y_axis: Optional[dict] = Field(
40
+ default=None,
41
+ description="Y-axis configuration (column, label, type)"
42
+ )
43
+ color_encoding: Optional[dict] = Field(
44
+ default=None,
45
+ description="Color encoding configuration"
46
+ )
47
+ size_encoding: Optional[dict] = Field(
48
+ default=None,
49
+ description="Size encoding for scatter plots"
50
+ )
51
+ transformations: list[dict] = Field(
52
+ default_factory=list,
53
+ description="Data transformations to apply"
54
+ )
55
+ styling_hints: dict = Field(
56
+ default_factory=dict,
57
+ description="Visual styling recommendations"
58
+ )
59
+ visualization_goals: list[str] = Field(
60
+ description="High-level goals for the visualization"
61
+ )
62
+ rationale: str = Field(
63
+ description="Explanation for the chosen visualization approach"
64
+ )
65
+
66
+
67
+ class VizMappingAgent(BaseAgent[VisualMapping]):
68
+ """
69
+ Maps query semantics to visualization specifications.
70
+
71
+ Bridges the gap between data analysis and code generation by
72
+ defining exactly how data should be visualized.
73
+ """
74
+
75
+ MEMORY_KEY = "visual_mapping"
76
+
77
+ def __init__(
78
+ self,
79
+ llm: LLMProvider,
80
+ memory: SharedMemory,
81
+ name: Optional[str] = None,
82
+ ) -> None:
83
+ super().__init__(llm, memory, name or "VizMapping")
84
+
85
+ def _get_system_prompt(self) -> str:
86
+ return """You are a Visualization Mapping specialist in a data visualization team.
87
+
88
+ Your expertise is in translating data analysis requirements into concrete visualization specifications that can be implemented in code.
89
+
90
+ Your responsibilities:
91
+ 1. Select the optimal chart type based on data and query requirements
92
+ 2. Define data-to-visual mappings (axes, colors, sizes)
93
+ 3. Specify required data transformations
94
+ 4. Provide styling hints for aesthetics
95
+ 5. Document the rationale for visualization choices
96
+
97
+ Consider:
98
+ - Data types when choosing encodings (categorical vs numerical)
99
+ - Query intent when selecting chart types
100
+ - Readability and best practices in visualization design
101
+
102
+ Always respond with a valid JSON object matching the required schema."""
103
+
104
+ def _build_prompt(self, context: AgentContext) -> str:
105
+ query_analysis = self._get_from_memory("query_analysis") or {}
106
+ data_analysis = self._get_from_memory("data_analysis") or {}
107
+
108
+ query_section = ""
109
+ if query_analysis:
110
+ query_section = f"""
111
+ Query Analysis:
112
+ - Suggested Types: {query_analysis.get('visualization_types', [])}
113
+ - Key Points: {query_analysis.get('key_points', [])}
114
+ - Data Requirements: {query_analysis.get('data_requirements', [])}
115
+ """
116
+
117
+ data_section = ""
118
+ if data_analysis:
119
+ files = data_analysis.get('files', [])
120
+ if files:
121
+ columns_info = []
122
+ for f in files:
123
+ for col in f.get('columns', []):
124
+ columns_info.append(f" - {col['name']} ({col['dtype']})")
125
+ data_section = f"""
126
+ Available Data:
127
+ - Columns:
128
+ {chr(10).join(columns_info)}
129
+ - Insights: {data_analysis.get('insights', [])}
130
+ - Suggested Aggregations: {data_analysis.get('aggregations_needed', [])}
131
+ """
132
+
133
+ feedback_section = ""
134
+ if context.feedback:
135
+ feedback_section = f"""
136
+ Refinement Feedback (iteration {context.iteration}):
137
+ {context.feedback}
138
+ """
139
+
140
+ return f"""Create a visualization mapping for the following query.
141
+
142
+ User Query: {context.query}
143
+ {query_section}{data_section}{feedback_section}
144
+
145
+ Provide a JSON object containing:
146
+ - chart_type: Primary chart type (line, bar, scatter, pie, heatmap, etc.)
147
+ - chart_subtype: Optional subtype (stacked, grouped, etc.)
148
+ - x_axis: {{"column": "...", "label": "...", "type": "categorical|numerical|temporal"}}
149
+ - y_axis: {{"column": "...", "label": "...", "type": "numerical"}}
150
+ - color_encoding: Optional color mapping {{"column": "...", "palette": "..."}}
151
+ - size_encoding: Optional size mapping for scatter {{"column": "..."}}
152
+ - transformations: List of {{"operation": "...", "columns": [...], "parameters": {{}}}}
153
+ - styling_hints: {{"theme": "...", "annotations": [...], "legend_position": "..."}}
154
+ - visualization_goals: List of high-level goals
155
+ - rationale: Brief explanation of choices
156
+
157
+ JSON Response:"""
158
+
159
+ def _parse_response(self, response: str) -> VisualMapping:
160
+ data = self._extract_json(response)
161
+ return VisualMapping(**data)
162
+
163
+ def _get_output_key(self) -> str:
164
+ return self.MEMORY_KEY
coda/config.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration management for CoDA.
3
+
4
+ Centralizes all configuration values including API keys, model settings,
5
+ quality thresholds, and execution parameters.
6
+ """
7
+
8
+ import os
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional
11
+
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv()
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ModelConfig:
19
+ """Configuration for LLM models."""
20
+
21
+ default_model: str = "llama-3.3-70b-versatile"
22
+ vision_model: str = "meta-llama/llama-4-maverick-17b-128e-instruct"
23
+ temperature: float = 0.7
24
+ max_tokens: int = 4096
25
+ max_retries: int = 3
26
+ retry_delay: float = 1.0
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class QualityThresholds:
31
+ """Quality score thresholds for the feedback loop."""
32
+
33
+ minimum_overall_score: float = 7.0
34
+ minimum_readability_score: float = 6.0
35
+ minimum_accuracy_score: float = 7.0
36
+ minimum_aesthetics_score: float = 6.0
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class ExecutionConfig:
41
+ """Configuration for code execution."""
42
+
43
+ code_timeout_seconds: int = 60
44
+ max_refinement_iterations: int = 3
45
+ output_directory: str = "outputs"
46
+
47
+
48
+ @dataclass
49
+ class Config:
50
+ """Main configuration container for CoDA."""
51
+
52
+ groq_api_key: str = field(default_factory=lambda: os.getenv("GROQ_API_KEY", ""))
53
+ model: ModelConfig = field(default_factory=ModelConfig)
54
+ quality: QualityThresholds = field(default_factory=QualityThresholds)
55
+ execution: ExecutionConfig = field(default_factory=ExecutionConfig)
56
+
57
+ def __post_init__(self) -> None:
58
+ if not self.groq_api_key:
59
+ raise ValueError(
60
+ "GROQ_API_KEY environment variable is required. "
61
+ "Get your API key at https://console.groq.com"
62
+ )
63
+
64
+ @classmethod
65
+ def from_env(cls) -> "Config":
66
+ """Create configuration from environment variables."""
67
+ return cls(
68
+ groq_api_key=os.getenv("GROQ_API_KEY", ""),
69
+ model=ModelConfig(
70
+ default_model=os.getenv("CODA_DEFAULT_MODEL", "llama-3.3-70b-versatile"),
71
+ vision_model=os.getenv("CODA_VISION_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct"),
72
+ temperature=float(os.getenv("CODA_TEMPERATURE", "0.7")),
73
+ max_tokens=int(os.getenv("CODA_MAX_TOKENS", "4096")),
74
+ ),
75
+ quality=QualityThresholds(
76
+ minimum_overall_score=float(os.getenv("CODA_MIN_OVERALL_SCORE", "7.0")),
77
+ minimum_readability_score=float(os.getenv("CODA_MIN_READABILITY_SCORE", "6.0")),
78
+ minimum_accuracy_score=float(os.getenv("CODA_MIN_ACCURACY_SCORE", "7.0")),
79
+ minimum_aesthetics_score=float(os.getenv("CODA_MIN_AESTHETICS_SCORE", "6.0")),
80
+ ),
81
+ execution=ExecutionConfig(
82
+ code_timeout_seconds=int(os.getenv("CODA_CODE_TIMEOUT", "60")),
83
+ max_refinement_iterations=int(os.getenv("CODA_MAX_ITERATIONS", "3")),
84
+ output_directory=os.getenv("CODA_OUTPUT_DIR", "outputs"),
85
+ ),
86
+ )
87
+
88
+
89
+ def get_config() -> Config:
90
+ """Get the application configuration."""
91
+ return Config.from_env()
coda/core/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core module for CoDA - contains LLM, memory, and base agent abstractions."""
2
+
3
+ from coda.core.llm import LLMProvider, GroqLLM, LLMResponse
4
+ from coda.core.memory import SharedMemory, MemoryEntry
5
+ from coda.core.base_agent import BaseAgent, AgentContext
6
+ from coda.core.agent_factory import AgentFactory
7
+
8
+ __all__ = [
9
+ "LLMProvider",
10
+ "GroqLLM",
11
+ "LLMResponse",
12
+ "SharedMemory",
13
+ "MemoryEntry",
14
+ "BaseAgent",
15
+ "AgentContext",
16
+ "AgentFactory",
17
+ ]
coda/core/agent_factory.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Factory for CoDA.
3
+
4
+ Provides factory methods for creating and configuring agents,
5
+ enabling flexible agent composition and testing.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ from coda.core.llm import LLMProvider
11
+ from coda.core.memory import SharedMemory
12
+ from coda.core.base_agent import BaseAgent
13
+
14
+ from coda.agents.query_analyzer import QueryAnalyzerAgent
15
+ from coda.agents.data_processor import DataProcessorAgent
16
+ from coda.agents.viz_mapping import VizMappingAgent
17
+ from coda.agents.search_agent import SearchAgent
18
+ from coda.agents.design_explorer import DesignExplorerAgent
19
+ from coda.agents.code_generator import CodeGeneratorAgent
20
+ from coda.agents.debug_agent import DebugAgent
21
+ from coda.agents.visual_evaluator import VisualEvaluatorAgent
22
+
23
+
24
+ class AgentFactory:
25
+ """
26
+ Factory for creating CoDA agents with shared dependencies.
27
+
28
+ Centralizes agent creation and configuration, making it easy to
29
+ swap implementations or configure agents for testing.
30
+ """
31
+
32
+ AGENT_TYPES = {
33
+ "query_analyzer": QueryAnalyzerAgent,
34
+ "data_processor": DataProcessorAgent,
35
+ "viz_mapping": VizMappingAgent,
36
+ "search_agent": SearchAgent,
37
+ "design_explorer": DesignExplorerAgent,
38
+ "code_generator": CodeGeneratorAgent,
39
+ "debug_agent": DebugAgent,
40
+ "visual_evaluator": VisualEvaluatorAgent,
41
+ }
42
+
43
+ def __init__(
44
+ self,
45
+ llm: LLMProvider,
46
+ memory: Optional[SharedMemory] = None,
47
+ ) -> None:
48
+ self._llm = llm
49
+ self._memory = memory or SharedMemory()
50
+
51
+ @property
52
+ def memory(self) -> SharedMemory:
53
+ """Get the shared memory instance."""
54
+ return self._memory
55
+
56
+ def create(
57
+ self,
58
+ agent_type: str,
59
+ **kwargs,
60
+ ) -> BaseAgent:
61
+ """
62
+ Create an agent by type name.
63
+
64
+ Args:
65
+ agent_type: Name of the agent type to create
66
+ **kwargs: Additional arguments passed to the agent constructor
67
+
68
+ Returns:
69
+ Configured agent instance
70
+
71
+ Raises:
72
+ ValueError: If agent_type is not recognized
73
+ """
74
+ if agent_type not in self.AGENT_TYPES:
75
+ raise ValueError(
76
+ f"Unknown agent type: {agent_type}. "
77
+ f"Available types: {list(self.AGENT_TYPES.keys())}"
78
+ )
79
+
80
+ agent_class = self.AGENT_TYPES[agent_type]
81
+ return agent_class(
82
+ llm=self._llm,
83
+ memory=self._memory,
84
+ **kwargs,
85
+ )
86
+
87
+ def create_all(self, **agent_kwargs) -> dict[str, BaseAgent]:
88
+ """
89
+ Create all available agent types.
90
+
91
+ Args:
92
+ **agent_kwargs: Arguments to pass to each agent constructor
93
+
94
+ Returns:
95
+ Dictionary mapping agent type names to instances
96
+ """
97
+ return {
98
+ agent_type: self.create(agent_type, **agent_kwargs.get(agent_type, {}))
99
+ for agent_type in self.AGENT_TYPES
100
+ }
101
+
102
+ def create_pipeline_agents(
103
+ self,
104
+ code_timeout: int = 60,
105
+ output_directory: str = "outputs",
106
+ min_quality_score: float = 7.0,
107
+ ) -> dict[str, BaseAgent]:
108
+ """
109
+ Create agents configured for the standard visualization pipeline.
110
+
111
+ Args:
112
+ code_timeout: Timeout for code execution in seconds
113
+ output_directory: Directory for output files
114
+ min_quality_score: Minimum quality score threshold
115
+
116
+ Returns:
117
+ Dictionary of configured agents for the pipeline
118
+ """
119
+ return {
120
+ "query_analyzer": self.create("query_analyzer"),
121
+ "data_processor": self.create("data_processor"),
122
+ "viz_mapping": self.create("viz_mapping"),
123
+ "search_agent": self.create("search_agent"),
124
+ "design_explorer": self.create("design_explorer"),
125
+ "code_generator": self.create("code_generator"),
126
+ "debug_agent": self.create(
127
+ "debug_agent",
128
+ timeout_seconds=code_timeout,
129
+ output_directory=output_directory,
130
+ ),
131
+ "visual_evaluator": self.create(
132
+ "visual_evaluator",
133
+ min_overall_score=min_quality_score,
134
+ ),
135
+ }
coda/core/base_agent.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base agent interface for CoDA.
3
+
4
+ Defines the contract that all specialized agents must implement,
5
+ providing common functionality for LLM interaction and memory access.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import re
11
+ from abc import ABC, abstractmethod
12
+ from typing import Any, Optional, TypeVar, Generic
13
+
14
+ from pydantic import BaseModel
15
+
16
+ from coda.core.llm import LLMProvider
17
+ from coda.core.memory import SharedMemory
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ T = TypeVar("T", bound=BaseModel)
22
+
23
+
24
+ class AgentContext(BaseModel):
25
+ """Context passed to an agent during execution."""
26
+
27
+ query: str
28
+ data_paths: list[str] = []
29
+ iteration: int = 0
30
+ feedback: Optional[str] = None
31
+
32
+
33
+ class BaseAgent(ABC, Generic[T]):
34
+ """
35
+ Abstract base class for all CoDA agents.
36
+
37
+ Each agent specializes in a specific aspect of the visualization pipeline.
38
+ Agents communicate through shared memory and use an LLM for reasoning.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ llm: LLMProvider,
44
+ memory: SharedMemory,
45
+ name: Optional[str] = None,
46
+ ) -> None:
47
+ self._llm = llm
48
+ self._memory = memory
49
+ self._name = name or self.__class__.__name__
50
+
51
+ @property
52
+ def name(self) -> str:
53
+ """Get the agent's name."""
54
+ return self._name
55
+
56
+ def execute(self, context: AgentContext) -> T:
57
+ """
58
+ Execute the agent's task.
59
+
60
+ Args:
61
+ context: The execution context containing query and data info
62
+
63
+ Returns:
64
+ The agent's structured output
65
+ """
66
+ logger.info(f"[{self._name}] Starting execution")
67
+
68
+ prompt = self._build_prompt(context)
69
+ system_prompt = self._get_system_prompt()
70
+
71
+ response = self._llm.complete(
72
+ prompt=prompt,
73
+ system_prompt=system_prompt,
74
+ )
75
+
76
+ result = self._parse_response(response.content)
77
+ self._store_result(result)
78
+
79
+ logger.info(f"[{self._name}] Execution complete")
80
+ return result
81
+
82
+ @abstractmethod
83
+ def _build_prompt(self, context: AgentContext) -> str:
84
+ """
85
+ Build the prompt for the LLM.
86
+
87
+ Args:
88
+ context: The execution context
89
+
90
+ Returns:
91
+ The formatted prompt string
92
+ """
93
+ pass
94
+
95
+ @abstractmethod
96
+ def _get_system_prompt(self) -> str:
97
+ """
98
+ Get the system prompt defining the agent's persona.
99
+
100
+ Returns:
101
+ The system prompt string
102
+ """
103
+ pass
104
+
105
+ @abstractmethod
106
+ def _parse_response(self, response: str) -> T:
107
+ """
108
+ Parse the LLM response into a structured output.
109
+
110
+ Args:
111
+ response: The raw LLM response
112
+
113
+ Returns:
114
+ The parsed and validated output
115
+ """
116
+ pass
117
+
118
+ @abstractmethod
119
+ def _get_output_key(self) -> str:
120
+ """
121
+ Get the key used to store this agent's output in memory.
122
+
123
+ Returns:
124
+ The memory key string
125
+ """
126
+ pass
127
+
128
+ def _store_result(self, result: T) -> None:
129
+ """Store the agent's result in shared memory."""
130
+ self._memory.store(
131
+ key=self._get_output_key(),
132
+ value=result.model_dump(),
133
+ agent_name=self._name,
134
+ )
135
+
136
+ def _get_from_memory(self, key: str) -> Optional[Any]:
137
+ """Retrieve a value from shared memory."""
138
+ return self._memory.retrieve(key)
139
+
140
+ def _extract_json(self, text: str) -> dict[str, Any]:
141
+ """
142
+ Extract JSON from LLM response text.
143
+
144
+ Handles responses where JSON is wrapped in markdown code blocks
145
+ and sanitizes control characters that can break JSON parsing.
146
+ """
147
+ json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
148
+ if json_match:
149
+ text = json_match.group(1)
150
+
151
+ text = text.strip()
152
+
153
+ try:
154
+ return json.loads(text)
155
+ except json.JSONDecodeError:
156
+ pass
157
+
158
+ # Try to fix unescaped newlines/tabs inside JSON strings
159
+ # First, find the JSON object boundaries
160
+ try:
161
+ obj_match = re.search(r'(\{[\s\S]*\})', text, re.DOTALL)
162
+ if obj_match:
163
+ json_text = obj_match.group(1)
164
+
165
+ # Replace problematic control characters (but NOT newlines between key:value pairs)
166
+ # Only remove NUL and other truly invalid chars
167
+ json_text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', json_text)
168
+
169
+ try:
170
+ return json.loads(json_text)
171
+ except json.JSONDecodeError:
172
+ pass
173
+
174
+ # If still failing, try to properly escape newlines within strings
175
+ # by parsing character by character
176
+ fixed = self._fix_json_strings(json_text)
177
+ return json.loads(fixed)
178
+ except Exception:
179
+ pass
180
+
181
+ logger.error(f"Failed to parse JSON after sanitization attempts")
182
+ logger.debug(f"Raw text: {text[:500]}...")
183
+ raise ValueError(f"Invalid JSON in response: Could not parse after sanitization")
184
+
185
+ def _fix_json_strings(self, text: str) -> str:
186
+ """Fix unescaped newlines and control characters inside JSON strings."""
187
+ result = []
188
+ in_string = False
189
+ escape_next = False
190
+
191
+ for char in text:
192
+ if escape_next:
193
+ result.append(char)
194
+ escape_next = False
195
+ continue
196
+
197
+ if char == '\\':
198
+ result.append(char)
199
+ escape_next = True
200
+ continue
201
+
202
+ if char == '"':
203
+ in_string = not in_string
204
+ result.append(char)
205
+ continue
206
+
207
+ if in_string:
208
+ # Escape problematic characters inside strings
209
+ if char == '\n':
210
+ result.append('\\n')
211
+ elif char == '\r':
212
+ result.append('\\r')
213
+ elif char == '\t':
214
+ result.append('\\t')
215
+ else:
216
+ result.append(char)
217
+ else:
218
+ result.append(char)
219
+
220
+ return ''.join(result)
coda/core/llm.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM abstraction layer for CoDA.
3
+
4
+ Provides a clean interface for interacting with language models,
5
+ with ChatGroq as the default implementation.
6
+ """
7
+
8
+ import base64
9
+ import logging
10
+ import time
11
+ from abc import ABC, abstractmethod
12
+ from pathlib import Path
13
+ from typing import Any, Optional, Union
14
+
15
+ from groq import Groq
16
+ from pydantic import BaseModel
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class LLMResponse(BaseModel):
22
+ """Structured response from an LLM call."""
23
+
24
+ content: str
25
+ model: str
26
+ usage: dict[str, int]
27
+ finish_reason: str
28
+
29
+
30
+ class LLMProvider(ABC):
31
+ """Abstract interface for language model providers."""
32
+
33
+ @abstractmethod
34
+ def complete(
35
+ self,
36
+ prompt: str,
37
+ system_prompt: Optional[str] = None,
38
+ temperature: Optional[float] = None,
39
+ max_tokens: Optional[int] = None,
40
+ ) -> LLMResponse:
41
+ """Generate a text completion."""
42
+ pass
43
+
44
+ @abstractmethod
45
+ def complete_with_image(
46
+ self,
47
+ prompt: str,
48
+ image_path: Union[str, Path],
49
+ system_prompt: Optional[str] = None,
50
+ temperature: Optional[float] = None,
51
+ max_tokens: Optional[int] = None,
52
+ ) -> LLMResponse:
53
+ """Generate a completion with image input (multimodal)."""
54
+ pass
55
+
56
+
57
+ class GroqLLM(LLMProvider):
58
+ """ChatGroq implementation of the LLM provider."""
59
+
60
+ def __init__(
61
+ self,
62
+ api_key: str,
63
+ default_model: str = "llama-3.3-70b-versatile",
64
+ vision_model: str = "meta-llama/llama-4-maverick-17b-128e-instruct",
65
+ temperature: float = 0.7,
66
+ max_tokens: int = 4096,
67
+ max_retries: int = 3,
68
+ retry_delay: float = 1.0,
69
+ ) -> None:
70
+ self._client = Groq(api_key=api_key)
71
+ self._default_model = default_model
72
+ self._vision_model = vision_model
73
+ self._temperature = temperature
74
+ self._max_tokens = max_tokens
75
+ self._max_retries = max_retries
76
+ self._retry_delay = retry_delay
77
+
78
+ def complete(
79
+ self,
80
+ prompt: str,
81
+ system_prompt: Optional[str] = None,
82
+ temperature: Optional[float] = None,
83
+ max_tokens: Optional[int] = None,
84
+ ) -> LLMResponse:
85
+ """Generate a text completion using ChatGroq."""
86
+ messages = self._build_messages(prompt, system_prompt)
87
+ return self._call_with_retry(
88
+ messages=messages,
89
+ model=self._default_model,
90
+ temperature=temperature or self._temperature,
91
+ max_tokens=max_tokens or self._max_tokens,
92
+ )
93
+
94
+ def complete_with_image(
95
+ self,
96
+ prompt: str,
97
+ image_path: Union[str, Path],
98
+ system_prompt: Optional[str] = None,
99
+ temperature: Optional[float] = None,
100
+ max_tokens: Optional[int] = None,
101
+ ) -> LLMResponse:
102
+ """Generate a completion with image input using the vision model."""
103
+ image_data = self._encode_image(image_path)
104
+
105
+ user_content = [
106
+ {
107
+ "type": "image_url",
108
+ "image_url": {
109
+ "url": f"data:image/png;base64,{image_data}"
110
+ }
111
+ },
112
+ {
113
+ "type": "text",
114
+ "text": prompt
115
+ }
116
+ ]
117
+
118
+ messages: list[dict[str, Any]] = []
119
+ if system_prompt:
120
+ messages.append({"role": "system", "content": system_prompt})
121
+ messages.append({"role": "user", "content": user_content})
122
+
123
+ return self._call_with_retry(
124
+ messages=messages,
125
+ model=self._vision_model,
126
+ temperature=temperature or self._temperature,
127
+ max_tokens=max_tokens or self._max_tokens,
128
+ )
129
+
130
+ def _build_messages(
131
+ self,
132
+ prompt: str,
133
+ system_prompt: Optional[str] = None,
134
+ ) -> list[dict[str, str]]:
135
+ """Build the message list for the API call."""
136
+ messages: list[dict[str, str]] = []
137
+ if system_prompt:
138
+ messages.append({"role": "system", "content": system_prompt})
139
+ messages.append({"role": "user", "content": prompt})
140
+ return messages
141
+
142
+ def _call_with_retry(
143
+ self,
144
+ messages: list[dict[str, Any]],
145
+ model: str,
146
+ temperature: float,
147
+ max_tokens: int,
148
+ ) -> LLMResponse:
149
+ """Execute API call with exponential backoff retry."""
150
+ last_exception: Optional[Exception] = None
151
+
152
+ for attempt in range(self._max_retries):
153
+ try:
154
+ response = self._client.chat.completions.create(
155
+ model=model,
156
+ messages=messages,
157
+ temperature=temperature,
158
+ max_tokens=max_tokens,
159
+ )
160
+
161
+ return LLMResponse(
162
+ content=response.choices[0].message.content or "",
163
+ model=response.model,
164
+ usage={
165
+ "prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
166
+ "completion_tokens": response.usage.completion_tokens if response.usage else 0,
167
+ "total_tokens": response.usage.total_tokens if response.usage else 0,
168
+ },
169
+ finish_reason=response.choices[0].finish_reason or "unknown",
170
+ )
171
+
172
+ except Exception as e:
173
+ last_exception = e
174
+ logger.warning(
175
+ f"API call failed (attempt {attempt + 1}/{self._max_retries}): {e}"
176
+ )
177
+ if attempt < self._max_retries - 1:
178
+ sleep_time = self._retry_delay * (2 ** attempt)
179
+ time.sleep(sleep_time)
180
+
181
+ raise RuntimeError(
182
+ f"API call failed after {self._max_retries} attempts"
183
+ ) from last_exception
184
+
185
+ def _encode_image(self, image_path: Union[str, Path]) -> str:
186
+ """Encode an image file to base64."""
187
+ path = Path(image_path)
188
+ if not path.exists():
189
+ raise FileNotFoundError(f"Image not found: {path}")
190
+
191
+ with open(path, "rb") as f:
192
+ return base64.b64encode(f.read()).decode("utf-8")
coda/core/memory.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared memory buffer for inter-agent communication in CoDA.
3
+
4
+ Provides thread-safe storage for agents to exchange context,
5
+ results, and feedback during the visualization pipeline.
6
+ """
7
+
8
+ import threading
9
+ from datetime import datetime
10
+ from typing import Any, Optional
11
+
12
+ from pydantic import BaseModel, Field
13
+
14
+
15
+ class MemoryEntry(BaseModel):
16
+ """A single entry in the shared memory."""
17
+
18
+ key: str
19
+ value: Any
20
+ agent_name: str
21
+ timestamp: datetime = Field(default_factory=datetime.now)
22
+ metadata: dict[str, Any] = Field(default_factory=dict)
23
+
24
+
25
+ class SharedMemory:
26
+ """
27
+ Thread-safe shared memory buffer for agent communication.
28
+
29
+ Agents can store and retrieve structured data using string keys.
30
+ Each entry tracks the source agent and timestamp for debugging.
31
+ """
32
+
33
+ def __init__(self) -> None:
34
+ self._storage: dict[str, MemoryEntry] = {}
35
+ self._lock = threading.RLock()
36
+ self._history: list[MemoryEntry] = []
37
+
38
+ def store(
39
+ self,
40
+ key: str,
41
+ value: Any,
42
+ agent_name: str,
43
+ metadata: Optional[dict[str, Any]] = None,
44
+ ) -> None:
45
+ """
46
+ Store a value in shared memory.
47
+
48
+ Args:
49
+ key: Unique identifier for the data
50
+ value: The data to store (should be JSON-serializable)
51
+ agent_name: Name of the agent storing the data
52
+ metadata: Optional additional context
53
+ """
54
+ entry = MemoryEntry(
55
+ key=key,
56
+ value=value,
57
+ agent_name=agent_name,
58
+ metadata=metadata or {},
59
+ )
60
+
61
+ with self._lock:
62
+ self._storage[key] = entry
63
+ self._history.append(entry)
64
+
65
+ def retrieve(self, key: str) -> Optional[Any]:
66
+ """
67
+ Retrieve a value from shared memory.
68
+
69
+ Args:
70
+ key: The key to look up
71
+
72
+ Returns:
73
+ The stored value, or None if not found
74
+ """
75
+ with self._lock:
76
+ entry = self._storage.get(key)
77
+ return entry.value if entry else None
78
+
79
+ def retrieve_entry(self, key: str) -> Optional[MemoryEntry]:
80
+ """
81
+ Retrieve the full memory entry including metadata.
82
+
83
+ Args:
84
+ key: The key to look up
85
+
86
+ Returns:
87
+ The full MemoryEntry, or None if not found
88
+ """
89
+ with self._lock:
90
+ return self._storage.get(key)
91
+
92
+ def get_context(self, keys: list[str]) -> dict[str, Any]:
93
+ """
94
+ Retrieve multiple values as a context dictionary.
95
+
96
+ Args:
97
+ keys: List of keys to retrieve
98
+
99
+ Returns:
100
+ Dictionary mapping keys to their values (missing keys excluded)
101
+ """
102
+ with self._lock:
103
+ return {
104
+ key: self._storage[key].value
105
+ for key in keys
106
+ if key in self._storage
107
+ }
108
+
109
+ def get_all(self) -> dict[str, Any]:
110
+ """
111
+ Retrieve all stored values.
112
+
113
+ Returns:
114
+ Dictionary mapping all keys to their values
115
+ """
116
+ with self._lock:
117
+ return {key: entry.value for key, entry in self._storage.items()}
118
+
119
+ def get_history(self, agent_name: Optional[str] = None) -> list[MemoryEntry]:
120
+ """
121
+ Get the history of all memory operations.
122
+
123
+ Args:
124
+ agent_name: Optional filter by agent name
125
+
126
+ Returns:
127
+ List of memory entries in chronological order
128
+ """
129
+ with self._lock:
130
+ if agent_name:
131
+ return [e for e in self._history if e.agent_name == agent_name]
132
+ return list(self._history)
133
+
134
+ def has_key(self, key: str) -> bool:
135
+ """Check if a key exists in memory."""
136
+ with self._lock:
137
+ return key in self._storage
138
+
139
+ def clear(self) -> None:
140
+ """Clear all stored data and history."""
141
+ with self._lock:
142
+ self._storage.clear()
143
+ self._history.clear()
144
+
145
+ def keys(self) -> list[str]:
146
+ """Get all stored keys."""
147
+ with self._lock:
148
+ return list(self._storage.keys())
coda/orchestrator.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Workflow Orchestrator for CoDA.
3
+
4
+ Manages the multi-agent pipeline, coordinating agent execution,
5
+ handling feedback loops, and implementing quality-driven halting.
6
+ """
7
+
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Optional, Callable
12
+
13
+ from coda.config import Config, get_config
14
+ from coda.core.llm import GroqLLM, LLMProvider
15
+ from coda.core.memory import SharedMemory
16
+ from coda.core.base_agent import AgentContext
17
+
18
+ from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
19
+ from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
20
+ from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
21
+ from coda.agents.search_agent import SearchAgent, SearchResult
22
+ from coda.agents.design_explorer import DesignExplorerAgent, DesignSpec
23
+ from coda.agents.code_generator import CodeGeneratorAgent, GeneratedCode
24
+ from coda.agents.debug_agent import DebugAgent, ExecutionResult
25
+ from coda.agents.visual_evaluator import VisualEvaluatorAgent, VisualEvaluation
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class PipelineResult:
32
+ """Final result from the CoDA pipeline."""
33
+
34
+ success: bool
35
+ output_file: Optional[str]
36
+ evaluation: Optional[VisualEvaluation]
37
+ iterations: int
38
+ error: Optional[str] = None
39
+
40
+ @property
41
+ def scores(self) -> Optional[dict]:
42
+ """Get quality scores if evaluation exists."""
43
+ if self.evaluation:
44
+ return self.evaluation.scores.model_dump()
45
+ return None
46
+
47
+
48
+ class CodaOrchestrator:
49
+ """
50
+ Orchestrates the CoDA multi-agent visualization pipeline.
51
+
52
+ Coordinates agent execution in sequence, manages the shared memory,
53
+ and implements iterative refinement through feedback loops.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ config: Optional[Config] = None,
59
+ llm: Optional[LLMProvider] = None,
60
+ progress_callback: Optional[Callable[[str, float], None]] = None,
61
+ ) -> None:
62
+ self._config = config or get_config()
63
+ self._llm = llm or self._create_llm()
64
+ self._memory = SharedMemory()
65
+ self._progress_callback = progress_callback
66
+
67
+ self._agents = self._create_agents()
68
+
69
+ def _create_llm(self) -> GroqLLM:
70
+ """Create the LLM instance."""
71
+ return GroqLLM(
72
+ api_key=self._config.groq_api_key,
73
+ default_model=self._config.model.default_model,
74
+ vision_model=self._config.model.vision_model,
75
+ temperature=self._config.model.temperature,
76
+ max_tokens=self._config.model.max_tokens,
77
+ max_retries=self._config.model.max_retries,
78
+ )
79
+
80
+ def _create_agents(self) -> dict:
81
+ """Initialize all agents with shared resources."""
82
+ return {
83
+ "query_analyzer": QueryAnalyzerAgent(self._llm, self._memory),
84
+ "data_processor": DataProcessorAgent(self._llm, self._memory),
85
+ "viz_mapping": VizMappingAgent(self._llm, self._memory),
86
+ "search_agent": SearchAgent(self._llm, self._memory),
87
+ "design_explorer": DesignExplorerAgent(self._llm, self._memory),
88
+ "code_generator": CodeGeneratorAgent(self._llm, self._memory),
89
+ "debug_agent": DebugAgent(
90
+ self._llm,
91
+ self._memory,
92
+ timeout_seconds=self._config.execution.code_timeout_seconds,
93
+ output_directory=self._config.execution.output_directory,
94
+ ),
95
+ "visual_evaluator": VisualEvaluatorAgent(
96
+ self._llm,
97
+ self._memory,
98
+ min_overall_score=self._config.quality.minimum_overall_score,
99
+ ),
100
+ }
101
+
102
+ def run(
103
+ self,
104
+ query: str,
105
+ data_paths: list[str],
106
+ ) -> PipelineResult:
107
+ """
108
+ Execute the full visualization pipeline.
109
+
110
+ Args:
111
+ query: Natural language visualization request
112
+ data_paths: Paths to data files
113
+
114
+ Returns:
115
+ PipelineResult with output file and evaluation
116
+ """
117
+ logger.info(f"Starting CoDA pipeline for query: {query[:50]}...")
118
+
119
+ self._memory.clear()
120
+
121
+ validated_paths = self._validate_data_paths(data_paths)
122
+ if not validated_paths:
123
+ return PipelineResult(
124
+ success=False,
125
+ output_file=None,
126
+ evaluation=None,
127
+ iterations=0,
128
+ error="No valid data files provided",
129
+ )
130
+
131
+ context = AgentContext(
132
+ query=query,
133
+ data_paths=validated_paths,
134
+ iteration=0,
135
+ )
136
+
137
+ try:
138
+ self._run_initial_pipeline(context)
139
+ except Exception as e:
140
+ logger.error(f"Initial pipeline failed: {e}")
141
+ return PipelineResult(
142
+ success=False,
143
+ output_file=None,
144
+ evaluation=None,
145
+ iterations=0,
146
+ error=str(e),
147
+ )
148
+
149
+ max_iterations = self._config.execution.max_refinement_iterations
150
+ final_result = self._run_refinement_loop(context, max_iterations)
151
+
152
+ return final_result
153
+
154
+ def _validate_data_paths(self, data_paths: list[str]) -> list[str]:
155
+ """Validate that data files exist."""
156
+ valid_paths = []
157
+ for path in data_paths:
158
+ if Path(path).exists():
159
+ valid_paths.append(path)
160
+ else:
161
+ logger.warning(f"Data file not found: {path}")
162
+ return valid_paths
163
+
164
+ def _run_initial_pipeline(self, context: AgentContext) -> None:
165
+ """Run the initial agent pipeline."""
166
+ steps = [
167
+ ("query_analyzer", "Analyzing query...", 0.1),
168
+ ("data_processor", "Processing data...", 0.2),
169
+ ("viz_mapping", "Mapping visualization...", 0.3),
170
+ ("search_agent", "Searching examples...", 0.4),
171
+ ("design_explorer", "Designing visualization...", 0.5),
172
+ ("code_generator", "Generating code...", 0.7),
173
+ ("debug_agent", "Executing code...", 0.85),
174
+ ("visual_evaluator", "Evaluating output...", 0.95),
175
+ ]
176
+
177
+ for agent_name, status, progress in steps:
178
+ self._report_progress(status, progress)
179
+ agent = self._agents[agent_name]
180
+ agent.execute(context)
181
+
182
+ def _run_refinement_loop(
183
+ self,
184
+ context: AgentContext,
185
+ max_iterations: int,
186
+ ) -> PipelineResult:
187
+ """Run the iterative refinement loop."""
188
+ for iteration in range(max_iterations):
189
+ evaluation = self._memory.retrieve("visual_evaluation")
190
+
191
+ if not evaluation:
192
+ break
193
+
194
+ if isinstance(evaluation, dict):
195
+ passes = evaluation.get("passes_threshold", False)
196
+ eval_obj = VisualEvaluation(**evaluation)
197
+ else:
198
+ passes = evaluation.passes_threshold
199
+ eval_obj = evaluation
200
+
201
+ if passes:
202
+ logger.info(f"Quality threshold met at iteration {iteration}")
203
+ return self._create_success_result(eval_obj, iteration + 1)
204
+
205
+ if iteration >= max_iterations - 1:
206
+ logger.info("Max iterations reached")
207
+ break
208
+
209
+ logger.info(f"Refinement iteration {iteration + 1}")
210
+ context = self._create_refinement_context(context, eval_obj, iteration + 1)
211
+
212
+ self._report_progress(f"Refining (iteration {iteration + 2})...", 0.5)
213
+
214
+ try:
215
+ self._run_refinement_agents(context)
216
+ except Exception as e:
217
+ logger.error(f"Refinement failed: {e}")
218
+ break
219
+
220
+ final_eval = self._memory.retrieve("visual_evaluation")
221
+ if isinstance(final_eval, dict):
222
+ final_eval = VisualEvaluation(**final_eval)
223
+
224
+ return self._create_success_result(final_eval, max_iterations)
225
+
226
+ def _run_refinement_agents(self, context: AgentContext) -> None:
227
+ """Run agents that participate in refinement."""
228
+ refinement_agents = [
229
+ "design_explorer",
230
+ "code_generator",
231
+ "debug_agent",
232
+ "visual_evaluator",
233
+ ]
234
+
235
+ for agent_name in refinement_agents:
236
+ agent = self._agents[agent_name]
237
+ agent.execute(context)
238
+
239
+ def _create_refinement_context(
240
+ self,
241
+ original_context: AgentContext,
242
+ evaluation: VisualEvaluation,
243
+ iteration: int,
244
+ ) -> AgentContext:
245
+ """Create context for refinement iteration."""
246
+ feedback_parts = []
247
+
248
+ if evaluation.issues:
249
+ feedback_parts.append(f"Issues: {', '.join(evaluation.issues[:3])}")
250
+
251
+ if evaluation.priority_fixes:
252
+ feedback_parts.append(f"Fix: {', '.join(evaluation.priority_fixes[:2])}")
253
+
254
+ feedback = " | ".join(feedback_parts)
255
+
256
+ return AgentContext(
257
+ query=original_context.query,
258
+ data_paths=original_context.data_paths,
259
+ iteration=iteration,
260
+ feedback=feedback,
261
+ )
262
+
263
+ def _create_success_result(
264
+ self,
265
+ evaluation: Optional[VisualEvaluation],
266
+ iterations: int,
267
+ ) -> PipelineResult:
268
+ """Create a successful pipeline result."""
269
+ execution_result = self._memory.retrieve("execution_result")
270
+ output_file = None
271
+
272
+ if execution_result:
273
+ if isinstance(execution_result, dict):
274
+ output_file = execution_result.get("output_file")
275
+ else:
276
+ output_file = execution_result.output_file
277
+
278
+ return PipelineResult(
279
+ success=output_file is not None and Path(output_file).exists(),
280
+ output_file=output_file,
281
+ evaluation=evaluation,
282
+ iterations=iterations,
283
+ )
284
+
285
+ def _report_progress(self, status: str, progress: float) -> None:
286
+ """Report progress to callback if set."""
287
+ if self._progress_callback:
288
+ self._progress_callback(status, progress)
289
+ logger.info(f"[{progress:.0%}] {status}")
290
+
291
+ def get_memory_state(self) -> dict:
292
+ """Get the current state of shared memory for debugging."""
293
+ return self._memory.get_all()
main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Command-Line Interface for CoDA.
3
+
4
+ Provides a CLI for running the CoDA visualization pipeline locally.
5
+ """
6
+
7
+ import argparse
8
+ import logging
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def main():
20
+ """Main entry point for the CLI."""
21
+ parser = argparse.ArgumentParser(
22
+ description="CoDA - Collaborative Data Visualization Agents",
23
+ formatter_class=argparse.RawDescriptionHelpFormatter,
24
+ epilog="""
25
+ Examples:
26
+ python main.py --query "Show sales trends" --data sales.csv
27
+ python main.py -q "Bar chart of categories" -d data.xlsx
28
+ python main.py --query "Scatter plot" --data file1.csv file2.csv
29
+ """
30
+ )
31
+
32
+ parser.add_argument(
33
+ "-q", "--query",
34
+ type=str,
35
+ required=True,
36
+ help="Natural language visualization query"
37
+ )
38
+
39
+ parser.add_argument(
40
+ "-d", "--data",
41
+ type=str,
42
+ nargs="+",
43
+ required=True,
44
+ help="Path(s) to data file(s)"
45
+ )
46
+
47
+ parser.add_argument(
48
+ "-o", "--output",
49
+ type=str,
50
+ default="outputs",
51
+ help="Output directory for visualizations (default: outputs)"
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--max-iterations",
56
+ type=int,
57
+ default=3,
58
+ help="Maximum refinement iterations (default: 3)"
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--min-score",
63
+ type=float,
64
+ default=7.0,
65
+ help="Minimum quality score threshold (default: 7.0)"
66
+ )
67
+
68
+ parser.add_argument(
69
+ "-v", "--verbose",
70
+ action="store_true",
71
+ help="Enable verbose logging"
72
+ )
73
+
74
+ args = parser.parse_args()
75
+
76
+ if args.verbose:
77
+ logging.getLogger().setLevel(logging.DEBUG)
78
+
79
+ for path in args.data:
80
+ if not Path(path).exists():
81
+ logger.error(f"Data file not found: {path}")
82
+ sys.exit(1)
83
+
84
+ try:
85
+ from coda.config import Config, ExecutionConfig, QualityThresholds
86
+ from coda.orchestrator import CodaOrchestrator
87
+ except ImportError as e:
88
+ logger.error(f"Failed to import CoDA modules: {e}")
89
+ logger.error("Make sure you have installed all dependencies: pip install -r requirements.txt")
90
+ sys.exit(1)
91
+
92
+ try:
93
+ config = Config(
94
+ execution=ExecutionConfig(
95
+ max_refinement_iterations=args.max_iterations,
96
+ output_directory=args.output,
97
+ ),
98
+ quality=QualityThresholds(
99
+ minimum_overall_score=args.min_score,
100
+ ),
101
+ )
102
+ except ValueError as e:
103
+ logger.error(f"Configuration error: {e}")
104
+ sys.exit(1)
105
+
106
+ def progress_callback(status: str, progress: float):
107
+ bar_length = 30
108
+ filled = int(bar_length * progress)
109
+ bar = "█" * filled + "░" * (bar_length - filled)
110
+ print(f"\r[{bar}] {progress:.0%} - {status}", end="", flush=True)
111
+ if progress >= 1.0:
112
+ print()
113
+
114
+ print(f"\n{'='*60}")
115
+ print("CoDA - Collaborative Data Visualization Agents")
116
+ print(f"{'='*60}\n")
117
+ print(f"Query: {args.query}")
118
+ print(f"Data: {', '.join(args.data)}")
119
+ print(f"Output: {args.output}/")
120
+ print()
121
+
122
+ orchestrator = CodaOrchestrator(
123
+ config=config,
124
+ progress_callback=progress_callback,
125
+ )
126
+
127
+ result = orchestrator.run(
128
+ query=args.query,
129
+ data_paths=args.data,
130
+ )
131
+
132
+ print()
133
+ print(f"{'='*60}")
134
+ print("Results")
135
+ print(f"{'='*60}\n")
136
+
137
+ if result.success:
138
+ print(f"✅ Visualization generated successfully!")
139
+ print(f"📁 Output: {result.output_file}")
140
+ print(f"🔄 Iterations: {result.iterations}")
141
+
142
+ if result.scores:
143
+ print(f"\n📊 Quality Scores:")
144
+ for key, value in result.scores.items():
145
+ emoji = "🟢" if value >= 7 else "🟡" if value >= 5 else "🔴"
146
+ print(f" {key.title()}: {emoji} {value:.1f}/10")
147
+
148
+ if result.evaluation and result.evaluation.strengths:
149
+ print(f"\n💪 Strengths:")
150
+ for s in result.evaluation.strengths[:3]:
151
+ print(f" • {s}")
152
+ else:
153
+ print(f"❌ Visualization failed!")
154
+ if result.error:
155
+ print(f" Error: {result.error}")
156
+ sys.exit(1)
157
+
158
+ print()
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CoDA Dependencies
2
+
3
+ # Core
4
+ groq>=0.4.0
5
+ pydantic>=2.0.0
6
+ python-dotenv>=1.0.0
7
+
8
+ # Data Processing
9
+ pandas>=2.0.0
10
+ openpyxl>=3.1.0
11
+ pyarrow>=14.0.0
12
+
13
+ # Visualization
14
+ matplotlib>=3.7.0
15
+ seaborn>=0.13.0
16
+
17
+ # Web Interface
18
+ gradio>=4.0.0
19
+
20
+ # Development (optional)
21
+ pytest>=7.0.0
22
+ pytest-asyncio>=0.21.0
sample_data.csv ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ date,sales,category,region
2
+ 2024-01-01,1500,Electronics,North
3
+ 2024-01-02,2300,Electronics,South
4
+ 2024-01-03,1800,Clothing,North
5
+ 2024-01-04,3200,Electronics,East
6
+ 2024-01-05,2100,Clothing,West
7
+ 2024-01-06,1900,Food,North
8
+ 2024-01-07,2800,Electronics,South
9
+ 2024-01-08,1600,Clothing,East
10
+ 2024-01-09,3500,Food,West
11
+ 2024-01-10,2400,Electronics,North
12
+ 2024-01-11,1700,Clothing,South
13
+ 2024-01-12,2900,Food,East
14
+ 2024-01-13,2200,Electronics,West
15
+ 2024-01-14,1400,Clothing,North
16
+ 2024-01-15,3100,Food,South
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Tests package
tests/test_agents.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for agent implementations.
3
+ """
4
+
5
+ import pytest
6
+ from unittest.mock import Mock, MagicMock
7
+ import json
8
+
9
+ from coda.core.memory import SharedMemory
10
+ from coda.core.llm import LLMResponse
11
+ from coda.core.base_agent import AgentContext
12
+ from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
13
+ from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
14
+ from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
15
+
16
+
17
+ class MockLLM:
18
+ """Mock LLM for testing agents."""
19
+
20
+ def __init__(self, response_content: str):
21
+ self._response = response_content
22
+
23
+ def complete(self, prompt, system_prompt=None, **kwargs):
24
+ return LLMResponse(
25
+ content=self._response,
26
+ model="mock",
27
+ usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
28
+ finish_reason="stop"
29
+ )
30
+
31
+ def complete_with_image(self, prompt, image_path, **kwargs):
32
+ return self.complete(prompt, **kwargs)
33
+
34
+
35
+ class TestQueryAnalyzerAgent:
36
+ """Tests for the Query Analyzer agent."""
37
+
38
+ @pytest.fixture
39
+ def mock_response(self):
40
+ return json.dumps({
41
+ "visualization_types": ["line chart", "bar chart"],
42
+ "key_points": ["sales trends", "monthly data"],
43
+ "todo_list": ["Load data", "Create chart", "Add labels"],
44
+ "data_requirements": ["date", "sales"],
45
+ "constraints": ["use blue colors"],
46
+ "ambiguities": []
47
+ })
48
+
49
+ def test_execute(self, mock_response):
50
+ """Test query analysis execution."""
51
+ llm = MockLLM(mock_response)
52
+ memory = SharedMemory()
53
+ agent = QueryAnalyzerAgent(llm, memory)
54
+
55
+ context = AgentContext(query="Show sales trends over time")
56
+ result = agent.execute(context)
57
+
58
+ assert isinstance(result, QueryAnalysis)
59
+ assert "line chart" in result.visualization_types
60
+ assert "sales trends" in result.key_points
61
+ assert len(result.todo_list) == 3
62
+
63
+ def test_stores_in_memory(self, mock_response):
64
+ """Test that results are stored in memory."""
65
+ llm = MockLLM(mock_response)
66
+ memory = SharedMemory()
67
+ agent = QueryAnalyzerAgent(llm, memory)
68
+
69
+ context = AgentContext(query="Test query")
70
+ agent.execute(context)
71
+
72
+ stored = memory.retrieve("query_analysis")
73
+ assert stored is not None
74
+ assert "visualization_types" in stored
75
+
76
+
77
+ class TestVizMappingAgent:
78
+ """Tests for the VizMapping agent."""
79
+
80
+ @pytest.fixture
81
+ def mock_response(self):
82
+ return json.dumps({
83
+ "chart_type": "line",
84
+ "chart_subtype": None,
85
+ "x_axis": {"column": "date", "label": "Date", "type": "temporal"},
86
+ "y_axis": {"column": "sales", "label": "Sales", "type": "numerical"},
87
+ "color_encoding": None,
88
+ "size_encoding": None,
89
+ "transformations": [],
90
+ "styling_hints": {"theme": "modern"},
91
+ "visualization_goals": ["Show trends"],
92
+ "rationale": "Line chart best for trends"
93
+ })
94
+
95
+ def test_execute(self, mock_response):
96
+ """Test visualization mapping execution."""
97
+ llm = MockLLM(mock_response)
98
+ memory = SharedMemory()
99
+
100
+ memory.store("query_analysis", {
101
+ "visualization_types": ["line chart"],
102
+ "key_points": ["trends"],
103
+ "data_requirements": ["date", "sales"]
104
+ }, "test")
105
+
106
+ agent = VizMappingAgent(llm, memory)
107
+ context = AgentContext(query="Show sales trends")
108
+ result = agent.execute(context)
109
+
110
+ assert isinstance(result, VisualMapping)
111
+ assert result.chart_type == "line"
112
+ assert result.x_axis["column"] == "date"
113
+
114
+
115
+ class TestAgentContext:
116
+ """Tests for the AgentContext model."""
117
+
118
+ def test_basic_context(self):
119
+ """Test creating a basic context."""
120
+ context = AgentContext(
121
+ query="Test query",
122
+ data_paths=["file1.csv", "file2.csv"]
123
+ )
124
+
125
+ assert context.query == "Test query"
126
+ assert len(context.data_paths) == 2
127
+ assert context.iteration == 0
128
+ assert context.feedback is None
129
+
130
+ def test_context_with_feedback(self):
131
+ """Test context with feedback for refinement."""
132
+ context = AgentContext(
133
+ query="Test",
134
+ iteration=2,
135
+ feedback="Improve colors"
136
+ )
137
+
138
+ assert context.iteration == 2
139
+ assert context.feedback == "Improve colors"
140
+
141
+
142
+ class TestBaseAgentJsonExtraction:
143
+ """Tests for JSON extraction from LLM responses."""
144
+
145
+ def test_extract_json_plain(self):
146
+ """Test extracting plain JSON."""
147
+ llm = MockLLM("{}")
148
+ memory = SharedMemory()
149
+ agent = QueryAnalyzerAgent(llm, memory)
150
+
151
+ result = agent._extract_json('{"key": "value"}')
152
+
153
+ assert result == {"key": "value"}
154
+
155
+ def test_extract_json_markdown(self):
156
+ """Test extracting JSON from markdown code block."""
157
+ llm = MockLLM("{}")
158
+ memory = SharedMemory()
159
+ agent = QueryAnalyzerAgent(llm, memory)
160
+
161
+ text = """Here is the response:
162
+ ```json
163
+ {"key": "value"}
164
+ ```
165
+ """
166
+ result = agent._extract_json(text)
167
+
168
+ assert result == {"key": "value"}
169
+
170
+ def test_extract_json_invalid(self):
171
+ """Test handling invalid JSON."""
172
+ llm = MockLLM("{}")
173
+ memory = SharedMemory()
174
+ agent = QueryAnalyzerAgent(llm, memory)
175
+
176
+ with pytest.raises(ValueError):
177
+ agent._extract_json("not valid json")
tests/test_llm.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for the LLM abstraction layer.
3
+ """
4
+
5
+ import pytest
6
+ from unittest.mock import Mock, patch, MagicMock
7
+ from pathlib import Path
8
+
9
+ from coda.core.llm import LLMProvider, GroqLLM, LLMResponse
10
+
11
+
12
+ class TestLLMResponse:
13
+ """Tests for the LLMResponse model."""
14
+
15
+ def test_response_creation(self):
16
+ """Test creating a valid response."""
17
+ response = LLMResponse(
18
+ content="Test content",
19
+ model="test-model",
20
+ usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
21
+ finish_reason="stop"
22
+ )
23
+
24
+ assert response.content == "Test content"
25
+ assert response.model == "test-model"
26
+ assert response.usage["total_tokens"] == 30
27
+ assert response.finish_reason == "stop"
28
+
29
+
30
+ class TestGroqLLM:
31
+ """Tests for the GroqLLM implementation."""
32
+
33
+ @pytest.fixture
34
+ def mock_groq_client(self):
35
+ """Create a mock Groq client."""
36
+ with patch("coda.core.llm.Groq") as mock:
37
+ client_instance = Mock()
38
+ mock.return_value = client_instance
39
+ yield client_instance
40
+
41
+ def test_initialization(self, mock_groq_client):
42
+ """Test LLM initialization with custom parameters."""
43
+ llm = GroqLLM(
44
+ api_key="test-key",
45
+ default_model="custom-model",
46
+ temperature=0.5,
47
+ max_tokens=2048,
48
+ )
49
+
50
+ assert llm._default_model == "custom-model"
51
+ assert llm._temperature == 0.5
52
+ assert llm._max_tokens == 2048
53
+
54
+ def test_complete_success(self, mock_groq_client):
55
+ """Test successful completion."""
56
+ mock_response = MagicMock()
57
+ mock_response.choices = [MagicMock()]
58
+ mock_response.choices[0].message.content = "Generated text"
59
+ mock_response.choices[0].finish_reason = "stop"
60
+ mock_response.model = "llama-3.3-70b-versatile"
61
+ mock_response.usage = MagicMock()
62
+ mock_response.usage.prompt_tokens = 10
63
+ mock_response.usage.completion_tokens = 20
64
+ mock_response.usage.total_tokens = 30
65
+
66
+ mock_groq_client.chat.completions.create.return_value = mock_response
67
+
68
+ llm = GroqLLM(api_key="test-key")
69
+ response = llm.complete(
70
+ prompt="Test prompt",
71
+ system_prompt="System prompt"
72
+ )
73
+
74
+ assert response.content == "Generated text"
75
+ assert response.finish_reason == "stop"
76
+ mock_groq_client.chat.completions.create.assert_called_once()
77
+
78
+ def test_complete_with_retry(self, mock_groq_client):
79
+ """Test retry logic on failure."""
80
+ mock_groq_client.chat.completions.create.side_effect = [
81
+ Exception("Rate limited"),
82
+ MagicMock(
83
+ choices=[MagicMock(message=MagicMock(content="Success"), finish_reason="stop")],
84
+ model="test",
85
+ usage=MagicMock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
86
+ )
87
+ ]
88
+
89
+ llm = GroqLLM(api_key="test-key", retry_delay=0.01)
90
+ response = llm.complete(prompt="Test")
91
+
92
+ assert response.content == "Success"
93
+ assert mock_groq_client.chat.completions.create.call_count == 2
94
+
95
+ def test_build_messages(self, mock_groq_client):
96
+ """Test message building with system prompt."""
97
+ llm = GroqLLM(api_key="test-key")
98
+
99
+ messages = llm._build_messages(
100
+ prompt="User message",
101
+ system_prompt="System message"
102
+ )
103
+
104
+ assert len(messages) == 2
105
+ assert messages[0]["role"] == "system"
106
+ assert messages[0]["content"] == "System message"
107
+ assert messages[1]["role"] == "user"
108
+ assert messages[1]["content"] == "User message"
109
+
110
+ def test_build_messages_no_system(self, mock_groq_client):
111
+ """Test message building without system prompt."""
112
+ llm = GroqLLM(api_key="test-key")
113
+
114
+ messages = llm._build_messages(prompt="User message")
115
+
116
+ assert len(messages) == 1
117
+ assert messages[0]["role"] == "user"
118
+
119
+
120
+ class TestLLMProviderInterface:
121
+ """Tests for the abstract interface."""
122
+
123
+ def test_interface_methods(self):
124
+ """Verify LLMProvider defines required methods."""
125
+ assert hasattr(LLMProvider, "complete")
126
+ assert hasattr(LLMProvider, "complete_with_image")
tests/test_memory.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for the SharedMemory class.
3
+ """
4
+
5
+ import pytest
6
+ from datetime import datetime
7
+ import threading
8
+ import time
9
+
10
+ from coda.core.memory import SharedMemory, MemoryEntry
11
+
12
+
13
+ class TestSharedMemory:
14
+ """Tests for the SharedMemory class."""
15
+
16
+ @pytest.fixture
17
+ def memory(self):
18
+ """Create a fresh SharedMemory instance."""
19
+ return SharedMemory()
20
+
21
+ def test_store_and_retrieve(self, memory):
22
+ """Test basic store and retrieve operations."""
23
+ memory.store(
24
+ key="test_key",
25
+ value={"data": "value"},
26
+ agent_name="TestAgent"
27
+ )
28
+
29
+ result = memory.retrieve("test_key")
30
+
31
+ assert result == {"data": "value"}
32
+
33
+ def test_retrieve_nonexistent(self, memory):
34
+ """Test retrieving a non-existent key."""
35
+ result = memory.retrieve("nonexistent")
36
+
37
+ assert result is None
38
+
39
+ def test_retrieve_entry(self, memory):
40
+ """Test retrieving the full entry with metadata."""
41
+ memory.store(
42
+ key="test_key",
43
+ value="test_value",
44
+ agent_name="TestAgent",
45
+ metadata={"extra": "info"}
46
+ )
47
+
48
+ entry = memory.retrieve_entry("test_key")
49
+
50
+ assert entry is not None
51
+ assert entry.value == "test_value"
52
+ assert entry.agent_name == "TestAgent"
53
+ assert entry.metadata == {"extra": "info"}
54
+ assert isinstance(entry.timestamp, datetime)
55
+
56
+ def test_get_context(self, memory):
57
+ """Test retrieving multiple keys as context."""
58
+ memory.store("key1", "value1", "Agent1")
59
+ memory.store("key2", "value2", "Agent2")
60
+ memory.store("key3", "value3", "Agent3")
61
+
62
+ context = memory.get_context(["key1", "key3", "nonexistent"])
63
+
64
+ assert context == {"key1": "value1", "key3": "value3"}
65
+
66
+ def test_get_all(self, memory):
67
+ """Test retrieving all stored values."""
68
+ memory.store("key1", "value1", "Agent")
69
+ memory.store("key2", "value2", "Agent")
70
+
71
+ all_data = memory.get_all()
72
+
73
+ assert all_data == {"key1": "value1", "key2": "value2"}
74
+
75
+ def test_overwrite_value(self, memory):
76
+ """Test overwriting an existing value."""
77
+ memory.store("key", "original", "Agent")
78
+ memory.store("key", "updated", "Agent")
79
+
80
+ assert memory.retrieve("key") == "updated"
81
+
82
+ def test_history_tracking(self, memory):
83
+ """Test that history is tracked for all operations."""
84
+ memory.store("key1", "v1", "Agent1")
85
+ memory.store("key2", "v2", "Agent2")
86
+ memory.store("key1", "v1_updated", "Agent1")
87
+
88
+ history = memory.get_history()
89
+
90
+ assert len(history) == 3
91
+ assert history[0].key == "key1"
92
+ assert history[1].key == "key2"
93
+ assert history[2].value == "v1_updated"
94
+
95
+ def test_history_filter_by_agent(self, memory):
96
+ """Test filtering history by agent name."""
97
+ memory.store("k1", "v1", "Agent1")
98
+ memory.store("k2", "v2", "Agent2")
99
+ memory.store("k3", "v3", "Agent1")
100
+
101
+ agent1_history = memory.get_history(agent_name="Agent1")
102
+
103
+ assert len(agent1_history) == 2
104
+ assert all(e.agent_name == "Agent1" for e in agent1_history)
105
+
106
+ def test_has_key(self, memory):
107
+ """Test key existence check."""
108
+ memory.store("exists", "value", "Agent")
109
+
110
+ assert memory.has_key("exists") is True
111
+ assert memory.has_key("not_exists") is False
112
+
113
+ def test_clear(self, memory):
114
+ """Test clearing all data."""
115
+ memory.store("k1", "v1", "Agent")
116
+ memory.store("k2", "v2", "Agent")
117
+
118
+ memory.clear()
119
+
120
+ assert memory.retrieve("k1") is None
121
+ assert memory.retrieve("k2") is None
122
+ assert len(memory.get_history()) == 0
123
+
124
+ def test_keys(self, memory):
125
+ """Test getting all keys."""
126
+ memory.store("a", 1, "Agent")
127
+ memory.store("b", 2, "Agent")
128
+ memory.store("c", 3, "Agent")
129
+
130
+ keys = memory.keys()
131
+
132
+ assert set(keys) == {"a", "b", "c"}
133
+
134
+ def test_thread_safety(self, memory):
135
+ """Test that operations are thread-safe."""
136
+ results = []
137
+ errors = []
138
+
139
+ def writer(n):
140
+ try:
141
+ for i in range(100):
142
+ memory.store(f"key_{n}_{i}", i, f"Agent{n}")
143
+ except Exception as e:
144
+ errors.append(e)
145
+
146
+ def reader():
147
+ try:
148
+ for _ in range(100):
149
+ memory.get_all()
150
+ memory.keys()
151
+ except Exception as e:
152
+ errors.append(e)
153
+
154
+ threads = [
155
+ threading.Thread(target=writer, args=(i,))
156
+ for i in range(3)
157
+ ]
158
+ threads.append(threading.Thread(target=reader))
159
+
160
+ for t in threads:
161
+ t.start()
162
+ for t in threads:
163
+ t.join()
164
+
165
+ assert len(errors) == 0