scrapeRL / backend /app /agents /coordinator.py
NeerajCodz's picture
fix: replace deprecated datetime.utcnow with timezone-aware
bfe0e24
"""Agent coordinator for orchestrating multiple agents with message passing."""
import asyncio
from datetime import datetime, timezone
from enum import Enum
from typing import Any
from app.core.action import Action, ActionType
from app.core.observation import Observation
from .base import BaseAgent
from .extractor import ExtractorAgent
from .memory_agent import MemoryAgent
from .navigator import NavigatorAgent
from .planner import PlannerAgent
from .verifier import VerifierAgent
class AgentRole(str, Enum):
"""Roles that agents can fulfill."""
PLANNER = "planner"
NAVIGATOR = "navigator"
EXTRACTOR = "extractor"
VERIFIER = "verifier"
MEMORY = "memory"
class Message:
"""A message between agents."""
def __init__(
self,
sender: str,
recipient: str,
message_type: str,
content: dict[str, Any],
priority: int = 0,
):
"""Initialize a message."""
self.sender = sender
self.recipient = recipient
self.message_type = message_type
self.content = content
self.priority = priority
self.timestamp = datetime.now(timezone.utc)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"sender": self.sender,
"recipient": self.recipient,
"message_type": self.message_type,
"content": self.content,
"priority": self.priority,
"timestamp": self.timestamp.isoformat(),
}
class AgentCoordinator:
"""
Orchestrator for multiple specialized agents.
The AgentCoordinator manages:
- Agent lifecycle and initialization
- Message passing between agents
- Action selection and routing
- Coordination of multi-agent workflows
- Error handling and recovery
"""
def __init__(
self,
config: dict[str, Any] | None = None,
):
"""
Initialize the AgentCoordinator.
Args:
config: Optional configuration with keys:
- enable_parallel: Allow parallel agent execution (default: False)
- max_messages_per_step: Max messages per step (default: 10)
- default_timeout: Default timeout for agent actions (default: 30)
"""
self.config = config or {}
self.enable_parallel = self.config.get("enable_parallel", False)
self.max_messages_per_step = self.config.get("max_messages_per_step", 10)
self.default_timeout = self.config.get("default_timeout", 30)
# Initialize agents
self._agents: dict[str, BaseAgent] = {}
self._message_queue: list[Message] = []
self._action_history: list[tuple[str, Action]] = []
self._current_lead: str | None = None
# Initialize default agents
self._initialize_default_agents()
def _initialize_default_agents(self) -> None:
"""Initialize the default set of agents."""
self._agents = {
AgentRole.PLANNER: PlannerAgent(
agent_id="planner",
config=self.config.get("planner_config"),
),
AgentRole.NAVIGATOR: NavigatorAgent(
agent_id="navigator",
config=self.config.get("navigator_config"),
),
AgentRole.EXTRACTOR: ExtractorAgent(
agent_id="extractor",
config=self.config.get("extractor_config"),
),
AgentRole.VERIFIER: VerifierAgent(
agent_id="verifier",
config=self.config.get("verifier_config"),
),
AgentRole.MEMORY: MemoryAgent(
agent_id="memory",
config=self.config.get("memory_config"),
),
}
def register_agent(self, role: str, agent: BaseAgent) -> None:
"""
Register an agent for a specific role.
Args:
role: The role this agent fulfills.
agent: The agent instance.
"""
self._agents[role] = agent
def get_agent(self, role: str) -> BaseAgent | None:
"""
Get an agent by role.
Args:
role: The role to look up.
Returns:
The agent if found, None otherwise.
"""
return self._agents.get(role)
async def step(self, observation: Observation) -> Action:
"""
Perform one coordination step.
Determines which agent should act, processes messages,
and returns the selected action.
Args:
observation: The current state observation.
Returns:
The action to execute.
"""
try:
# Process pending messages
await self._process_messages()
# Determine lead agent based on state
lead_role = self._determine_lead_agent(observation)
self._current_lead = lead_role
# Get action from lead agent
lead_agent = self._agents.get(lead_role)
if not lead_agent:
return self._create_error_action(f"No agent for role: {lead_role}")
# Get action from the lead agent
action = await lead_agent.act(observation)
action.agent_id = lead_agent.agent_id
# Record action
self._action_history.append((lead_role, action))
lead_agent.record_action(action)
# Handle inter-agent communication actions
if action.action_type == ActionType.SEND_MESSAGE:
self._handle_send_message(action)
return action
except Exception as e:
return self._create_error_action(f"Coordination error: {e}")
async def plan(self, observation: Observation) -> list[Action]:
"""
Create a coordinated plan using multiple agents.
The planner agent creates the high-level plan, which is then
refined by other agents.
Args:
observation: The current state observation.
Returns:
A coordinated list of actions.
"""
try:
# Get plan from planner
planner = self._agents.get(AgentRole.PLANNER)
if not planner:
return []
plan = await planner.plan(observation)
# Refine with navigator for navigation steps
navigator = self._agents.get(AgentRole.NAVIGATOR)
if navigator:
nav_plan = await navigator.plan(observation)
# Insert navigation at the beginning if needed
if nav_plan and not observation.current_url:
plan = nav_plan + plan
return plan
except Exception as e:
return [self._create_error_action(f"Planning error: {e}")]
def send_message(
self,
sender: str,
recipient: str,
message_type: str,
content: dict[str, Any],
priority: int = 0,
) -> None:
"""
Send a message between agents.
Args:
sender: ID of the sending agent.
recipient: ID of the receiving agent.
message_type: Type of the message.
content: Message content.
priority: Message priority (higher = more urgent).
"""
message = Message(
sender=sender,
recipient=recipient,
message_type=message_type,
content=content,
priority=priority,
)
self._message_queue.append(message)
async def _process_messages(self) -> None:
"""Process queued messages and deliver to agents."""
# Sort by priority (highest first)
self._message_queue.sort(key=lambda m: -m.priority)
# Process up to max messages
messages_processed = 0
while self._message_queue and messages_processed < self.max_messages_per_step:
message = self._message_queue.pop(0)
# Find recipient agent
recipient = None
for role, agent in self._agents.items():
if agent.agent_id == message.recipient or role == message.recipient:
recipient = agent
break
if recipient:
recipient.receive_message(message.to_dict())
messages_processed += 1
def _determine_lead_agent(self, observation: Observation) -> str:
"""
Determine which agent should lead based on state.
Args:
observation: Current observation.
Returns:
The role of the agent that should lead.
"""
# If no URL, navigator should lead
if not observation.current_url:
return AgentRole.NAVIGATOR
# If there are unverified fields, verifier should lead
unverified = [f for f in observation.extracted_so_far if not f.verified]
if unverified and observation.extraction_progress > 0.5:
return AgentRole.VERIFIER
# If there are remaining fields to extract, extractor should lead
if observation.fields_remaining:
return AgentRole.EXTRACTOR
# If we have errors, planner should re-plan
if observation.consecutive_errors > 0:
return AgentRole.PLANNER
# Default to planner
return AgentRole.PLANNER
def _handle_send_message(self, action: Action) -> None:
"""Handle a send_message action from an agent."""
params = action.parameters
self.send_message(
sender=action.agent_id or "unknown",
recipient=params.get("target_agent", ""),
message_type=params.get("message_type", "generic"),
content=params.get("content", {}),
)
def _create_error_action(self, error: str) -> Action:
"""Create a fail action for errors."""
return Action(
action_type=ActionType.FAIL,
parameters={"success": False, "message": error},
reasoning=error,
confidence=1.0,
agent_id="coordinator",
)
async def run_parallel_agents(
self,
observation: Observation,
roles: list[str],
) -> dict[str, Action]:
"""
Run multiple agents in parallel.
Args:
observation: Current observation.
roles: List of agent roles to run.
Returns:
Dictionary mapping role to action.
"""
if not self.enable_parallel:
# Fallback to sequential
results = {}
for role in roles:
agent = self._agents.get(role)
if agent:
results[role] = await agent.act(observation)
return results
# Run agents in parallel
async def run_agent(role: str) -> tuple[str, Action]:
agent = self._agents.get(role)
if agent:
action = await agent.act(observation)
return (role, action)
return (role, self._create_error_action(f"No agent for role: {role}"))
tasks = [run_agent(role) for role in roles]
results = await asyncio.gather(*tasks)
return dict(results)
def get_action_history(self) -> list[tuple[str, Action]]:
"""Get the history of actions with their agent roles."""
return self._action_history.copy()
def get_current_lead(self) -> str | None:
"""Get the current lead agent role."""
return self._current_lead
def get_message_queue_length(self) -> int:
"""Get the number of pending messages."""
return len(self._message_queue)
def reset(self) -> None:
"""Reset all agents and coordinator state."""
for agent in self._agents.values():
agent.reset()
self._message_queue.clear()
self._action_history.clear()
self._current_lead = None
def get_stats(self) -> dict[str, Any]:
"""Get coordinator statistics."""
return {
"agents": list(self._agents.keys()),
"current_lead": self._current_lead,
"pending_messages": len(self._message_queue),
"action_count": len(self._action_history),
"enable_parallel": self.enable_parallel,
}