""" Base agent interface for CoDA. Defines the contract that all specialized agents must implement, providing common functionality for LLM interaction and memory access. """ import json import logging import re from abc import ABC, abstractmethod from typing import Any, Optional, TypeVar, Generic from pydantic import BaseModel from coda.core.llm import LLMProvider from coda.core.memory import SharedMemory logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) class AgentContext(BaseModel): """Context passed to an agent during execution.""" query: str data_paths: list[str] = [] iteration: int = 0 feedback: Optional[str] = None class BaseAgent(ABC, Generic[T]): """ Abstract base class for all CoDA agents. Each agent specializes in a specific aspect of the visualization pipeline. Agents communicate through shared memory and use an LLM for reasoning. """ def __init__( self, llm: LLMProvider, memory: SharedMemory, name: Optional[str] = None, ) -> None: self._llm = llm self._memory = memory self._name = name or self.__class__.__name__ @property def name(self) -> str: """Get the agent's name.""" return self._name def execute(self, context: AgentContext) -> T: """ Execute the agent's task. Args: context: The execution context containing query and data info Returns: The agent's structured output """ logger.info(f"[{self._name}] Starting execution") prompt = self._build_prompt(context) system_prompt = self._get_system_prompt() response = self._llm.complete( prompt=prompt, system_prompt=system_prompt, ) result = self._parse_response(response.content) self._store_result(result) logger.info(f"[{self._name}] Execution complete") return result @abstractmethod def _build_prompt(self, context: AgentContext) -> str: """ Build the prompt for the LLM. Args: context: The execution context Returns: The formatted prompt string """ pass @abstractmethod def _get_system_prompt(self) -> str: """ Get the system prompt defining the agent's persona. Returns: The system prompt string """ pass @abstractmethod def _parse_response(self, response: str) -> T: """ Parse the LLM response into a structured output. Args: response: The raw LLM response Returns: The parsed and validated output """ pass @abstractmethod def _get_output_key(self) -> str: """ Get the key used to store this agent's output in memory. Returns: The memory key string """ pass def _store_result(self, result: T) -> None: """Store the agent's result in shared memory.""" self._memory.store( key=self._get_output_key(), value=result.model_dump(), agent_name=self._name, ) def _get_from_memory(self, key: str) -> Optional[Any]: """Retrieve a value from shared memory.""" return self._memory.retrieve(key) def _extract_json(self, text: str) -> dict[str, Any]: """ Extract JSON from LLM response text. Handles responses where JSON is wrapped in markdown code blocks and sanitizes control characters that can break JSON parsing. """ json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text) if json_match: text = json_match.group(1) text = text.strip() try: return json.loads(text) except json.JSONDecodeError: pass # Try to fix unescaped newlines/tabs inside JSON strings # First, find the JSON object boundaries try: obj_match = re.search(r'(\{[\s\S]*\})', text, re.DOTALL) if obj_match: json_text = obj_match.group(1) # Replace problematic control characters (but NOT newlines between key:value pairs) # Only remove NUL and other truly invalid chars json_text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', json_text) try: return json.loads(json_text) except json.JSONDecodeError: pass # If still failing, try to properly escape newlines within strings # by parsing character by character fixed = self._fix_json_strings(json_text) return json.loads(fixed) except Exception: pass logger.error(f"Failed to parse JSON after sanitization attempts") logger.debug(f"Raw text: {text[:500]}...") raise ValueError(f"Invalid JSON in response: Could not parse after sanitization") def _fix_json_strings(self, text: str) -> str: """Fix unescaped newlines and control characters inside JSON strings.""" result = [] in_string = False escape_next = False for char in text: if escape_next: result.append(char) escape_next = False continue if char == '\\': result.append(char) escape_next = True continue if char == '"': in_string = not in_string result.append(char) continue if in_string: # Escape problematic characters inside strings if char == '\n': result.append('\\n') elif char == '\r': result.append('\\r') elif char == '\t': result.append('\\t') else: result.append(char) else: result.append(char) return ''.join(result)