SCoDA / coda /orchestrator.py
vanishingradient's picture
Added init files
9281fab
"""
Workflow Orchestrator for CoDA.
Manages the multi-agent pipeline, coordinating agent execution,
handling feedback loops, and implementing quality-driven halting.
"""
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Callable
from coda.config import Config, get_config
from coda.core.llm import GroqLLM, LLMProvider
from coda.core.memory import SharedMemory
from coda.core.base_agent import AgentContext
from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis
from coda.agents.data_processor import DataProcessorAgent, DataAnalysis
from coda.agents.viz_mapping import VizMappingAgent, VisualMapping
from coda.agents.search_agent import SearchAgent, SearchResult
from coda.agents.design_explorer import DesignExplorerAgent, DesignSpec
from coda.agents.code_generator import CodeGeneratorAgent, GeneratedCode
from coda.agents.debug_agent import DebugAgent, ExecutionResult
from coda.agents.visual_evaluator import VisualEvaluatorAgent, VisualEvaluation
logger = logging.getLogger(__name__)
@dataclass
class PipelineResult:
"""Final result from the CoDA pipeline."""
success: bool
output_file: Optional[str]
evaluation: Optional[VisualEvaluation]
iterations: int
error: Optional[str] = None
@property
def scores(self) -> Optional[dict]:
"""Get quality scores if evaluation exists."""
if self.evaluation:
return self.evaluation.scores.model_dump()
return None
class CodaOrchestrator:
"""
Orchestrates the CoDA multi-agent visualization pipeline.
Coordinates agent execution in sequence, manages the shared memory,
and implements iterative refinement through feedback loops.
"""
def __init__(
self,
config: Optional[Config] = None,
llm: Optional[LLMProvider] = None,
progress_callback: Optional[Callable[[str, float], None]] = None,
) -> None:
self._config = config or get_config()
self._llm = llm or self._create_llm()
self._memory = SharedMemory()
self._progress_callback = progress_callback
self._agents = self._create_agents()
def _create_llm(self) -> GroqLLM:
"""Create the LLM instance."""
return GroqLLM(
api_key=self._config.groq_api_key,
default_model=self._config.model.default_model,
vision_model=self._config.model.vision_model,
temperature=self._config.model.temperature,
max_tokens=self._config.model.max_tokens,
max_retries=self._config.model.max_retries,
)
def _create_agents(self) -> dict:
"""Initialize all agents with shared resources."""
return {
"query_analyzer": QueryAnalyzerAgent(self._llm, self._memory),
"data_processor": DataProcessorAgent(self._llm, self._memory),
"viz_mapping": VizMappingAgent(self._llm, self._memory),
"search_agent": SearchAgent(self._llm, self._memory),
"design_explorer": DesignExplorerAgent(self._llm, self._memory),
"code_generator": CodeGeneratorAgent(self._llm, self._memory),
"debug_agent": DebugAgent(
self._llm,
self._memory,
timeout_seconds=self._config.execution.code_timeout_seconds,
output_directory=self._config.execution.output_directory,
),
"visual_evaluator": VisualEvaluatorAgent(
self._llm,
self._memory,
min_overall_score=self._config.quality.minimum_overall_score,
),
}
def run(
self,
query: str,
data_paths: list[str],
) -> PipelineResult:
"""
Execute the full visualization pipeline.
Args:
query: Natural language visualization request
data_paths: Paths to data files
Returns:
PipelineResult with output file and evaluation
"""
logger.info(f"Starting CoDA pipeline for query: {query[:50]}...")
self._memory.clear()
validated_paths = self._validate_data_paths(data_paths)
if not validated_paths:
return PipelineResult(
success=False,
output_file=None,
evaluation=None,
iterations=0,
error="No valid data files provided",
)
context = AgentContext(
query=query,
data_paths=validated_paths,
iteration=0,
)
try:
self._run_initial_pipeline(context)
except Exception as e:
logger.error(f"Initial pipeline failed: {e}")
return PipelineResult(
success=False,
output_file=None,
evaluation=None,
iterations=0,
error=str(e),
)
max_iterations = self._config.execution.max_refinement_iterations
final_result = self._run_refinement_loop(context, max_iterations)
return final_result
def _validate_data_paths(self, data_paths: list[str]) -> list[str]:
"""Validate that data files exist."""
valid_paths = []
for path in data_paths:
if Path(path).exists():
valid_paths.append(path)
else:
logger.warning(f"Data file not found: {path}")
return valid_paths
def _run_initial_pipeline(self, context: AgentContext) -> None:
"""Run the initial agent pipeline."""
steps = [
("query_analyzer", "Analyzing query...", 0.1),
("data_processor", "Processing data...", 0.2),
("viz_mapping", "Mapping visualization...", 0.3),
("search_agent", "Searching examples...", 0.4),
("design_explorer", "Designing visualization...", 0.5),
("code_generator", "Generating code...", 0.7),
("debug_agent", "Executing code...", 0.85),
("visual_evaluator", "Evaluating output...", 0.95),
]
for agent_name, status, progress in steps:
self._report_progress(status, progress)
agent = self._agents[agent_name]
agent.execute(context)
def _run_refinement_loop(
self,
context: AgentContext,
max_iterations: int,
) -> PipelineResult:
"""Run the iterative refinement loop."""
for iteration in range(max_iterations):
evaluation = self._memory.retrieve("visual_evaluation")
if not evaluation:
break
if isinstance(evaluation, dict):
passes = evaluation.get("passes_threshold", False)
eval_obj = VisualEvaluation(**evaluation)
else:
passes = evaluation.passes_threshold
eval_obj = evaluation
if passes:
logger.info(f"Quality threshold met at iteration {iteration}")
return self._create_success_result(eval_obj, iteration + 1)
if iteration >= max_iterations - 1:
logger.info("Max iterations reached")
break
logger.info(f"Refinement iteration {iteration + 1}")
context = self._create_refinement_context(context, eval_obj, iteration + 1)
self._report_progress(f"Refining (iteration {iteration + 2})...", 0.5)
try:
self._run_refinement_agents(context)
except Exception as e:
logger.error(f"Refinement failed: {e}")
break
final_eval = self._memory.retrieve("visual_evaluation")
if isinstance(final_eval, dict):
final_eval = VisualEvaluation(**final_eval)
return self._create_success_result(final_eval, max_iterations)
def _run_refinement_agents(self, context: AgentContext) -> None:
"""Run agents that participate in refinement."""
refinement_agents = [
"design_explorer",
"code_generator",
"debug_agent",
"visual_evaluator",
]
for agent_name in refinement_agents:
agent = self._agents[agent_name]
agent.execute(context)
def _create_refinement_context(
self,
original_context: AgentContext,
evaluation: VisualEvaluation,
iteration: int,
) -> AgentContext:
"""Create context for refinement iteration."""
feedback_parts = []
if evaluation.issues:
feedback_parts.append(f"Issues: {', '.join(evaluation.issues[:3])}")
if evaluation.priority_fixes:
feedback_parts.append(f"Fix: {', '.join(evaluation.priority_fixes[:2])}")
feedback = " | ".join(feedback_parts)
return AgentContext(
query=original_context.query,
data_paths=original_context.data_paths,
iteration=iteration,
feedback=feedback,
)
def _create_success_result(
self,
evaluation: Optional[VisualEvaluation],
iterations: int,
) -> PipelineResult:
"""Create a successful pipeline result."""
execution_result = self._memory.retrieve("execution_result")
output_file = None
if execution_result:
if isinstance(execution_result, dict):
output_file = execution_result.get("output_file")
else:
output_file = execution_result.output_file
return PipelineResult(
success=output_file is not None and Path(output_file).exists(),
output_file=output_file,
evaluation=evaluation,
iterations=iterations,
)
def _report_progress(self, status: str, progress: float) -> None:
"""Report progress to callback if set."""
if self._progress_callback:
self._progress_callback(status, progress)
logger.info(f"[{progress:.0%}] {status}")
def get_memory_state(self) -> dict:
"""Get the current state of shared memory for debugging."""
return self._memory.get_all()