Spaces:
Sleeping
Sleeping
| """ | |
| Workflow Orchestrator for CoDA. | |
| Manages the multi-agent pipeline, coordinating agent execution, | |
| handling feedback loops, and implementing quality-driven halting. | |
| """ | |
| import logging | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional, Callable | |
| from coda.config import Config, get_config | |
| from coda.core.llm import GroqLLM, LLMProvider | |
| from coda.core.memory import SharedMemory | |
| from coda.core.base_agent import AgentContext | |
| from coda.agents.query_analyzer import QueryAnalyzerAgent, QueryAnalysis | |
| from coda.agents.data_processor import DataProcessorAgent, DataAnalysis | |
| from coda.agents.viz_mapping import VizMappingAgent, VisualMapping | |
| from coda.agents.search_agent import SearchAgent, SearchResult | |
| from coda.agents.design_explorer import DesignExplorerAgent, DesignSpec | |
| from coda.agents.code_generator import CodeGeneratorAgent, GeneratedCode | |
| from coda.agents.debug_agent import DebugAgent, ExecutionResult | |
| from coda.agents.visual_evaluator import VisualEvaluatorAgent, VisualEvaluation | |
| logger = logging.getLogger(__name__) | |
| class PipelineResult: | |
| """Final result from the CoDA pipeline.""" | |
| success: bool | |
| output_file: Optional[str] | |
| evaluation: Optional[VisualEvaluation] | |
| iterations: int | |
| error: Optional[str] = None | |
| def scores(self) -> Optional[dict]: | |
| """Get quality scores if evaluation exists.""" | |
| if self.evaluation: | |
| return self.evaluation.scores.model_dump() | |
| return None | |
| class CodaOrchestrator: | |
| """ | |
| Orchestrates the CoDA multi-agent visualization pipeline. | |
| Coordinates agent execution in sequence, manages the shared memory, | |
| and implements iterative refinement through feedback loops. | |
| """ | |
| def __init__( | |
| self, | |
| config: Optional[Config] = None, | |
| llm: Optional[LLMProvider] = None, | |
| progress_callback: Optional[Callable[[str, float], None]] = None, | |
| ) -> None: | |
| self._config = config or get_config() | |
| self._llm = llm or self._create_llm() | |
| self._memory = SharedMemory() | |
| self._progress_callback = progress_callback | |
| self._agents = self._create_agents() | |
| def _create_llm(self) -> GroqLLM: | |
| """Create the LLM instance.""" | |
| return GroqLLM( | |
| api_key=self._config.groq_api_key, | |
| default_model=self._config.model.default_model, | |
| vision_model=self._config.model.vision_model, | |
| temperature=self._config.model.temperature, | |
| max_tokens=self._config.model.max_tokens, | |
| max_retries=self._config.model.max_retries, | |
| ) | |
| def _create_agents(self) -> dict: | |
| """Initialize all agents with shared resources.""" | |
| return { | |
| "query_analyzer": QueryAnalyzerAgent(self._llm, self._memory), | |
| "data_processor": DataProcessorAgent(self._llm, self._memory), | |
| "viz_mapping": VizMappingAgent(self._llm, self._memory), | |
| "search_agent": SearchAgent(self._llm, self._memory), | |
| "design_explorer": DesignExplorerAgent(self._llm, self._memory), | |
| "code_generator": CodeGeneratorAgent(self._llm, self._memory), | |
| "debug_agent": DebugAgent( | |
| self._llm, | |
| self._memory, | |
| timeout_seconds=self._config.execution.code_timeout_seconds, | |
| output_directory=self._config.execution.output_directory, | |
| ), | |
| "visual_evaluator": VisualEvaluatorAgent( | |
| self._llm, | |
| self._memory, | |
| min_overall_score=self._config.quality.minimum_overall_score, | |
| ), | |
| } | |
| def run( | |
| self, | |
| query: str, | |
| data_paths: list[str], | |
| ) -> PipelineResult: | |
| """ | |
| Execute the full visualization pipeline. | |
| Args: | |
| query: Natural language visualization request | |
| data_paths: Paths to data files | |
| Returns: | |
| PipelineResult with output file and evaluation | |
| """ | |
| logger.info(f"Starting CoDA pipeline for query: {query[:50]}...") | |
| self._memory.clear() | |
| validated_paths = self._validate_data_paths(data_paths) | |
| if not validated_paths: | |
| return PipelineResult( | |
| success=False, | |
| output_file=None, | |
| evaluation=None, | |
| iterations=0, | |
| error="No valid data files provided", | |
| ) | |
| context = AgentContext( | |
| query=query, | |
| data_paths=validated_paths, | |
| iteration=0, | |
| ) | |
| try: | |
| self._run_initial_pipeline(context) | |
| except Exception as e: | |
| logger.error(f"Initial pipeline failed: {e}") | |
| return PipelineResult( | |
| success=False, | |
| output_file=None, | |
| evaluation=None, | |
| iterations=0, | |
| error=str(e), | |
| ) | |
| max_iterations = self._config.execution.max_refinement_iterations | |
| final_result = self._run_refinement_loop(context, max_iterations) | |
| return final_result | |
| def _validate_data_paths(self, data_paths: list[str]) -> list[str]: | |
| """Validate that data files exist.""" | |
| valid_paths = [] | |
| for path in data_paths: | |
| if Path(path).exists(): | |
| valid_paths.append(path) | |
| else: | |
| logger.warning(f"Data file not found: {path}") | |
| return valid_paths | |
| def _run_initial_pipeline(self, context: AgentContext) -> None: | |
| """Run the initial agent pipeline.""" | |
| steps = [ | |
| ("query_analyzer", "Analyzing query...", 0.1), | |
| ("data_processor", "Processing data...", 0.2), | |
| ("viz_mapping", "Mapping visualization...", 0.3), | |
| ("search_agent", "Searching examples...", 0.4), | |
| ("design_explorer", "Designing visualization...", 0.5), | |
| ("code_generator", "Generating code...", 0.7), | |
| ("debug_agent", "Executing code...", 0.85), | |
| ("visual_evaluator", "Evaluating output...", 0.95), | |
| ] | |
| for agent_name, status, progress in steps: | |
| self._report_progress(status, progress) | |
| agent = self._agents[agent_name] | |
| agent.execute(context) | |
| def _run_refinement_loop( | |
| self, | |
| context: AgentContext, | |
| max_iterations: int, | |
| ) -> PipelineResult: | |
| """Run the iterative refinement loop.""" | |
| for iteration in range(max_iterations): | |
| evaluation = self._memory.retrieve("visual_evaluation") | |
| if not evaluation: | |
| break | |
| if isinstance(evaluation, dict): | |
| passes = evaluation.get("passes_threshold", False) | |
| eval_obj = VisualEvaluation(**evaluation) | |
| else: | |
| passes = evaluation.passes_threshold | |
| eval_obj = evaluation | |
| if passes: | |
| logger.info(f"Quality threshold met at iteration {iteration}") | |
| return self._create_success_result(eval_obj, iteration + 1) | |
| if iteration >= max_iterations - 1: | |
| logger.info("Max iterations reached") | |
| break | |
| logger.info(f"Refinement iteration {iteration + 1}") | |
| context = self._create_refinement_context(context, eval_obj, iteration + 1) | |
| self._report_progress(f"Refining (iteration {iteration + 2})...", 0.5) | |
| try: | |
| self._run_refinement_agents(context) | |
| except Exception as e: | |
| logger.error(f"Refinement failed: {e}") | |
| break | |
| final_eval = self._memory.retrieve("visual_evaluation") | |
| if isinstance(final_eval, dict): | |
| final_eval = VisualEvaluation(**final_eval) | |
| return self._create_success_result(final_eval, max_iterations) | |
| def _run_refinement_agents(self, context: AgentContext) -> None: | |
| """Run agents that participate in refinement.""" | |
| refinement_agents = [ | |
| "design_explorer", | |
| "code_generator", | |
| "debug_agent", | |
| "visual_evaluator", | |
| ] | |
| for agent_name in refinement_agents: | |
| agent = self._agents[agent_name] | |
| agent.execute(context) | |
| def _create_refinement_context( | |
| self, | |
| original_context: AgentContext, | |
| evaluation: VisualEvaluation, | |
| iteration: int, | |
| ) -> AgentContext: | |
| """Create context for refinement iteration.""" | |
| feedback_parts = [] | |
| if evaluation.issues: | |
| feedback_parts.append(f"Issues: {', '.join(evaluation.issues[:3])}") | |
| if evaluation.priority_fixes: | |
| feedback_parts.append(f"Fix: {', '.join(evaluation.priority_fixes[:2])}") | |
| feedback = " | ".join(feedback_parts) | |
| return AgentContext( | |
| query=original_context.query, | |
| data_paths=original_context.data_paths, | |
| iteration=iteration, | |
| feedback=feedback, | |
| ) | |
| def _create_success_result( | |
| self, | |
| evaluation: Optional[VisualEvaluation], | |
| iterations: int, | |
| ) -> PipelineResult: | |
| """Create a successful pipeline result.""" | |
| execution_result = self._memory.retrieve("execution_result") | |
| output_file = None | |
| if execution_result: | |
| if isinstance(execution_result, dict): | |
| output_file = execution_result.get("output_file") | |
| else: | |
| output_file = execution_result.output_file | |
| return PipelineResult( | |
| success=output_file is not None and Path(output_file).exists(), | |
| output_file=output_file, | |
| evaluation=evaluation, | |
| iterations=iterations, | |
| ) | |
| def _report_progress(self, status: str, progress: float) -> None: | |
| """Report progress to callback if set.""" | |
| if self._progress_callback: | |
| self._progress_callback(status, progress) | |
| logger.info(f"[{progress:.0%}] {status}") | |
| def get_memory_state(self) -> dict: | |
| """Get the current state of shared memory for debugging.""" | |
| return self._memory.get_all() | |