Spaces:
Sleeping
Sleeping
| """ | |
| Base agent interface for CoDA. | |
| Defines the contract that all specialized agents must implement, | |
| providing common functionality for LLM interaction and memory access. | |
| """ | |
| import json | |
| import logging | |
| import re | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Optional, TypeVar, Generic | |
| from pydantic import BaseModel | |
| from coda.core.llm import LLMProvider | |
| from coda.core.memory import SharedMemory | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T", bound=BaseModel) | |
| class AgentContext(BaseModel): | |
| """Context passed to an agent during execution.""" | |
| query: str | |
| data_paths: list[str] = [] | |
| iteration: int = 0 | |
| feedback: Optional[str] = None | |
| class BaseAgent(ABC, Generic[T]): | |
| """ | |
| Abstract base class for all CoDA agents. | |
| Each agent specializes in a specific aspect of the visualization pipeline. | |
| Agents communicate through shared memory and use an LLM for reasoning. | |
| """ | |
| def __init__( | |
| self, | |
| llm: LLMProvider, | |
| memory: SharedMemory, | |
| name: Optional[str] = None, | |
| ) -> None: | |
| self._llm = llm | |
| self._memory = memory | |
| self._name = name or self.__class__.__name__ | |
| def name(self) -> str: | |
| """Get the agent's name.""" | |
| return self._name | |
| def execute(self, context: AgentContext) -> T: | |
| """ | |
| Execute the agent's task. | |
| Args: | |
| context: The execution context containing query and data info | |
| Returns: | |
| The agent's structured output | |
| """ | |
| logger.info(f"[{self._name}] Starting execution") | |
| prompt = self._build_prompt(context) | |
| system_prompt = self._get_system_prompt() | |
| response = self._llm.complete( | |
| prompt=prompt, | |
| system_prompt=system_prompt, | |
| ) | |
| result = self._parse_response(response.content) | |
| self._store_result(result) | |
| logger.info(f"[{self._name}] Execution complete") | |
| return result | |
| def _build_prompt(self, context: AgentContext) -> str: | |
| """ | |
| Build the prompt for the LLM. | |
| Args: | |
| context: The execution context | |
| Returns: | |
| The formatted prompt string | |
| """ | |
| pass | |
| def _get_system_prompt(self) -> str: | |
| """ | |
| Get the system prompt defining the agent's persona. | |
| Returns: | |
| The system prompt string | |
| """ | |
| pass | |
| def _parse_response(self, response: str) -> T: | |
| """ | |
| Parse the LLM response into a structured output. | |
| Args: | |
| response: The raw LLM response | |
| Returns: | |
| The parsed and validated output | |
| """ | |
| pass | |
| def _get_output_key(self) -> str: | |
| """ | |
| Get the key used to store this agent's output in memory. | |
| Returns: | |
| The memory key string | |
| """ | |
| pass | |
| def _store_result(self, result: T) -> None: | |
| """Store the agent's result in shared memory.""" | |
| self._memory.store( | |
| key=self._get_output_key(), | |
| value=result.model_dump(), | |
| agent_name=self._name, | |
| ) | |
| def _get_from_memory(self, key: str) -> Optional[Any]: | |
| """Retrieve a value from shared memory.""" | |
| return self._memory.retrieve(key) | |
| def _extract_json(self, text: str) -> dict[str, Any]: | |
| """ | |
| Extract JSON from LLM response text. | |
| Handles responses where JSON is wrapped in markdown code blocks | |
| and sanitizes control characters that can break JSON parsing. | |
| """ | |
| json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text) | |
| if json_match: | |
| text = json_match.group(1) | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to fix unescaped newlines/tabs inside JSON strings | |
| # First, find the JSON object boundaries | |
| try: | |
| obj_match = re.search(r'(\{[\s\S]*\})', text, re.DOTALL) | |
| if obj_match: | |
| json_text = obj_match.group(1) | |
| # Replace problematic control characters (but NOT newlines between key:value pairs) | |
| # Only remove NUL and other truly invalid chars | |
| json_text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', json_text) | |
| try: | |
| return json.loads(json_text) | |
| except json.JSONDecodeError: | |
| pass | |
| # If still failing, try to properly escape newlines within strings | |
| # by parsing character by character | |
| fixed = self._fix_json_strings(json_text) | |
| return json.loads(fixed) | |
| except Exception: | |
| pass | |
| logger.error(f"Failed to parse JSON after sanitization attempts") | |
| logger.debug(f"Raw text: {text[:500]}...") | |
| raise ValueError(f"Invalid JSON in response: Could not parse after sanitization") | |
| def _fix_json_strings(self, text: str) -> str: | |
| """Fix unescaped newlines and control characters inside JSON strings.""" | |
| result = [] | |
| in_string = False | |
| escape_next = False | |
| for char in text: | |
| if escape_next: | |
| result.append(char) | |
| escape_next = False | |
| continue | |
| if char == '\\': | |
| result.append(char) | |
| escape_next = True | |
| continue | |
| if char == '"': | |
| in_string = not in_string | |
| result.append(char) | |
| continue | |
| if in_string: | |
| # Escape problematic characters inside strings | |
| if char == '\n': | |
| result.append('\\n') | |
| elif char == '\r': | |
| result.append('\\r') | |
| elif char == '\t': | |
| result.append('\\t') | |
| else: | |
| result.append(char) | |
| else: | |
| result.append(char) | |
| return ''.join(result) | |