Spaces:
Sleeping
Sleeping
Commit
·
9281fab
0
Parent(s):
Added init files
Browse files- .gitignore +6 -0
- LICENSE +21 -0
- README.md +190 -0
- app.py +220 -0
- coda/__init__.py +31 -0
- coda/agents/__init__.py +29 -0
- coda/agents/code_generator.py +162 -0
- coda/agents/data_processor.py +280 -0
- coda/agents/debug_agent.py +252 -0
- coda/agents/design_explorer.py +207 -0
- coda/agents/query_analyzer.py +112 -0
- coda/agents/search_agent.py +295 -0
- coda/agents/visual_evaluator.py +228 -0
- coda/agents/viz_mapping.py +164 -0
- coda/config.py +91 -0
- coda/core/__init__.py +17 -0
- coda/core/agent_factory.py +135 -0
- coda/core/base_agent.py +220 -0
- coda/core/llm.py +192 -0
- coda/core/memory.py +148 -0
- coda/orchestrator.py +293 -0
- main.py +162 -0
- requirements.txt +22 -0
- sample_data.csv +16 -0
- tests/__init__.py +1 -0
- tests/test_agents.py +177 -0
- tests/test_llm.py +126 -0
- tests/test_memory.py +165 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.env
|
| 5 |
+
outputs/
|
| 6 |
+
*.log
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 M Saqlain
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CoDA
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# CoDA: Collaborative Data Visualization Agents
|
| 13 |
+
|
| 14 |
+
A production-grade multi-agent system for automated data visualization from natural language queries.
|
| 15 |
+
|
| 16 |
+
[](https://huggingface.co/spaces)
|
| 17 |
+
[](https://www.python.org/downloads/)
|
| 18 |
+
[](https://opensource.org/licenses/MIT)
|
| 19 |
+
|
| 20 |
+
## Overview
|
| 21 |
+
|
| 22 |
+
CoDA reframes data visualization as a collaborative multi-agent problem. Instead of treating it as a monolithic task, CoDA employs specialized LLM agents that work together:
|
| 23 |
+
|
| 24 |
+
- **Query Analyzer** - Interprets natural language and extracts visualization intent
|
| 25 |
+
- **Data Processor** - Extracts metadata without token-heavy data loading
|
| 26 |
+
- **VizMapping Agent** - Maps semantics to visualization primitives
|
| 27 |
+
- **Search Agent** - Retrieves relevant code patterns
|
| 28 |
+
- **Design Explorer** - Generates aesthetic specifications
|
| 29 |
+
- **Code Generator** - Synthesizes executable Python code
|
| 30 |
+
- **Debug Agent** - Executes code and fixes errors
|
| 31 |
+
- **Visual Evaluator** - Assesses quality and triggers refinement
|
| 32 |
+
|
| 33 |
+
## Quick Start
|
| 34 |
+
|
| 35 |
+
### Installation
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# Clone the repository
|
| 39 |
+
git clone https://github.com/yourusername/CoDA.git
|
| 40 |
+
cd CoDA
|
| 41 |
+
|
| 42 |
+
# Install dependencies
|
| 43 |
+
pip install -r requirements.txt
|
| 44 |
+
|
| 45 |
+
# Configure API key
|
| 46 |
+
cp .env.example .env
|
| 47 |
+
# Edit .env and add your GROQ_API_KEY
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Usage
|
| 51 |
+
|
| 52 |
+
#### Web Interface (Gradio)
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
python app.py
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Open http://localhost:7860 in your browser.
|
| 59 |
+
|
| 60 |
+
#### Command Line
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
python main.py --query "Create a bar chart of sales by category" --data sales.csv
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
Options:
|
| 67 |
+
- `-q, --query`: Visualization query (required)
|
| 68 |
+
- `-d, --data`: Data file path(s) (required)
|
| 69 |
+
- `-o, --output`: Output directory (default: outputs)
|
| 70 |
+
- `--max-iterations`: Refinement iterations (default: 3)
|
| 71 |
+
- `--min-score`: Quality threshold (default: 7.0)
|
| 72 |
+
|
| 73 |
+
### Python API
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from coda.orchestrator import CodaOrchestrator
|
| 77 |
+
|
| 78 |
+
orchestrator = CodaOrchestrator()
|
| 79 |
+
result = orchestrator.run(
|
| 80 |
+
query="Show sales trends over time",
|
| 81 |
+
data_paths=["sales_data.csv"]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if result.success:
|
| 85 |
+
print(f"Visualization saved to: {result.output_file}")
|
| 86 |
+
print(f"Quality Score: {result.scores['overall']}/10")
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
## Hugging Face Spaces Deployment
|
| 90 |
+
|
| 91 |
+
1. Create a new Space on [Hugging Face](https://huggingface.co/new-space)
|
| 92 |
+
2. Select "Gradio" as the SDK
|
| 93 |
+
3. Upload all files from this repository
|
| 94 |
+
4. Add `GROQ_API_KEY` as a Secret in Space Settings
|
| 95 |
+
5. The Space will automatically build and deploy
|
| 96 |
+
|
| 97 |
+
## Architecture
|
| 98 |
+
|
| 99 |
+
```
|
| 100 |
+
Natural Language Query + Data Files
|
| 101 |
+
│
|
| 102 |
+
▼
|
| 103 |
+
┌───────────────┐
|
| 104 |
+
│ Query Analyzer │ ─── Extracts intent, TODO list
|
| 105 |
+
└───────────────┘
|
| 106 |
+
│
|
| 107 |
+
▼
|
| 108 |
+
┌───────────────┐
|
| 109 |
+
│ Data Processor │ ─── Metadata extraction (no full load)
|
| 110 |
+
└───────────────┘
|
| 111 |
+
│
|
| 112 |
+
▼
|
| 113 |
+
┌───────────────┐
|
| 114 |
+
│ VizMapping │ ─── Chart type, encodings
|
| 115 |
+
└───────────────┘
|
| 116 |
+
│
|
| 117 |
+
▼
|
| 118 |
+
┌───────────────┐
|
| 119 |
+
│ Search Agent │ ─── Code examples
|
| 120 |
+
└───────────────┘
|
| 121 |
+
│
|
| 122 |
+
▼
|
| 123 |
+
┌───────────────┐
|
| 124 |
+
│Design Explorer│ ─── Colors, layout, styling
|
| 125 |
+
└───────────────┘
|
| 126 |
+
│
|
| 127 |
+
▼
|
| 128 |
+
┌───────────────┐
|
| 129 |
+
│Code Generator │ ─── Python visualization code
|
| 130 |
+
└───────────────┘
|
| 131 |
+
│
|
| 132 |
+
▼
|
| 133 |
+
┌───────────────┐
|
| 134 |
+
│ Debug Agent │ ─── Execute & fix errors
|
| 135 |
+
└───────────────┘
|
| 136 |
+
│
|
| 137 |
+
▼
|
| 138 |
+
┌───────────────┐
|
| 139 |
+
│Visual Evaluator│ ─── Quality assessment
|
| 140 |
+
└───────────────┘
|
| 141 |
+
│
|
| 142 |
+
───────┴───────
|
| 143 |
+
↓ Feedback Loop ↓
|
| 144 |
+
(if quality < threshold)
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## Configuration
|
| 148 |
+
|
| 149 |
+
Environment variables (in `.env`):
|
| 150 |
+
|
| 151 |
+
| Variable | Default | Description |
|
| 152 |
+
|----------|---------|-------------|
|
| 153 |
+
| `GROQ_API_KEY` | Required | Your Groq API key |
|
| 154 |
+
| `CODA_DEFAULT_MODEL` | llama-3.3-70b-versatile | Text model |
|
| 155 |
+
| `CODA_VISION_MODEL` | llama-3.2-90b-vision-preview | Vision model |
|
| 156 |
+
| `CODA_MIN_OVERALL_SCORE` | 7.0 | Quality threshold |
|
| 157 |
+
| `CODA_MAX_ITERATIONS` | 3 | Max refinement loops |
|
| 158 |
+
|
| 159 |
+
## Supported Data Formats
|
| 160 |
+
|
| 161 |
+
- CSV (`.csv`)
|
| 162 |
+
- JSON (`.json`)
|
| 163 |
+
- Excel (`.xlsx`, `.xls`)
|
| 164 |
+
- Parquet (`.parquet`)
|
| 165 |
+
|
| 166 |
+
## Requirements
|
| 167 |
+
|
| 168 |
+
- Python 3.10+
|
| 169 |
+
- Groq API key ([Get one free](https://console.groq.com))
|
| 170 |
+
|
| 171 |
+
## License
|
| 172 |
+
|
| 173 |
+
MIT License - See LICENSE for details.
|
| 174 |
+
|
| 175 |
+
## Citation
|
| 176 |
+
|
| 177 |
+
If you use CoDA in your research, please cite:
|
| 178 |
+
|
| 179 |
+
```bibtex
|
| 180 |
+
@article{chen2025coda,
|
| 181 |
+
title={CoDA: Agentic Systems for Collaborative Data Visualization},
|
| 182 |
+
author={Chen, Zichen and Chen, Jiefeng and Arik, Sercan {\"O}. and Sra, Misha and Pfister, Tomas and Yoon, Jinsung},
|
| 183 |
+
journal={arXiv preprint arXiv:2510.03194},
|
| 184 |
+
year={2025},
|
| 185 |
+
url={https://arxiv.org/abs/2510.03194},
|
| 186 |
+
doi={10.48550/arXiv.2510.03194}
|
| 187 |
+
}
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
**Paper**: [arXiv:2510.03194](https://arxiv.org/abs/2510.03194)
|
app.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Web Interface for CoDA.
|
| 3 |
+
|
| 4 |
+
Provides a user-friendly web UI for the CoDA visualization system,
|
| 5 |
+
designed for deployment on Hugging Face Spaces.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(
|
| 17 |
+
level=logging.INFO,
|
| 18 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 19 |
+
)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_coda_interface():
|
| 24 |
+
"""Create the Gradio interface for CoDA."""
|
| 25 |
+
|
| 26 |
+
def process_visualization(
|
| 27 |
+
query: str,
|
| 28 |
+
data_file,
|
| 29 |
+
progress=gr.Progress()
|
| 30 |
+
) -> tuple[Optional[str], str, str]:
|
| 31 |
+
"""
|
| 32 |
+
Process a visualization request.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
query: Natural language visualization query
|
| 36 |
+
data_file: Uploaded data file
|
| 37 |
+
progress: Gradio progress tracker
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tuple of (image_path, status_message, details)
|
| 41 |
+
"""
|
| 42 |
+
if not query.strip():
|
| 43 |
+
return None, "❌ Error", "Please enter a visualization query."
|
| 44 |
+
|
| 45 |
+
if data_file is None:
|
| 46 |
+
return None, "❌ Error", "Please upload a data file."
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
from coda.config import Config
|
| 50 |
+
from coda.orchestrator import CodaOrchestrator
|
| 51 |
+
except ImportError as e:
|
| 52 |
+
return None, "❌ Import Error", f"Failed to import CoDA: {e}"
|
| 53 |
+
|
| 54 |
+
groq_api_key = os.getenv("GROQ_API_KEY", "")
|
| 55 |
+
if not groq_api_key:
|
| 56 |
+
return (
|
| 57 |
+
None,
|
| 58 |
+
"❌ Configuration Error",
|
| 59 |
+
"GROQ_API_KEY environment variable is not set. "
|
| 60 |
+
"Please add your API key in the Spaces settings."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 64 |
+
data_path = Path(temp_dir) / Path(data_file.name).name
|
| 65 |
+
|
| 66 |
+
with open(data_file.name, 'rb') as src:
|
| 67 |
+
with open(data_path, 'wb') as dst:
|
| 68 |
+
dst.write(src.read())
|
| 69 |
+
|
| 70 |
+
def update_progress(status: str, pct: float):
|
| 71 |
+
progress(pct, desc=status)
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
config = Config(
|
| 75 |
+
groq_api_key=groq_api_key,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
orchestrator = CodaOrchestrator(
|
| 79 |
+
config=config,
|
| 80 |
+
progress_callback=update_progress,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
result = orchestrator.run(
|
| 84 |
+
query=query,
|
| 85 |
+
data_paths=[str(data_path)],
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if result.success and result.output_file:
|
| 89 |
+
scores = result.scores or {}
|
| 90 |
+
details = format_results(result, scores)
|
| 91 |
+
return result.output_file, "✅ Success", details
|
| 92 |
+
else:
|
| 93 |
+
error_msg = result.error or "Unknown error occurred"
|
| 94 |
+
return None, "❌ Failed", f"Visualization failed: {error_msg}"
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.exception("Pipeline error")
|
| 98 |
+
return None, "❌ Error", f"An error occurred: {str(e)}"
|
| 99 |
+
|
| 100 |
+
def format_results(result, scores: dict) -> str:
|
| 101 |
+
"""Format the results for display."""
|
| 102 |
+
lines = [
|
| 103 |
+
f"**Iterations:** {result.iterations}",
|
| 104 |
+
"",
|
| 105 |
+
"### Quality Scores",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
if scores:
|
| 109 |
+
for key, value in scores.items():
|
| 110 |
+
emoji = "🟢" if value >= 7 else "🟡" if value >= 5 else "🔴"
|
| 111 |
+
lines.append(f"- {key.title()}: {emoji} {value:.1f}/10")
|
| 112 |
+
|
| 113 |
+
if result.evaluation:
|
| 114 |
+
if result.evaluation.strengths:
|
| 115 |
+
lines.extend(["", "### Strengths"])
|
| 116 |
+
for s in result.evaluation.strengths[:3]:
|
| 117 |
+
lines.append(f"- {s}")
|
| 118 |
+
|
| 119 |
+
if result.evaluation.recommendations:
|
| 120 |
+
lines.extend(["", "### Recommendations"])
|
| 121 |
+
for r in result.evaluation.recommendations[:3]:
|
| 122 |
+
lines.append(f"- {r}")
|
| 123 |
+
|
| 124 |
+
return "\n".join(lines)
|
| 125 |
+
|
| 126 |
+
with gr.Blocks(
|
| 127 |
+
title="CoDA - Collaborative Data Visualization",
|
| 128 |
+
theme=gr.themes.Soft(),
|
| 129 |
+
css="""
|
| 130 |
+
.main-title {
|
| 131 |
+
text-align: center;
|
| 132 |
+
margin-bottom: 1rem;
|
| 133 |
+
}
|
| 134 |
+
.status-box {
|
| 135 |
+
padding: 1rem;
|
| 136 |
+
border-radius: 8px;
|
| 137 |
+
margin-top: 1rem;
|
| 138 |
+
}
|
| 139 |
+
"""
|
| 140 |
+
) as interface:
|
| 141 |
+
gr.Markdown(
|
| 142 |
+
"""
|
| 143 |
+
# 🎨 CoDA: Collaborative Data Visualization Agents
|
| 144 |
+
|
| 145 |
+
Transform your data into beautiful visualizations using natural language.
|
| 146 |
+
Simply upload your data and describe what you want to see!
|
| 147 |
+
""",
|
| 148 |
+
elem_classes=["main-title"]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
with gr.Row():
|
| 152 |
+
with gr.Column(scale=1):
|
| 153 |
+
query_input = gr.Textbox(
|
| 154 |
+
label="Visualization Query",
|
| 155 |
+
placeholder="e.g., 'Create a line chart showing sales trends over time'",
|
| 156 |
+
lines=3,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
file_input = gr.File(
|
| 160 |
+
label="Upload Data File",
|
| 161 |
+
file_types=[".csv", ".json", ".xlsx", ".xls", ".parquet"],
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
submit_btn = gr.Button(
|
| 165 |
+
"🚀 Generate Visualization",
|
| 166 |
+
variant="primary",
|
| 167 |
+
size="lg",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
gr.Markdown(
|
| 171 |
+
"""
|
| 172 |
+
### Supported Formats
|
| 173 |
+
- CSV, JSON, Excel (.xlsx, .xls), Parquet
|
| 174 |
+
|
| 175 |
+
### Example Queries
|
| 176 |
+
- "Show me a bar chart of sales by category"
|
| 177 |
+
- "Create a scatter plot of price vs quantity"
|
| 178 |
+
- "Plot the distribution of ages as a histogram"
|
| 179 |
+
"""
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
with gr.Column(scale=2):
|
| 183 |
+
output_image = gr.Image(
|
| 184 |
+
label="Generated Visualization",
|
| 185 |
+
type="filepath",
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
with gr.Row():
|
| 189 |
+
status_output = gr.Textbox(
|
| 190 |
+
label="Status",
|
| 191 |
+
interactive=False,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
details_output = gr.Markdown(
|
| 195 |
+
label="Details",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
gr.Examples(
|
| 199 |
+
examples=[
|
| 200 |
+
["Create a bar chart showing the top 10 values", None],
|
| 201 |
+
["Plot a line chart of trends over time", None],
|
| 202 |
+
["Show a scatter plot with correlation", None],
|
| 203 |
+
["Create a pie chart of category distribution", None],
|
| 204 |
+
],
|
| 205 |
+
inputs=[query_input, file_input],
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
submit_btn.click(
|
| 209 |
+
fn=process_visualization,
|
| 210 |
+
inputs=[query_input, file_input],
|
| 211 |
+
outputs=[output_image, status_output, details_output],
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return interface
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
app = create_coda_interface()
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
app.launch()
|
coda/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CoDA - Collaborative Data Visualization Agents
|
| 3 |
+
|
| 4 |
+
A multi-agent system for automated data visualization from natural language queries.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from coda.config import Config, get_config
|
| 8 |
+
from coda.orchestrator import CodaOrchestrator, PipelineResult
|
| 9 |
+
from coda.core import (
|
| 10 |
+
LLMProvider,
|
| 11 |
+
GroqLLM,
|
| 12 |
+
SharedMemory,
|
| 13 |
+
BaseAgent,
|
| 14 |
+
AgentContext,
|
| 15 |
+
AgentFactory,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__version__ = "1.0.0"
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"Config",
|
| 22 |
+
"get_config",
|
| 23 |
+
"CodaOrchestrator",
|
| 24 |
+
"PipelineResult",
|
| 25 |
+
"LLMProvider",
|
| 26 |
+
"GroqLLM",
|
| 27 |
+
"SharedMemory",
|
| 28 |
+
"BaseAgent",
|
| 29 |
+
"AgentContext",
|
| 30 |
+
"AgentFactory",
|
| 31 |
+
]
|
coda/agents/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent implementations for CoDA visualization pipeline."""
|
| 2 |
+
|
| 3 |
+
from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
|
| 4 |
+
from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
|
| 5 |
+
from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
|
| 6 |
+
from coda.agents.search_agent import SearchAgent, SearchResult
|
| 7 |
+
from coda.agents.design_explorer import DesignExplorerAgent, DesignSpec
|
| 8 |
+
from coda.agents.code_generator import CodeGeneratorAgent, GeneratedCode
|
| 9 |
+
from coda.agents.debug_agent import DebugAgent, ExecutionResult
|
| 10 |
+
from coda.agents.visual_evaluator import VisualEvaluatorAgent, VisualEvaluation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"QueryAnalyzerAgent",
|
| 14 |
+
"QueryAnalysis",
|
| 15 |
+
"DataProcessorAgent",
|
| 16 |
+
"DataAnalysis",
|
| 17 |
+
"VizMappingAgent",
|
| 18 |
+
"VisualMapping",
|
| 19 |
+
"SearchAgent",
|
| 20 |
+
"SearchResult",
|
| 21 |
+
"DesignExplorerAgent",
|
| 22 |
+
"DesignSpec",
|
| 23 |
+
"CodeGeneratorAgent",
|
| 24 |
+
"GeneratedCode",
|
| 25 |
+
"DebugAgent",
|
| 26 |
+
"ExecutionResult",
|
| 27 |
+
"VisualEvaluatorAgent",
|
| 28 |
+
"VisualEvaluation",
|
| 29 |
+
]
|
coda/agents/code_generator.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code Generator Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Synthesizes executable Python visualization code by integrating
|
| 5 |
+
specifications from upstream agents.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 13 |
+
from coda.core.llm import LLMProvider
|
| 14 |
+
from coda.core.memory import SharedMemory
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GeneratedCode(BaseModel):
|
| 18 |
+
"""Structured output from the Code Generator."""
|
| 19 |
+
|
| 20 |
+
code: str = Field(default="", description="The generated Python code")
|
| 21 |
+
dependencies: list[str] = Field(
|
| 22 |
+
default_factory=lambda: ["matplotlib", "pandas"],
|
| 23 |
+
description="Required Python packages"
|
| 24 |
+
)
|
| 25 |
+
output_filename: str = Field(
|
| 26 |
+
default="output.png",
|
| 27 |
+
description="Name of the output visualization file"
|
| 28 |
+
)
|
| 29 |
+
documentation: str = Field(
|
| 30 |
+
default="Generated visualization code",
|
| 31 |
+
description="Brief documentation of the code"
|
| 32 |
+
)
|
| 33 |
+
quality_score: float = Field(
|
| 34 |
+
default=5.0,
|
| 35 |
+
description="Self-assessed code quality (0-10)"
|
| 36 |
+
)
|
| 37 |
+
potential_issues: list[str] = Field(
|
| 38 |
+
default_factory=list,
|
| 39 |
+
description="Potential issues or edge cases"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CodeGeneratorAgent(BaseAgent[GeneratedCode]):
|
| 44 |
+
"""
|
| 45 |
+
Generates executable Python visualization code.
|
| 46 |
+
|
| 47 |
+
Integrates all upstream specifications (data processing, visual mapping,
|
| 48 |
+
design specs) into working code.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
MEMORY_KEY = "generated_code"
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
llm: LLMProvider,
|
| 56 |
+
memory: SharedMemory,
|
| 57 |
+
name: Optional[str] = None,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__(llm, memory, name or "CodeGenerator")
|
| 60 |
+
|
| 61 |
+
def _get_system_prompt(self) -> str:
|
| 62 |
+
return """You are an expert Python Developer specializing in data visualization.
|
| 63 |
+
|
| 64 |
+
Your expertise is in writing clean, efficient, and well-documented Python code for data visualization using matplotlib, seaborn, and pandas.
|
| 65 |
+
|
| 66 |
+
Your responsibilities:
|
| 67 |
+
1. Generate complete, executable Python code
|
| 68 |
+
2. Integrate all specifications from the design and mapping agents
|
| 69 |
+
3. Handle data loading and transformation correctly
|
| 70 |
+
4. Apply proper styling and formatting
|
| 71 |
+
5. Include error handling for robustness
|
| 72 |
+
6. Write clear documentation
|
| 73 |
+
|
| 74 |
+
Code requirements:
|
| 75 |
+
- Use matplotlib and seaborn as primary libraries
|
| 76 |
+
- Include all necessary imports at the top
|
| 77 |
+
- Load data from the specified file paths
|
| 78 |
+
- Apply all transformations before plotting
|
| 79 |
+
- Set figure size, colors, and labels as specified
|
| 80 |
+
- Save the output to a file (PNG format)
|
| 81 |
+
- Use descriptive variable names
|
| 82 |
+
- Add comments for complex operations
|
| 83 |
+
|
| 84 |
+
IMPORTANT styling rules:
|
| 85 |
+
- For seaborn barplots, ALWAYS use hue parameter: sns.barplot(..., hue='category_column', legend=False)
|
| 86 |
+
- Use ONLY these reliable palettes: 'viridis', 'plasma', 'inferno', 'magma', 'cividis', 'Deep', 'Muted', 'Pastel'
|
| 87 |
+
- DO NOT use complex or custom named palettes like 'tableau10' or 'husl' unless you are sure.
|
| 88 |
+
- When in doubt, omit the palette argument or use 'viridis'.
|
| 89 |
+
- Always use plt.tight_layout() before saving.
|
| 90 |
+
|
| 91 |
+
Always respond with a valid JSON object containing the code and metadata."""
|
| 92 |
+
|
| 93 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 94 |
+
data_analysis = self._get_from_memory("data_analysis") or {}
|
| 95 |
+
visual_mapping = self._get_from_memory("visual_mapping") or {}
|
| 96 |
+
design_spec = self._get_from_memory("design_spec") or {}
|
| 97 |
+
search_results = self._get_from_memory("search_results") or {}
|
| 98 |
+
|
| 99 |
+
file_info = data_analysis.get("files", [])
|
| 100 |
+
data_paths = [f.get("file_path", "") for f in file_info] if file_info else context.data_paths
|
| 101 |
+
|
| 102 |
+
code_examples = search_results.get("examples", [])
|
| 103 |
+
examples_section = ""
|
| 104 |
+
if code_examples:
|
| 105 |
+
examples_section = "\nReference Code Examples:\n"
|
| 106 |
+
for ex in code_examples[:2]:
|
| 107 |
+
if isinstance(ex, dict):
|
| 108 |
+
examples_section += f"```python\n# {ex.get('title', 'Example')}\n{ex.get('code', '')}\n```\n"
|
| 109 |
+
|
| 110 |
+
feedback_section = ""
|
| 111 |
+
if context.feedback:
|
| 112 |
+
feedback_section = f"""
|
| 113 |
+
Code Feedback (iteration {context.iteration}):
|
| 114 |
+
{context.feedback}
|
| 115 |
+
|
| 116 |
+
Fix the issues mentioned in the feedback.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
return f"""Generate Python visualization code based on the following specifications.
|
| 120 |
+
|
| 121 |
+
User Query: {context.query}
|
| 122 |
+
|
| 123 |
+
Data Files: {data_paths}
|
| 124 |
+
|
| 125 |
+
Visual Mapping:
|
| 126 |
+
- Chart Type: {visual_mapping.get('chart_type', 'line')}
|
| 127 |
+
- X-Axis: {visual_mapping.get('x_axis') or 'Not specified (infer from data or chart type)'}
|
| 128 |
+
- Y-Axis: {visual_mapping.get('y_axis') or 'Not specified (infer from data or chart type)'}
|
| 129 |
+
- Color Encoding: {visual_mapping.get('color_encoding')}
|
| 130 |
+
- Transformations: {visual_mapping.get('transformations', [])}
|
| 131 |
+
|
| 132 |
+
Design Specification:
|
| 133 |
+
- Colors: {design_spec.get('color_scheme', {})}
|
| 134 |
+
- Layout: {design_spec.get('layout', {})}
|
| 135 |
+
- Typography: {design_spec.get('typography', {})}
|
| 136 |
+
- Annotations: {design_spec.get('annotations', [])}
|
| 137 |
+
- Guidelines: {design_spec.get('implementation_guidelines', [])}
|
| 138 |
+
{examples_section}{feedback_section}
|
| 139 |
+
|
| 140 |
+
Generate a complete Python script that:
|
| 141 |
+
1. Imports all necessary libraries
|
| 142 |
+
2. Loads the data file(s)
|
| 143 |
+
3. Applies required transformations
|
| 144 |
+
4. Creates the visualization with specified styling
|
| 145 |
+
5. Saves to 'output.png'
|
| 146 |
+
|
| 147 |
+
Respond with a JSON object:
|
| 148 |
+
- code: Complete Python code as a string
|
| 149 |
+
- dependencies: List of required packages
|
| 150 |
+
- output_filename: Output file name
|
| 151 |
+
- documentation: Brief description
|
| 152 |
+
- quality_score: Self-assessment 0-10
|
| 153 |
+
- potential_issues: List of potential issues
|
| 154 |
+
|
| 155 |
+
JSON Response:"""
|
| 156 |
+
|
| 157 |
+
def _parse_response(self, response: str) -> GeneratedCode:
|
| 158 |
+
data = self._extract_json(response)
|
| 159 |
+
return GeneratedCode(**data)
|
| 160 |
+
|
| 161 |
+
def _get_output_key(self) -> str:
|
| 162 |
+
return self.MEMORY_KEY
|
coda/agents/data_processor.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Processor Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Extracts metadata and insights from data files without loading full datasets,
|
| 5 |
+
enabling the system to work within token limits while providing rich context
|
| 6 |
+
for visualization decisions.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Optional
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from pydantic import BaseModel, Field
|
| 15 |
+
|
| 16 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 17 |
+
from coda.core.llm import LLMProvider
|
| 18 |
+
from coda.core.memory import SharedMemory
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ColumnInfo(BaseModel):
|
| 24 |
+
"""Information about a single column."""
|
| 25 |
+
|
| 26 |
+
name: str
|
| 27 |
+
dtype: str
|
| 28 |
+
non_null_count: int
|
| 29 |
+
unique_count: int
|
| 30 |
+
sample_values: list[Any]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DataFileInfo(BaseModel):
|
| 34 |
+
"""Metadata about a single data file."""
|
| 35 |
+
|
| 36 |
+
file_path: str
|
| 37 |
+
file_type: str
|
| 38 |
+
shape: tuple[int, int]
|
| 39 |
+
columns: list[ColumnInfo]
|
| 40 |
+
memory_usage_mb: float
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DataAnalysis(BaseModel):
|
| 44 |
+
"""Structured output from the Data Processor."""
|
| 45 |
+
|
| 46 |
+
files: list[DataFileInfo] = Field(
|
| 47 |
+
description="Metadata for each processed data file"
|
| 48 |
+
)
|
| 49 |
+
insights: list[str] = Field(
|
| 50 |
+
description="Key insights about the data (patterns, outliers, etc.)"
|
| 51 |
+
)
|
| 52 |
+
processing_steps: list[str] = Field(
|
| 53 |
+
description="Recommended data processing steps"
|
| 54 |
+
)
|
| 55 |
+
aggregations_needed: list[str] = Field(
|
| 56 |
+
default_factory=list,
|
| 57 |
+
description="Suggested aggregations for visualization"
|
| 58 |
+
)
|
| 59 |
+
visualization_hints: list[str] = Field(
|
| 60 |
+
default_factory=list,
|
| 61 |
+
description="Hints for visualization based on data characteristics"
|
| 62 |
+
)
|
| 63 |
+
potential_issues: list[str] = Field(
|
| 64 |
+
default_factory=list,
|
| 65 |
+
description="Potential data quality issues"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DataProcessorAgent(BaseAgent[DataAnalysis]):
|
| 70 |
+
"""
|
| 71 |
+
Processes data files to extract metadata and insights.
|
| 72 |
+
|
| 73 |
+
Uses lightweight analysis to avoid token limits while providing
|
| 74 |
+
comprehensive data understanding for downstream agents.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
MEMORY_KEY = "data_analysis"
|
| 78 |
+
SUPPORTED_EXTENSIONS = {".csv", ".json", ".xlsx", ".xls", ".parquet"}
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
llm: LLMProvider,
|
| 83 |
+
memory: SharedMemory,
|
| 84 |
+
name: Optional[str] = None,
|
| 85 |
+
) -> None:
|
| 86 |
+
super().__init__(llm, memory, name or "DataProcessor")
|
| 87 |
+
|
| 88 |
+
def execute(self, context: AgentContext) -> DataAnalysis:
|
| 89 |
+
"""Override to include data extraction before LLM analysis."""
|
| 90 |
+
logger.info(f"[{self._name}] Processing {len(context.data_paths)} data files")
|
| 91 |
+
|
| 92 |
+
file_infos = []
|
| 93 |
+
for path in context.data_paths:
|
| 94 |
+
info = self._extract_file_metadata(path)
|
| 95 |
+
if info:
|
| 96 |
+
file_infos.append(info)
|
| 97 |
+
|
| 98 |
+
self._memory.store(
|
| 99 |
+
key="raw_file_info",
|
| 100 |
+
value=[f.model_dump() for f in file_infos],
|
| 101 |
+
agent_name=self._name,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return super().execute(context)
|
| 105 |
+
|
| 106 |
+
def _extract_file_metadata(self, file_path: str) -> Optional[DataFileInfo]:
|
| 107 |
+
"""Extract metadata from a data file using pandas."""
|
| 108 |
+
path = Path(file_path)
|
| 109 |
+
|
| 110 |
+
if not path.exists():
|
| 111 |
+
logger.warning(f"File not found: {path}")
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
if path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
|
| 115 |
+
logger.warning(f"Unsupported file type: {path.suffix}")
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
df = self._load_dataframe(path)
|
| 120 |
+
columns = self._analyze_columns(df)
|
| 121 |
+
|
| 122 |
+
return DataFileInfo(
|
| 123 |
+
file_path=str(path),
|
| 124 |
+
file_type=path.suffix.lower(),
|
| 125 |
+
shape=(len(df), len(df.columns)),
|
| 126 |
+
columns=columns,
|
| 127 |
+
memory_usage_mb=df.memory_usage(deep=True).sum() / (1024 * 1024),
|
| 128 |
+
)
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Failed to process {path}: {e}")
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
def _load_dataframe(self, path: Path) -> pd.DataFrame:
|
| 134 |
+
"""Load a dataframe from various file formats."""
|
| 135 |
+
suffix = path.suffix.lower()
|
| 136 |
+
|
| 137 |
+
if suffix == ".csv":
|
| 138 |
+
return pd.read_csv(path)
|
| 139 |
+
elif suffix == ".json":
|
| 140 |
+
return pd.read_json(path)
|
| 141 |
+
elif suffix in {".xlsx", ".xls"}:
|
| 142 |
+
return pd.read_excel(path)
|
| 143 |
+
elif suffix == ".parquet":
|
| 144 |
+
return pd.read_parquet(path)
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError(f"Unsupported format: {suffix}")
|
| 147 |
+
|
| 148 |
+
def _analyze_columns(self, df: pd.DataFrame) -> list[ColumnInfo]:
|
| 149 |
+
"""Analyze each column in the dataframe."""
|
| 150 |
+
columns = []
|
| 151 |
+
|
| 152 |
+
for col in df.columns:
|
| 153 |
+
series = df[col]
|
| 154 |
+
sample_values = series.dropna().head(5).tolist()
|
| 155 |
+
|
| 156 |
+
columns.append(ColumnInfo(
|
| 157 |
+
name=str(col),
|
| 158 |
+
dtype=str(series.dtype),
|
| 159 |
+
non_null_count=int(series.count()),
|
| 160 |
+
unique_count=int(series.nunique()),
|
| 161 |
+
sample_values=sample_values,
|
| 162 |
+
))
|
| 163 |
+
|
| 164 |
+
return columns
|
| 165 |
+
|
| 166 |
+
def _get_system_prompt(self) -> str:
|
| 167 |
+
return """You are a Data Analyst specialist in a data visualization team.
|
| 168 |
+
|
| 169 |
+
Your expertise is in understanding data structures, identifying patterns, and recommending processing steps for effective visualization.
|
| 170 |
+
|
| 171 |
+
Your responsibilities:
|
| 172 |
+
1. Analyze metadata to understand data characteristics
|
| 173 |
+
2. Identify insights and patterns relevant to visualization
|
| 174 |
+
3. Recommend data processing and aggregation steps
|
| 175 |
+
4. Suggest visualization approaches based on data types
|
| 176 |
+
5. Flag potential data quality issues
|
| 177 |
+
|
| 178 |
+
Always respond with a valid JSON object matching the required schema."""
|
| 179 |
+
|
| 180 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 181 |
+
file_info = self._get_from_memory("raw_file_info") or []
|
| 182 |
+
query_analysis = self._get_from_memory("query_analysis") or {}
|
| 183 |
+
|
| 184 |
+
file_summary = self._format_file_info(file_info)
|
| 185 |
+
|
| 186 |
+
query_context = ""
|
| 187 |
+
if query_analysis:
|
| 188 |
+
query_context = f"""
|
| 189 |
+
Query Analysis:
|
| 190 |
+
- Visualization Types: {query_analysis.get('visualization_types', [])}
|
| 191 |
+
- Key Points: {query_analysis.get('key_points', [])}
|
| 192 |
+
- Data Requirements: {query_analysis.get('data_requirements', [])}
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
return f"""Analyze the following data files for visualization purposes.
|
| 196 |
+
|
| 197 |
+
User Query: {context.query}
|
| 198 |
+
{query_context}
|
| 199 |
+
|
| 200 |
+
Data Files:
|
| 201 |
+
{file_summary}
|
| 202 |
+
|
| 203 |
+
Based on this metadata, provide a JSON object with these fields.
|
| 204 |
+
IMPORTANT: All list fields must contain SIMPLE STRINGS, not objects.
|
| 205 |
+
|
| 206 |
+
{{
|
| 207 |
+
"insights": ["string1", "string2", ...], // Simple string descriptions of patterns
|
| 208 |
+
"processing_steps": ["step1", "step2", ...], // Simple string descriptions of steps
|
| 209 |
+
"aggregations_needed": ["agg1", "agg2", ...], // Simple string descriptions
|
| 210 |
+
"visualization_hints": ["hint1", "hint2", ...], // Simple string hints
|
| 211 |
+
"potential_issues": ["issue1", "issue2", ...] // Simple string issues
|
| 212 |
+
}}
|
| 213 |
+
|
| 214 |
+
JSON Response:"""
|
| 215 |
+
|
| 216 |
+
def _format_file_info(self, file_info: list[dict]) -> str:
|
| 217 |
+
"""Format file information for the prompt."""
|
| 218 |
+
if not file_info:
|
| 219 |
+
return "No data files available."
|
| 220 |
+
|
| 221 |
+
lines = []
|
| 222 |
+
for f in file_info:
|
| 223 |
+
lines.append(f"\nFile: {f['file_path']}")
|
| 224 |
+
lines.append(f" Type: {f['file_type']}")
|
| 225 |
+
lines.append(f" Shape: {f['shape'][0]} rows × {f['shape'][1]} columns")
|
| 226 |
+
lines.append(" Columns:")
|
| 227 |
+
|
| 228 |
+
for col in f.get("columns", []):
|
| 229 |
+
samples = ", ".join(str(v) for v in col.get("sample_values", [])[:3])
|
| 230 |
+
lines.append(
|
| 231 |
+
f" - {col['name']} ({col['dtype']}): "
|
| 232 |
+
f"{col['unique_count']} unique, samples: [{samples}]"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return "\n".join(lines)
|
| 236 |
+
|
| 237 |
+
def _normalize_list_field(self, value: Any) -> list[str]:
|
| 238 |
+
"""Normalize a field that should be a list of strings."""
|
| 239 |
+
if value is None:
|
| 240 |
+
return []
|
| 241 |
+
|
| 242 |
+
if isinstance(value, dict):
|
| 243 |
+
return [f"{k}: {v}" for k, v in value.items()]
|
| 244 |
+
|
| 245 |
+
if isinstance(value, list):
|
| 246 |
+
result = []
|
| 247 |
+
for item in value:
|
| 248 |
+
if isinstance(item, str):
|
| 249 |
+
result.append(item)
|
| 250 |
+
elif isinstance(item, dict):
|
| 251 |
+
desc_keys = ["description", "desc", "text", "value", "step", "hint", "issue"]
|
| 252 |
+
for key in desc_keys:
|
| 253 |
+
if key in item:
|
| 254 |
+
result.append(str(item[key]))
|
| 255 |
+
break
|
| 256 |
+
else:
|
| 257 |
+
result.append(str(item))
|
| 258 |
+
else:
|
| 259 |
+
result.append(str(item))
|
| 260 |
+
return result
|
| 261 |
+
|
| 262 |
+
return [str(value)]
|
| 263 |
+
|
| 264 |
+
def _parse_response(self, response: str) -> DataAnalysis:
|
| 265 |
+
data = self._extract_json(response)
|
| 266 |
+
|
| 267 |
+
data["insights"] = self._normalize_list_field(data.get("insights"))
|
| 268 |
+
data["processing_steps"] = self._normalize_list_field(data.get("processing_steps"))
|
| 269 |
+
data["aggregations_needed"] = self._normalize_list_field(data.get("aggregations_needed"))
|
| 270 |
+
data["visualization_hints"] = self._normalize_list_field(data.get("visualization_hints"))
|
| 271 |
+
data["potential_issues"] = self._normalize_list_field(data.get("potential_issues"))
|
| 272 |
+
|
| 273 |
+
file_info = self._get_from_memory("raw_file_info") or []
|
| 274 |
+
data["files"] = file_info
|
| 275 |
+
|
| 276 |
+
return DataAnalysis(**data)
|
| 277 |
+
|
| 278 |
+
def _get_output_key(self) -> str:
|
| 279 |
+
return self.MEMORY_KEY
|
| 280 |
+
|
coda/agents/debug_agent.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Debug Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Executes generated code, diagnoses errors, and applies fixes
|
| 5 |
+
to produce working visualizations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import subprocess
|
| 11 |
+
import sys
|
| 12 |
+
import tempfile
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from pydantic import BaseModel, Field
|
| 17 |
+
|
| 18 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 19 |
+
from coda.core.llm import LLMProvider
|
| 20 |
+
from coda.core.memory import SharedMemory
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ExecutionResult(BaseModel):
|
| 26 |
+
"""Structured output from the Debug Agent."""
|
| 27 |
+
|
| 28 |
+
success: bool = Field(
|
| 29 |
+
description="Whether execution succeeded"
|
| 30 |
+
)
|
| 31 |
+
output_file: Optional[str] = Field(
|
| 32 |
+
default=None,
|
| 33 |
+
description="Path to the generated visualization"
|
| 34 |
+
)
|
| 35 |
+
stdout: str = Field(
|
| 36 |
+
default="",
|
| 37 |
+
description="Standard output from execution"
|
| 38 |
+
)
|
| 39 |
+
stderr: str = Field(
|
| 40 |
+
default="",
|
| 41 |
+
description="Error output from execution"
|
| 42 |
+
)
|
| 43 |
+
error_diagnosis: Optional[str] = Field(
|
| 44 |
+
default=None,
|
| 45 |
+
description="Diagnosis of any errors"
|
| 46 |
+
)
|
| 47 |
+
corrected_code: Optional[str] = Field(
|
| 48 |
+
default=None,
|
| 49 |
+
description="Fixed code if errors occurred"
|
| 50 |
+
)
|
| 51 |
+
fix_applied: bool = Field(
|
| 52 |
+
default=False,
|
| 53 |
+
description="Whether a fix was applied"
|
| 54 |
+
)
|
| 55 |
+
execution_time_seconds: float = Field(
|
| 56 |
+
default=0.0,
|
| 57 |
+
description="Time taken to execute"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class DebugAgent(BaseAgent[ExecutionResult]):
|
| 62 |
+
"""
|
| 63 |
+
Executes generated code and handles errors.
|
| 64 |
+
|
| 65 |
+
Runs the visualization code in a subprocess with timeout,
|
| 66 |
+
diagnoses errors, and attempts automatic fixes.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
MEMORY_KEY = "execution_result"
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
llm: LLMProvider,
|
| 74 |
+
memory: SharedMemory,
|
| 75 |
+
timeout_seconds: int = 60,
|
| 76 |
+
output_directory: str = "outputs",
|
| 77 |
+
name: Optional[str] = None,
|
| 78 |
+
) -> None:
|
| 79 |
+
super().__init__(llm, memory, name or "DebugAgent")
|
| 80 |
+
self._timeout = timeout_seconds
|
| 81 |
+
self._output_dir = Path(output_directory)
|
| 82 |
+
self._output_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
def execute(self, context: AgentContext) -> ExecutionResult:
|
| 85 |
+
"""Execute the generated code and handle errors."""
|
| 86 |
+
logger.info(f"[{self._name}] Starting code execution")
|
| 87 |
+
|
| 88 |
+
generated_code = self._get_from_memory("generated_code")
|
| 89 |
+
if not generated_code:
|
| 90 |
+
return ExecutionResult(
|
| 91 |
+
success=False,
|
| 92 |
+
stderr="No generated code found in memory",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
code = generated_code.get("code", "")
|
| 96 |
+
output_filename = generated_code.get("output_filename", "output.png")
|
| 97 |
+
|
| 98 |
+
code = self._prepare_code(code, output_filename)
|
| 99 |
+
|
| 100 |
+
result = self._execute_code(code)
|
| 101 |
+
|
| 102 |
+
if not result.success and result.stderr:
|
| 103 |
+
logger.warning(f"[{self._name}] Code execution failed: {result.stderr[:500]}")
|
| 104 |
+
logger.info(f"[{self._name}] Attempting to fix errors")
|
| 105 |
+
fixed_result = self._attempt_fix(code, result.stderr, context)
|
| 106 |
+
if fixed_result.success:
|
| 107 |
+
self._store_result(fixed_result)
|
| 108 |
+
logger.info(f"[{self._name}] Fix successful!")
|
| 109 |
+
return fixed_result
|
| 110 |
+
logger.warning(f"[{self._name}] Fix attempt failed")
|
| 111 |
+
result.error_diagnosis = fixed_result.error_diagnosis
|
| 112 |
+
result.corrected_code = fixed_result.corrected_code
|
| 113 |
+
|
| 114 |
+
self._store_result(result)
|
| 115 |
+
logger.info(f"[{self._name}] Execution complete: success={result.success}")
|
| 116 |
+
return result
|
| 117 |
+
|
| 118 |
+
def _prepare_code(self, code: str, output_filename: str) -> str:
|
| 119 |
+
"""Prepare code for execution by setting up paths."""
|
| 120 |
+
output_path = self._output_dir / output_filename
|
| 121 |
+
|
| 122 |
+
code = code.replace(
|
| 123 |
+
f"'{output_filename}'",
|
| 124 |
+
f"r'{output_path}'"
|
| 125 |
+
)
|
| 126 |
+
code = code.replace(
|
| 127 |
+
f'"{output_filename}"',
|
| 128 |
+
f"r'{output_path}'"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if "plt.savefig" not in code and "fig.savefig" not in code:
|
| 132 |
+
code += f"\nplt.savefig(r'{output_path}', dpi=150, bbox_inches='tight')\n"
|
| 133 |
+
|
| 134 |
+
return code
|
| 135 |
+
|
| 136 |
+
def _execute_code(self, code: str) -> ExecutionResult:
|
| 137 |
+
"""Execute Python code in a subprocess."""
|
| 138 |
+
import time
|
| 139 |
+
start_time = time.time()
|
| 140 |
+
|
| 141 |
+
with tempfile.NamedTemporaryFile(
|
| 142 |
+
mode="w",
|
| 143 |
+
suffix=".py",
|
| 144 |
+
delete=False,
|
| 145 |
+
encoding="utf-8"
|
| 146 |
+
) as f:
|
| 147 |
+
f.write(code)
|
| 148 |
+
temp_file = f.name
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
result = subprocess.run(
|
| 152 |
+
[sys.executable, temp_file],
|
| 153 |
+
capture_output=True,
|
| 154 |
+
text=True,
|
| 155 |
+
timeout=self._timeout,
|
| 156 |
+
cwd=str(self._output_dir.parent),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
execution_time = time.time() - start_time
|
| 160 |
+
|
| 161 |
+
output_files = list(self._output_dir.glob("*.png"))
|
| 162 |
+
output_file = str(output_files[-1]) if output_files else None
|
| 163 |
+
|
| 164 |
+
return ExecutionResult(
|
| 165 |
+
success=result.returncode == 0,
|
| 166 |
+
output_file=output_file,
|
| 167 |
+
stdout=result.stdout,
|
| 168 |
+
stderr=result.stderr,
|
| 169 |
+
execution_time_seconds=execution_time,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
except subprocess.TimeoutExpired:
|
| 173 |
+
return ExecutionResult(
|
| 174 |
+
success=False,
|
| 175 |
+
stderr=f"Execution timed out after {self._timeout} seconds",
|
| 176 |
+
)
|
| 177 |
+
except Exception as e:
|
| 178 |
+
return ExecutionResult(
|
| 179 |
+
success=False,
|
| 180 |
+
stderr=str(e),
|
| 181 |
+
)
|
| 182 |
+
finally:
|
| 183 |
+
try:
|
| 184 |
+
os.unlink(temp_file)
|
| 185 |
+
except OSError:
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
def _attempt_fix(
|
| 189 |
+
self,
|
| 190 |
+
original_code: str,
|
| 191 |
+
error_message: str,
|
| 192 |
+
context: AgentContext,
|
| 193 |
+
) -> ExecutionResult:
|
| 194 |
+
"""Attempt to fix code errors using the LLM."""
|
| 195 |
+
fix_prompt = f"""The following Python visualization code produced an error. Please fix it.
|
| 196 |
+
|
| 197 |
+
Original Code:
|
| 198 |
+
```python
|
| 199 |
+
{original_code}
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
Error Message:
|
| 203 |
+
{error_message}
|
| 204 |
+
|
| 205 |
+
Provide a JSON response with:
|
| 206 |
+
- diagnosis: What caused the error
|
| 207 |
+
- corrected_code: The fixed Python code
|
| 208 |
+
|
| 209 |
+
IMPORTANT: Return ONLY valid JSON. Do not include markdown formatting or explanations outside the JSON.
|
| 210 |
+
Safe to assume standard libraries (matplotlib, seaborn, pandas, numpy) are available.
|
| 211 |
+
|
| 212 |
+
JSON Response:"""
|
| 213 |
+
|
| 214 |
+
response = self._llm.complete(
|
| 215 |
+
prompt=fix_prompt,
|
| 216 |
+
system_prompt="You are an expert Python debugger. Fix the code error and provide corrected code.",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
data = self._extract_json(response.content)
|
| 221 |
+
diagnosis = data.get("diagnosis", "Unknown error")
|
| 222 |
+
corrected_code = data.get("corrected_code", "")
|
| 223 |
+
|
| 224 |
+
if corrected_code:
|
| 225 |
+
output_filename = "output.png"
|
| 226 |
+
corrected_code = self._prepare_code(corrected_code, output_filename)
|
| 227 |
+
result = self._execute_code(corrected_code)
|
| 228 |
+
result.error_diagnosis = diagnosis
|
| 229 |
+
result.corrected_code = corrected_code
|
| 230 |
+
result.fix_applied = result.success
|
| 231 |
+
return result
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logger.error(f"Failed to parse fix response: {e}")
|
| 235 |
+
|
| 236 |
+
return ExecutionResult(
|
| 237 |
+
success=False,
|
| 238 |
+
stderr=error_message,
|
| 239 |
+
error_diagnosis="Failed to automatically fix the error",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 243 |
+
return ""
|
| 244 |
+
|
| 245 |
+
def _get_system_prompt(self) -> str:
|
| 246 |
+
return ""
|
| 247 |
+
|
| 248 |
+
def _parse_response(self, response: str) -> ExecutionResult:
|
| 249 |
+
return ExecutionResult(success=False)
|
| 250 |
+
|
| 251 |
+
def _get_output_key(self) -> str:
|
| 252 |
+
return self.MEMORY_KEY
|
coda/agents/design_explorer.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Design Explorer Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Generates aesthetic and content specifications for visualizations,
|
| 5 |
+
optimizing for user experience and effective communication.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Any, Optional
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 14 |
+
from coda.core.llm import LLMProvider
|
| 15 |
+
from coda.core.memory import SharedMemory
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ColorScheme(BaseModel):
|
| 21 |
+
"""Color scheme specification."""
|
| 22 |
+
|
| 23 |
+
primary: str = "#4A90D9"
|
| 24 |
+
secondary: list[str] = Field(default_factory=lambda: ["#67B7DC", "#A5D6A7"])
|
| 25 |
+
background: str = "#FFFFFF"
|
| 26 |
+
text: str = "#333333"
|
| 27 |
+
accent: Optional[str] = "#FF6B6B"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LayoutSpec(BaseModel):
|
| 31 |
+
"""Layout specification."""
|
| 32 |
+
|
| 33 |
+
figure_size: tuple[int, int] = (10, 6)
|
| 34 |
+
margins: dict[str, float] = Field(default_factory=lambda: {"top": 0.1, "bottom": 0.15, "left": 0.1, "right": 0.1})
|
| 35 |
+
title_position: str = "top"
|
| 36 |
+
legend_position: str = "right"
|
| 37 |
+
grid_visible: bool = True
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DesignSpec(BaseModel):
|
| 41 |
+
"""Structured output from the Design Explorer."""
|
| 42 |
+
|
| 43 |
+
color_scheme: ColorScheme = Field(default_factory=ColorScheme)
|
| 44 |
+
layout: LayoutSpec = Field(default_factory=LayoutSpec)
|
| 45 |
+
typography: dict = Field(
|
| 46 |
+
default_factory=lambda: {"title_size": 14, "label_size": 12, "tick_size": 10, "font_family": "sans-serif"}
|
| 47 |
+
)
|
| 48 |
+
annotations: list[dict] = Field(
|
| 49 |
+
default_factory=list,
|
| 50 |
+
description="Text annotations to add"
|
| 51 |
+
)
|
| 52 |
+
implementation_guidelines: list[str] = Field(
|
| 53 |
+
default_factory=list,
|
| 54 |
+
description="Specific implementation instructions"
|
| 55 |
+
)
|
| 56 |
+
quality_metrics: dict = Field(
|
| 57 |
+
default_factory=lambda: {"readability": "high", "aesthetics": "clean", "clarity": "clear"}
|
| 58 |
+
)
|
| 59 |
+
alternatives: list[dict] = Field(
|
| 60 |
+
default_factory=list,
|
| 61 |
+
description="Alternative design approaches"
|
| 62 |
+
)
|
| 63 |
+
success_indicators: list[str] = Field(
|
| 64 |
+
default_factory=list,
|
| 65 |
+
description="How to know if the design is successful"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DesignExplorerAgent(BaseAgent[DesignSpec]):
|
| 70 |
+
"""
|
| 71 |
+
Generates aesthetic design specifications for visualizations.
|
| 72 |
+
|
| 73 |
+
Focuses on creating visually appealing and effective designs
|
| 74 |
+
that communicate data insights clearly.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
MEMORY_KEY = "design_spec"
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
llm: LLMProvider,
|
| 82 |
+
memory: SharedMemory,
|
| 83 |
+
name: Optional[str] = None,
|
| 84 |
+
) -> None:
|
| 85 |
+
super().__init__(llm, memory, name or "DesignExplorer")
|
| 86 |
+
|
| 87 |
+
def _get_system_prompt(self) -> str:
|
| 88 |
+
return """You are a Visualization Design specialist in a data visualization team.
|
| 89 |
+
|
| 90 |
+
Your expertise is in creating aesthetically pleasing and effective data visualizations that communicate insights clearly.
|
| 91 |
+
|
| 92 |
+
Your responsibilities:
|
| 93 |
+
1. Design harmonious color schemes suitable for the data
|
| 94 |
+
2. Specify optimal layouts for readability
|
| 95 |
+
3. Choose appropriate typography
|
| 96 |
+
4. Plan meaningful annotations
|
| 97 |
+
5. Define quality metrics for evaluation
|
| 98 |
+
6. Consider accessibility and best practices
|
| 99 |
+
|
| 100 |
+
Design principles to follow:
|
| 101 |
+
- Use color purposefully, not decoratively
|
| 102 |
+
- Ensure sufficient contrast for readability
|
| 103 |
+
- Maintain consistent visual hierarchy
|
| 104 |
+
- Minimize chart junk and maximize data-ink ratio
|
| 105 |
+
- Consider colorblind-friendly palettes when appropriate
|
| 106 |
+
|
| 107 |
+
Always respond with a valid JSON object matching the required schema."""
|
| 108 |
+
|
| 109 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 110 |
+
query_analysis = self._get_from_memory("query_analysis") or {}
|
| 111 |
+
data_analysis = self._get_from_memory("data_analysis") or {}
|
| 112 |
+
visual_mapping = self._get_from_memory("visual_mapping") or {}
|
| 113 |
+
|
| 114 |
+
feedback_section = ""
|
| 115 |
+
if context.feedback:
|
| 116 |
+
feedback_section = f"""
|
| 117 |
+
Design Feedback (iteration {context.iteration}):
|
| 118 |
+
{context.feedback}
|
| 119 |
+
|
| 120 |
+
Please address the feedback in your revised design.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
return f"""Create a design specification for the following visualization.
|
| 124 |
+
|
| 125 |
+
User Query: {context.query}
|
| 126 |
+
|
| 127 |
+
Visualization Type: {visual_mapping.get('chart_type', 'Unknown')}
|
| 128 |
+
Visualization Goals: {visual_mapping.get('visualization_goals', [])}
|
| 129 |
+
Styling Hints: {visual_mapping.get('styling_hints', {})}
|
| 130 |
+
{feedback_section}
|
| 131 |
+
|
| 132 |
+
Provide a JSON object containing:
|
| 133 |
+
- color_scheme: {{
|
| 134 |
+
"primary": "#hex",
|
| 135 |
+
"secondary": ["#hex", ...],
|
| 136 |
+
"background": "#hex",
|
| 137 |
+
"text": "#hex",
|
| 138 |
+
"accent": "#hex" (optional)
|
| 139 |
+
}}
|
| 140 |
+
- layout: {{
|
| 141 |
+
"figure_size": [width, height],
|
| 142 |
+
"margins": {{"top": 0.1, "bottom": 0.1, "left": 0.1, "right": 0.1}},
|
| 143 |
+
"title_position": "top|center",
|
| 144 |
+
"legend_position": "right|bottom|none",
|
| 145 |
+
"grid_visible": true|false
|
| 146 |
+
}}
|
| 147 |
+
- typography: {{
|
| 148 |
+
"title_size": 16,
|
| 149 |
+
"label_size": 12,
|
| 150 |
+
"tick_size": 10,
|
| 151 |
+
"font_family": "sans-serif"
|
| 152 |
+
}}
|
| 153 |
+
- annotations: List of {{"text": "...", "position": "...", "style": "..."}}
|
| 154 |
+
- implementation_guidelines: Specific instructions for implementation
|
| 155 |
+
- quality_metrics: {{"readability": "...", "aesthetics": "...", "clarity": "..."}}
|
| 156 |
+
- alternatives: Alternative design approaches
|
| 157 |
+
- success_indicators: How to evaluate design success
|
| 158 |
+
|
| 159 |
+
JSON Response:"""
|
| 160 |
+
|
| 161 |
+
def _normalize_to_list(self, value: Any, as_dicts: bool = False) -> list:
|
| 162 |
+
"""Normalize a value to a list."""
|
| 163 |
+
if value is None:
|
| 164 |
+
return []
|
| 165 |
+
if isinstance(value, str):
|
| 166 |
+
if as_dicts:
|
| 167 |
+
return [{"description": value}]
|
| 168 |
+
return [value] if value.strip() else []
|
| 169 |
+
if isinstance(value, list):
|
| 170 |
+
if as_dicts:
|
| 171 |
+
return [item if isinstance(item, dict) else {"description": str(item)} for item in value]
|
| 172 |
+
return [str(item) if not isinstance(item, str) else item for item in value]
|
| 173 |
+
return []
|
| 174 |
+
|
| 175 |
+
def _parse_response(self, response: str) -> DesignSpec:
|
| 176 |
+
data = self._extract_json(response)
|
| 177 |
+
|
| 178 |
+
# Normalize list fields
|
| 179 |
+
data["implementation_guidelines"] = self._normalize_to_list(data.get("implementation_guidelines"))
|
| 180 |
+
data["success_indicators"] = self._normalize_to_list(data.get("success_indicators"))
|
| 181 |
+
data["alternatives"] = self._normalize_to_list(data.get("alternatives"), as_dicts=True)
|
| 182 |
+
data["annotations"] = self._normalize_to_list(data.get("annotations"), as_dicts=True)
|
| 183 |
+
|
| 184 |
+
# Ensure nested models have valid data
|
| 185 |
+
if "color_scheme" not in data or not isinstance(data["color_scheme"], dict):
|
| 186 |
+
data["color_scheme"] = {}
|
| 187 |
+
if "layout" not in data or not isinstance(data["layout"], dict):
|
| 188 |
+
data["layout"] = {}
|
| 189 |
+
else:
|
| 190 |
+
# Sanitize figure_size to prevent crashes (e.g. LLM giving pixels instead of inches)
|
| 191 |
+
figsize = data["layout"].get("figure_size")
|
| 192 |
+
if isinstance(figsize, (list, tuple)) and len(figsize) == 2:
|
| 193 |
+
w, h = figsize
|
| 194 |
+
# If width > 50, assume pixels and scale down, or just clamp
|
| 195 |
+
if w > 50 or h > 50:
|
| 196 |
+
logger.warning(f"Extremely large figure size detected: {figsize}. Clamping to (12, 8).")
|
| 197 |
+
data["layout"]["figure_size"] = (12, 8)
|
| 198 |
+
|
| 199 |
+
if "typography" not in data or not isinstance(data["typography"], dict):
|
| 200 |
+
data["typography"] = {}
|
| 201 |
+
if "quality_metrics" not in data or not isinstance(data["quality_metrics"], dict):
|
| 202 |
+
data["quality_metrics"] = {}
|
| 203 |
+
|
| 204 |
+
return DesignSpec(**data)
|
| 205 |
+
|
| 206 |
+
def _get_output_key(self) -> str:
|
| 207 |
+
return self.MEMORY_KEY
|
coda/agents/query_analyzer.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Query Analyzer Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Interprets natural language queries to extract visualization intent,
|
| 5 |
+
decompose requirements into actionable items, and provide guidance
|
| 6 |
+
for downstream agents.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 14 |
+
from coda.core.llm import LLMProvider
|
| 15 |
+
from coda.core.memory import SharedMemory
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class QueryAnalysis(BaseModel):
|
| 19 |
+
"""Structured output from the Query Analyzer."""
|
| 20 |
+
|
| 21 |
+
visualization_types: list[str] = Field(
|
| 22 |
+
description="Suggested visualization types (e.g., line chart, bar chart)"
|
| 23 |
+
)
|
| 24 |
+
key_points: list[str] = Field(
|
| 25 |
+
description="Key data points or aspects to visualize"
|
| 26 |
+
)
|
| 27 |
+
todo_list: list[str] = Field(
|
| 28 |
+
description="Decomposed list of tasks for visualization creation"
|
| 29 |
+
)
|
| 30 |
+
data_requirements: list[str] = Field(
|
| 31 |
+
description="Required data columns or features"
|
| 32 |
+
)
|
| 33 |
+
constraints: list[str] = Field(
|
| 34 |
+
default_factory=list,
|
| 35 |
+
description="Any constraints mentioned in the query"
|
| 36 |
+
)
|
| 37 |
+
ambiguities: list[str] = Field(
|
| 38 |
+
default_factory=list,
|
| 39 |
+
description="Ambiguous aspects that may need clarification"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class QueryAnalyzerAgent(BaseAgent[QueryAnalysis]):
|
| 44 |
+
"""
|
| 45 |
+
Analyzes natural language queries to extract visualization intent.
|
| 46 |
+
|
| 47 |
+
This agent is the first in the pipeline, responsible for understanding
|
| 48 |
+
what the user wants to visualize and breaking it down into actionable steps.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
MEMORY_KEY = "query_analysis"
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
llm: LLMProvider,
|
| 56 |
+
memory: SharedMemory,
|
| 57 |
+
name: Optional[str] = None,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__(llm, memory, name or "QueryAnalyzer")
|
| 60 |
+
|
| 61 |
+
def _get_system_prompt(self) -> str:
|
| 62 |
+
return """You are a Query Analyzer specialist in a data visualization team.
|
| 63 |
+
|
| 64 |
+
Your expertise lies in interpreting natural language requests and extracting clear, actionable requirements for creating visualizations.
|
| 65 |
+
|
| 66 |
+
Your responsibilities:
|
| 67 |
+
1. Identify the type(s) of visualizations that best suit the request
|
| 68 |
+
2. Extract key data points and features to be visualized
|
| 69 |
+
3. Decompose the request into a clear TODO list for the visualization pipeline
|
| 70 |
+
4. Identify required data columns or features
|
| 71 |
+
5. Note any constraints or preferences mentioned
|
| 72 |
+
6. Flag ambiguities that might affect the visualization
|
| 73 |
+
|
| 74 |
+
Always respond with a valid JSON object matching the required schema."""
|
| 75 |
+
|
| 76 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 77 |
+
metadata = self._get_from_memory("metadata_summary")
|
| 78 |
+
metadata_section = ""
|
| 79 |
+
if metadata:
|
| 80 |
+
metadata_section = f"""
|
| 81 |
+
Available Metadata:
|
| 82 |
+
{metadata}
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
feedback_section = ""
|
| 86 |
+
if context.feedback:
|
| 87 |
+
feedback_section = f"""
|
| 88 |
+
Previous Feedback (iteration {context.iteration}):
|
| 89 |
+
{context.feedback}
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
return f"""Analyze the following visualization query and extract structured requirements.
|
| 93 |
+
|
| 94 |
+
Query: {context.query}
|
| 95 |
+
{metadata_section}{feedback_section}
|
| 96 |
+
|
| 97 |
+
Respond with a JSON object containing:
|
| 98 |
+
- visualization_types: List of suggested chart types
|
| 99 |
+
- key_points: Key aspects or data points to highlight
|
| 100 |
+
- todo_list: Step-by-step tasks for creating the visualization
|
| 101 |
+
- data_requirements: Required data columns or features
|
| 102 |
+
- constraints: Any mentioned constraints or preferences
|
| 103 |
+
- ambiguities: Unclear aspects that may need clarification
|
| 104 |
+
|
| 105 |
+
JSON Response:"""
|
| 106 |
+
|
| 107 |
+
def _parse_response(self, response: str) -> QueryAnalysis:
|
| 108 |
+
data = self._extract_json(response)
|
| 109 |
+
return QueryAnalysis(**data)
|
| 110 |
+
|
| 111 |
+
def _get_output_key(self) -> str:
|
| 112 |
+
return self.MEMORY_KEY
|
coda/agents/search_agent.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Search Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Retrieves relevant code examples and patterns from a knowledge base
|
| 5 |
+
to assist code generation with proven implementations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 13 |
+
from coda.core.llm import LLMProvider
|
| 14 |
+
from coda.core.memory import SharedMemory
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CodeExample(BaseModel):
|
| 18 |
+
"""A retrieved code example."""
|
| 19 |
+
|
| 20 |
+
title: str = ""
|
| 21 |
+
description: str = ""
|
| 22 |
+
code: str = ""
|
| 23 |
+
library: str = "matplotlib"
|
| 24 |
+
relevance_score: float = 0.5
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SearchResult(BaseModel):
|
| 28 |
+
"""Structured output from the Search Agent."""
|
| 29 |
+
|
| 30 |
+
search_queries: list[str] = Field(
|
| 31 |
+
default_factory=list,
|
| 32 |
+
description="Queries used to find examples"
|
| 33 |
+
)
|
| 34 |
+
examples: list[CodeExample] = Field(
|
| 35 |
+
default_factory=list,
|
| 36 |
+
description="Retrieved code examples"
|
| 37 |
+
)
|
| 38 |
+
recommended_libraries: list[str] = Field(
|
| 39 |
+
default_factory=list,
|
| 40 |
+
description="Recommended visualization libraries"
|
| 41 |
+
)
|
| 42 |
+
implementation_notes: list[str] = Field(
|
| 43 |
+
default_factory=list,
|
| 44 |
+
description="Notes on implementing the visualization"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Built-in code examples for common visualization patterns
|
| 49 |
+
CODE_EXAMPLES_DB: dict[str, list[dict]] = {
|
| 50 |
+
"line": [
|
| 51 |
+
{
|
| 52 |
+
"title": "Basic Line Chart",
|
| 53 |
+
"description": "Simple line chart with matplotlib",
|
| 54 |
+
"code": """import matplotlib.pyplot as plt
|
| 55 |
+
|
| 56 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 57 |
+
ax.plot(x_data, y_data, marker='o', linewidth=2)
|
| 58 |
+
ax.set_xlabel('X Label')
|
| 59 |
+
ax.set_ylabel('Y Label')
|
| 60 |
+
ax.set_title('Line Chart Title')
|
| 61 |
+
ax.grid(True, alpha=0.3)
|
| 62 |
+
plt.tight_layout()
|
| 63 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 64 |
+
"library": "matplotlib",
|
| 65 |
+
"relevance_score": 0.9,
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"title": "Multi-line Chart with Legend",
|
| 69 |
+
"description": "Multiple lines with different colors and legend",
|
| 70 |
+
"code": """import matplotlib.pyplot as plt
|
| 71 |
+
|
| 72 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 73 |
+
for label, data in grouped_data.items():
|
| 74 |
+
ax.plot(data['x'], data['y'], label=label, marker='o')
|
| 75 |
+
ax.set_xlabel('X Label')
|
| 76 |
+
ax.set_ylabel('Y Label')
|
| 77 |
+
ax.set_title('Multi-line Chart')
|
| 78 |
+
ax.legend(loc='best')
|
| 79 |
+
ax.grid(True, alpha=0.3)
|
| 80 |
+
plt.tight_layout()
|
| 81 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 82 |
+
"library": "matplotlib",
|
| 83 |
+
"relevance_score": 0.85,
|
| 84 |
+
},
|
| 85 |
+
],
|
| 86 |
+
"bar": [
|
| 87 |
+
{
|
| 88 |
+
"title": "Basic Bar Chart",
|
| 89 |
+
"description": "Vertical bar chart with matplotlib",
|
| 90 |
+
"code": """import matplotlib.pyplot as plt
|
| 91 |
+
|
| 92 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 93 |
+
bars = ax.bar(categories, values, color='steelblue', edgecolor='black')
|
| 94 |
+
ax.set_xlabel('Category')
|
| 95 |
+
ax.set_ylabel('Value')
|
| 96 |
+
ax.set_title('Bar Chart Title')
|
| 97 |
+
ax.bar_label(bars, fmt='%.1f')
|
| 98 |
+
plt.xticks(rotation=45, ha='right')
|
| 99 |
+
plt.tight_layout()
|
| 100 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 101 |
+
"library": "matplotlib",
|
| 102 |
+
"relevance_score": 0.9,
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"title": "Grouped Bar Chart",
|
| 106 |
+
"description": "Side-by-side bars for comparison",
|
| 107 |
+
"code": """import matplotlib.pyplot as plt
|
| 108 |
+
import numpy as np
|
| 109 |
+
|
| 110 |
+
x = np.arange(len(categories))
|
| 111 |
+
width = 0.35
|
| 112 |
+
|
| 113 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 114 |
+
bars1 = ax.bar(x - width/2, values1, width, label='Group 1')
|
| 115 |
+
bars2 = ax.bar(x + width/2, values2, width, label='Group 2')
|
| 116 |
+
ax.set_xlabel('Category')
|
| 117 |
+
ax.set_ylabel('Value')
|
| 118 |
+
ax.set_title('Grouped Bar Chart')
|
| 119 |
+
ax.set_xticks(x)
|
| 120 |
+
ax.set_xticklabels(categories)
|
| 121 |
+
ax.legend()
|
| 122 |
+
plt.tight_layout()
|
| 123 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 124 |
+
"library": "matplotlib",
|
| 125 |
+
"relevance_score": 0.85,
|
| 126 |
+
},
|
| 127 |
+
],
|
| 128 |
+
"scatter": [
|
| 129 |
+
{
|
| 130 |
+
"title": "Basic Scatter Plot",
|
| 131 |
+
"description": "Scatter plot with optional color encoding",
|
| 132 |
+
"code": """import matplotlib.pyplot as plt
|
| 133 |
+
|
| 134 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 135 |
+
scatter = ax.scatter(x_data, y_data, c=color_data, s=50, alpha=0.7, cmap='viridis')
|
| 136 |
+
ax.set_xlabel('X Label')
|
| 137 |
+
ax.set_ylabel('Y Label')
|
| 138 |
+
ax.set_title('Scatter Plot Title')
|
| 139 |
+
plt.colorbar(scatter, label='Color Label')
|
| 140 |
+
plt.tight_layout()
|
| 141 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 142 |
+
"library": "matplotlib",
|
| 143 |
+
"relevance_score": 0.9,
|
| 144 |
+
},
|
| 145 |
+
],
|
| 146 |
+
"pie": [
|
| 147 |
+
{
|
| 148 |
+
"title": "Pie Chart",
|
| 149 |
+
"description": "Pie chart with percentages",
|
| 150 |
+
"code": """import matplotlib.pyplot as plt
|
| 151 |
+
|
| 152 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 153 |
+
wedges, texts, autotexts = ax.pie(
|
| 154 |
+
values, labels=labels, autopct='%1.1f%%',
|
| 155 |
+
startangle=90, colors=plt.cm.Pastel1.colors
|
| 156 |
+
)
|
| 157 |
+
ax.set_title('Pie Chart Title')
|
| 158 |
+
plt.tight_layout()
|
| 159 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 160 |
+
"library": "matplotlib",
|
| 161 |
+
"relevance_score": 0.9,
|
| 162 |
+
},
|
| 163 |
+
],
|
| 164 |
+
"heatmap": [
|
| 165 |
+
{
|
| 166 |
+
"title": "Heatmap with Seaborn",
|
| 167 |
+
"description": "Correlation or matrix heatmap",
|
| 168 |
+
"code": """import matplotlib.pyplot as plt
|
| 169 |
+
import seaborn as sns
|
| 170 |
+
|
| 171 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 172 |
+
sns.heatmap(data_matrix, annot=True, fmt='.2f', cmap='coolwarm', ax=ax)
|
| 173 |
+
ax.set_title('Heatmap Title')
|
| 174 |
+
plt.tight_layout()
|
| 175 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 176 |
+
"library": "seaborn",
|
| 177 |
+
"relevance_score": 0.9,
|
| 178 |
+
},
|
| 179 |
+
],
|
| 180 |
+
"histogram": [
|
| 181 |
+
{
|
| 182 |
+
"title": "Histogram",
|
| 183 |
+
"description": "Distribution histogram with optional KDE",
|
| 184 |
+
"code": """import matplotlib.pyplot as plt
|
| 185 |
+
import seaborn as sns
|
| 186 |
+
|
| 187 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 188 |
+
sns.histplot(data, kde=True, ax=ax, color='steelblue')
|
| 189 |
+
ax.set_xlabel('Value')
|
| 190 |
+
ax.set_ylabel('Frequency')
|
| 191 |
+
ax.set_title('Histogram Title')
|
| 192 |
+
plt.tight_layout()
|
| 193 |
+
plt.savefig('output.png', dpi=150, bbox_inches='tight')""",
|
| 194 |
+
"library": "seaborn",
|
| 195 |
+
"relevance_score": 0.9,
|
| 196 |
+
},
|
| 197 |
+
],
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class SearchAgent(BaseAgent[SearchResult]):
|
| 202 |
+
"""
|
| 203 |
+
Searches for relevant code examples to guide code generation.
|
| 204 |
+
|
| 205 |
+
Uses a built-in knowledge base of visualization patterns
|
| 206 |
+
and can be extended to search external resources.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
MEMORY_KEY = "search_results"
|
| 210 |
+
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
llm: LLMProvider,
|
| 214 |
+
memory: SharedMemory,
|
| 215 |
+
name: Optional[str] = None,
|
| 216 |
+
) -> None:
|
| 217 |
+
super().__init__(llm, memory, name or "SearchAgent")
|
| 218 |
+
|
| 219 |
+
def _get_system_prompt(self) -> str:
|
| 220 |
+
return """You are a Code Search specialist in a data visualization team.
|
| 221 |
+
|
| 222 |
+
Your expertise is in finding and recommending relevant code examples and patterns for visualization implementation.
|
| 223 |
+
|
| 224 |
+
Your responsibilities:
|
| 225 |
+
1. Formulate effective search queries based on requirements
|
| 226 |
+
2. Select the most relevant examples from available patterns
|
| 227 |
+
3. Recommend appropriate libraries for the task
|
| 228 |
+
4. Provide implementation guidance
|
| 229 |
+
|
| 230 |
+
Consider the specific chart type, data characteristics, and styling requirements when selecting examples.
|
| 231 |
+
|
| 232 |
+
Always respond with a valid JSON object matching the required schema."""
|
| 233 |
+
|
| 234 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 235 |
+
visual_mapping = self._get_from_memory("visual_mapping") or {}
|
| 236 |
+
query_analysis = self._get_from_memory("query_analysis") or {}
|
| 237 |
+
|
| 238 |
+
chart_type = visual_mapping.get("chart_type", "")
|
| 239 |
+
chart_subtype = visual_mapping.get("chart_subtype", "")
|
| 240 |
+
styling_hints = visual_mapping.get("styling_hints", {})
|
| 241 |
+
|
| 242 |
+
available_examples = list(CODE_EXAMPLES_DB.keys())
|
| 243 |
+
|
| 244 |
+
return f"""Find relevant code examples for the following visualization requirements.
|
| 245 |
+
|
| 246 |
+
User Query: {context.query}
|
| 247 |
+
|
| 248 |
+
Visualization Mapping:
|
| 249 |
+
- Chart Type: {chart_type}
|
| 250 |
+
- Chart Subtype: {chart_subtype}
|
| 251 |
+
- Styling: {styling_hints}
|
| 252 |
+
- Goals: {visual_mapping.get('visualization_goals', [])}
|
| 253 |
+
|
| 254 |
+
Available Example Categories: {available_examples}
|
| 255 |
+
|
| 256 |
+
Provide a JSON object containing:
|
| 257 |
+
- search_queries: List of search queries you would use
|
| 258 |
+
- examples: List of relevant examples (select from available categories)
|
| 259 |
+
- recommended_libraries: Libraries best suited for this visualization
|
| 260 |
+
- implementation_notes: Tips for implementing this specific visualization
|
| 261 |
+
|
| 262 |
+
For examples, include:
|
| 263 |
+
- title, description, code (from your knowledge of matplotlib/seaborn)
|
| 264 |
+
- library used
|
| 265 |
+
- relevance_score (0.0 to 1.0)
|
| 266 |
+
|
| 267 |
+
JSON Response:"""
|
| 268 |
+
|
| 269 |
+
def _parse_response(self, response: str) -> SearchResult:
|
| 270 |
+
data = self._extract_json(response)
|
| 271 |
+
|
| 272 |
+
# Normalize examples to ensure required fields have defaults
|
| 273 |
+
if "examples" in data and isinstance(data["examples"], list):
|
| 274 |
+
normalized_examples = []
|
| 275 |
+
for ex in data["examples"]:
|
| 276 |
+
if isinstance(ex, dict):
|
| 277 |
+
ex.setdefault("library", "matplotlib")
|
| 278 |
+
ex.setdefault("title", "")
|
| 279 |
+
ex.setdefault("description", "")
|
| 280 |
+
ex.setdefault("code", "")
|
| 281 |
+
ex.setdefault("relevance_score", 0.5)
|
| 282 |
+
normalized_examples.append(ex)
|
| 283 |
+
data["examples"] = normalized_examples
|
| 284 |
+
|
| 285 |
+
visual_mapping = self._get_from_memory("visual_mapping") or {}
|
| 286 |
+
chart_type = visual_mapping.get("chart_type", "line").lower()
|
| 287 |
+
|
| 288 |
+
# Use built-in examples if LLM didn't provide any
|
| 289 |
+
if chart_type in CODE_EXAMPLES_DB and not data.get("examples"):
|
| 290 |
+
data["examples"] = CODE_EXAMPLES_DB[chart_type]
|
| 291 |
+
|
| 292 |
+
return SearchResult(**data)
|
| 293 |
+
|
| 294 |
+
def _get_output_key(self) -> str:
|
| 295 |
+
return self.MEMORY_KEY
|
coda/agents/visual_evaluator.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visual Evaluator Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Assesses generated visualizations across multiple quality dimensions
|
| 5 |
+
using multimodal LLM capabilities to analyze the output image.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 15 |
+
from coda.core.llm import LLMProvider
|
| 16 |
+
from coda.core.memory import SharedMemory
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class QualityScores(BaseModel):
|
| 22 |
+
"""Quality scores for different dimensions."""
|
| 23 |
+
|
| 24 |
+
overall: float = Field(ge=0, le=10, description="Overall quality score")
|
| 25 |
+
readability: float = Field(ge=0, le=10, description="How easy to read and understand")
|
| 26 |
+
accuracy: float = Field(ge=0, le=10, description="How accurately it represents the data")
|
| 27 |
+
aesthetics: float = Field(ge=0, le=10, description="Visual appeal and design quality")
|
| 28 |
+
layout: float = Field(ge=0, le=10, description="Layout and spacing quality")
|
| 29 |
+
correctness: float = Field(ge=0, le=10, description="Technical correctness")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class VisualEvaluation(BaseModel):
|
| 33 |
+
"""Structured output from the Visual Evaluator."""
|
| 34 |
+
|
| 35 |
+
scores: QualityScores = Field(default_factory=lambda: QualityScores(
|
| 36 |
+
overall=5.0, readability=5.0, accuracy=5.0, aesthetics=5.0, layout=5.0, correctness=5.0
|
| 37 |
+
))
|
| 38 |
+
strengths: list[str] = Field(default_factory=list)
|
| 39 |
+
issues: list[str] = Field(default_factory=list)
|
| 40 |
+
priority_fixes: list[str] = Field(default_factory=list)
|
| 41 |
+
todo_completion: dict[str, bool] = Field(default_factory=dict)
|
| 42 |
+
recommendations: list[str] = Field(default_factory=list)
|
| 43 |
+
passes_threshold: bool = Field(default=False)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class VisualEvaluatorAgent(BaseAgent[VisualEvaluation]):
|
| 47 |
+
"""
|
| 48 |
+
Evaluates visualization quality using multimodal analysis.
|
| 49 |
+
|
| 50 |
+
Analyzes the output image against the original requirements
|
| 51 |
+
and provides detailed feedback for iterative refinement.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
MEMORY_KEY = "visual_evaluation"
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
llm: LLMProvider,
|
| 59 |
+
memory: SharedMemory,
|
| 60 |
+
min_overall_score: float = 7.0,
|
| 61 |
+
name: Optional[str] = None,
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__(llm, memory, name or "VisualEvaluator")
|
| 64 |
+
self._min_score = min_overall_score
|
| 65 |
+
|
| 66 |
+
def execute(self, context: AgentContext) -> VisualEvaluation:
|
| 67 |
+
"""Execute visual evaluation using the vision model."""
|
| 68 |
+
logger.info(f"[{self._name}] Evaluating visualization quality")
|
| 69 |
+
|
| 70 |
+
execution_result = self._get_from_memory("execution_result")
|
| 71 |
+
if not execution_result or not execution_result.get("success"):
|
| 72 |
+
return VisualEvaluation(
|
| 73 |
+
scores=QualityScores(
|
| 74 |
+
overall=0, readability=0, accuracy=0,
|
| 75 |
+
aesthetics=0, layout=0, correctness=0
|
| 76 |
+
),
|
| 77 |
+
strengths=[],
|
| 78 |
+
issues=["Visualization generation failed"],
|
| 79 |
+
priority_fixes=["Fix code execution errors"],
|
| 80 |
+
todo_completion={},
|
| 81 |
+
recommendations=["Debug and fix code errors first"],
|
| 82 |
+
passes_threshold=False,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
output_file = execution_result.get("output_file")
|
| 86 |
+
if not output_file or not Path(output_file).exists():
|
| 87 |
+
return VisualEvaluation(
|
| 88 |
+
scores=QualityScores(
|
| 89 |
+
overall=0, readability=0, accuracy=0,
|
| 90 |
+
aesthetics=0, layout=0, correctness=0
|
| 91 |
+
),
|
| 92 |
+
strengths=[],
|
| 93 |
+
issues=["Output file not found"],
|
| 94 |
+
priority_fixes=["Ensure code saves output correctly"],
|
| 95 |
+
todo_completion={},
|
| 96 |
+
recommendations=["Check savefig call in code"],
|
| 97 |
+
passes_threshold=False,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
prompt = self._build_evaluation_prompt(context)
|
| 101 |
+
system_prompt = self._get_system_prompt()
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
response = self._llm.complete_with_image(
|
| 105 |
+
prompt=prompt,
|
| 106 |
+
image_path=output_file,
|
| 107 |
+
system_prompt=system_prompt,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
result = self._parse_response(response.content)
|
| 111 |
+
self._store_result(result)
|
| 112 |
+
|
| 113 |
+
logger.info(
|
| 114 |
+
f"[{self._name}] Evaluation complete: "
|
| 115 |
+
f"overall={result.scores.overall}, passes={result.passes_threshold}"
|
| 116 |
+
)
|
| 117 |
+
return result
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error(f"[{self._name}] Evaluation failed: {e}")
|
| 120 |
+
# Return a fallback evaluation instead of crashing
|
| 121 |
+
fallback = VisualEvaluation(
|
| 122 |
+
scores=QualityScores(
|
| 123 |
+
overall=5.0, readability=5.0, accuracy=5.0,
|
| 124 |
+
aesthetics=5.0, layout=5.0, correctness=5.0
|
| 125 |
+
),
|
| 126 |
+
strengths=["Backup evaluation (parsing failed)"],
|
| 127 |
+
issues=[f"Evaluation parsing error: {str(e)}"],
|
| 128 |
+
priority_fixes=[],
|
| 129 |
+
todo_completion={},
|
| 130 |
+
recommendations=[],
|
| 131 |
+
passes_threshold=False
|
| 132 |
+
)
|
| 133 |
+
self._store_result(fallback)
|
| 134 |
+
return fallback
|
| 135 |
+
|
| 136 |
+
def _get_system_prompt(self) -> str:
|
| 137 |
+
return """You are a Visualization Quality Evaluator specialist.
|
| 138 |
+
|
| 139 |
+
Your expertise is in assessing data visualizations for quality, effectiveness, and adherence to best practices.
|
| 140 |
+
|
| 141 |
+
Evaluate visualizations on these dimensions:
|
| 142 |
+
1. Readability: Clear labels, appropriate font sizes, uncluttered design
|
| 143 |
+
2. Accuracy: Correct representation of data, appropriate scales
|
| 144 |
+
3. Aesthetics: Visual appeal, harmonious colors, professional appearance
|
| 145 |
+
4. Layout: Good use of space, proper alignment, balanced composition
|
| 146 |
+
5. Correctness: Technically correct chart type, proper axis handling
|
| 147 |
+
|
| 148 |
+
Be rigorous but fair in your assessment. Provide specific, actionable feedback.
|
| 149 |
+
|
| 150 |
+
Always respond with a valid JSON object matching the required schema."""
|
| 151 |
+
|
| 152 |
+
def _build_evaluation_prompt(self, context: AgentContext) -> str:
|
| 153 |
+
query_analysis = self._get_from_memory("query_analysis") or {}
|
| 154 |
+
visual_mapping = self._get_from_memory("visual_mapping") or {}
|
| 155 |
+
design_spec = self._get_from_memory("design_spec") or {}
|
| 156 |
+
|
| 157 |
+
todo_list = query_analysis.get("todo_list", [])
|
| 158 |
+
|
| 159 |
+
return f"""Evaluate this visualization against the original requirements.
|
| 160 |
+
|
| 161 |
+
Original Query: {context.query}
|
| 162 |
+
|
| 163 |
+
Requirements:
|
| 164 |
+
- Visualization Type: {visual_mapping.get('chart_type', 'Unknown')}
|
| 165 |
+
- Goals: {visual_mapping.get('visualization_goals', [])}
|
| 166 |
+
- TODO Items: {todo_list}
|
| 167 |
+
|
| 168 |
+
Design Specifications:
|
| 169 |
+
- Color Scheme: {design_spec.get('color_scheme', {})}
|
| 170 |
+
- Success Indicators: {design_spec.get('success_indicators', [])}
|
| 171 |
+
|
| 172 |
+
Evaluate the visualization image and provide a JSON response with:
|
| 173 |
+
- scores: {{
|
| 174 |
+
"overall": 0-10,
|
| 175 |
+
"readability": 0-10,
|
| 176 |
+
"accuracy": 0-10,
|
| 177 |
+
"aesthetics": 0-10,
|
| 178 |
+
"layout": 0-10,
|
| 179 |
+
"correctness": 0-10
|
| 180 |
+
}}
|
| 181 |
+
- strengths: List of positive aspects
|
| 182 |
+
- issues: List of problems found
|
| 183 |
+
- priority_fixes: Most important fixes (max 3)
|
| 184 |
+
- todo_completion: {{"todo_item": true/false}} for each TODO
|
| 185 |
+
- recommendations: Improvement suggestions
|
| 186 |
+
- passes_threshold: true if overall >= {self._min_score}
|
| 187 |
+
|
| 188 |
+
JSON Response:"""
|
| 189 |
+
|
| 190 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 191 |
+
return self._build_evaluation_prompt(context)
|
| 192 |
+
|
| 193 |
+
def _parse_response(self, response: str) -> VisualEvaluation:
|
| 194 |
+
data = self._extract_json(response)
|
| 195 |
+
|
| 196 |
+
# Ensure scores exists and is properly formatted
|
| 197 |
+
scores_data = data.get("scores", {})
|
| 198 |
+
if isinstance(scores_data, dict):
|
| 199 |
+
# Ensure all required fields have defaults
|
| 200 |
+
scores_data.setdefault("overall", 5.0)
|
| 201 |
+
scores_data.setdefault("readability", 5.0)
|
| 202 |
+
scores_data.setdefault("accuracy", 5.0)
|
| 203 |
+
scores_data.setdefault("aesthetics", 5.0)
|
| 204 |
+
scores_data.setdefault("layout", 5.0)
|
| 205 |
+
scores_data.setdefault("correctness", 5.0)
|
| 206 |
+
data["scores"] = QualityScores(**scores_data)
|
| 207 |
+
|
| 208 |
+
# Ensure list fields are lists
|
| 209 |
+
for field in ["strengths", "issues", "priority_fixes", "recommendations"]:
|
| 210 |
+
if field not in data or not isinstance(data[field], list):
|
| 211 |
+
data[field] = [data[field]] if isinstance(data.get(field), str) else []
|
| 212 |
+
|
| 213 |
+
# Ensure todo_completion is a dict
|
| 214 |
+
if not isinstance(data.get("todo_completion"), dict):
|
| 215 |
+
data["todo_completion"] = {}
|
| 216 |
+
|
| 217 |
+
# Calculate passes_threshold if not provided
|
| 218 |
+
if "passes_threshold" not in data:
|
| 219 |
+
overall = data.get("scores")
|
| 220 |
+
if isinstance(overall, QualityScores):
|
| 221 |
+
data["passes_threshold"] = overall.overall >= self._min_score
|
| 222 |
+
else:
|
| 223 |
+
data["passes_threshold"] = False
|
| 224 |
+
|
| 225 |
+
return VisualEvaluation(**data)
|
| 226 |
+
|
| 227 |
+
def _get_output_key(self) -> str:
|
| 228 |
+
return self.MEMORY_KEY
|
coda/agents/viz_mapping.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VizMapping Agent for CoDA.
|
| 3 |
+
|
| 4 |
+
Maps query semantics and data characteristics to visualization primitives,
|
| 5 |
+
selecting appropriate chart types and defining data-to-visual bindings.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
from coda.core.base_agent import AgentContext, BaseAgent
|
| 13 |
+
from coda.core.llm import LLMProvider
|
| 14 |
+
from coda.core.memory import SharedMemory
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DataTransformation(BaseModel):
|
| 18 |
+
"""A data transformation step."""
|
| 19 |
+
|
| 20 |
+
operation: str
|
| 21 |
+
columns: list[str]
|
| 22 |
+
parameters: dict
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VisualMapping(BaseModel):
|
| 26 |
+
"""Structured output from the VizMapping Agent."""
|
| 27 |
+
|
| 28 |
+
chart_type: str = Field(
|
| 29 |
+
description="Primary chart type (e.g., line, bar, scatter)"
|
| 30 |
+
)
|
| 31 |
+
chart_subtype: Optional[str] = Field(
|
| 32 |
+
default=None,
|
| 33 |
+
description="Chart subtype if applicable (e.g., stacked, grouped)"
|
| 34 |
+
)
|
| 35 |
+
x_axis: Optional[dict] = Field(
|
| 36 |
+
default=None,
|
| 37 |
+
description="X-axis configuration (column, label, type)"
|
| 38 |
+
)
|
| 39 |
+
y_axis: Optional[dict] = Field(
|
| 40 |
+
default=None,
|
| 41 |
+
description="Y-axis configuration (column, label, type)"
|
| 42 |
+
)
|
| 43 |
+
color_encoding: Optional[dict] = Field(
|
| 44 |
+
default=None,
|
| 45 |
+
description="Color encoding configuration"
|
| 46 |
+
)
|
| 47 |
+
size_encoding: Optional[dict] = Field(
|
| 48 |
+
default=None,
|
| 49 |
+
description="Size encoding for scatter plots"
|
| 50 |
+
)
|
| 51 |
+
transformations: list[dict] = Field(
|
| 52 |
+
default_factory=list,
|
| 53 |
+
description="Data transformations to apply"
|
| 54 |
+
)
|
| 55 |
+
styling_hints: dict = Field(
|
| 56 |
+
default_factory=dict,
|
| 57 |
+
description="Visual styling recommendations"
|
| 58 |
+
)
|
| 59 |
+
visualization_goals: list[str] = Field(
|
| 60 |
+
description="High-level goals for the visualization"
|
| 61 |
+
)
|
| 62 |
+
rationale: str = Field(
|
| 63 |
+
description="Explanation for the chosen visualization approach"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class VizMappingAgent(BaseAgent[VisualMapping]):
|
| 68 |
+
"""
|
| 69 |
+
Maps query semantics to visualization specifications.
|
| 70 |
+
|
| 71 |
+
Bridges the gap between data analysis and code generation by
|
| 72 |
+
defining exactly how data should be visualized.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
MEMORY_KEY = "visual_mapping"
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
llm: LLMProvider,
|
| 80 |
+
memory: SharedMemory,
|
| 81 |
+
name: Optional[str] = None,
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(llm, memory, name or "VizMapping")
|
| 84 |
+
|
| 85 |
+
def _get_system_prompt(self) -> str:
|
| 86 |
+
return """You are a Visualization Mapping specialist in a data visualization team.
|
| 87 |
+
|
| 88 |
+
Your expertise is in translating data analysis requirements into concrete visualization specifications that can be implemented in code.
|
| 89 |
+
|
| 90 |
+
Your responsibilities:
|
| 91 |
+
1. Select the optimal chart type based on data and query requirements
|
| 92 |
+
2. Define data-to-visual mappings (axes, colors, sizes)
|
| 93 |
+
3. Specify required data transformations
|
| 94 |
+
4. Provide styling hints for aesthetics
|
| 95 |
+
5. Document the rationale for visualization choices
|
| 96 |
+
|
| 97 |
+
Consider:
|
| 98 |
+
- Data types when choosing encodings (categorical vs numerical)
|
| 99 |
+
- Query intent when selecting chart types
|
| 100 |
+
- Readability and best practices in visualization design
|
| 101 |
+
|
| 102 |
+
Always respond with a valid JSON object matching the required schema."""
|
| 103 |
+
|
| 104 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 105 |
+
query_analysis = self._get_from_memory("query_analysis") or {}
|
| 106 |
+
data_analysis = self._get_from_memory("data_analysis") or {}
|
| 107 |
+
|
| 108 |
+
query_section = ""
|
| 109 |
+
if query_analysis:
|
| 110 |
+
query_section = f"""
|
| 111 |
+
Query Analysis:
|
| 112 |
+
- Suggested Types: {query_analysis.get('visualization_types', [])}
|
| 113 |
+
- Key Points: {query_analysis.get('key_points', [])}
|
| 114 |
+
- Data Requirements: {query_analysis.get('data_requirements', [])}
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
data_section = ""
|
| 118 |
+
if data_analysis:
|
| 119 |
+
files = data_analysis.get('files', [])
|
| 120 |
+
if files:
|
| 121 |
+
columns_info = []
|
| 122 |
+
for f in files:
|
| 123 |
+
for col in f.get('columns', []):
|
| 124 |
+
columns_info.append(f" - {col['name']} ({col['dtype']})")
|
| 125 |
+
data_section = f"""
|
| 126 |
+
Available Data:
|
| 127 |
+
- Columns:
|
| 128 |
+
{chr(10).join(columns_info)}
|
| 129 |
+
- Insights: {data_analysis.get('insights', [])}
|
| 130 |
+
- Suggested Aggregations: {data_analysis.get('aggregations_needed', [])}
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
feedback_section = ""
|
| 134 |
+
if context.feedback:
|
| 135 |
+
feedback_section = f"""
|
| 136 |
+
Refinement Feedback (iteration {context.iteration}):
|
| 137 |
+
{context.feedback}
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
return f"""Create a visualization mapping for the following query.
|
| 141 |
+
|
| 142 |
+
User Query: {context.query}
|
| 143 |
+
{query_section}{data_section}{feedback_section}
|
| 144 |
+
|
| 145 |
+
Provide a JSON object containing:
|
| 146 |
+
- chart_type: Primary chart type (line, bar, scatter, pie, heatmap, etc.)
|
| 147 |
+
- chart_subtype: Optional subtype (stacked, grouped, etc.)
|
| 148 |
+
- x_axis: {{"column": "...", "label": "...", "type": "categorical|numerical|temporal"}}
|
| 149 |
+
- y_axis: {{"column": "...", "label": "...", "type": "numerical"}}
|
| 150 |
+
- color_encoding: Optional color mapping {{"column": "...", "palette": "..."}}
|
| 151 |
+
- size_encoding: Optional size mapping for scatter {{"column": "..."}}
|
| 152 |
+
- transformations: List of {{"operation": "...", "columns": [...], "parameters": {{}}}}
|
| 153 |
+
- styling_hints: {{"theme": "...", "annotations": [...], "legend_position": "..."}}
|
| 154 |
+
- visualization_goals: List of high-level goals
|
| 155 |
+
- rationale: Brief explanation of choices
|
| 156 |
+
|
| 157 |
+
JSON Response:"""
|
| 158 |
+
|
| 159 |
+
def _parse_response(self, response: str) -> VisualMapping:
|
| 160 |
+
data = self._extract_json(response)
|
| 161 |
+
return VisualMapping(**data)
|
| 162 |
+
|
| 163 |
+
def _get_output_key(self) -> str:
|
| 164 |
+
return self.MEMORY_KEY
|
coda/config.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration management for CoDA.
|
| 3 |
+
|
| 4 |
+
Centralizes all configuration values including API keys, model settings,
|
| 5 |
+
quality thresholds, and execution parameters.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class ModelConfig:
|
| 19 |
+
"""Configuration for LLM models."""
|
| 20 |
+
|
| 21 |
+
default_model: str = "llama-3.3-70b-versatile"
|
| 22 |
+
vision_model: str = "meta-llama/llama-4-maverick-17b-128e-instruct"
|
| 23 |
+
temperature: float = 0.7
|
| 24 |
+
max_tokens: int = 4096
|
| 25 |
+
max_retries: int = 3
|
| 26 |
+
retry_delay: float = 1.0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class QualityThresholds:
|
| 31 |
+
"""Quality score thresholds for the feedback loop."""
|
| 32 |
+
|
| 33 |
+
minimum_overall_score: float = 7.0
|
| 34 |
+
minimum_readability_score: float = 6.0
|
| 35 |
+
minimum_accuracy_score: float = 7.0
|
| 36 |
+
minimum_aesthetics_score: float = 6.0
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass(frozen=True)
|
| 40 |
+
class ExecutionConfig:
|
| 41 |
+
"""Configuration for code execution."""
|
| 42 |
+
|
| 43 |
+
code_timeout_seconds: int = 60
|
| 44 |
+
max_refinement_iterations: int = 3
|
| 45 |
+
output_directory: str = "outputs"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class Config:
|
| 50 |
+
"""Main configuration container for CoDA."""
|
| 51 |
+
|
| 52 |
+
groq_api_key: str = field(default_factory=lambda: os.getenv("GROQ_API_KEY", ""))
|
| 53 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
| 54 |
+
quality: QualityThresholds = field(default_factory=QualityThresholds)
|
| 55 |
+
execution: ExecutionConfig = field(default_factory=ExecutionConfig)
|
| 56 |
+
|
| 57 |
+
def __post_init__(self) -> None:
|
| 58 |
+
if not self.groq_api_key:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
"GROQ_API_KEY environment variable is required. "
|
| 61 |
+
"Get your API key at https://console.groq.com"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def from_env(cls) -> "Config":
|
| 66 |
+
"""Create configuration from environment variables."""
|
| 67 |
+
return cls(
|
| 68 |
+
groq_api_key=os.getenv("GROQ_API_KEY", ""),
|
| 69 |
+
model=ModelConfig(
|
| 70 |
+
default_model=os.getenv("CODA_DEFAULT_MODEL", "llama-3.3-70b-versatile"),
|
| 71 |
+
vision_model=os.getenv("CODA_VISION_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct"),
|
| 72 |
+
temperature=float(os.getenv("CODA_TEMPERATURE", "0.7")),
|
| 73 |
+
max_tokens=int(os.getenv("CODA_MAX_TOKENS", "4096")),
|
| 74 |
+
),
|
| 75 |
+
quality=QualityThresholds(
|
| 76 |
+
minimum_overall_score=float(os.getenv("CODA_MIN_OVERALL_SCORE", "7.0")),
|
| 77 |
+
minimum_readability_score=float(os.getenv("CODA_MIN_READABILITY_SCORE", "6.0")),
|
| 78 |
+
minimum_accuracy_score=float(os.getenv("CODA_MIN_ACCURACY_SCORE", "7.0")),
|
| 79 |
+
minimum_aesthetics_score=float(os.getenv("CODA_MIN_AESTHETICS_SCORE", "6.0")),
|
| 80 |
+
),
|
| 81 |
+
execution=ExecutionConfig(
|
| 82 |
+
code_timeout_seconds=int(os.getenv("CODA_CODE_TIMEOUT", "60")),
|
| 83 |
+
max_refinement_iterations=int(os.getenv("CODA_MAX_ITERATIONS", "3")),
|
| 84 |
+
output_directory=os.getenv("CODA_OUTPUT_DIR", "outputs"),
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_config() -> Config:
|
| 90 |
+
"""Get the application configuration."""
|
| 91 |
+
return Config.from_env()
|
coda/core/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core module for CoDA - contains LLM, memory, and base agent abstractions."""
|
| 2 |
+
|
| 3 |
+
from coda.core.llm import LLMProvider, GroqLLM, LLMResponse
|
| 4 |
+
from coda.core.memory import SharedMemory, MemoryEntry
|
| 5 |
+
from coda.core.base_agent import BaseAgent, AgentContext
|
| 6 |
+
from coda.core.agent_factory import AgentFactory
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"LLMProvider",
|
| 10 |
+
"GroqLLM",
|
| 11 |
+
"LLMResponse",
|
| 12 |
+
"SharedMemory",
|
| 13 |
+
"MemoryEntry",
|
| 14 |
+
"BaseAgent",
|
| 15 |
+
"AgentContext",
|
| 16 |
+
"AgentFactory",
|
| 17 |
+
]
|
coda/core/agent_factory.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Factory for CoDA.
|
| 3 |
+
|
| 4 |
+
Provides factory methods for creating and configuring agents,
|
| 5 |
+
enabling flexible agent composition and testing.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from coda.core.llm import LLMProvider
|
| 11 |
+
from coda.core.memory import SharedMemory
|
| 12 |
+
from coda.core.base_agent import BaseAgent
|
| 13 |
+
|
| 14 |
+
from coda.agents.query_analyzer import QueryAnalyzerAgent
|
| 15 |
+
from coda.agents.data_processor import DataProcessorAgent
|
| 16 |
+
from coda.agents.viz_mapping import VizMappingAgent
|
| 17 |
+
from coda.agents.search_agent import SearchAgent
|
| 18 |
+
from coda.agents.design_explorer import DesignExplorerAgent
|
| 19 |
+
from coda.agents.code_generator import CodeGeneratorAgent
|
| 20 |
+
from coda.agents.debug_agent import DebugAgent
|
| 21 |
+
from coda.agents.visual_evaluator import VisualEvaluatorAgent
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AgentFactory:
|
| 25 |
+
"""
|
| 26 |
+
Factory for creating CoDA agents with shared dependencies.
|
| 27 |
+
|
| 28 |
+
Centralizes agent creation and configuration, making it easy to
|
| 29 |
+
swap implementations or configure agents for testing.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
AGENT_TYPES = {
|
| 33 |
+
"query_analyzer": QueryAnalyzerAgent,
|
| 34 |
+
"data_processor": DataProcessorAgent,
|
| 35 |
+
"viz_mapping": VizMappingAgent,
|
| 36 |
+
"search_agent": SearchAgent,
|
| 37 |
+
"design_explorer": DesignExplorerAgent,
|
| 38 |
+
"code_generator": CodeGeneratorAgent,
|
| 39 |
+
"debug_agent": DebugAgent,
|
| 40 |
+
"visual_evaluator": VisualEvaluatorAgent,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
llm: LLMProvider,
|
| 46 |
+
memory: Optional[SharedMemory] = None,
|
| 47 |
+
) -> None:
|
| 48 |
+
self._llm = llm
|
| 49 |
+
self._memory = memory or SharedMemory()
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def memory(self) -> SharedMemory:
|
| 53 |
+
"""Get the shared memory instance."""
|
| 54 |
+
return self._memory
|
| 55 |
+
|
| 56 |
+
def create(
|
| 57 |
+
self,
|
| 58 |
+
agent_type: str,
|
| 59 |
+
**kwargs,
|
| 60 |
+
) -> BaseAgent:
|
| 61 |
+
"""
|
| 62 |
+
Create an agent by type name.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
agent_type: Name of the agent type to create
|
| 66 |
+
**kwargs: Additional arguments passed to the agent constructor
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Configured agent instance
|
| 70 |
+
|
| 71 |
+
Raises:
|
| 72 |
+
ValueError: If agent_type is not recognized
|
| 73 |
+
"""
|
| 74 |
+
if agent_type not in self.AGENT_TYPES:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Unknown agent type: {agent_type}. "
|
| 77 |
+
f"Available types: {list(self.AGENT_TYPES.keys())}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
agent_class = self.AGENT_TYPES[agent_type]
|
| 81 |
+
return agent_class(
|
| 82 |
+
llm=self._llm,
|
| 83 |
+
memory=self._memory,
|
| 84 |
+
**kwargs,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def create_all(self, **agent_kwargs) -> dict[str, BaseAgent]:
|
| 88 |
+
"""
|
| 89 |
+
Create all available agent types.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
**agent_kwargs: Arguments to pass to each agent constructor
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dictionary mapping agent type names to instances
|
| 96 |
+
"""
|
| 97 |
+
return {
|
| 98 |
+
agent_type: self.create(agent_type, **agent_kwargs.get(agent_type, {}))
|
| 99 |
+
for agent_type in self.AGENT_TYPES
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
def create_pipeline_agents(
|
| 103 |
+
self,
|
| 104 |
+
code_timeout: int = 60,
|
| 105 |
+
output_directory: str = "outputs",
|
| 106 |
+
min_quality_score: float = 7.0,
|
| 107 |
+
) -> dict[str, BaseAgent]:
|
| 108 |
+
"""
|
| 109 |
+
Create agents configured for the standard visualization pipeline.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
code_timeout: Timeout for code execution in seconds
|
| 113 |
+
output_directory: Directory for output files
|
| 114 |
+
min_quality_score: Minimum quality score threshold
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Dictionary of configured agents for the pipeline
|
| 118 |
+
"""
|
| 119 |
+
return {
|
| 120 |
+
"query_analyzer": self.create("query_analyzer"),
|
| 121 |
+
"data_processor": self.create("data_processor"),
|
| 122 |
+
"viz_mapping": self.create("viz_mapping"),
|
| 123 |
+
"search_agent": self.create("search_agent"),
|
| 124 |
+
"design_explorer": self.create("design_explorer"),
|
| 125 |
+
"code_generator": self.create("code_generator"),
|
| 126 |
+
"debug_agent": self.create(
|
| 127 |
+
"debug_agent",
|
| 128 |
+
timeout_seconds=code_timeout,
|
| 129 |
+
output_directory=output_directory,
|
| 130 |
+
),
|
| 131 |
+
"visual_evaluator": self.create(
|
| 132 |
+
"visual_evaluator",
|
| 133 |
+
min_overall_score=min_quality_score,
|
| 134 |
+
),
|
| 135 |
+
}
|
coda/core/base_agent.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base agent interface for CoDA.
|
| 3 |
+
|
| 4 |
+
Defines the contract that all specialized agents must implement,
|
| 5 |
+
providing common functionality for LLM interaction and memory access.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import re
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from typing import Any, Optional, TypeVar, Generic
|
| 13 |
+
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
from coda.core.llm import LLMProvider
|
| 17 |
+
from coda.core.memory import SharedMemory
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
T = TypeVar("T", bound=BaseModel)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AgentContext(BaseModel):
|
| 25 |
+
"""Context passed to an agent during execution."""
|
| 26 |
+
|
| 27 |
+
query: str
|
| 28 |
+
data_paths: list[str] = []
|
| 29 |
+
iteration: int = 0
|
| 30 |
+
feedback: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class BaseAgent(ABC, Generic[T]):
|
| 34 |
+
"""
|
| 35 |
+
Abstract base class for all CoDA agents.
|
| 36 |
+
|
| 37 |
+
Each agent specializes in a specific aspect of the visualization pipeline.
|
| 38 |
+
Agents communicate through shared memory and use an LLM for reasoning.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
llm: LLMProvider,
|
| 44 |
+
memory: SharedMemory,
|
| 45 |
+
name: Optional[str] = None,
|
| 46 |
+
) -> None:
|
| 47 |
+
self._llm = llm
|
| 48 |
+
self._memory = memory
|
| 49 |
+
self._name = name or self.__class__.__name__
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def name(self) -> str:
|
| 53 |
+
"""Get the agent's name."""
|
| 54 |
+
return self._name
|
| 55 |
+
|
| 56 |
+
def execute(self, context: AgentContext) -> T:
|
| 57 |
+
"""
|
| 58 |
+
Execute the agent's task.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
context: The execution context containing query and data info
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
The agent's structured output
|
| 65 |
+
"""
|
| 66 |
+
logger.info(f"[{self._name}] Starting execution")
|
| 67 |
+
|
| 68 |
+
prompt = self._build_prompt(context)
|
| 69 |
+
system_prompt = self._get_system_prompt()
|
| 70 |
+
|
| 71 |
+
response = self._llm.complete(
|
| 72 |
+
prompt=prompt,
|
| 73 |
+
system_prompt=system_prompt,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
result = self._parse_response(response.content)
|
| 77 |
+
self._store_result(result)
|
| 78 |
+
|
| 79 |
+
logger.info(f"[{self._name}] Execution complete")
|
| 80 |
+
return result
|
| 81 |
+
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def _build_prompt(self, context: AgentContext) -> str:
|
| 84 |
+
"""
|
| 85 |
+
Build the prompt for the LLM.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
context: The execution context
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
The formatted prompt string
|
| 92 |
+
"""
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
@abstractmethod
|
| 96 |
+
def _get_system_prompt(self) -> str:
|
| 97 |
+
"""
|
| 98 |
+
Get the system prompt defining the agent's persona.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
The system prompt string
|
| 102 |
+
"""
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
@abstractmethod
|
| 106 |
+
def _parse_response(self, response: str) -> T:
|
| 107 |
+
"""
|
| 108 |
+
Parse the LLM response into a structured output.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
response: The raw LLM response
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
The parsed and validated output
|
| 115 |
+
"""
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
@abstractmethod
|
| 119 |
+
def _get_output_key(self) -> str:
|
| 120 |
+
"""
|
| 121 |
+
Get the key used to store this agent's output in memory.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
The memory key string
|
| 125 |
+
"""
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def _store_result(self, result: T) -> None:
|
| 129 |
+
"""Store the agent's result in shared memory."""
|
| 130 |
+
self._memory.store(
|
| 131 |
+
key=self._get_output_key(),
|
| 132 |
+
value=result.model_dump(),
|
| 133 |
+
agent_name=self._name,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def _get_from_memory(self, key: str) -> Optional[Any]:
|
| 137 |
+
"""Retrieve a value from shared memory."""
|
| 138 |
+
return self._memory.retrieve(key)
|
| 139 |
+
|
| 140 |
+
def _extract_json(self, text: str) -> dict[str, Any]:
|
| 141 |
+
"""
|
| 142 |
+
Extract JSON from LLM response text.
|
| 143 |
+
|
| 144 |
+
Handles responses where JSON is wrapped in markdown code blocks
|
| 145 |
+
and sanitizes control characters that can break JSON parsing.
|
| 146 |
+
"""
|
| 147 |
+
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
|
| 148 |
+
if json_match:
|
| 149 |
+
text = json_match.group(1)
|
| 150 |
+
|
| 151 |
+
text = text.strip()
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
return json.loads(text)
|
| 155 |
+
except json.JSONDecodeError:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
# Try to fix unescaped newlines/tabs inside JSON strings
|
| 159 |
+
# First, find the JSON object boundaries
|
| 160 |
+
try:
|
| 161 |
+
obj_match = re.search(r'(\{[\s\S]*\})', text, re.DOTALL)
|
| 162 |
+
if obj_match:
|
| 163 |
+
json_text = obj_match.group(1)
|
| 164 |
+
|
| 165 |
+
# Replace problematic control characters (but NOT newlines between key:value pairs)
|
| 166 |
+
# Only remove NUL and other truly invalid chars
|
| 167 |
+
json_text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', json_text)
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
return json.loads(json_text)
|
| 171 |
+
except json.JSONDecodeError:
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
# If still failing, try to properly escape newlines within strings
|
| 175 |
+
# by parsing character by character
|
| 176 |
+
fixed = self._fix_json_strings(json_text)
|
| 177 |
+
return json.loads(fixed)
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
logger.error(f"Failed to parse JSON after sanitization attempts")
|
| 182 |
+
logger.debug(f"Raw text: {text[:500]}...")
|
| 183 |
+
raise ValueError(f"Invalid JSON in response: Could not parse after sanitization")
|
| 184 |
+
|
| 185 |
+
def _fix_json_strings(self, text: str) -> str:
|
| 186 |
+
"""Fix unescaped newlines and control characters inside JSON strings."""
|
| 187 |
+
result = []
|
| 188 |
+
in_string = False
|
| 189 |
+
escape_next = False
|
| 190 |
+
|
| 191 |
+
for char in text:
|
| 192 |
+
if escape_next:
|
| 193 |
+
result.append(char)
|
| 194 |
+
escape_next = False
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
if char == '\\':
|
| 198 |
+
result.append(char)
|
| 199 |
+
escape_next = True
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
if char == '"':
|
| 203 |
+
in_string = not in_string
|
| 204 |
+
result.append(char)
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
if in_string:
|
| 208 |
+
# Escape problematic characters inside strings
|
| 209 |
+
if char == '\n':
|
| 210 |
+
result.append('\\n')
|
| 211 |
+
elif char == '\r':
|
| 212 |
+
result.append('\\r')
|
| 213 |
+
elif char == '\t':
|
| 214 |
+
result.append('\\t')
|
| 215 |
+
else:
|
| 216 |
+
result.append(char)
|
| 217 |
+
else:
|
| 218 |
+
result.append(char)
|
| 219 |
+
|
| 220 |
+
return ''.join(result)
|
coda/core/llm.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM abstraction layer for CoDA.
|
| 3 |
+
|
| 4 |
+
Provides a clean interface for interacting with language models,
|
| 5 |
+
with ChatGroq as the default implementation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Optional, Union
|
| 14 |
+
|
| 15 |
+
from groq import Groq
|
| 16 |
+
from pydantic import BaseModel
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LLMResponse(BaseModel):
|
| 22 |
+
"""Structured response from an LLM call."""
|
| 23 |
+
|
| 24 |
+
content: str
|
| 25 |
+
model: str
|
| 26 |
+
usage: dict[str, int]
|
| 27 |
+
finish_reason: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LLMProvider(ABC):
|
| 31 |
+
"""Abstract interface for language model providers."""
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def complete(
|
| 35 |
+
self,
|
| 36 |
+
prompt: str,
|
| 37 |
+
system_prompt: Optional[str] = None,
|
| 38 |
+
temperature: Optional[float] = None,
|
| 39 |
+
max_tokens: Optional[int] = None,
|
| 40 |
+
) -> LLMResponse:
|
| 41 |
+
"""Generate a text completion."""
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def complete_with_image(
|
| 46 |
+
self,
|
| 47 |
+
prompt: str,
|
| 48 |
+
image_path: Union[str, Path],
|
| 49 |
+
system_prompt: Optional[str] = None,
|
| 50 |
+
temperature: Optional[float] = None,
|
| 51 |
+
max_tokens: Optional[int] = None,
|
| 52 |
+
) -> LLMResponse:
|
| 53 |
+
"""Generate a completion with image input (multimodal)."""
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class GroqLLM(LLMProvider):
|
| 58 |
+
"""ChatGroq implementation of the LLM provider."""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
api_key: str,
|
| 63 |
+
default_model: str = "llama-3.3-70b-versatile",
|
| 64 |
+
vision_model: str = "meta-llama/llama-4-maverick-17b-128e-instruct",
|
| 65 |
+
temperature: float = 0.7,
|
| 66 |
+
max_tokens: int = 4096,
|
| 67 |
+
max_retries: int = 3,
|
| 68 |
+
retry_delay: float = 1.0,
|
| 69 |
+
) -> None:
|
| 70 |
+
self._client = Groq(api_key=api_key)
|
| 71 |
+
self._default_model = default_model
|
| 72 |
+
self._vision_model = vision_model
|
| 73 |
+
self._temperature = temperature
|
| 74 |
+
self._max_tokens = max_tokens
|
| 75 |
+
self._max_retries = max_retries
|
| 76 |
+
self._retry_delay = retry_delay
|
| 77 |
+
|
| 78 |
+
def complete(
|
| 79 |
+
self,
|
| 80 |
+
prompt: str,
|
| 81 |
+
system_prompt: Optional[str] = None,
|
| 82 |
+
temperature: Optional[float] = None,
|
| 83 |
+
max_tokens: Optional[int] = None,
|
| 84 |
+
) -> LLMResponse:
|
| 85 |
+
"""Generate a text completion using ChatGroq."""
|
| 86 |
+
messages = self._build_messages(prompt, system_prompt)
|
| 87 |
+
return self._call_with_retry(
|
| 88 |
+
messages=messages,
|
| 89 |
+
model=self._default_model,
|
| 90 |
+
temperature=temperature or self._temperature,
|
| 91 |
+
max_tokens=max_tokens or self._max_tokens,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def complete_with_image(
|
| 95 |
+
self,
|
| 96 |
+
prompt: str,
|
| 97 |
+
image_path: Union[str, Path],
|
| 98 |
+
system_prompt: Optional[str] = None,
|
| 99 |
+
temperature: Optional[float] = None,
|
| 100 |
+
max_tokens: Optional[int] = None,
|
| 101 |
+
) -> LLMResponse:
|
| 102 |
+
"""Generate a completion with image input using the vision model."""
|
| 103 |
+
image_data = self._encode_image(image_path)
|
| 104 |
+
|
| 105 |
+
user_content = [
|
| 106 |
+
{
|
| 107 |
+
"type": "image_url",
|
| 108 |
+
"image_url": {
|
| 109 |
+
"url": f"data:image/png;base64,{image_data}"
|
| 110 |
+
}
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"type": "text",
|
| 114 |
+
"text": prompt
|
| 115 |
+
}
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
messages: list[dict[str, Any]] = []
|
| 119 |
+
if system_prompt:
|
| 120 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 121 |
+
messages.append({"role": "user", "content": user_content})
|
| 122 |
+
|
| 123 |
+
return self._call_with_retry(
|
| 124 |
+
messages=messages,
|
| 125 |
+
model=self._vision_model,
|
| 126 |
+
temperature=temperature or self._temperature,
|
| 127 |
+
max_tokens=max_tokens or self._max_tokens,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def _build_messages(
|
| 131 |
+
self,
|
| 132 |
+
prompt: str,
|
| 133 |
+
system_prompt: Optional[str] = None,
|
| 134 |
+
) -> list[dict[str, str]]:
|
| 135 |
+
"""Build the message list for the API call."""
|
| 136 |
+
messages: list[dict[str, str]] = []
|
| 137 |
+
if system_prompt:
|
| 138 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 139 |
+
messages.append({"role": "user", "content": prompt})
|
| 140 |
+
return messages
|
| 141 |
+
|
| 142 |
+
def _call_with_retry(
|
| 143 |
+
self,
|
| 144 |
+
messages: list[dict[str, Any]],
|
| 145 |
+
model: str,
|
| 146 |
+
temperature: float,
|
| 147 |
+
max_tokens: int,
|
| 148 |
+
) -> LLMResponse:
|
| 149 |
+
"""Execute API call with exponential backoff retry."""
|
| 150 |
+
last_exception: Optional[Exception] = None
|
| 151 |
+
|
| 152 |
+
for attempt in range(self._max_retries):
|
| 153 |
+
try:
|
| 154 |
+
response = self._client.chat.completions.create(
|
| 155 |
+
model=model,
|
| 156 |
+
messages=messages,
|
| 157 |
+
temperature=temperature,
|
| 158 |
+
max_tokens=max_tokens,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return LLMResponse(
|
| 162 |
+
content=response.choices[0].message.content or "",
|
| 163 |
+
model=response.model,
|
| 164 |
+
usage={
|
| 165 |
+
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
| 166 |
+
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
| 167 |
+
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
| 168 |
+
},
|
| 169 |
+
finish_reason=response.choices[0].finish_reason or "unknown",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
last_exception = e
|
| 174 |
+
logger.warning(
|
| 175 |
+
f"API call failed (attempt {attempt + 1}/{self._max_retries}): {e}"
|
| 176 |
+
)
|
| 177 |
+
if attempt < self._max_retries - 1:
|
| 178 |
+
sleep_time = self._retry_delay * (2 ** attempt)
|
| 179 |
+
time.sleep(sleep_time)
|
| 180 |
+
|
| 181 |
+
raise RuntimeError(
|
| 182 |
+
f"API call failed after {self._max_retries} attempts"
|
| 183 |
+
) from last_exception
|
| 184 |
+
|
| 185 |
+
def _encode_image(self, image_path: Union[str, Path]) -> str:
|
| 186 |
+
"""Encode an image file to base64."""
|
| 187 |
+
path = Path(image_path)
|
| 188 |
+
if not path.exists():
|
| 189 |
+
raise FileNotFoundError(f"Image not found: {path}")
|
| 190 |
+
|
| 191 |
+
with open(path, "rb") as f:
|
| 192 |
+
return base64.b64encode(f.read()).decode("utf-8")
|
coda/core/memory.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared memory buffer for inter-agent communication in CoDA.
|
| 3 |
+
|
| 4 |
+
Provides thread-safe storage for agents to exchange context,
|
| 5 |
+
results, and feedback during the visualization pipeline.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import threading
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MemoryEntry(BaseModel):
|
| 16 |
+
"""A single entry in the shared memory."""
|
| 17 |
+
|
| 18 |
+
key: str
|
| 19 |
+
value: Any
|
| 20 |
+
agent_name: str
|
| 21 |
+
timestamp: datetime = Field(default_factory=datetime.now)
|
| 22 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SharedMemory:
|
| 26 |
+
"""
|
| 27 |
+
Thread-safe shared memory buffer for agent communication.
|
| 28 |
+
|
| 29 |
+
Agents can store and retrieve structured data using string keys.
|
| 30 |
+
Each entry tracks the source agent and timestamp for debugging.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self) -> None:
|
| 34 |
+
self._storage: dict[str, MemoryEntry] = {}
|
| 35 |
+
self._lock = threading.RLock()
|
| 36 |
+
self._history: list[MemoryEntry] = []
|
| 37 |
+
|
| 38 |
+
def store(
|
| 39 |
+
self,
|
| 40 |
+
key: str,
|
| 41 |
+
value: Any,
|
| 42 |
+
agent_name: str,
|
| 43 |
+
metadata: Optional[dict[str, Any]] = None,
|
| 44 |
+
) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Store a value in shared memory.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
key: Unique identifier for the data
|
| 50 |
+
value: The data to store (should be JSON-serializable)
|
| 51 |
+
agent_name: Name of the agent storing the data
|
| 52 |
+
metadata: Optional additional context
|
| 53 |
+
"""
|
| 54 |
+
entry = MemoryEntry(
|
| 55 |
+
key=key,
|
| 56 |
+
value=value,
|
| 57 |
+
agent_name=agent_name,
|
| 58 |
+
metadata=metadata or {},
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
with self._lock:
|
| 62 |
+
self._storage[key] = entry
|
| 63 |
+
self._history.append(entry)
|
| 64 |
+
|
| 65 |
+
def retrieve(self, key: str) -> Optional[Any]:
|
| 66 |
+
"""
|
| 67 |
+
Retrieve a value from shared memory.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
key: The key to look up
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
The stored value, or None if not found
|
| 74 |
+
"""
|
| 75 |
+
with self._lock:
|
| 76 |
+
entry = self._storage.get(key)
|
| 77 |
+
return entry.value if entry else None
|
| 78 |
+
|
| 79 |
+
def retrieve_entry(self, key: str) -> Optional[MemoryEntry]:
|
| 80 |
+
"""
|
| 81 |
+
Retrieve the full memory entry including metadata.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
key: The key to look up
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
The full MemoryEntry, or None if not found
|
| 88 |
+
"""
|
| 89 |
+
with self._lock:
|
| 90 |
+
return self._storage.get(key)
|
| 91 |
+
|
| 92 |
+
def get_context(self, keys: list[str]) -> dict[str, Any]:
|
| 93 |
+
"""
|
| 94 |
+
Retrieve multiple values as a context dictionary.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
keys: List of keys to retrieve
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Dictionary mapping keys to their values (missing keys excluded)
|
| 101 |
+
"""
|
| 102 |
+
with self._lock:
|
| 103 |
+
return {
|
| 104 |
+
key: self._storage[key].value
|
| 105 |
+
for key in keys
|
| 106 |
+
if key in self._storage
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def get_all(self) -> dict[str, Any]:
|
| 110 |
+
"""
|
| 111 |
+
Retrieve all stored values.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Dictionary mapping all keys to their values
|
| 115 |
+
"""
|
| 116 |
+
with self._lock:
|
| 117 |
+
return {key: entry.value for key, entry in self._storage.items()}
|
| 118 |
+
|
| 119 |
+
def get_history(self, agent_name: Optional[str] = None) -> list[MemoryEntry]:
|
| 120 |
+
"""
|
| 121 |
+
Get the history of all memory operations.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
agent_name: Optional filter by agent name
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
List of memory entries in chronological order
|
| 128 |
+
"""
|
| 129 |
+
with self._lock:
|
| 130 |
+
if agent_name:
|
| 131 |
+
return [e for e in self._history if e.agent_name == agent_name]
|
| 132 |
+
return list(self._history)
|
| 133 |
+
|
| 134 |
+
def has_key(self, key: str) -> bool:
|
| 135 |
+
"""Check if a key exists in memory."""
|
| 136 |
+
with self._lock:
|
| 137 |
+
return key in self._storage
|
| 138 |
+
|
| 139 |
+
def clear(self) -> None:
|
| 140 |
+
"""Clear all stored data and history."""
|
| 141 |
+
with self._lock:
|
| 142 |
+
self._storage.clear()
|
| 143 |
+
self._history.clear()
|
| 144 |
+
|
| 145 |
+
def keys(self) -> list[str]:
|
| 146 |
+
"""Get all stored keys."""
|
| 147 |
+
with self._lock:
|
| 148 |
+
return list(self._storage.keys())
|
coda/orchestrator.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Workflow Orchestrator for CoDA.
|
| 3 |
+
|
| 4 |
+
Manages the multi-agent pipeline, coordinating agent execution,
|
| 5 |
+
handling feedback loops, and implementing quality-driven halting.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Callable
|
| 12 |
+
|
| 13 |
+
from coda.config import Config, get_config
|
| 14 |
+
from coda.core.llm import GroqLLM, LLMProvider
|
| 15 |
+
from coda.core.memory import SharedMemory
|
| 16 |
+
from coda.core.base_agent import AgentContext
|
| 17 |
+
|
| 18 |
+
from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
|
| 19 |
+
from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
|
| 20 |
+
from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
|
| 21 |
+
from coda.agents.search_agent import SearchAgent, SearchResult
|
| 22 |
+
from coda.agents.design_explorer import DesignExplorerAgent, DesignSpec
|
| 23 |
+
from coda.agents.code_generator import CodeGeneratorAgent, GeneratedCode
|
| 24 |
+
from coda.agents.debug_agent import DebugAgent, ExecutionResult
|
| 25 |
+
from coda.agents.visual_evaluator import VisualEvaluatorAgent, VisualEvaluation
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class PipelineResult:
|
| 32 |
+
"""Final result from the CoDA pipeline."""
|
| 33 |
+
|
| 34 |
+
success: bool
|
| 35 |
+
output_file: Optional[str]
|
| 36 |
+
evaluation: Optional[VisualEvaluation]
|
| 37 |
+
iterations: int
|
| 38 |
+
error: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def scores(self) -> Optional[dict]:
|
| 42 |
+
"""Get quality scores if evaluation exists."""
|
| 43 |
+
if self.evaluation:
|
| 44 |
+
return self.evaluation.scores.model_dump()
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class CodaOrchestrator:
|
| 49 |
+
"""
|
| 50 |
+
Orchestrates the CoDA multi-agent visualization pipeline.
|
| 51 |
+
|
| 52 |
+
Coordinates agent execution in sequence, manages the shared memory,
|
| 53 |
+
and implements iterative refinement through feedback loops.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
config: Optional[Config] = None,
|
| 59 |
+
llm: Optional[LLMProvider] = None,
|
| 60 |
+
progress_callback: Optional[Callable[[str, float], None]] = None,
|
| 61 |
+
) -> None:
|
| 62 |
+
self._config = config or get_config()
|
| 63 |
+
self._llm = llm or self._create_llm()
|
| 64 |
+
self._memory = SharedMemory()
|
| 65 |
+
self._progress_callback = progress_callback
|
| 66 |
+
|
| 67 |
+
self._agents = self._create_agents()
|
| 68 |
+
|
| 69 |
+
def _create_llm(self) -> GroqLLM:
|
| 70 |
+
"""Create the LLM instance."""
|
| 71 |
+
return GroqLLM(
|
| 72 |
+
api_key=self._config.groq_api_key,
|
| 73 |
+
default_model=self._config.model.default_model,
|
| 74 |
+
vision_model=self._config.model.vision_model,
|
| 75 |
+
temperature=self._config.model.temperature,
|
| 76 |
+
max_tokens=self._config.model.max_tokens,
|
| 77 |
+
max_retries=self._config.model.max_retries,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def _create_agents(self) -> dict:
|
| 81 |
+
"""Initialize all agents with shared resources."""
|
| 82 |
+
return {
|
| 83 |
+
"query_analyzer": QueryAnalyzerAgent(self._llm, self._memory),
|
| 84 |
+
"data_processor": DataProcessorAgent(self._llm, self._memory),
|
| 85 |
+
"viz_mapping": VizMappingAgent(self._llm, self._memory),
|
| 86 |
+
"search_agent": SearchAgent(self._llm, self._memory),
|
| 87 |
+
"design_explorer": DesignExplorerAgent(self._llm, self._memory),
|
| 88 |
+
"code_generator": CodeGeneratorAgent(self._llm, self._memory),
|
| 89 |
+
"debug_agent": DebugAgent(
|
| 90 |
+
self._llm,
|
| 91 |
+
self._memory,
|
| 92 |
+
timeout_seconds=self._config.execution.code_timeout_seconds,
|
| 93 |
+
output_directory=self._config.execution.output_directory,
|
| 94 |
+
),
|
| 95 |
+
"visual_evaluator": VisualEvaluatorAgent(
|
| 96 |
+
self._llm,
|
| 97 |
+
self._memory,
|
| 98 |
+
min_overall_score=self._config.quality.minimum_overall_score,
|
| 99 |
+
),
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
def run(
|
| 103 |
+
self,
|
| 104 |
+
query: str,
|
| 105 |
+
data_paths: list[str],
|
| 106 |
+
) -> PipelineResult:
|
| 107 |
+
"""
|
| 108 |
+
Execute the full visualization pipeline.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
query: Natural language visualization request
|
| 112 |
+
data_paths: Paths to data files
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
PipelineResult with output file and evaluation
|
| 116 |
+
"""
|
| 117 |
+
logger.info(f"Starting CoDA pipeline for query: {query[:50]}...")
|
| 118 |
+
|
| 119 |
+
self._memory.clear()
|
| 120 |
+
|
| 121 |
+
validated_paths = self._validate_data_paths(data_paths)
|
| 122 |
+
if not validated_paths:
|
| 123 |
+
return PipelineResult(
|
| 124 |
+
success=False,
|
| 125 |
+
output_file=None,
|
| 126 |
+
evaluation=None,
|
| 127 |
+
iterations=0,
|
| 128 |
+
error="No valid data files provided",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
context = AgentContext(
|
| 132 |
+
query=query,
|
| 133 |
+
data_paths=validated_paths,
|
| 134 |
+
iteration=0,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
self._run_initial_pipeline(context)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Initial pipeline failed: {e}")
|
| 141 |
+
return PipelineResult(
|
| 142 |
+
success=False,
|
| 143 |
+
output_file=None,
|
| 144 |
+
evaluation=None,
|
| 145 |
+
iterations=0,
|
| 146 |
+
error=str(e),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
max_iterations = self._config.execution.max_refinement_iterations
|
| 150 |
+
final_result = self._run_refinement_loop(context, max_iterations)
|
| 151 |
+
|
| 152 |
+
return final_result
|
| 153 |
+
|
| 154 |
+
def _validate_data_paths(self, data_paths: list[str]) -> list[str]:
|
| 155 |
+
"""Validate that data files exist."""
|
| 156 |
+
valid_paths = []
|
| 157 |
+
for path in data_paths:
|
| 158 |
+
if Path(path).exists():
|
| 159 |
+
valid_paths.append(path)
|
| 160 |
+
else:
|
| 161 |
+
logger.warning(f"Data file not found: {path}")
|
| 162 |
+
return valid_paths
|
| 163 |
+
|
| 164 |
+
def _run_initial_pipeline(self, context: AgentContext) -> None:
|
| 165 |
+
"""Run the initial agent pipeline."""
|
| 166 |
+
steps = [
|
| 167 |
+
("query_analyzer", "Analyzing query...", 0.1),
|
| 168 |
+
("data_processor", "Processing data...", 0.2),
|
| 169 |
+
("viz_mapping", "Mapping visualization...", 0.3),
|
| 170 |
+
("search_agent", "Searching examples...", 0.4),
|
| 171 |
+
("design_explorer", "Designing visualization...", 0.5),
|
| 172 |
+
("code_generator", "Generating code...", 0.7),
|
| 173 |
+
("debug_agent", "Executing code...", 0.85),
|
| 174 |
+
("visual_evaluator", "Evaluating output...", 0.95),
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
for agent_name, status, progress in steps:
|
| 178 |
+
self._report_progress(status, progress)
|
| 179 |
+
agent = self._agents[agent_name]
|
| 180 |
+
agent.execute(context)
|
| 181 |
+
|
| 182 |
+
def _run_refinement_loop(
|
| 183 |
+
self,
|
| 184 |
+
context: AgentContext,
|
| 185 |
+
max_iterations: int,
|
| 186 |
+
) -> PipelineResult:
|
| 187 |
+
"""Run the iterative refinement loop."""
|
| 188 |
+
for iteration in range(max_iterations):
|
| 189 |
+
evaluation = self._memory.retrieve("visual_evaluation")
|
| 190 |
+
|
| 191 |
+
if not evaluation:
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
if isinstance(evaluation, dict):
|
| 195 |
+
passes = evaluation.get("passes_threshold", False)
|
| 196 |
+
eval_obj = VisualEvaluation(**evaluation)
|
| 197 |
+
else:
|
| 198 |
+
passes = evaluation.passes_threshold
|
| 199 |
+
eval_obj = evaluation
|
| 200 |
+
|
| 201 |
+
if passes:
|
| 202 |
+
logger.info(f"Quality threshold met at iteration {iteration}")
|
| 203 |
+
return self._create_success_result(eval_obj, iteration + 1)
|
| 204 |
+
|
| 205 |
+
if iteration >= max_iterations - 1:
|
| 206 |
+
logger.info("Max iterations reached")
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
logger.info(f"Refinement iteration {iteration + 1}")
|
| 210 |
+
context = self._create_refinement_context(context, eval_obj, iteration + 1)
|
| 211 |
+
|
| 212 |
+
self._report_progress(f"Refining (iteration {iteration + 2})...", 0.5)
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
self._run_refinement_agents(context)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Refinement failed: {e}")
|
| 218 |
+
break
|
| 219 |
+
|
| 220 |
+
final_eval = self._memory.retrieve("visual_evaluation")
|
| 221 |
+
if isinstance(final_eval, dict):
|
| 222 |
+
final_eval = VisualEvaluation(**final_eval)
|
| 223 |
+
|
| 224 |
+
return self._create_success_result(final_eval, max_iterations)
|
| 225 |
+
|
| 226 |
+
def _run_refinement_agents(self, context: AgentContext) -> None:
|
| 227 |
+
"""Run agents that participate in refinement."""
|
| 228 |
+
refinement_agents = [
|
| 229 |
+
"design_explorer",
|
| 230 |
+
"code_generator",
|
| 231 |
+
"debug_agent",
|
| 232 |
+
"visual_evaluator",
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
for agent_name in refinement_agents:
|
| 236 |
+
agent = self._agents[agent_name]
|
| 237 |
+
agent.execute(context)
|
| 238 |
+
|
| 239 |
+
def _create_refinement_context(
|
| 240 |
+
self,
|
| 241 |
+
original_context: AgentContext,
|
| 242 |
+
evaluation: VisualEvaluation,
|
| 243 |
+
iteration: int,
|
| 244 |
+
) -> AgentContext:
|
| 245 |
+
"""Create context for refinement iteration."""
|
| 246 |
+
feedback_parts = []
|
| 247 |
+
|
| 248 |
+
if evaluation.issues:
|
| 249 |
+
feedback_parts.append(f"Issues: {', '.join(evaluation.issues[:3])}")
|
| 250 |
+
|
| 251 |
+
if evaluation.priority_fixes:
|
| 252 |
+
feedback_parts.append(f"Fix: {', '.join(evaluation.priority_fixes[:2])}")
|
| 253 |
+
|
| 254 |
+
feedback = " | ".join(feedback_parts)
|
| 255 |
+
|
| 256 |
+
return AgentContext(
|
| 257 |
+
query=original_context.query,
|
| 258 |
+
data_paths=original_context.data_paths,
|
| 259 |
+
iteration=iteration,
|
| 260 |
+
feedback=feedback,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def _create_success_result(
|
| 264 |
+
self,
|
| 265 |
+
evaluation: Optional[VisualEvaluation],
|
| 266 |
+
iterations: int,
|
| 267 |
+
) -> PipelineResult:
|
| 268 |
+
"""Create a successful pipeline result."""
|
| 269 |
+
execution_result = self._memory.retrieve("execution_result")
|
| 270 |
+
output_file = None
|
| 271 |
+
|
| 272 |
+
if execution_result:
|
| 273 |
+
if isinstance(execution_result, dict):
|
| 274 |
+
output_file = execution_result.get("output_file")
|
| 275 |
+
else:
|
| 276 |
+
output_file = execution_result.output_file
|
| 277 |
+
|
| 278 |
+
return PipelineResult(
|
| 279 |
+
success=output_file is not None and Path(output_file).exists(),
|
| 280 |
+
output_file=output_file,
|
| 281 |
+
evaluation=evaluation,
|
| 282 |
+
iterations=iterations,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def _report_progress(self, status: str, progress: float) -> None:
|
| 286 |
+
"""Report progress to callback if set."""
|
| 287 |
+
if self._progress_callback:
|
| 288 |
+
self._progress_callback(status, progress)
|
| 289 |
+
logger.info(f"[{progress:.0%}] {status}")
|
| 290 |
+
|
| 291 |
+
def get_memory_state(self) -> dict:
|
| 292 |
+
"""Get the current state of shared memory for debugging."""
|
| 293 |
+
return self._memory.get_all()
|
main.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Command-Line Interface for CoDA.
|
| 3 |
+
|
| 4 |
+
Provides a CLI for running the CoDA visualization pipeline locally.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import logging
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
level=logging.INFO,
|
| 14 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 15 |
+
)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
"""Main entry point for the CLI."""
|
| 21 |
+
parser = argparse.ArgumentParser(
|
| 22 |
+
description="CoDA - Collaborative Data Visualization Agents",
|
| 23 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 24 |
+
epilog="""
|
| 25 |
+
Examples:
|
| 26 |
+
python main.py --query "Show sales trends" --data sales.csv
|
| 27 |
+
python main.py -q "Bar chart of categories" -d data.xlsx
|
| 28 |
+
python main.py --query "Scatter plot" --data file1.csv file2.csv
|
| 29 |
+
"""
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"-q", "--query",
|
| 34 |
+
type=str,
|
| 35 |
+
required=True,
|
| 36 |
+
help="Natural language visualization query"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"-d", "--data",
|
| 41 |
+
type=str,
|
| 42 |
+
nargs="+",
|
| 43 |
+
required=True,
|
| 44 |
+
help="Path(s) to data file(s)"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"-o", "--output",
|
| 49 |
+
type=str,
|
| 50 |
+
default="outputs",
|
| 51 |
+
help="Output directory for visualizations (default: outputs)"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--max-iterations",
|
| 56 |
+
type=int,
|
| 57 |
+
default=3,
|
| 58 |
+
help="Maximum refinement iterations (default: 3)"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--min-score",
|
| 63 |
+
type=float,
|
| 64 |
+
default=7.0,
|
| 65 |
+
help="Minimum quality score threshold (default: 7.0)"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"-v", "--verbose",
|
| 70 |
+
action="store_true",
|
| 71 |
+
help="Enable verbose logging"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
if args.verbose:
|
| 77 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 78 |
+
|
| 79 |
+
for path in args.data:
|
| 80 |
+
if not Path(path).exists():
|
| 81 |
+
logger.error(f"Data file not found: {path}")
|
| 82 |
+
sys.exit(1)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
from coda.config import Config, ExecutionConfig, QualityThresholds
|
| 86 |
+
from coda.orchestrator import CodaOrchestrator
|
| 87 |
+
except ImportError as e:
|
| 88 |
+
logger.error(f"Failed to import CoDA modules: {e}")
|
| 89 |
+
logger.error("Make sure you have installed all dependencies: pip install -r requirements.txt")
|
| 90 |
+
sys.exit(1)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
config = Config(
|
| 94 |
+
execution=ExecutionConfig(
|
| 95 |
+
max_refinement_iterations=args.max_iterations,
|
| 96 |
+
output_directory=args.output,
|
| 97 |
+
),
|
| 98 |
+
quality=QualityThresholds(
|
| 99 |
+
minimum_overall_score=args.min_score,
|
| 100 |
+
),
|
| 101 |
+
)
|
| 102 |
+
except ValueError as e:
|
| 103 |
+
logger.error(f"Configuration error: {e}")
|
| 104 |
+
sys.exit(1)
|
| 105 |
+
|
| 106 |
+
def progress_callback(status: str, progress: float):
|
| 107 |
+
bar_length = 30
|
| 108 |
+
filled = int(bar_length * progress)
|
| 109 |
+
bar = "█" * filled + "░" * (bar_length - filled)
|
| 110 |
+
print(f"\r[{bar}] {progress:.0%} - {status}", end="", flush=True)
|
| 111 |
+
if progress >= 1.0:
|
| 112 |
+
print()
|
| 113 |
+
|
| 114 |
+
print(f"\n{'='*60}")
|
| 115 |
+
print("CoDA - Collaborative Data Visualization Agents")
|
| 116 |
+
print(f"{'='*60}\n")
|
| 117 |
+
print(f"Query: {args.query}")
|
| 118 |
+
print(f"Data: {', '.join(args.data)}")
|
| 119 |
+
print(f"Output: {args.output}/")
|
| 120 |
+
print()
|
| 121 |
+
|
| 122 |
+
orchestrator = CodaOrchestrator(
|
| 123 |
+
config=config,
|
| 124 |
+
progress_callback=progress_callback,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
result = orchestrator.run(
|
| 128 |
+
query=args.query,
|
| 129 |
+
data_paths=args.data,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print()
|
| 133 |
+
print(f"{'='*60}")
|
| 134 |
+
print("Results")
|
| 135 |
+
print(f"{'='*60}\n")
|
| 136 |
+
|
| 137 |
+
if result.success:
|
| 138 |
+
print(f"✅ Visualization generated successfully!")
|
| 139 |
+
print(f"📁 Output: {result.output_file}")
|
| 140 |
+
print(f"🔄 Iterations: {result.iterations}")
|
| 141 |
+
|
| 142 |
+
if result.scores:
|
| 143 |
+
print(f"\n📊 Quality Scores:")
|
| 144 |
+
for key, value in result.scores.items():
|
| 145 |
+
emoji = "🟢" if value >= 7 else "🟡" if value >= 5 else "🔴"
|
| 146 |
+
print(f" {key.title()}: {emoji} {value:.1f}/10")
|
| 147 |
+
|
| 148 |
+
if result.evaluation and result.evaluation.strengths:
|
| 149 |
+
print(f"\n💪 Strengths:")
|
| 150 |
+
for s in result.evaluation.strengths[:3]:
|
| 151 |
+
print(f" • {s}")
|
| 152 |
+
else:
|
| 153 |
+
print(f"❌ Visualization failed!")
|
| 154 |
+
if result.error:
|
| 155 |
+
print(f" Error: {result.error}")
|
| 156 |
+
sys.exit(1)
|
| 157 |
+
|
| 158 |
+
print()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CoDA Dependencies
|
| 2 |
+
|
| 3 |
+
# Core
|
| 4 |
+
groq>=0.4.0
|
| 5 |
+
pydantic>=2.0.0
|
| 6 |
+
python-dotenv>=1.0.0
|
| 7 |
+
|
| 8 |
+
# Data Processing
|
| 9 |
+
pandas>=2.0.0
|
| 10 |
+
openpyxl>=3.1.0
|
| 11 |
+
pyarrow>=14.0.0
|
| 12 |
+
|
| 13 |
+
# Visualization
|
| 14 |
+
matplotlib>=3.7.0
|
| 15 |
+
seaborn>=0.13.0
|
| 16 |
+
|
| 17 |
+
# Web Interface
|
| 18 |
+
gradio>=4.0.0
|
| 19 |
+
|
| 20 |
+
# Development (optional)
|
| 21 |
+
pytest>=7.0.0
|
| 22 |
+
pytest-asyncio>=0.21.0
|
sample_data.csv
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
date,sales,category,region
|
| 2 |
+
2024-01-01,1500,Electronics,North
|
| 3 |
+
2024-01-02,2300,Electronics,South
|
| 4 |
+
2024-01-03,1800,Clothing,North
|
| 5 |
+
2024-01-04,3200,Electronics,East
|
| 6 |
+
2024-01-05,2100,Clothing,West
|
| 7 |
+
2024-01-06,1900,Food,North
|
| 8 |
+
2024-01-07,2800,Electronics,South
|
| 9 |
+
2024-01-08,1600,Clothing,East
|
| 10 |
+
2024-01-09,3500,Food,West
|
| 11 |
+
2024-01-10,2400,Electronics,North
|
| 12 |
+
2024-01-11,1700,Clothing,South
|
| 13 |
+
2024-01-12,2900,Food,East
|
| 14 |
+
2024-01-13,2200,Electronics,West
|
| 15 |
+
2024-01-14,1400,Clothing,North
|
| 16 |
+
2024-01-15,3100,Food,South
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Tests package
|
tests/test_agents.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for agent implementations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from unittest.mock import Mock, MagicMock
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from coda.core.memory import SharedMemory
|
| 10 |
+
from coda.core.llm import LLMResponse
|
| 11 |
+
from coda.core.base_agent import AgentContext
|
| 12 |
+
from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
|
| 13 |
+
from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
|
| 14 |
+
from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MockLLM:
|
| 18 |
+
"""Mock LLM for testing agents."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, response_content: str):
|
| 21 |
+
self._response = response_content
|
| 22 |
+
|
| 23 |
+
def complete(self, prompt, system_prompt=None, **kwargs):
|
| 24 |
+
return LLMResponse(
|
| 25 |
+
content=self._response,
|
| 26 |
+
model="mock",
|
| 27 |
+
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
| 28 |
+
finish_reason="stop"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def complete_with_image(self, prompt, image_path, **kwargs):
|
| 32 |
+
return self.complete(prompt, **kwargs)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestQueryAnalyzerAgent:
|
| 36 |
+
"""Tests for the Query Analyzer agent."""
|
| 37 |
+
|
| 38 |
+
@pytest.fixture
|
| 39 |
+
def mock_response(self):
|
| 40 |
+
return json.dumps({
|
| 41 |
+
"visualization_types": ["line chart", "bar chart"],
|
| 42 |
+
"key_points": ["sales trends", "monthly data"],
|
| 43 |
+
"todo_list": ["Load data", "Create chart", "Add labels"],
|
| 44 |
+
"data_requirements": ["date", "sales"],
|
| 45 |
+
"constraints": ["use blue colors"],
|
| 46 |
+
"ambiguities": []
|
| 47 |
+
})
|
| 48 |
+
|
| 49 |
+
def test_execute(self, mock_response):
|
| 50 |
+
"""Test query analysis execution."""
|
| 51 |
+
llm = MockLLM(mock_response)
|
| 52 |
+
memory = SharedMemory()
|
| 53 |
+
agent = QueryAnalyzerAgent(llm, memory)
|
| 54 |
+
|
| 55 |
+
context = AgentContext(query="Show sales trends over time")
|
| 56 |
+
result = agent.execute(context)
|
| 57 |
+
|
| 58 |
+
assert isinstance(result, QueryAnalysis)
|
| 59 |
+
assert "line chart" in result.visualization_types
|
| 60 |
+
assert "sales trends" in result.key_points
|
| 61 |
+
assert len(result.todo_list) == 3
|
| 62 |
+
|
| 63 |
+
def test_stores_in_memory(self, mock_response):
|
| 64 |
+
"""Test that results are stored in memory."""
|
| 65 |
+
llm = MockLLM(mock_response)
|
| 66 |
+
memory = SharedMemory()
|
| 67 |
+
agent = QueryAnalyzerAgent(llm, memory)
|
| 68 |
+
|
| 69 |
+
context = AgentContext(query="Test query")
|
| 70 |
+
agent.execute(context)
|
| 71 |
+
|
| 72 |
+
stored = memory.retrieve("query_analysis")
|
| 73 |
+
assert stored is not None
|
| 74 |
+
assert "visualization_types" in stored
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TestVizMappingAgent:
|
| 78 |
+
"""Tests for the VizMapping agent."""
|
| 79 |
+
|
| 80 |
+
@pytest.fixture
|
| 81 |
+
def mock_response(self):
|
| 82 |
+
return json.dumps({
|
| 83 |
+
"chart_type": "line",
|
| 84 |
+
"chart_subtype": None,
|
| 85 |
+
"x_axis": {"column": "date", "label": "Date", "type": "temporal"},
|
| 86 |
+
"y_axis": {"column": "sales", "label": "Sales", "type": "numerical"},
|
| 87 |
+
"color_encoding": None,
|
| 88 |
+
"size_encoding": None,
|
| 89 |
+
"transformations": [],
|
| 90 |
+
"styling_hints": {"theme": "modern"},
|
| 91 |
+
"visualization_goals": ["Show trends"],
|
| 92 |
+
"rationale": "Line chart best for trends"
|
| 93 |
+
})
|
| 94 |
+
|
| 95 |
+
def test_execute(self, mock_response):
|
| 96 |
+
"""Test visualization mapping execution."""
|
| 97 |
+
llm = MockLLM(mock_response)
|
| 98 |
+
memory = SharedMemory()
|
| 99 |
+
|
| 100 |
+
memory.store("query_analysis", {
|
| 101 |
+
"visualization_types": ["line chart"],
|
| 102 |
+
"key_points": ["trends"],
|
| 103 |
+
"data_requirements": ["date", "sales"]
|
| 104 |
+
}, "test")
|
| 105 |
+
|
| 106 |
+
agent = VizMappingAgent(llm, memory)
|
| 107 |
+
context = AgentContext(query="Show sales trends")
|
| 108 |
+
result = agent.execute(context)
|
| 109 |
+
|
| 110 |
+
assert isinstance(result, VisualMapping)
|
| 111 |
+
assert result.chart_type == "line"
|
| 112 |
+
assert result.x_axis["column"] == "date"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TestAgentContext:
|
| 116 |
+
"""Tests for the AgentContext model."""
|
| 117 |
+
|
| 118 |
+
def test_basic_context(self):
|
| 119 |
+
"""Test creating a basic context."""
|
| 120 |
+
context = AgentContext(
|
| 121 |
+
query="Test query",
|
| 122 |
+
data_paths=["file1.csv", "file2.csv"]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
assert context.query == "Test query"
|
| 126 |
+
assert len(context.data_paths) == 2
|
| 127 |
+
assert context.iteration == 0
|
| 128 |
+
assert context.feedback is None
|
| 129 |
+
|
| 130 |
+
def test_context_with_feedback(self):
|
| 131 |
+
"""Test context with feedback for refinement."""
|
| 132 |
+
context = AgentContext(
|
| 133 |
+
query="Test",
|
| 134 |
+
iteration=2,
|
| 135 |
+
feedback="Improve colors"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
assert context.iteration == 2
|
| 139 |
+
assert context.feedback == "Improve colors"
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class TestBaseAgentJsonExtraction:
|
| 143 |
+
"""Tests for JSON extraction from LLM responses."""
|
| 144 |
+
|
| 145 |
+
def test_extract_json_plain(self):
|
| 146 |
+
"""Test extracting plain JSON."""
|
| 147 |
+
llm = MockLLM("{}")
|
| 148 |
+
memory = SharedMemory()
|
| 149 |
+
agent = QueryAnalyzerAgent(llm, memory)
|
| 150 |
+
|
| 151 |
+
result = agent._extract_json('{"key": "value"}')
|
| 152 |
+
|
| 153 |
+
assert result == {"key": "value"}
|
| 154 |
+
|
| 155 |
+
def test_extract_json_markdown(self):
|
| 156 |
+
"""Test extracting JSON from markdown code block."""
|
| 157 |
+
llm = MockLLM("{}")
|
| 158 |
+
memory = SharedMemory()
|
| 159 |
+
agent = QueryAnalyzerAgent(llm, memory)
|
| 160 |
+
|
| 161 |
+
text = """Here is the response:
|
| 162 |
+
```json
|
| 163 |
+
{"key": "value"}
|
| 164 |
+
```
|
| 165 |
+
"""
|
| 166 |
+
result = agent._extract_json(text)
|
| 167 |
+
|
| 168 |
+
assert result == {"key": "value"}
|
| 169 |
+
|
| 170 |
+
def test_extract_json_invalid(self):
|
| 171 |
+
"""Test handling invalid JSON."""
|
| 172 |
+
llm = MockLLM("{}")
|
| 173 |
+
memory = SharedMemory()
|
| 174 |
+
agent = QueryAnalyzerAgent(llm, memory)
|
| 175 |
+
|
| 176 |
+
with pytest.raises(ValueError):
|
| 177 |
+
agent._extract_json("not valid json")
|
tests/test_llm.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for the LLM abstraction layer.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from unittest.mock import Mock, patch, MagicMock
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from coda.core.llm import LLMProvider, GroqLLM, LLMResponse
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestLLMResponse:
|
| 13 |
+
"""Tests for the LLMResponse model."""
|
| 14 |
+
|
| 15 |
+
def test_response_creation(self):
|
| 16 |
+
"""Test creating a valid response."""
|
| 17 |
+
response = LLMResponse(
|
| 18 |
+
content="Test content",
|
| 19 |
+
model="test-model",
|
| 20 |
+
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
| 21 |
+
finish_reason="stop"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
assert response.content == "Test content"
|
| 25 |
+
assert response.model == "test-model"
|
| 26 |
+
assert response.usage["total_tokens"] == 30
|
| 27 |
+
assert response.finish_reason == "stop"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TestGroqLLM:
|
| 31 |
+
"""Tests for the GroqLLM implementation."""
|
| 32 |
+
|
| 33 |
+
@pytest.fixture
|
| 34 |
+
def mock_groq_client(self):
|
| 35 |
+
"""Create a mock Groq client."""
|
| 36 |
+
with patch("coda.core.llm.Groq") as mock:
|
| 37 |
+
client_instance = Mock()
|
| 38 |
+
mock.return_value = client_instance
|
| 39 |
+
yield client_instance
|
| 40 |
+
|
| 41 |
+
def test_initialization(self, mock_groq_client):
|
| 42 |
+
"""Test LLM initialization with custom parameters."""
|
| 43 |
+
llm = GroqLLM(
|
| 44 |
+
api_key="test-key",
|
| 45 |
+
default_model="custom-model",
|
| 46 |
+
temperature=0.5,
|
| 47 |
+
max_tokens=2048,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
assert llm._default_model == "custom-model"
|
| 51 |
+
assert llm._temperature == 0.5
|
| 52 |
+
assert llm._max_tokens == 2048
|
| 53 |
+
|
| 54 |
+
def test_complete_success(self, mock_groq_client):
|
| 55 |
+
"""Test successful completion."""
|
| 56 |
+
mock_response = MagicMock()
|
| 57 |
+
mock_response.choices = [MagicMock()]
|
| 58 |
+
mock_response.choices[0].message.content = "Generated text"
|
| 59 |
+
mock_response.choices[0].finish_reason = "stop"
|
| 60 |
+
mock_response.model = "llama-3.3-70b-versatile"
|
| 61 |
+
mock_response.usage = MagicMock()
|
| 62 |
+
mock_response.usage.prompt_tokens = 10
|
| 63 |
+
mock_response.usage.completion_tokens = 20
|
| 64 |
+
mock_response.usage.total_tokens = 30
|
| 65 |
+
|
| 66 |
+
mock_groq_client.chat.completions.create.return_value = mock_response
|
| 67 |
+
|
| 68 |
+
llm = GroqLLM(api_key="test-key")
|
| 69 |
+
response = llm.complete(
|
| 70 |
+
prompt="Test prompt",
|
| 71 |
+
system_prompt="System prompt"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
assert response.content == "Generated text"
|
| 75 |
+
assert response.finish_reason == "stop"
|
| 76 |
+
mock_groq_client.chat.completions.create.assert_called_once()
|
| 77 |
+
|
| 78 |
+
def test_complete_with_retry(self, mock_groq_client):
|
| 79 |
+
"""Test retry logic on failure."""
|
| 80 |
+
mock_groq_client.chat.completions.create.side_effect = [
|
| 81 |
+
Exception("Rate limited"),
|
| 82 |
+
MagicMock(
|
| 83 |
+
choices=[MagicMock(message=MagicMock(content="Success"), finish_reason="stop")],
|
| 84 |
+
model="test",
|
| 85 |
+
usage=MagicMock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
| 86 |
+
)
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
llm = GroqLLM(api_key="test-key", retry_delay=0.01)
|
| 90 |
+
response = llm.complete(prompt="Test")
|
| 91 |
+
|
| 92 |
+
assert response.content == "Success"
|
| 93 |
+
assert mock_groq_client.chat.completions.create.call_count == 2
|
| 94 |
+
|
| 95 |
+
def test_build_messages(self, mock_groq_client):
|
| 96 |
+
"""Test message building with system prompt."""
|
| 97 |
+
llm = GroqLLM(api_key="test-key")
|
| 98 |
+
|
| 99 |
+
messages = llm._build_messages(
|
| 100 |
+
prompt="User message",
|
| 101 |
+
system_prompt="System message"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
assert len(messages) == 2
|
| 105 |
+
assert messages[0]["role"] == "system"
|
| 106 |
+
assert messages[0]["content"] == "System message"
|
| 107 |
+
assert messages[1]["role"] == "user"
|
| 108 |
+
assert messages[1]["content"] == "User message"
|
| 109 |
+
|
| 110 |
+
def test_build_messages_no_system(self, mock_groq_client):
|
| 111 |
+
"""Test message building without system prompt."""
|
| 112 |
+
llm = GroqLLM(api_key="test-key")
|
| 113 |
+
|
| 114 |
+
messages = llm._build_messages(prompt="User message")
|
| 115 |
+
|
| 116 |
+
assert len(messages) == 1
|
| 117 |
+
assert messages[0]["role"] == "user"
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class TestLLMProviderInterface:
|
| 121 |
+
"""Tests for the abstract interface."""
|
| 122 |
+
|
| 123 |
+
def test_interface_methods(self):
|
| 124 |
+
"""Verify LLMProvider defines required methods."""
|
| 125 |
+
assert hasattr(LLMProvider, "complete")
|
| 126 |
+
assert hasattr(LLMProvider, "complete_with_image")
|
tests/test_memory.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for the SharedMemory class.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
from coda.core.memory import SharedMemory, MemoryEntry
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestSharedMemory:
|
| 14 |
+
"""Tests for the SharedMemory class."""
|
| 15 |
+
|
| 16 |
+
@pytest.fixture
|
| 17 |
+
def memory(self):
|
| 18 |
+
"""Create a fresh SharedMemory instance."""
|
| 19 |
+
return SharedMemory()
|
| 20 |
+
|
| 21 |
+
def test_store_and_retrieve(self, memory):
|
| 22 |
+
"""Test basic store and retrieve operations."""
|
| 23 |
+
memory.store(
|
| 24 |
+
key="test_key",
|
| 25 |
+
value={"data": "value"},
|
| 26 |
+
agent_name="TestAgent"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
result = memory.retrieve("test_key")
|
| 30 |
+
|
| 31 |
+
assert result == {"data": "value"}
|
| 32 |
+
|
| 33 |
+
def test_retrieve_nonexistent(self, memory):
|
| 34 |
+
"""Test retrieving a non-existent key."""
|
| 35 |
+
result = memory.retrieve("nonexistent")
|
| 36 |
+
|
| 37 |
+
assert result is None
|
| 38 |
+
|
| 39 |
+
def test_retrieve_entry(self, memory):
|
| 40 |
+
"""Test retrieving the full entry with metadata."""
|
| 41 |
+
memory.store(
|
| 42 |
+
key="test_key",
|
| 43 |
+
value="test_value",
|
| 44 |
+
agent_name="TestAgent",
|
| 45 |
+
metadata={"extra": "info"}
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
entry = memory.retrieve_entry("test_key")
|
| 49 |
+
|
| 50 |
+
assert entry is not None
|
| 51 |
+
assert entry.value == "test_value"
|
| 52 |
+
assert entry.agent_name == "TestAgent"
|
| 53 |
+
assert entry.metadata == {"extra": "info"}
|
| 54 |
+
assert isinstance(entry.timestamp, datetime)
|
| 55 |
+
|
| 56 |
+
def test_get_context(self, memory):
|
| 57 |
+
"""Test retrieving multiple keys as context."""
|
| 58 |
+
memory.store("key1", "value1", "Agent1")
|
| 59 |
+
memory.store("key2", "value2", "Agent2")
|
| 60 |
+
memory.store("key3", "value3", "Agent3")
|
| 61 |
+
|
| 62 |
+
context = memory.get_context(["key1", "key3", "nonexistent"])
|
| 63 |
+
|
| 64 |
+
assert context == {"key1": "value1", "key3": "value3"}
|
| 65 |
+
|
| 66 |
+
def test_get_all(self, memory):
|
| 67 |
+
"""Test retrieving all stored values."""
|
| 68 |
+
memory.store("key1", "value1", "Agent")
|
| 69 |
+
memory.store("key2", "value2", "Agent")
|
| 70 |
+
|
| 71 |
+
all_data = memory.get_all()
|
| 72 |
+
|
| 73 |
+
assert all_data == {"key1": "value1", "key2": "value2"}
|
| 74 |
+
|
| 75 |
+
def test_overwrite_value(self, memory):
|
| 76 |
+
"""Test overwriting an existing value."""
|
| 77 |
+
memory.store("key", "original", "Agent")
|
| 78 |
+
memory.store("key", "updated", "Agent")
|
| 79 |
+
|
| 80 |
+
assert memory.retrieve("key") == "updated"
|
| 81 |
+
|
| 82 |
+
def test_history_tracking(self, memory):
|
| 83 |
+
"""Test that history is tracked for all operations."""
|
| 84 |
+
memory.store("key1", "v1", "Agent1")
|
| 85 |
+
memory.store("key2", "v2", "Agent2")
|
| 86 |
+
memory.store("key1", "v1_updated", "Agent1")
|
| 87 |
+
|
| 88 |
+
history = memory.get_history()
|
| 89 |
+
|
| 90 |
+
assert len(history) == 3
|
| 91 |
+
assert history[0].key == "key1"
|
| 92 |
+
assert history[1].key == "key2"
|
| 93 |
+
assert history[2].value == "v1_updated"
|
| 94 |
+
|
| 95 |
+
def test_history_filter_by_agent(self, memory):
|
| 96 |
+
"""Test filtering history by agent name."""
|
| 97 |
+
memory.store("k1", "v1", "Agent1")
|
| 98 |
+
memory.store("k2", "v2", "Agent2")
|
| 99 |
+
memory.store("k3", "v3", "Agent1")
|
| 100 |
+
|
| 101 |
+
agent1_history = memory.get_history(agent_name="Agent1")
|
| 102 |
+
|
| 103 |
+
assert len(agent1_history) == 2
|
| 104 |
+
assert all(e.agent_name == "Agent1" for e in agent1_history)
|
| 105 |
+
|
| 106 |
+
def test_has_key(self, memory):
|
| 107 |
+
"""Test key existence check."""
|
| 108 |
+
memory.store("exists", "value", "Agent")
|
| 109 |
+
|
| 110 |
+
assert memory.has_key("exists") is True
|
| 111 |
+
assert memory.has_key("not_exists") is False
|
| 112 |
+
|
| 113 |
+
def test_clear(self, memory):
|
| 114 |
+
"""Test clearing all data."""
|
| 115 |
+
memory.store("k1", "v1", "Agent")
|
| 116 |
+
memory.store("k2", "v2", "Agent")
|
| 117 |
+
|
| 118 |
+
memory.clear()
|
| 119 |
+
|
| 120 |
+
assert memory.retrieve("k1") is None
|
| 121 |
+
assert memory.retrieve("k2") is None
|
| 122 |
+
assert len(memory.get_history()) == 0
|
| 123 |
+
|
| 124 |
+
def test_keys(self, memory):
|
| 125 |
+
"""Test getting all keys."""
|
| 126 |
+
memory.store("a", 1, "Agent")
|
| 127 |
+
memory.store("b", 2, "Agent")
|
| 128 |
+
memory.store("c", 3, "Agent")
|
| 129 |
+
|
| 130 |
+
keys = memory.keys()
|
| 131 |
+
|
| 132 |
+
assert set(keys) == {"a", "b", "c"}
|
| 133 |
+
|
| 134 |
+
def test_thread_safety(self, memory):
|
| 135 |
+
"""Test that operations are thread-safe."""
|
| 136 |
+
results = []
|
| 137 |
+
errors = []
|
| 138 |
+
|
| 139 |
+
def writer(n):
|
| 140 |
+
try:
|
| 141 |
+
for i in range(100):
|
| 142 |
+
memory.store(f"key_{n}_{i}", i, f"Agent{n}")
|
| 143 |
+
except Exception as e:
|
| 144 |
+
errors.append(e)
|
| 145 |
+
|
| 146 |
+
def reader():
|
| 147 |
+
try:
|
| 148 |
+
for _ in range(100):
|
| 149 |
+
memory.get_all()
|
| 150 |
+
memory.keys()
|
| 151 |
+
except Exception as e:
|
| 152 |
+
errors.append(e)
|
| 153 |
+
|
| 154 |
+
threads = [
|
| 155 |
+
threading.Thread(target=writer, args=(i,))
|
| 156 |
+
for i in range(3)
|
| 157 |
+
]
|
| 158 |
+
threads.append(threading.Thread(target=reader))
|
| 159 |
+
|
| 160 |
+
for t in threads:
|
| 161 |
+
t.start()
|
| 162 |
+
for t in threads:
|
| 163 |
+
t.join()
|
| 164 |
+
|
| 165 |
+
assert len(errors) == 0
|