SCoDA / coda /core /base_agent.py
vanishingradient's picture
Added init files
9281fab
"""
Base agent interface for CoDA.
Defines the contract that all specialized agents must implement,
providing common functionality for LLM interaction and memory access.
"""
import json
import logging
import re
from abc import ABC, abstractmethod
from typing import Any, Optional, TypeVar, Generic
from pydantic import BaseModel
from coda.core.llm import LLMProvider
from coda.core.memory import SharedMemory
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
class AgentContext(BaseModel):
"""Context passed to an agent during execution."""
query: str
data_paths: list[str] = []
iteration: int = 0
feedback: Optional[str] = None
class BaseAgent(ABC, Generic[T]):
"""
Abstract base class for all CoDA agents.
Each agent specializes in a specific aspect of the visualization pipeline.
Agents communicate through shared memory and use an LLM for reasoning.
"""
def __init__(
self,
llm: LLMProvider,
memory: SharedMemory,
name: Optional[str] = None,
) -> None:
self._llm = llm
self._memory = memory
self._name = name or self.__class__.__name__
@property
def name(self) -> str:
"""Get the agent's name."""
return self._name
def execute(self, context: AgentContext) -> T:
"""
Execute the agent's task.
Args:
context: The execution context containing query and data info
Returns:
The agent's structured output
"""
logger.info(f"[{self._name}] Starting execution")
prompt = self._build_prompt(context)
system_prompt = self._get_system_prompt()
response = self._llm.complete(
prompt=prompt,
system_prompt=system_prompt,
)
result = self._parse_response(response.content)
self._store_result(result)
logger.info(f"[{self._name}] Execution complete")
return result
@abstractmethod
def _build_prompt(self, context: AgentContext) -> str:
"""
Build the prompt for the LLM.
Args:
context: The execution context
Returns:
The formatted prompt string
"""
pass
@abstractmethod
def _get_system_prompt(self) -> str:
"""
Get the system prompt defining the agent's persona.
Returns:
The system prompt string
"""
pass
@abstractmethod
def _parse_response(self, response: str) -> T:
"""
Parse the LLM response into a structured output.
Args:
response: The raw LLM response
Returns:
The parsed and validated output
"""
pass
@abstractmethod
def _get_output_key(self) -> str:
"""
Get the key used to store this agent's output in memory.
Returns:
The memory key string
"""
pass
def _store_result(self, result: T) -> None:
"""Store the agent's result in shared memory."""
self._memory.store(
key=self._get_output_key(),
value=result.model_dump(),
agent_name=self._name,
)
def _get_from_memory(self, key: str) -> Optional[Any]:
"""Retrieve a value from shared memory."""
return self._memory.retrieve(key)
def _extract_json(self, text: str) -> dict[str, Any]:
"""
Extract JSON from LLM response text.
Handles responses where JSON is wrapped in markdown code blocks
and sanitizes control characters that can break JSON parsing.
"""
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
if json_match:
text = json_match.group(1)
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try to fix unescaped newlines/tabs inside JSON strings
# First, find the JSON object boundaries
try:
obj_match = re.search(r'(\{[\s\S]*\})', text, re.DOTALL)
if obj_match:
json_text = obj_match.group(1)
# Replace problematic control characters (but NOT newlines between key:value pairs)
# Only remove NUL and other truly invalid chars
json_text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', json_text)
try:
return json.loads(json_text)
except json.JSONDecodeError:
pass
# If still failing, try to properly escape newlines within strings
# by parsing character by character
fixed = self._fix_json_strings(json_text)
return json.loads(fixed)
except Exception:
pass
logger.error(f"Failed to parse JSON after sanitization attempts")
logger.debug(f"Raw text: {text[:500]}...")
raise ValueError(f"Invalid JSON in response: Could not parse after sanitization")
def _fix_json_strings(self, text: str) -> str:
"""Fix unescaped newlines and control characters inside JSON strings."""
result = []
in_string = False
escape_next = False
for char in text:
if escape_next:
result.append(char)
escape_next = False
continue
if char == '\\':
result.append(char)
escape_next = True
continue
if char == '"':
in_string = not in_string
result.append(char)
continue
if in_string:
# Escape problematic characters inside strings
if char == '\n':
result.append('\\n')
elif char == '\r':
result.append('\\r')
elif char == '\t':
result.append('\\t')
else:
result.append(char)
else:
result.append(char)
return ''.join(result)