Spaces:
Sleeping
Sleeping
Commit ·
3bfb250
1
Parent(s): ab65628
feat: implement multi-agent system with coordinator
Browse files- backend/app/agents/__init__.py +38 -0
- backend/app/agents/__pycache__/__init__.cpython-314.pyc +0 -0
- backend/app/agents/__pycache__/base.cpython-314.pyc +0 -0
- backend/app/agents/__pycache__/coordinator.cpython-314.pyc +0 -0
- backend/app/agents/__pycache__/extractor.cpython-314.pyc +0 -0
- backend/app/agents/__pycache__/memory_agent.cpython-314.pyc +0 -0
- backend/app/agents/__pycache__/navigator.cpython-314.pyc +0 -0
- backend/app/agents/__pycache__/planner.cpython-314.pyc +0 -0
- backend/app/agents/__pycache__/verifier.cpython-314.pyc +0 -0
- backend/app/agents/base.py +127 -0
- backend/app/agents/coordinator.py +387 -0
- backend/app/agents/extractor.py +489 -0
- backend/app/agents/memory_agent.py +474 -0
- backend/app/agents/navigator.py +368 -0
- backend/app/agents/planner.py +242 -0
- backend/app/agents/verifier.py +468 -0
backend/app/agents/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agents module for ScrapeRL.
|
| 3 |
+
|
| 4 |
+
This module contains specialized agents for web scraping with RL:
|
| 5 |
+
- BaseAgent: Abstract base class for all agents
|
| 6 |
+
- PlannerAgent: Goal decomposition and task planning
|
| 7 |
+
- NavigatorAgent: URL prioritization and page navigation
|
| 8 |
+
- ExtractorAgent: Data extraction with selectors
|
| 9 |
+
- VerifierAgent: Cross-source verification
|
| 10 |
+
- MemoryAgent: Memory operations and knowledge management
|
| 11 |
+
- AgentCoordinator: Orchestrates multiple agents with message passing
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from .base import BaseAgent
|
| 15 |
+
from .coordinator import AgentCoordinator, AgentRole, Message
|
| 16 |
+
from .extractor import ExtractorAgent
|
| 17 |
+
from .memory_agent import MemoryAgent, MemoryEntry
|
| 18 |
+
from .navigator import NavigatorAgent
|
| 19 |
+
from .planner import PlannerAgent
|
| 20 |
+
from .verifier import VerificationResult, VerifierAgent
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
# Base
|
| 24 |
+
"BaseAgent",
|
| 25 |
+
# Agents
|
| 26 |
+
"PlannerAgent",
|
| 27 |
+
"NavigatorAgent",
|
| 28 |
+
"ExtractorAgent",
|
| 29 |
+
"VerifierAgent",
|
| 30 |
+
"MemoryAgent",
|
| 31 |
+
# Coordinator
|
| 32 |
+
"AgentCoordinator",
|
| 33 |
+
"AgentRole",
|
| 34 |
+
"Message",
|
| 35 |
+
# Data classes
|
| 36 |
+
"VerificationResult",
|
| 37 |
+
"MemoryEntry",
|
| 38 |
+
]
|
backend/app/agents/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
backend/app/agents/__pycache__/base.cpython-314.pyc
ADDED
|
Binary file (6.75 kB). View file
|
|
|
backend/app/agents/__pycache__/coordinator.cpython-314.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
backend/app/agents/__pycache__/extractor.cpython-314.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
backend/app/agents/__pycache__/memory_agent.cpython-314.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
backend/app/agents/__pycache__/navigator.cpython-314.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
backend/app/agents/__pycache__/planner.cpython-314.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
backend/app/agents/__pycache__/verifier.cpython-314.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
backend/app/agents/base.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base agent abstract class for ScrapeRL agents."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from app.core.action import Action
|
| 7 |
+
from app.core.observation import Observation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseAgent(ABC):
|
| 11 |
+
"""
|
| 12 |
+
Abstract base class for all agents in the ScrapeRL system.
|
| 13 |
+
|
| 14 |
+
Each agent specializes in a specific aspect of the scraping workflow:
|
| 15 |
+
- Planning and goal decomposition
|
| 16 |
+
- Navigation and URL prioritization
|
| 17 |
+
- Data extraction
|
| 18 |
+
- Verification and validation
|
| 19 |
+
- Memory operations
|
| 20 |
+
|
| 21 |
+
Agents communicate through message passing and coordinate via
|
| 22 |
+
the AgentCoordinator.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, agent_id: str, config: dict[str, Any] | None = None):
|
| 26 |
+
"""
|
| 27 |
+
Initialize the agent.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
agent_id: Unique identifier for this agent instance.
|
| 31 |
+
config: Optional configuration dictionary for the agent.
|
| 32 |
+
"""
|
| 33 |
+
self.agent_id = agent_id
|
| 34 |
+
self.config = config or {}
|
| 35 |
+
self._message_queue: list[dict[str, Any]] = []
|
| 36 |
+
self._action_history: list[Action] = []
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
async def act(self, observation: Observation) -> Action:
|
| 40 |
+
"""
|
| 41 |
+
Select an action based on the current observation.
|
| 42 |
+
|
| 43 |
+
This is the main decision-making method. The agent analyzes
|
| 44 |
+
the observation and returns the best action to take.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
observation: The current state observation from the environment.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
The action to execute.
|
| 51 |
+
"""
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
async def plan(self, observation: Observation) -> list[Action]:
|
| 56 |
+
"""
|
| 57 |
+
Create a plan of actions based on the current observation.
|
| 58 |
+
|
| 59 |
+
Unlike act() which returns a single action, plan() creates
|
| 60 |
+
a sequence of actions to achieve a goal.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
observation: The current state observation from the environment.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
A list of planned actions in execution order.
|
| 67 |
+
"""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
async def explain(self, action: Action) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Explain why this action was chosen.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
action: The action to explain.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
A human-readable explanation of the action choice.
|
| 79 |
+
"""
|
| 80 |
+
return action.reasoning or "No explanation provided"
|
| 81 |
+
|
| 82 |
+
def receive_message(self, message: dict[str, Any]) -> None:
|
| 83 |
+
"""
|
| 84 |
+
Receive a message from another agent.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
message: The message dictionary containing sender, type, and content.
|
| 88 |
+
"""
|
| 89 |
+
self._message_queue.append(message)
|
| 90 |
+
|
| 91 |
+
def get_pending_messages(self) -> list[dict[str, Any]]:
|
| 92 |
+
"""
|
| 93 |
+
Get all pending messages and clear the queue.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
List of pending messages.
|
| 97 |
+
"""
|
| 98 |
+
messages = self._message_queue.copy()
|
| 99 |
+
self._message_queue.clear()
|
| 100 |
+
return messages
|
| 101 |
+
|
| 102 |
+
def record_action(self, action: Action) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Record an action in the agent's history.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
action: The action that was executed.
|
| 108 |
+
"""
|
| 109 |
+
self._action_history.append(action)
|
| 110 |
+
|
| 111 |
+
def get_action_history(self) -> list[Action]:
|
| 112 |
+
"""
|
| 113 |
+
Get the history of actions taken by this agent.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
List of past actions.
|
| 117 |
+
"""
|
| 118 |
+
return self._action_history.copy()
|
| 119 |
+
|
| 120 |
+
def reset(self) -> None:
|
| 121 |
+
"""Reset the agent state for a new episode."""
|
| 122 |
+
self._message_queue.clear()
|
| 123 |
+
self._action_history.clear()
|
| 124 |
+
|
| 125 |
+
def __repr__(self) -> str:
|
| 126 |
+
"""String representation of the agent."""
|
| 127 |
+
return f"{self.__class__.__name__}(agent_id={self.agent_id!r})"
|
backend/app/agents/coordinator.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent coordinator for orchestrating multiple agents with message passing."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from app.core.action import Action, ActionType
|
| 9 |
+
from app.core.observation import Observation
|
| 10 |
+
|
| 11 |
+
from .base import BaseAgent
|
| 12 |
+
from .extractor import ExtractorAgent
|
| 13 |
+
from .memory_agent import MemoryAgent
|
| 14 |
+
from .navigator import NavigatorAgent
|
| 15 |
+
from .planner import PlannerAgent
|
| 16 |
+
from .verifier import VerifierAgent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AgentRole(str, Enum):
|
| 20 |
+
"""Roles that agents can fulfill."""
|
| 21 |
+
|
| 22 |
+
PLANNER = "planner"
|
| 23 |
+
NAVIGATOR = "navigator"
|
| 24 |
+
EXTRACTOR = "extractor"
|
| 25 |
+
VERIFIER = "verifier"
|
| 26 |
+
MEMORY = "memory"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Message:
|
| 30 |
+
"""A message between agents."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
sender: str,
|
| 35 |
+
recipient: str,
|
| 36 |
+
message_type: str,
|
| 37 |
+
content: dict[str, Any],
|
| 38 |
+
priority: int = 0,
|
| 39 |
+
):
|
| 40 |
+
"""Initialize a message."""
|
| 41 |
+
self.sender = sender
|
| 42 |
+
self.recipient = recipient
|
| 43 |
+
self.message_type = message_type
|
| 44 |
+
self.content = content
|
| 45 |
+
self.priority = priority
|
| 46 |
+
self.timestamp = datetime.utcnow()
|
| 47 |
+
|
| 48 |
+
def to_dict(self) -> dict[str, Any]:
|
| 49 |
+
"""Convert to dictionary."""
|
| 50 |
+
return {
|
| 51 |
+
"sender": self.sender,
|
| 52 |
+
"recipient": self.recipient,
|
| 53 |
+
"message_type": self.message_type,
|
| 54 |
+
"content": self.content,
|
| 55 |
+
"priority": self.priority,
|
| 56 |
+
"timestamp": self.timestamp.isoformat(),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AgentCoordinator:
|
| 61 |
+
"""
|
| 62 |
+
Orchestrator for multiple specialized agents.
|
| 63 |
+
|
| 64 |
+
The AgentCoordinator manages:
|
| 65 |
+
- Agent lifecycle and initialization
|
| 66 |
+
- Message passing between agents
|
| 67 |
+
- Action selection and routing
|
| 68 |
+
- Coordination of multi-agent workflows
|
| 69 |
+
- Error handling and recovery
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
config: dict[str, Any] | None = None,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Initialize the AgentCoordinator.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
config: Optional configuration with keys:
|
| 81 |
+
- enable_parallel: Allow parallel agent execution (default: False)
|
| 82 |
+
- max_messages_per_step: Max messages per step (default: 10)
|
| 83 |
+
- default_timeout: Default timeout for agent actions (default: 30)
|
| 84 |
+
"""
|
| 85 |
+
self.config = config or {}
|
| 86 |
+
self.enable_parallel = self.config.get("enable_parallel", False)
|
| 87 |
+
self.max_messages_per_step = self.config.get("max_messages_per_step", 10)
|
| 88 |
+
self.default_timeout = self.config.get("default_timeout", 30)
|
| 89 |
+
|
| 90 |
+
# Initialize agents
|
| 91 |
+
self._agents: dict[str, BaseAgent] = {}
|
| 92 |
+
self._message_queue: list[Message] = []
|
| 93 |
+
self._action_history: list[tuple[str, Action]] = []
|
| 94 |
+
self._current_lead: str | None = None
|
| 95 |
+
|
| 96 |
+
# Initialize default agents
|
| 97 |
+
self._initialize_default_agents()
|
| 98 |
+
|
| 99 |
+
def _initialize_default_agents(self) -> None:
|
| 100 |
+
"""Initialize the default set of agents."""
|
| 101 |
+
self._agents = {
|
| 102 |
+
AgentRole.PLANNER: PlannerAgent(
|
| 103 |
+
agent_id="planner",
|
| 104 |
+
config=self.config.get("planner_config"),
|
| 105 |
+
),
|
| 106 |
+
AgentRole.NAVIGATOR: NavigatorAgent(
|
| 107 |
+
agent_id="navigator",
|
| 108 |
+
config=self.config.get("navigator_config"),
|
| 109 |
+
),
|
| 110 |
+
AgentRole.EXTRACTOR: ExtractorAgent(
|
| 111 |
+
agent_id="extractor",
|
| 112 |
+
config=self.config.get("extractor_config"),
|
| 113 |
+
),
|
| 114 |
+
AgentRole.VERIFIER: VerifierAgent(
|
| 115 |
+
agent_id="verifier",
|
| 116 |
+
config=self.config.get("verifier_config"),
|
| 117 |
+
),
|
| 118 |
+
AgentRole.MEMORY: MemoryAgent(
|
| 119 |
+
agent_id="memory",
|
| 120 |
+
config=self.config.get("memory_config"),
|
| 121 |
+
),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
def register_agent(self, role: str, agent: BaseAgent) -> None:
|
| 125 |
+
"""
|
| 126 |
+
Register an agent for a specific role.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
role: The role this agent fulfills.
|
| 130 |
+
agent: The agent instance.
|
| 131 |
+
"""
|
| 132 |
+
self._agents[role] = agent
|
| 133 |
+
|
| 134 |
+
def get_agent(self, role: str) -> BaseAgent | None:
|
| 135 |
+
"""
|
| 136 |
+
Get an agent by role.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
role: The role to look up.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
The agent if found, None otherwise.
|
| 143 |
+
"""
|
| 144 |
+
return self._agents.get(role)
|
| 145 |
+
|
| 146 |
+
async def step(self, observation: Observation) -> Action:
|
| 147 |
+
"""
|
| 148 |
+
Perform one coordination step.
|
| 149 |
+
|
| 150 |
+
Determines which agent should act, processes messages,
|
| 151 |
+
and returns the selected action.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
observation: The current state observation.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
The action to execute.
|
| 158 |
+
"""
|
| 159 |
+
try:
|
| 160 |
+
# Process pending messages
|
| 161 |
+
await self._process_messages()
|
| 162 |
+
|
| 163 |
+
# Determine lead agent based on state
|
| 164 |
+
lead_role = self._determine_lead_agent(observation)
|
| 165 |
+
self._current_lead = lead_role
|
| 166 |
+
|
| 167 |
+
# Get action from lead agent
|
| 168 |
+
lead_agent = self._agents.get(lead_role)
|
| 169 |
+
if not lead_agent:
|
| 170 |
+
return self._create_error_action(f"No agent for role: {lead_role}")
|
| 171 |
+
|
| 172 |
+
# Get action from the lead agent
|
| 173 |
+
action = await lead_agent.act(observation)
|
| 174 |
+
action.agent_id = lead_agent.agent_id
|
| 175 |
+
|
| 176 |
+
# Record action
|
| 177 |
+
self._action_history.append((lead_role, action))
|
| 178 |
+
lead_agent.record_action(action)
|
| 179 |
+
|
| 180 |
+
# Handle inter-agent communication actions
|
| 181 |
+
if action.action_type == ActionType.SEND_MESSAGE:
|
| 182 |
+
self._handle_send_message(action)
|
| 183 |
+
|
| 184 |
+
return action
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
return self._create_error_action(f"Coordination error: {e}")
|
| 188 |
+
|
| 189 |
+
async def plan(self, observation: Observation) -> list[Action]:
|
| 190 |
+
"""
|
| 191 |
+
Create a coordinated plan using multiple agents.
|
| 192 |
+
|
| 193 |
+
The planner agent creates the high-level plan, which is then
|
| 194 |
+
refined by other agents.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
observation: The current state observation.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
A coordinated list of actions.
|
| 201 |
+
"""
|
| 202 |
+
try:
|
| 203 |
+
# Get plan from planner
|
| 204 |
+
planner = self._agents.get(AgentRole.PLANNER)
|
| 205 |
+
if not planner:
|
| 206 |
+
return []
|
| 207 |
+
|
| 208 |
+
plan = await planner.plan(observation)
|
| 209 |
+
|
| 210 |
+
# Refine with navigator for navigation steps
|
| 211 |
+
navigator = self._agents.get(AgentRole.NAVIGATOR)
|
| 212 |
+
if navigator:
|
| 213 |
+
nav_plan = await navigator.plan(observation)
|
| 214 |
+
# Insert navigation at the beginning if needed
|
| 215 |
+
if nav_plan and not observation.current_url:
|
| 216 |
+
plan = nav_plan + plan
|
| 217 |
+
|
| 218 |
+
return plan
|
| 219 |
+
|
| 220 |
+
except Exception as e:
|
| 221 |
+
return [self._create_error_action(f"Planning error: {e}")]
|
| 222 |
+
|
| 223 |
+
def send_message(
|
| 224 |
+
self,
|
| 225 |
+
sender: str,
|
| 226 |
+
recipient: str,
|
| 227 |
+
message_type: str,
|
| 228 |
+
content: dict[str, Any],
|
| 229 |
+
priority: int = 0,
|
| 230 |
+
) -> None:
|
| 231 |
+
"""
|
| 232 |
+
Send a message between agents.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
sender: ID of the sending agent.
|
| 236 |
+
recipient: ID of the receiving agent.
|
| 237 |
+
message_type: Type of the message.
|
| 238 |
+
content: Message content.
|
| 239 |
+
priority: Message priority (higher = more urgent).
|
| 240 |
+
"""
|
| 241 |
+
message = Message(
|
| 242 |
+
sender=sender,
|
| 243 |
+
recipient=recipient,
|
| 244 |
+
message_type=message_type,
|
| 245 |
+
content=content,
|
| 246 |
+
priority=priority,
|
| 247 |
+
)
|
| 248 |
+
self._message_queue.append(message)
|
| 249 |
+
|
| 250 |
+
async def _process_messages(self) -> None:
|
| 251 |
+
"""Process queued messages and deliver to agents."""
|
| 252 |
+
# Sort by priority (highest first)
|
| 253 |
+
self._message_queue.sort(key=lambda m: -m.priority)
|
| 254 |
+
|
| 255 |
+
# Process up to max messages
|
| 256 |
+
messages_processed = 0
|
| 257 |
+
while self._message_queue and messages_processed < self.max_messages_per_step:
|
| 258 |
+
message = self._message_queue.pop(0)
|
| 259 |
+
|
| 260 |
+
# Find recipient agent
|
| 261 |
+
recipient = None
|
| 262 |
+
for role, agent in self._agents.items():
|
| 263 |
+
if agent.agent_id == message.recipient or role == message.recipient:
|
| 264 |
+
recipient = agent
|
| 265 |
+
break
|
| 266 |
+
|
| 267 |
+
if recipient:
|
| 268 |
+
recipient.receive_message(message.to_dict())
|
| 269 |
+
messages_processed += 1
|
| 270 |
+
|
| 271 |
+
def _determine_lead_agent(self, observation: Observation) -> str:
|
| 272 |
+
"""
|
| 273 |
+
Determine which agent should lead based on state.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
observation: Current observation.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
The role of the agent that should lead.
|
| 280 |
+
"""
|
| 281 |
+
# If no URL, navigator should lead
|
| 282 |
+
if not observation.current_url:
|
| 283 |
+
return AgentRole.NAVIGATOR
|
| 284 |
+
|
| 285 |
+
# If there are unverified fields, verifier should lead
|
| 286 |
+
unverified = [f for f in observation.extracted_so_far if not f.verified]
|
| 287 |
+
if unverified and observation.extraction_progress > 0.5:
|
| 288 |
+
return AgentRole.VERIFIER
|
| 289 |
+
|
| 290 |
+
# If there are remaining fields to extract, extractor should lead
|
| 291 |
+
if observation.fields_remaining:
|
| 292 |
+
return AgentRole.EXTRACTOR
|
| 293 |
+
|
| 294 |
+
# If we have errors, planner should re-plan
|
| 295 |
+
if observation.consecutive_errors > 0:
|
| 296 |
+
return AgentRole.PLANNER
|
| 297 |
+
|
| 298 |
+
# Default to planner
|
| 299 |
+
return AgentRole.PLANNER
|
| 300 |
+
|
| 301 |
+
def _handle_send_message(self, action: Action) -> None:
|
| 302 |
+
"""Handle a send_message action from an agent."""
|
| 303 |
+
params = action.parameters
|
| 304 |
+
self.send_message(
|
| 305 |
+
sender=action.agent_id or "unknown",
|
| 306 |
+
recipient=params.get("target_agent", ""),
|
| 307 |
+
message_type=params.get("message_type", "generic"),
|
| 308 |
+
content=params.get("content", {}),
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def _create_error_action(self, error: str) -> Action:
|
| 312 |
+
"""Create a fail action for errors."""
|
| 313 |
+
return Action(
|
| 314 |
+
action_type=ActionType.FAIL,
|
| 315 |
+
parameters={"success": False, "message": error},
|
| 316 |
+
reasoning=error,
|
| 317 |
+
confidence=1.0,
|
| 318 |
+
agent_id="coordinator",
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
async def run_parallel_agents(
|
| 322 |
+
self,
|
| 323 |
+
observation: Observation,
|
| 324 |
+
roles: list[str],
|
| 325 |
+
) -> dict[str, Action]:
|
| 326 |
+
"""
|
| 327 |
+
Run multiple agents in parallel.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
observation: Current observation.
|
| 331 |
+
roles: List of agent roles to run.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Dictionary mapping role to action.
|
| 335 |
+
"""
|
| 336 |
+
if not self.enable_parallel:
|
| 337 |
+
# Fallback to sequential
|
| 338 |
+
results = {}
|
| 339 |
+
for role in roles:
|
| 340 |
+
agent = self._agents.get(role)
|
| 341 |
+
if agent:
|
| 342 |
+
results[role] = await agent.act(observation)
|
| 343 |
+
return results
|
| 344 |
+
|
| 345 |
+
# Run agents in parallel
|
| 346 |
+
async def run_agent(role: str) -> tuple[str, Action]:
|
| 347 |
+
agent = self._agents.get(role)
|
| 348 |
+
if agent:
|
| 349 |
+
action = await agent.act(observation)
|
| 350 |
+
return (role, action)
|
| 351 |
+
return (role, self._create_error_action(f"No agent for role: {role}"))
|
| 352 |
+
|
| 353 |
+
tasks = [run_agent(role) for role in roles]
|
| 354 |
+
results = await asyncio.gather(*tasks)
|
| 355 |
+
|
| 356 |
+
return dict(results)
|
| 357 |
+
|
| 358 |
+
def get_action_history(self) -> list[tuple[str, Action]]:
|
| 359 |
+
"""Get the history of actions with their agent roles."""
|
| 360 |
+
return self._action_history.copy()
|
| 361 |
+
|
| 362 |
+
def get_current_lead(self) -> str | None:
|
| 363 |
+
"""Get the current lead agent role."""
|
| 364 |
+
return self._current_lead
|
| 365 |
+
|
| 366 |
+
def get_message_queue_length(self) -> int:
|
| 367 |
+
"""Get the number of pending messages."""
|
| 368 |
+
return len(self._message_queue)
|
| 369 |
+
|
| 370 |
+
def reset(self) -> None:
|
| 371 |
+
"""Reset all agents and coordinator state."""
|
| 372 |
+
for agent in self._agents.values():
|
| 373 |
+
agent.reset()
|
| 374 |
+
|
| 375 |
+
self._message_queue.clear()
|
| 376 |
+
self._action_history.clear()
|
| 377 |
+
self._current_lead = None
|
| 378 |
+
|
| 379 |
+
def get_stats(self) -> dict[str, Any]:
|
| 380 |
+
"""Get coordinator statistics."""
|
| 381 |
+
return {
|
| 382 |
+
"agents": list(self._agents.keys()),
|
| 383 |
+
"current_lead": self._current_lead,
|
| 384 |
+
"pending_messages": len(self._message_queue),
|
| 385 |
+
"action_count": len(self._action_history),
|
| 386 |
+
"enable_parallel": self.enable_parallel,
|
| 387 |
+
}
|
backend/app/agents/extractor.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Extractor agent for data extraction with selectors."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from app.core.action import Action, ActionType
|
| 7 |
+
from app.core.observation import Observation, PageElement
|
| 8 |
+
|
| 9 |
+
from .base import BaseAgent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ExtractorAgent(BaseAgent):
|
| 13 |
+
"""
|
| 14 |
+
Agent responsible for extracting structured data from pages.
|
| 15 |
+
|
| 16 |
+
The ExtractorAgent handles:
|
| 17 |
+
- Identifying data elements using CSS/XPath selectors
|
| 18 |
+
- Extracting text, attributes, and structured content
|
| 19 |
+
- Handling tables and lists
|
| 20 |
+
- Post-processing extracted values
|
| 21 |
+
- Confidence scoring for extractions
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
agent_id: str = "extractor",
|
| 27 |
+
config: dict[str, Any] | None = None,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the ExtractorAgent.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
agent_id: Unique identifier for this agent.
|
| 34 |
+
config: Optional configuration with keys:
|
| 35 |
+
- min_confidence: Minimum confidence to accept extraction
|
| 36 |
+
- extraction_timeout: Timeout for extraction operations
|
| 37 |
+
- enable_fuzzy_matching: Enable fuzzy text matching
|
| 38 |
+
"""
|
| 39 |
+
super().__init__(agent_id, config)
|
| 40 |
+
self.min_confidence = self.config.get("min_confidence", 0.5)
|
| 41 |
+
self.extraction_timeout = self.config.get("extraction_timeout", 5000)
|
| 42 |
+
self.enable_fuzzy_matching = self.config.get("enable_fuzzy_matching", True)
|
| 43 |
+
self._extraction_cache: dict[str, Any] = {}
|
| 44 |
+
self._selector_patterns: dict[str, list[str]] = self._init_selector_patterns()
|
| 45 |
+
|
| 46 |
+
def _init_selector_patterns(self) -> dict[str, list[str]]:
|
| 47 |
+
"""Initialize common selector patterns for different field types."""
|
| 48 |
+
return {
|
| 49 |
+
"price": [
|
| 50 |
+
"[class*='price']",
|
| 51 |
+
"[id*='price']",
|
| 52 |
+
"[itemprop='price']",
|
| 53 |
+
".product-price",
|
| 54 |
+
".item-price",
|
| 55 |
+
"span[data-price]",
|
| 56 |
+
],
|
| 57 |
+
"title": [
|
| 58 |
+
"h1",
|
| 59 |
+
"[class*='title']",
|
| 60 |
+
"[itemprop='name']",
|
| 61 |
+
".product-title",
|
| 62 |
+
".item-title",
|
| 63 |
+
],
|
| 64 |
+
"description": [
|
| 65 |
+
"[class*='description']",
|
| 66 |
+
"[itemprop='description']",
|
| 67 |
+
".product-description",
|
| 68 |
+
"article p",
|
| 69 |
+
".content p",
|
| 70 |
+
],
|
| 71 |
+
"image": [
|
| 72 |
+
"[class*='product-image'] img",
|
| 73 |
+
"[itemprop='image']",
|
| 74 |
+
".main-image img",
|
| 75 |
+
"figure img",
|
| 76 |
+
],
|
| 77 |
+
"date": [
|
| 78 |
+
"time",
|
| 79 |
+
"[datetime]",
|
| 80 |
+
"[class*='date']",
|
| 81 |
+
"[itemprop='datePublished']",
|
| 82 |
+
],
|
| 83 |
+
"author": [
|
| 84 |
+
"[class*='author']",
|
| 85 |
+
"[itemprop='author']",
|
| 86 |
+
"[rel='author']",
|
| 87 |
+
".byline",
|
| 88 |
+
],
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
async def act(self, observation: Observation) -> Action:
|
| 92 |
+
"""
|
| 93 |
+
Select the best extraction action based on observation.
|
| 94 |
+
|
| 95 |
+
Analyzes the page and decides what data to extract next.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
observation: The current state observation.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
The extraction action to execute.
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
# Get remaining fields to extract
|
| 105 |
+
remaining_fields = observation.fields_remaining
|
| 106 |
+
|
| 107 |
+
if not remaining_fields:
|
| 108 |
+
return Action(
|
| 109 |
+
action_type=ActionType.DONE,
|
| 110 |
+
parameters={"success": True, "message": "All fields extracted"},
|
| 111 |
+
reasoning="No more fields to extract",
|
| 112 |
+
confidence=1.0,
|
| 113 |
+
agent_id=self.agent_id,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Pick the next field to extract
|
| 117 |
+
field_name = remaining_fields[0]
|
| 118 |
+
|
| 119 |
+
# Find best selector for the field
|
| 120 |
+
selector, confidence = await self._find_selector_for_field(
|
| 121 |
+
field_name,
|
| 122 |
+
observation,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if selector and confidence >= self.min_confidence:
|
| 126 |
+
return self._create_extraction_action(
|
| 127 |
+
field_name,
|
| 128 |
+
selector,
|
| 129 |
+
confidence,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Try alternative extraction methods
|
| 133 |
+
alt_action = await self._try_alternative_extraction(
|
| 134 |
+
field_name,
|
| 135 |
+
observation,
|
| 136 |
+
)
|
| 137 |
+
if alt_action:
|
| 138 |
+
return alt_action
|
| 139 |
+
|
| 140 |
+
# Cannot extract this field
|
| 141 |
+
return Action(
|
| 142 |
+
action_type=ActionType.EXTRACT_FIELD,
|
| 143 |
+
parameters={
|
| 144 |
+
"field_name": field_name,
|
| 145 |
+
"selector": None,
|
| 146 |
+
"extraction_method": "llm",
|
| 147 |
+
},
|
| 148 |
+
reasoning=f"No selector found, using LLM extraction for {field_name}",
|
| 149 |
+
confidence=0.4,
|
| 150 |
+
agent_id=self.agent_id,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
return Action(
|
| 155 |
+
action_type=ActionType.FAIL,
|
| 156 |
+
parameters={"success": False, "message": str(e)},
|
| 157 |
+
reasoning=f"Extraction error: {e}",
|
| 158 |
+
confidence=1.0,
|
| 159 |
+
agent_id=self.agent_id,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
async def plan(self, observation: Observation) -> list[Action]:
|
| 163 |
+
"""
|
| 164 |
+
Create an extraction plan for all remaining fields.
|
| 165 |
+
|
| 166 |
+
Analyzes the page structure and plans the optimal
|
| 167 |
+
extraction sequence.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
observation: The current state observation.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
A list of planned extraction actions.
|
| 174 |
+
"""
|
| 175 |
+
try:
|
| 176 |
+
actions: list[Action] = []
|
| 177 |
+
remaining_fields = observation.fields_remaining
|
| 178 |
+
|
| 179 |
+
for field_name in remaining_fields:
|
| 180 |
+
selector, confidence = await self._find_selector_for_field(
|
| 181 |
+
field_name,
|
| 182 |
+
observation,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if selector:
|
| 186 |
+
actions.append(
|
| 187 |
+
self._create_extraction_action(
|
| 188 |
+
field_name,
|
| 189 |
+
selector,
|
| 190 |
+
confidence,
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
# Plan LLM-based extraction as fallback
|
| 195 |
+
actions.append(
|
| 196 |
+
Action(
|
| 197 |
+
action_type=ActionType.EXTRACT_FIELD,
|
| 198 |
+
parameters={
|
| 199 |
+
"field_name": field_name,
|
| 200 |
+
"extraction_method": "llm",
|
| 201 |
+
},
|
| 202 |
+
reasoning=f"Planning LLM extraction for {field_name}",
|
| 203 |
+
confidence=0.5,
|
| 204 |
+
agent_id=self.agent_id,
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return actions
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
return [
|
| 212 |
+
Action(
|
| 213 |
+
action_type=ActionType.FAIL,
|
| 214 |
+
parameters={"message": f"Extraction planning failed: {e}"},
|
| 215 |
+
reasoning=str(e),
|
| 216 |
+
confidence=1.0,
|
| 217 |
+
agent_id=self.agent_id,
|
| 218 |
+
)
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
async def _find_selector_for_field(
|
| 222 |
+
self,
|
| 223 |
+
field_name: str,
|
| 224 |
+
observation: Observation,
|
| 225 |
+
) -> tuple[str | None, float]:
|
| 226 |
+
"""
|
| 227 |
+
Find the best selector for a field.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
field_name: Name of the field to extract.
|
| 231 |
+
observation: Current observation.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Tuple of (selector, confidence).
|
| 235 |
+
"""
|
| 236 |
+
best_selector: str | None = None
|
| 237 |
+
best_confidence = 0.0
|
| 238 |
+
|
| 239 |
+
# Check predefined patterns first
|
| 240 |
+
patterns = self._get_patterns_for_field(field_name)
|
| 241 |
+
for pattern in patterns:
|
| 242 |
+
element = self._find_element_by_selector(
|
| 243 |
+
pattern,
|
| 244 |
+
observation.page_elements,
|
| 245 |
+
)
|
| 246 |
+
if element:
|
| 247 |
+
confidence = self._calculate_confidence(element, field_name)
|
| 248 |
+
if confidence > best_confidence:
|
| 249 |
+
best_selector = element.selector
|
| 250 |
+
best_confidence = confidence
|
| 251 |
+
|
| 252 |
+
# Search by text content if fuzzy matching enabled
|
| 253 |
+
if self.enable_fuzzy_matching and best_confidence < 0.7:
|
| 254 |
+
element, confidence = self._find_element_by_text(
|
| 255 |
+
field_name,
|
| 256 |
+
observation.page_elements,
|
| 257 |
+
)
|
| 258 |
+
if element and confidence > best_confidence:
|
| 259 |
+
best_selector = element.selector
|
| 260 |
+
best_confidence = confidence
|
| 261 |
+
|
| 262 |
+
return best_selector, best_confidence
|
| 263 |
+
|
| 264 |
+
def _get_patterns_for_field(self, field_name: str) -> list[str]:
|
| 265 |
+
"""Get selector patterns for a field type."""
|
| 266 |
+
field_lower = field_name.lower()
|
| 267 |
+
|
| 268 |
+
# Direct match
|
| 269 |
+
if field_lower in self._selector_patterns:
|
| 270 |
+
return self._selector_patterns[field_lower]
|
| 271 |
+
|
| 272 |
+
# Partial match
|
| 273 |
+
for key, patterns in self._selector_patterns.items():
|
| 274 |
+
if key in field_lower or field_lower in key:
|
| 275 |
+
return patterns
|
| 276 |
+
|
| 277 |
+
# Generate generic patterns
|
| 278 |
+
return [
|
| 279 |
+
f"[class*='{field_lower}']",
|
| 280 |
+
f"[id*='{field_lower}']",
|
| 281 |
+
f"[data-{field_lower}]",
|
| 282 |
+
f".{field_lower}",
|
| 283 |
+
f"#{field_lower}",
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
def _find_element_by_selector(
|
| 287 |
+
self,
|
| 288 |
+
selector: str,
|
| 289 |
+
elements: list[PageElement],
|
| 290 |
+
) -> PageElement | None:
|
| 291 |
+
"""Find an element matching a selector pattern."""
|
| 292 |
+
selector_lower = selector.lower()
|
| 293 |
+
|
| 294 |
+
for element in elements:
|
| 295 |
+
element_selector = element.selector.lower()
|
| 296 |
+
if selector_lower in element_selector:
|
| 297 |
+
return element
|
| 298 |
+
|
| 299 |
+
# Check class and id attributes
|
| 300 |
+
classes = element.attributes.get("class", "").lower()
|
| 301 |
+
element_id = element.attributes.get("id", "").lower()
|
| 302 |
+
|
| 303 |
+
if selector_lower.strip(".[#]") in classes:
|
| 304 |
+
return element
|
| 305 |
+
if selector_lower.strip(".[#]") in element_id:
|
| 306 |
+
return element
|
| 307 |
+
|
| 308 |
+
return None
|
| 309 |
+
|
| 310 |
+
def _find_element_by_text(
|
| 311 |
+
self,
|
| 312 |
+
field_name: str,
|
| 313 |
+
elements: list[PageElement],
|
| 314 |
+
) -> tuple[PageElement | None, float]:
|
| 315 |
+
"""Find an element by text content matching."""
|
| 316 |
+
field_lower = field_name.lower().replace("_", " ")
|
| 317 |
+
best_element: PageElement | None = None
|
| 318 |
+
best_score = 0.0
|
| 319 |
+
|
| 320 |
+
for element in elements:
|
| 321 |
+
if not element.text:
|
| 322 |
+
continue
|
| 323 |
+
|
| 324 |
+
text_lower = element.text.lower()
|
| 325 |
+
|
| 326 |
+
# Check for label-like patterns
|
| 327 |
+
if f"{field_lower}:" in text_lower or f"{field_lower} :" in text_lower:
|
| 328 |
+
score = 0.9
|
| 329 |
+
elif field_lower in text_lower:
|
| 330 |
+
# Calculate similarity score
|
| 331 |
+
score = len(field_lower) / max(len(text_lower), 1) * 0.8
|
| 332 |
+
else:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
if score > best_score:
|
| 336 |
+
best_element = element
|
| 337 |
+
best_score = score
|
| 338 |
+
|
| 339 |
+
return best_element, best_score
|
| 340 |
+
|
| 341 |
+
def _calculate_confidence(self, element: PageElement, field_name: str) -> float:
|
| 342 |
+
"""Calculate extraction confidence for an element."""
|
| 343 |
+
confidence = 0.5
|
| 344 |
+
|
| 345 |
+
# Boost for visible elements
|
| 346 |
+
if element.is_visible:
|
| 347 |
+
confidence += 0.1
|
| 348 |
+
|
| 349 |
+
# Boost for semantic attributes
|
| 350 |
+
if element.attributes.get("itemprop"):
|
| 351 |
+
confidence += 0.2
|
| 352 |
+
if element.attributes.get("data-field"):
|
| 353 |
+
confidence += 0.15
|
| 354 |
+
|
| 355 |
+
# Boost if text contains field name
|
| 356 |
+
if element.text and field_name.lower() in element.text.lower():
|
| 357 |
+
confidence += 0.1
|
| 358 |
+
|
| 359 |
+
# Penalty for very long text (likely not a single field)
|
| 360 |
+
if element.text and len(element.text) > 500:
|
| 361 |
+
confidence -= 0.2
|
| 362 |
+
|
| 363 |
+
return min(1.0, max(0.0, confidence))
|
| 364 |
+
|
| 365 |
+
async def _try_alternative_extraction(
|
| 366 |
+
self,
|
| 367 |
+
field_name: str,
|
| 368 |
+
observation: Observation,
|
| 369 |
+
) -> Action | None:
|
| 370 |
+
"""Try alternative extraction methods."""
|
| 371 |
+
# Check for table data
|
| 372 |
+
for element in observation.page_elements:
|
| 373 |
+
if element.tag in ("table", "tbody"):
|
| 374 |
+
return Action(
|
| 375 |
+
action_type=ActionType.EXTRACT_TABLE,
|
| 376 |
+
parameters={
|
| 377 |
+
"table_selector": element.selector,
|
| 378 |
+
"target_field": field_name,
|
| 379 |
+
},
|
| 380 |
+
reasoning=f"Extracting {field_name} from table",
|
| 381 |
+
confidence=0.6,
|
| 382 |
+
agent_id=self.agent_id,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Check for list data
|
| 386 |
+
for element in observation.page_elements:
|
| 387 |
+
if element.tag in ("ul", "ol", "dl"):
|
| 388 |
+
return Action(
|
| 389 |
+
action_type=ActionType.EXTRACT_LIST,
|
| 390 |
+
parameters={
|
| 391 |
+
"container_selector": element.selector,
|
| 392 |
+
"item_selector": "li",
|
| 393 |
+
"field_selectors": {field_name: "text"},
|
| 394 |
+
},
|
| 395 |
+
reasoning=f"Extracting {field_name} from list",
|
| 396 |
+
confidence=0.55,
|
| 397 |
+
agent_id=self.agent_id,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
return None
|
| 401 |
+
|
| 402 |
+
def _create_extraction_action(
|
| 403 |
+
self,
|
| 404 |
+
field_name: str,
|
| 405 |
+
selector: str,
|
| 406 |
+
confidence: float,
|
| 407 |
+
) -> Action:
|
| 408 |
+
"""Create an extraction action."""
|
| 409 |
+
return Action(
|
| 410 |
+
action_type=ActionType.EXTRACT_FIELD,
|
| 411 |
+
parameters={
|
| 412 |
+
"field_name": field_name,
|
| 413 |
+
"selector": selector,
|
| 414 |
+
"extraction_method": "text",
|
| 415 |
+
},
|
| 416 |
+
reasoning=f"Extracting {field_name} using selector: {selector}",
|
| 417 |
+
confidence=confidence,
|
| 418 |
+
agent_id=self.agent_id,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
def extract_with_regex(
|
| 422 |
+
self,
|
| 423 |
+
text: str,
|
| 424 |
+
pattern: str,
|
| 425 |
+
group: int = 0,
|
| 426 |
+
) -> str | None:
|
| 427 |
+
"""
|
| 428 |
+
Extract text using a regex pattern.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
text: The text to search in.
|
| 432 |
+
pattern: Regex pattern.
|
| 433 |
+
group: Capture group to return.
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
Extracted text or None.
|
| 437 |
+
"""
|
| 438 |
+
try:
|
| 439 |
+
match = re.search(pattern, text)
|
| 440 |
+
if match:
|
| 441 |
+
return match.group(group)
|
| 442 |
+
return None
|
| 443 |
+
except re.error:
|
| 444 |
+
return None
|
| 445 |
+
|
| 446 |
+
def post_process_value(
|
| 447 |
+
self,
|
| 448 |
+
value: Any,
|
| 449 |
+
field_name: str,
|
| 450 |
+
) -> Any:
|
| 451 |
+
"""
|
| 452 |
+
Post-process an extracted value based on field type.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
value: The raw extracted value.
|
| 456 |
+
field_name: Name of the field (used to infer type).
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
Processed value.
|
| 460 |
+
"""
|
| 461 |
+
if value is None:
|
| 462 |
+
return None
|
| 463 |
+
|
| 464 |
+
value_str = str(value).strip()
|
| 465 |
+
field_lower = field_name.lower()
|
| 466 |
+
|
| 467 |
+
# Price processing
|
| 468 |
+
if "price" in field_lower:
|
| 469 |
+
# Remove currency symbols but keep numbers and decimal
|
| 470 |
+
price_match = re.search(r"[\d,]+\.?\d*", value_str.replace(",", ""))
|
| 471 |
+
if price_match:
|
| 472 |
+
return float(price_match.group().replace(",", ""))
|
| 473 |
+
|
| 474 |
+
# Date processing
|
| 475 |
+
if "date" in field_lower:
|
| 476 |
+
return value_str # Return as-is, let caller parse
|
| 477 |
+
|
| 478 |
+
# Number processing
|
| 479 |
+
if any(x in field_lower for x in ["count", "quantity", "number"]):
|
| 480 |
+
num_match = re.search(r"\d+", value_str)
|
| 481 |
+
if num_match:
|
| 482 |
+
return int(num_match.group())
|
| 483 |
+
|
| 484 |
+
return value_str
|
| 485 |
+
|
| 486 |
+
def reset(self) -> None:
|
| 487 |
+
"""Reset the extractor state."""
|
| 488 |
+
super().reset()
|
| 489 |
+
self._extraction_cache.clear()
|
backend/app/agents/memory_agent.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Memory agent for memory operations and knowledge management."""
|
| 2 |
+
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from app.core.action import Action, ActionType
|
| 7 |
+
from app.core.observation import Observation
|
| 8 |
+
|
| 9 |
+
from .base import BaseAgent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MemoryEntry:
|
| 13 |
+
"""A single memory entry."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
key: str,
|
| 18 |
+
value: Any,
|
| 19 |
+
memory_type: str = "working",
|
| 20 |
+
ttl_seconds: int | None = None,
|
| 21 |
+
metadata: dict[str, Any] | None = None,
|
| 22 |
+
):
|
| 23 |
+
"""Initialize memory entry."""
|
| 24 |
+
self.key = key
|
| 25 |
+
self.value = value
|
| 26 |
+
self.memory_type = memory_type
|
| 27 |
+
self.ttl_seconds = ttl_seconds
|
| 28 |
+
self.metadata = metadata or {}
|
| 29 |
+
self.created_at = datetime.utcnow()
|
| 30 |
+
self.accessed_at = datetime.utcnow()
|
| 31 |
+
self.access_count = 0
|
| 32 |
+
|
| 33 |
+
def is_expired(self) -> bool:
|
| 34 |
+
"""Check if the memory entry has expired."""
|
| 35 |
+
if self.ttl_seconds is None:
|
| 36 |
+
return False
|
| 37 |
+
elapsed = (datetime.utcnow() - self.created_at).total_seconds()
|
| 38 |
+
return elapsed > self.ttl_seconds
|
| 39 |
+
|
| 40 |
+
def access(self) -> Any:
|
| 41 |
+
"""Access the memory and update metadata."""
|
| 42 |
+
self.accessed_at = datetime.utcnow()
|
| 43 |
+
self.access_count += 1
|
| 44 |
+
return self.value
|
| 45 |
+
|
| 46 |
+
def to_dict(self) -> dict[str, Any]:
|
| 47 |
+
"""Convert to dictionary."""
|
| 48 |
+
return {
|
| 49 |
+
"key": self.key,
|
| 50 |
+
"value": self.value,
|
| 51 |
+
"memory_type": self.memory_type,
|
| 52 |
+
"ttl_seconds": self.ttl_seconds,
|
| 53 |
+
"metadata": self.metadata,
|
| 54 |
+
"created_at": self.created_at.isoformat(),
|
| 55 |
+
"accessed_at": self.accessed_at.isoformat(),
|
| 56 |
+
"access_count": self.access_count,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class MemoryAgent(BaseAgent):
|
| 61 |
+
"""
|
| 62 |
+
Agent responsible for memory operations and knowledge management.
|
| 63 |
+
|
| 64 |
+
The MemoryAgent handles:
|
| 65 |
+
- Storing and retrieving memories across different layers
|
| 66 |
+
- Managing short-term, working, and long-term memory
|
| 67 |
+
- Memory consolidation and cleanup
|
| 68 |
+
- Relevance-based memory retrieval
|
| 69 |
+
- Sharing knowledge between episodes
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
agent_id: str = "memory",
|
| 75 |
+
config: dict[str, Any] | None = None,
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Initialize the MemoryAgent.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
agent_id: Unique identifier for this agent.
|
| 82 |
+
config: Optional configuration with keys:
|
| 83 |
+
- max_short_term: Max short-term memory entries (default: 100)
|
| 84 |
+
- max_working: Max working memory entries (default: 50)
|
| 85 |
+
- consolidation_threshold: Accesses before long-term (default: 3)
|
| 86 |
+
- enable_auto_cleanup: Auto cleanup expired entries (default: True)
|
| 87 |
+
"""
|
| 88 |
+
super().__init__(agent_id, config)
|
| 89 |
+
self.max_short_term = self.config.get("max_short_term", 100)
|
| 90 |
+
self.max_working = self.config.get("max_working", 50)
|
| 91 |
+
self.consolidation_threshold = self.config.get("consolidation_threshold", 3)
|
| 92 |
+
self.enable_auto_cleanup = self.config.get("enable_auto_cleanup", True)
|
| 93 |
+
|
| 94 |
+
# Memory stores
|
| 95 |
+
self._short_term: dict[str, MemoryEntry] = {}
|
| 96 |
+
self._working: dict[str, MemoryEntry] = {}
|
| 97 |
+
self._pending_operations: list[dict[str, Any]] = []
|
| 98 |
+
|
| 99 |
+
async def act(self, observation: Observation) -> Action:
|
| 100 |
+
"""
|
| 101 |
+
Select the best memory action based on observation.
|
| 102 |
+
|
| 103 |
+
Analyzes the current state and determines if any memory
|
| 104 |
+
operations are needed.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
observation: The current state observation.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
The memory action to execute.
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
# Process any pending messages requesting memory operations
|
| 114 |
+
messages = self.get_pending_messages()
|
| 115 |
+
for msg in messages:
|
| 116 |
+
if msg.get("message_type") == "memory_request":
|
| 117 |
+
return self._process_memory_request(msg)
|
| 118 |
+
|
| 119 |
+
# Auto cleanup if enabled
|
| 120 |
+
if self.enable_auto_cleanup:
|
| 121 |
+
self._cleanup_expired()
|
| 122 |
+
|
| 123 |
+
# Check if we should store new information
|
| 124 |
+
store_action = self._check_for_storage(observation)
|
| 125 |
+
if store_action:
|
| 126 |
+
return store_action
|
| 127 |
+
|
| 128 |
+
# Check if any memories need consolidation
|
| 129 |
+
consolidation_action = self._check_for_consolidation()
|
| 130 |
+
if consolidation_action:
|
| 131 |
+
return consolidation_action
|
| 132 |
+
|
| 133 |
+
# No memory operations needed
|
| 134 |
+
return Action(
|
| 135 |
+
action_type=ActionType.WAIT,
|
| 136 |
+
parameters={"duration_ms": 100},
|
| 137 |
+
reasoning="No memory operations required",
|
| 138 |
+
confidence=1.0,
|
| 139 |
+
agent_id=self.agent_id,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
return Action(
|
| 144 |
+
action_type=ActionType.FAIL,
|
| 145 |
+
parameters={"success": False, "message": str(e)},
|
| 146 |
+
reasoning=f"Memory operation error: {e}",
|
| 147 |
+
confidence=1.0,
|
| 148 |
+
agent_id=self.agent_id,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
async def plan(self, observation: Observation) -> list[Action]:
|
| 152 |
+
"""
|
| 153 |
+
Create a plan of memory operations.
|
| 154 |
+
|
| 155 |
+
Plans memory operations needed based on the current state
|
| 156 |
+
and extracted data.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
observation: The current state observation.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
A list of planned memory actions.
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
actions: list[Action] = []
|
| 166 |
+
|
| 167 |
+
# Plan to store extracted fields
|
| 168 |
+
for field in observation.extracted_so_far:
|
| 169 |
+
if field.verified and field.confidence > 0.8:
|
| 170 |
+
actions.append(
|
| 171 |
+
Action(
|
| 172 |
+
action_type=ActionType.STORE_MEMORY,
|
| 173 |
+
parameters={
|
| 174 |
+
"key": f"extracted:{field.field_name}",
|
| 175 |
+
"value": field.value,
|
| 176 |
+
"memory_type": "working",
|
| 177 |
+
"metadata": {
|
| 178 |
+
"source": observation.current_url,
|
| 179 |
+
"confidence": field.confidence,
|
| 180 |
+
},
|
| 181 |
+
},
|
| 182 |
+
reasoning=f"Storing verified field: {field.field_name}",
|
| 183 |
+
confidence=0.9,
|
| 184 |
+
agent_id=self.agent_id,
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Plan to recall relevant memories for current task
|
| 189 |
+
if observation.task_context:
|
| 190 |
+
for target in observation.task_context.target_fields:
|
| 191 |
+
actions.append(
|
| 192 |
+
Action(
|
| 193 |
+
action_type=ActionType.RECALL_MEMORY,
|
| 194 |
+
parameters={
|
| 195 |
+
"key": f"pattern:{target}",
|
| 196 |
+
"memory_type": "long_term",
|
| 197 |
+
},
|
| 198 |
+
reasoning=f"Recalling patterns for field: {target}",
|
| 199 |
+
confidence=0.7,
|
| 200 |
+
agent_id=self.agent_id,
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return actions
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
return [
|
| 208 |
+
Action(
|
| 209 |
+
action_type=ActionType.FAIL,
|
| 210 |
+
parameters={"message": f"Memory planning failed: {e}"},
|
| 211 |
+
reasoning=str(e),
|
| 212 |
+
confidence=1.0,
|
| 213 |
+
agent_id=self.agent_id,
|
| 214 |
+
)
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
def store(
|
| 218 |
+
self,
|
| 219 |
+
key: str,
|
| 220 |
+
value: Any,
|
| 221 |
+
memory_type: str = "working",
|
| 222 |
+
ttl_seconds: int | None = None,
|
| 223 |
+
metadata: dict[str, Any] | None = None,
|
| 224 |
+
) -> bool:
|
| 225 |
+
"""
|
| 226 |
+
Store a value in memory.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
key: The key to store under.
|
| 230 |
+
value: The value to store.
|
| 231 |
+
memory_type: Type of memory (short_term, working).
|
| 232 |
+
ttl_seconds: Optional time-to-live.
|
| 233 |
+
metadata: Optional metadata.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
True if stored successfully.
|
| 237 |
+
"""
|
| 238 |
+
entry = MemoryEntry(
|
| 239 |
+
key=key,
|
| 240 |
+
value=value,
|
| 241 |
+
memory_type=memory_type,
|
| 242 |
+
ttl_seconds=ttl_seconds,
|
| 243 |
+
metadata=metadata,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if memory_type == "short_term":
|
| 247 |
+
self._enforce_limit(self._short_term, self.max_short_term)
|
| 248 |
+
self._short_term[key] = entry
|
| 249 |
+
elif memory_type == "working":
|
| 250 |
+
self._enforce_limit(self._working, self.max_working)
|
| 251 |
+
self._working[key] = entry
|
| 252 |
+
else:
|
| 253 |
+
return False
|
| 254 |
+
|
| 255 |
+
return True
|
| 256 |
+
|
| 257 |
+
def recall(
|
| 258 |
+
self,
|
| 259 |
+
key: str,
|
| 260 |
+
memory_type: str | None = None,
|
| 261 |
+
) -> Any | None:
|
| 262 |
+
"""
|
| 263 |
+
Recall a value from memory.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
key: The key to recall.
|
| 267 |
+
memory_type: Optional specific memory type to search.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
The value if found, None otherwise.
|
| 271 |
+
"""
|
| 272 |
+
# Search in order of specificity
|
| 273 |
+
stores = []
|
| 274 |
+
if memory_type == "working" or memory_type is None:
|
| 275 |
+
stores.append(self._working)
|
| 276 |
+
if memory_type == "short_term" or memory_type is None:
|
| 277 |
+
stores.append(self._short_term)
|
| 278 |
+
|
| 279 |
+
for store in stores:
|
| 280 |
+
if key in store:
|
| 281 |
+
entry = store[key]
|
| 282 |
+
if not entry.is_expired():
|
| 283 |
+
return entry.access()
|
| 284 |
+
else:
|
| 285 |
+
# Clean up expired entry
|
| 286 |
+
del store[key]
|
| 287 |
+
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
def search(
|
| 291 |
+
self,
|
| 292 |
+
query: str,
|
| 293 |
+
memory_type: str | None = None,
|
| 294 |
+
limit: int = 10,
|
| 295 |
+
) -> list[dict[str, Any]]:
|
| 296 |
+
"""
|
| 297 |
+
Search memories by key prefix or content.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
query: Search query (matches key prefix).
|
| 301 |
+
memory_type: Optional specific memory type.
|
| 302 |
+
limit: Maximum results to return.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
List of matching memories.
|
| 306 |
+
"""
|
| 307 |
+
results: list[dict[str, Any]] = []
|
| 308 |
+
query_lower = query.lower()
|
| 309 |
+
|
| 310 |
+
stores = []
|
| 311 |
+
if memory_type in ("working", None):
|
| 312 |
+
stores.append(("working", self._working))
|
| 313 |
+
if memory_type in ("short_term", None):
|
| 314 |
+
stores.append(("short_term", self._short_term))
|
| 315 |
+
|
| 316 |
+
for store_name, store in stores:
|
| 317 |
+
for key, entry in store.items():
|
| 318 |
+
if entry.is_expired():
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
# Match by key prefix or value content
|
| 322 |
+
if (
|
| 323 |
+
key.lower().startswith(query_lower)
|
| 324 |
+
or query_lower in str(entry.value).lower()
|
| 325 |
+
):
|
| 326 |
+
results.append({
|
| 327 |
+
**entry.to_dict(),
|
| 328 |
+
"store": store_name,
|
| 329 |
+
})
|
| 330 |
+
|
| 331 |
+
if len(results) >= limit:
|
| 332 |
+
break
|
| 333 |
+
|
| 334 |
+
return results[:limit]
|
| 335 |
+
|
| 336 |
+
def _process_memory_request(self, message: dict[str, Any]) -> Action:
|
| 337 |
+
"""Process a memory request from another agent."""
|
| 338 |
+
content = message.get("content", {})
|
| 339 |
+
operation = content.get("operation", "recall")
|
| 340 |
+
key = content.get("key", "")
|
| 341 |
+
|
| 342 |
+
if operation == "store":
|
| 343 |
+
success = self.store(
|
| 344 |
+
key=key,
|
| 345 |
+
value=content.get("value"),
|
| 346 |
+
memory_type=content.get("memory_type", "working"),
|
| 347 |
+
ttl_seconds=content.get("ttl_seconds"),
|
| 348 |
+
metadata=content.get("metadata"),
|
| 349 |
+
)
|
| 350 |
+
return Action(
|
| 351 |
+
action_type=ActionType.STORE_MEMORY,
|
| 352 |
+
parameters={"key": key, "success": success},
|
| 353 |
+
reasoning=f"Processed store request for key: {key}",
|
| 354 |
+
confidence=1.0 if success else 0.5,
|
| 355 |
+
agent_id=self.agent_id,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
elif operation == "recall":
|
| 359 |
+
value = self.recall(key, content.get("memory_type"))
|
| 360 |
+
return Action(
|
| 361 |
+
action_type=ActionType.RECALL_MEMORY,
|
| 362 |
+
parameters={"key": key, "value": value, "found": value is not None},
|
| 363 |
+
reasoning=f"Processed recall request for key: {key}",
|
| 364 |
+
confidence=1.0 if value else 0.3,
|
| 365 |
+
agent_id=self.agent_id,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
else:
|
| 369 |
+
return Action(
|
| 370 |
+
action_type=ActionType.FAIL,
|
| 371 |
+
parameters={"message": f"Unknown memory operation: {operation}"},
|
| 372 |
+
reasoning=f"Invalid memory request",
|
| 373 |
+
confidence=1.0,
|
| 374 |
+
agent_id=self.agent_id,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def _check_for_storage(self, observation: Observation) -> Action | None:
|
| 378 |
+
"""Check if any new information should be stored."""
|
| 379 |
+
# Store newly extracted, verified fields
|
| 380 |
+
for field in observation.extracted_so_far:
|
| 381 |
+
key = f"field:{field.field_name}"
|
| 382 |
+
if key not in self._working and field.verified:
|
| 383 |
+
return Action(
|
| 384 |
+
action_type=ActionType.STORE_MEMORY,
|
| 385 |
+
parameters={
|
| 386 |
+
"key": key,
|
| 387 |
+
"value": {
|
| 388 |
+
"field_name": field.field_name,
|
| 389 |
+
"value": field.value,
|
| 390 |
+
"confidence": field.confidence,
|
| 391 |
+
"source": observation.current_url,
|
| 392 |
+
},
|
| 393 |
+
"memory_type": "working",
|
| 394 |
+
},
|
| 395 |
+
reasoning=f"Storing verified extraction: {field.field_name}",
|
| 396 |
+
confidence=0.85,
|
| 397 |
+
agent_id=self.agent_id,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
return None
|
| 401 |
+
|
| 402 |
+
def _check_for_consolidation(self) -> Action | None:
|
| 403 |
+
"""Check if any memories should be consolidated to long-term."""
|
| 404 |
+
for key, entry in self._working.items():
|
| 405 |
+
if entry.access_count >= self.consolidation_threshold:
|
| 406 |
+
return Action(
|
| 407 |
+
action_type=ActionType.STORE_MEMORY,
|
| 408 |
+
parameters={
|
| 409 |
+
"key": key,
|
| 410 |
+
"value": entry.value,
|
| 411 |
+
"memory_type": "long_term",
|
| 412 |
+
"metadata": {
|
| 413 |
+
"access_count": entry.access_count,
|
| 414 |
+
"consolidated_from": "working",
|
| 415 |
+
},
|
| 416 |
+
},
|
| 417 |
+
reasoning=f"Consolidating frequently accessed memory: {key}",
|
| 418 |
+
confidence=0.8,
|
| 419 |
+
agent_id=self.agent_id,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
return None
|
| 423 |
+
|
| 424 |
+
def _cleanup_expired(self) -> int:
|
| 425 |
+
"""Clean up expired memory entries."""
|
| 426 |
+
cleaned = 0
|
| 427 |
+
|
| 428 |
+
for store in [self._short_term, self._working]:
|
| 429 |
+
expired_keys = [
|
| 430 |
+
k for k, v in store.items()
|
| 431 |
+
if v.is_expired()
|
| 432 |
+
]
|
| 433 |
+
for key in expired_keys:
|
| 434 |
+
del store[key]
|
| 435 |
+
cleaned += 1
|
| 436 |
+
|
| 437 |
+
return cleaned
|
| 438 |
+
|
| 439 |
+
def _enforce_limit(
|
| 440 |
+
self,
|
| 441 |
+
store: dict[str, MemoryEntry],
|
| 442 |
+
limit: int,
|
| 443 |
+
) -> None:
|
| 444 |
+
"""Enforce memory limit by removing least accessed entries."""
|
| 445 |
+
if len(store) < limit:
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
# Sort by access count and last access time
|
| 449 |
+
sorted_entries = sorted(
|
| 450 |
+
store.items(),
|
| 451 |
+
key=lambda x: (x[1].access_count, x[1].accessed_at),
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
# Remove oldest/least accessed entries
|
| 455 |
+
to_remove = len(store) - limit + 1
|
| 456 |
+
for key, _ in sorted_entries[:to_remove]:
|
| 457 |
+
del store[key]
|
| 458 |
+
|
| 459 |
+
def get_memory_stats(self) -> dict[str, Any]:
|
| 460 |
+
"""Get statistics about memory usage."""
|
| 461 |
+
return {
|
| 462 |
+
"short_term_count": len(self._short_term),
|
| 463 |
+
"short_term_limit": self.max_short_term,
|
| 464 |
+
"working_count": len(self._working),
|
| 465 |
+
"working_limit": self.max_working,
|
| 466 |
+
"total_entries": len(self._short_term) + len(self._working),
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
def reset(self) -> None:
|
| 470 |
+
"""Reset the memory agent state."""
|
| 471 |
+
super().reset()
|
| 472 |
+
self._short_term.clear()
|
| 473 |
+
self._working.clear()
|
| 474 |
+
self._pending_operations.clear()
|
backend/app/agents/navigator.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Navigator agent for URL prioritization and page navigation."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
from urllib.parse import urljoin, urlparse
|
| 5 |
+
|
| 6 |
+
from app.core.action import Action, ActionType
|
| 7 |
+
from app.core.observation import Observation, PageElement
|
| 8 |
+
|
| 9 |
+
from .base import BaseAgent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class NavigatorAgent(BaseAgent):
|
| 13 |
+
"""
|
| 14 |
+
Agent responsible for intelligent page navigation.
|
| 15 |
+
|
| 16 |
+
The NavigatorAgent handles:
|
| 17 |
+
- URL prioritization based on relevance to task
|
| 18 |
+
- Link discovery and scoring
|
| 19 |
+
- Navigation decision making
|
| 20 |
+
- Handling pagination and multi-page content
|
| 21 |
+
- Avoiding irrelevant or harmful URLs
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
agent_id: str = "navigator",
|
| 27 |
+
config: dict[str, Any] | None = None,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the NavigatorAgent.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
agent_id: Unique identifier for this agent.
|
| 34 |
+
config: Optional configuration with keys:
|
| 35 |
+
- max_depth: Maximum navigation depth (default: 5)
|
| 36 |
+
- allowed_domains: List of allowed domains to visit
|
| 37 |
+
- blocked_patterns: URL patterns to avoid
|
| 38 |
+
- prioritize_https: Prefer HTTPS URLs (default: True)
|
| 39 |
+
"""
|
| 40 |
+
super().__init__(agent_id, config)
|
| 41 |
+
self.max_depth = self.config.get("max_depth", 5)
|
| 42 |
+
self.allowed_domains = self.config.get("allowed_domains", [])
|
| 43 |
+
self.blocked_patterns = self.config.get("blocked_patterns", [
|
| 44 |
+
"logout", "signout", "delete", "remove", "unsubscribe",
|
| 45 |
+
])
|
| 46 |
+
self.prioritize_https = self.config.get("prioritize_https", True)
|
| 47 |
+
self._visited_urls: set[str] = set()
|
| 48 |
+
self._url_scores: dict[str, float] = {}
|
| 49 |
+
|
| 50 |
+
async def act(self, observation: Observation) -> Action:
|
| 51 |
+
"""
|
| 52 |
+
Select the best navigation action based on observation.
|
| 53 |
+
|
| 54 |
+
Analyzes available links and decides whether to:
|
| 55 |
+
- Navigate to a new page
|
| 56 |
+
- Go back to a previous page
|
| 57 |
+
- Click an element to reveal more content
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
observation: The current state observation.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
The navigation action to execute.
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
# Track current URL
|
| 67 |
+
if observation.current_url:
|
| 68 |
+
self._visited_urls.add(observation.current_url)
|
| 69 |
+
|
| 70 |
+
# Check if we've reached max depth
|
| 71 |
+
nav_depth = len(observation.navigation_history)
|
| 72 |
+
if nav_depth >= self.max_depth:
|
| 73 |
+
return self._create_go_back_action(
|
| 74 |
+
"Reached maximum navigation depth"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Find best link to follow
|
| 78 |
+
best_link = await self._find_best_link(observation)
|
| 79 |
+
|
| 80 |
+
if best_link:
|
| 81 |
+
return self._create_navigate_action(best_link, observation)
|
| 82 |
+
|
| 83 |
+
# Check for pagination
|
| 84 |
+
pagination_action = self._find_pagination(observation)
|
| 85 |
+
if pagination_action:
|
| 86 |
+
return pagination_action
|
| 87 |
+
|
| 88 |
+
# No good links, consider going back
|
| 89 |
+
if observation.can_go_back and nav_depth > 1:
|
| 90 |
+
return self._create_go_back_action(
|
| 91 |
+
"No relevant links found, going back"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Nothing to navigate to
|
| 95 |
+
return Action(
|
| 96 |
+
action_type=ActionType.WAIT,
|
| 97 |
+
parameters={"duration_ms": 500},
|
| 98 |
+
reasoning="No navigation targets found",
|
| 99 |
+
confidence=0.5,
|
| 100 |
+
agent_id=self.agent_id,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
return Action(
|
| 105 |
+
action_type=ActionType.FAIL,
|
| 106 |
+
parameters={"success": False, "message": str(e)},
|
| 107 |
+
reasoning=f"Navigation error: {e}",
|
| 108 |
+
confidence=1.0,
|
| 109 |
+
agent_id=self.agent_id,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
async def plan(self, observation: Observation) -> list[Action]:
|
| 113 |
+
"""
|
| 114 |
+
Create a navigation plan based on task requirements.
|
| 115 |
+
|
| 116 |
+
Plans a sequence of navigation actions to reach content
|
| 117 |
+
relevant to the task.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
observation: The current state observation.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
A list of planned navigation actions.
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
actions: list[Action] = []
|
| 127 |
+
task_context = observation.task_context
|
| 128 |
+
|
| 129 |
+
if not task_context:
|
| 130 |
+
return []
|
| 131 |
+
|
| 132 |
+
# Analyze task hints for navigation targets
|
| 133 |
+
target_urls = self._extract_urls_from_hints(task_context.hints)
|
| 134 |
+
|
| 135 |
+
for url in target_urls[:3]: # Limit to top 3 URLs
|
| 136 |
+
if url not in self._visited_urls:
|
| 137 |
+
actions.append(
|
| 138 |
+
Action(
|
| 139 |
+
action_type=ActionType.NAVIGATE,
|
| 140 |
+
parameters={"url": url, "timeout_ms": 30000},
|
| 141 |
+
reasoning=f"Navigating to task-relevant URL: {url}",
|
| 142 |
+
confidence=0.85,
|
| 143 |
+
agent_id=self.agent_id,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# If no URLs from hints, plan a search
|
| 148 |
+
if not actions:
|
| 149 |
+
search_query = self._build_search_query(task_context)
|
| 150 |
+
actions.append(
|
| 151 |
+
Action(
|
| 152 |
+
action_type=ActionType.SEARCH_ENGINE,
|
| 153 |
+
parameters={"query": search_query, "engine": "google"},
|
| 154 |
+
reasoning=f"Searching for: {search_query}",
|
| 155 |
+
confidence=0.7,
|
| 156 |
+
agent_id=self.agent_id,
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return actions
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
return [
|
| 164 |
+
Action(
|
| 165 |
+
action_type=ActionType.FAIL,
|
| 166 |
+
parameters={"message": f"Navigation planning failed: {e}"},
|
| 167 |
+
reasoning=str(e),
|
| 168 |
+
confidence=1.0,
|
| 169 |
+
agent_id=self.agent_id,
|
| 170 |
+
)
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
async def _find_best_link(self, observation: Observation) -> str | None:
|
| 174 |
+
"""Find the best link to follow based on task relevance."""
|
| 175 |
+
if not observation.task_context:
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
target_fields = observation.task_context.target_fields
|
| 179 |
+
remaining_fields = observation.fields_remaining
|
| 180 |
+
|
| 181 |
+
# Score all links on the page
|
| 182 |
+
link_scores: list[tuple[str, float]] = []
|
| 183 |
+
|
| 184 |
+
for element in observation.page_elements:
|
| 185 |
+
if not element.is_interactive:
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
href = element.attributes.get("href", "")
|
| 189 |
+
if not href or href.startswith("#") or href.startswith("javascript:"):
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
# Resolve relative URLs
|
| 193 |
+
full_url = self._resolve_url(href, observation.current_url)
|
| 194 |
+
if not full_url:
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
# Skip already visited URLs
|
| 198 |
+
if full_url in self._visited_urls:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
# Skip blocked patterns
|
| 202 |
+
if self._is_blocked_url(full_url):
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
# Check domain restrictions
|
| 206 |
+
if not self._is_allowed_domain(full_url):
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
# Score the link
|
| 210 |
+
score = self._score_link(element, full_url, remaining_fields)
|
| 211 |
+
if score > 0:
|
| 212 |
+
link_scores.append((full_url, score))
|
| 213 |
+
|
| 214 |
+
# Return highest scoring link
|
| 215 |
+
if link_scores:
|
| 216 |
+
link_scores.sort(key=lambda x: x[1], reverse=True)
|
| 217 |
+
return link_scores[0][0]
|
| 218 |
+
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
def _score_link(
|
| 222 |
+
self,
|
| 223 |
+
element: PageElement,
|
| 224 |
+
url: str,
|
| 225 |
+
target_fields: list[str],
|
| 226 |
+
) -> float:
|
| 227 |
+
"""Score a link based on relevance to task fields."""
|
| 228 |
+
score = 0.0
|
| 229 |
+
text = (element.text or "").lower()
|
| 230 |
+
url_lower = url.lower()
|
| 231 |
+
|
| 232 |
+
# Check if link text contains target field names
|
| 233 |
+
for field in target_fields:
|
| 234 |
+
field_lower = field.lower()
|
| 235 |
+
if field_lower in text:
|
| 236 |
+
score += 0.4
|
| 237 |
+
if field_lower in url_lower:
|
| 238 |
+
score += 0.3
|
| 239 |
+
|
| 240 |
+
# Prefer HTTPS
|
| 241 |
+
if self.prioritize_https and url.startswith("https://"):
|
| 242 |
+
score += 0.1
|
| 243 |
+
|
| 244 |
+
# Boost content-like URLs
|
| 245 |
+
content_indicators = ["detail", "view", "info", "about", "product", "page"]
|
| 246 |
+
for indicator in content_indicators:
|
| 247 |
+
if indicator in url_lower:
|
| 248 |
+
score += 0.2
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
# Penalize non-content URLs
|
| 252 |
+
noise_indicators = ["login", "cart", "checkout", "share", "print"]
|
| 253 |
+
for indicator in noise_indicators:
|
| 254 |
+
if indicator in url_lower:
|
| 255 |
+
score -= 0.3
|
| 256 |
+
break
|
| 257 |
+
|
| 258 |
+
return max(0.0, score)
|
| 259 |
+
|
| 260 |
+
def _resolve_url(self, href: str, base_url: str | None) -> str | None:
|
| 261 |
+
"""Resolve a relative URL to an absolute URL."""
|
| 262 |
+
if not href:
|
| 263 |
+
return None
|
| 264 |
+
|
| 265 |
+
if href.startswith(("http://", "https://")):
|
| 266 |
+
return href
|
| 267 |
+
|
| 268 |
+
if not base_url:
|
| 269 |
+
return None
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
return urljoin(base_url, href)
|
| 273 |
+
except Exception:
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
def _is_blocked_url(self, url: str) -> bool:
|
| 277 |
+
"""Check if URL matches any blocked patterns."""
|
| 278 |
+
url_lower = url.lower()
|
| 279 |
+
for pattern in self.blocked_patterns:
|
| 280 |
+
if pattern.lower() in url_lower:
|
| 281 |
+
return True
|
| 282 |
+
return False
|
| 283 |
+
|
| 284 |
+
def _is_allowed_domain(self, url: str) -> bool:
|
| 285 |
+
"""Check if URL domain is allowed."""
|
| 286 |
+
if not self.allowed_domains:
|
| 287 |
+
return True
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
parsed = urlparse(url)
|
| 291 |
+
domain = parsed.netloc.lower()
|
| 292 |
+
for allowed in self.allowed_domains:
|
| 293 |
+
if domain == allowed.lower() or domain.endswith("." + allowed.lower()):
|
| 294 |
+
return True
|
| 295 |
+
return False
|
| 296 |
+
except Exception:
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
def _find_pagination(self, observation: Observation) -> Action | None:
|
| 300 |
+
"""Find and create action for pagination elements."""
|
| 301 |
+
pagination_selectors = [
|
| 302 |
+
"[aria-label*='next']",
|
| 303 |
+
"[aria-label*='Next']",
|
| 304 |
+
"a.next",
|
| 305 |
+
"button.next",
|
| 306 |
+
"[rel='next']",
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
for element in observation.page_elements:
|
| 310 |
+
text = (element.text or "").lower()
|
| 311 |
+
if element.is_interactive and ("next" in text or "more" in text):
|
| 312 |
+
return Action(
|
| 313 |
+
action_type=ActionType.CLICK,
|
| 314 |
+
parameters={"selector": element.selector},
|
| 315 |
+
reasoning="Clicking pagination to load more content",
|
| 316 |
+
confidence=0.7,
|
| 317 |
+
agent_id=self.agent_id,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return None
|
| 321 |
+
|
| 322 |
+
def _extract_urls_from_hints(self, hints: list[str]) -> list[str]:
|
| 323 |
+
"""Extract URLs from task hints."""
|
| 324 |
+
urls = []
|
| 325 |
+
for hint in hints:
|
| 326 |
+
if hint.startswith(("http://", "https://")):
|
| 327 |
+
urls.append(hint)
|
| 328 |
+
elif "://" not in hint and "." in hint:
|
| 329 |
+
# Might be a domain without protocol
|
| 330 |
+
urls.append(f"https://{hint}")
|
| 331 |
+
return urls
|
| 332 |
+
|
| 333 |
+
def _build_search_query(self, task_context: Any) -> str:
|
| 334 |
+
"""Build a search query from task context."""
|
| 335 |
+
parts = [task_context.task_name]
|
| 336 |
+
if task_context.target_fields:
|
| 337 |
+
parts.extend(task_context.target_fields[:2])
|
| 338 |
+
return " ".join(parts)
|
| 339 |
+
|
| 340 |
+
def _create_navigate_action(self, url: str, observation: Observation) -> Action:
|
| 341 |
+
"""Create a navigate action for the given URL."""
|
| 342 |
+
return Action(
|
| 343 |
+
action_type=ActionType.NAVIGATE,
|
| 344 |
+
parameters={"url": url, "timeout_ms": 30000},
|
| 345 |
+
reasoning=f"Navigating to relevant URL: {url}",
|
| 346 |
+
confidence=0.75,
|
| 347 |
+
agent_id=self.agent_id,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def _create_go_back_action(self, reason: str) -> Action:
|
| 351 |
+
"""Create a go back action."""
|
| 352 |
+
return Action(
|
| 353 |
+
action_type=ActionType.GO_BACK,
|
| 354 |
+
parameters={},
|
| 355 |
+
reasoning=reason,
|
| 356 |
+
confidence=0.8,
|
| 357 |
+
agent_id=self.agent_id,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
def get_visited_urls(self) -> set[str]:
|
| 361 |
+
"""Get the set of visited URLs."""
|
| 362 |
+
return self._visited_urls.copy()
|
| 363 |
+
|
| 364 |
+
def reset(self) -> None:
|
| 365 |
+
"""Reset the navigator state."""
|
| 366 |
+
super().reset()
|
| 367 |
+
self._visited_urls.clear()
|
| 368 |
+
self._url_scores.clear()
|
backend/app/agents/planner.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Planner agent for goal decomposition and task planning."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from app.core.action import Action, ActionType
|
| 6 |
+
from app.core.observation import Observation
|
| 7 |
+
|
| 8 |
+
from .base import BaseAgent
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PlannerAgent(BaseAgent):
|
| 12 |
+
"""
|
| 13 |
+
Agent responsible for high-level planning and goal decomposition.
|
| 14 |
+
|
| 15 |
+
The PlannerAgent analyzes the task requirements and creates
|
| 16 |
+
structured plans that other agents can execute. It handles:
|
| 17 |
+
- Breaking down complex tasks into subtasks
|
| 18 |
+
- Determining the optimal sequence of actions
|
| 19 |
+
- Adapting plans based on execution results
|
| 20 |
+
- Coordinating multi-step extraction workflows
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
agent_id: str = "planner",
|
| 26 |
+
config: dict[str, Any] | None = None,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Initialize the PlannerAgent.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
agent_id: Unique identifier for this agent.
|
| 33 |
+
config: Optional configuration with keys:
|
| 34 |
+
- max_plan_depth: Maximum depth of nested plans (default: 5)
|
| 35 |
+
- replan_threshold: Error count before replanning (default: 2)
|
| 36 |
+
- planning_model: LLM model to use for planning
|
| 37 |
+
"""
|
| 38 |
+
super().__init__(agent_id, config)
|
| 39 |
+
self.max_plan_depth = self.config.get("max_plan_depth", 5)
|
| 40 |
+
self.replan_threshold = self.config.get("replan_threshold", 2)
|
| 41 |
+
self._current_plan: list[Action] | None = None
|
| 42 |
+
self._plan_step: int = 0
|
| 43 |
+
|
| 44 |
+
async def act(self, observation: Observation) -> Action:
|
| 45 |
+
"""
|
| 46 |
+
Select the next action based on the current plan or create a new one.
|
| 47 |
+
|
| 48 |
+
If no plan exists or the current plan has failed, creates a new plan.
|
| 49 |
+
Otherwise, returns the next action in the current plan.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
observation: The current state observation.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
The next action to execute.
|
| 56 |
+
"""
|
| 57 |
+
try:
|
| 58 |
+
# Check if we need to replan due to errors
|
| 59 |
+
if observation.consecutive_errors >= self.replan_threshold:
|
| 60 |
+
self._current_plan = None
|
| 61 |
+
self._plan_step = 0
|
| 62 |
+
|
| 63 |
+
# Create plan if none exists
|
| 64 |
+
if self._current_plan is None or self._plan_step >= len(self._current_plan):
|
| 65 |
+
self._current_plan = await self.plan(observation)
|
| 66 |
+
self._plan_step = 0
|
| 67 |
+
|
| 68 |
+
if not self._current_plan:
|
| 69 |
+
return self._create_done_action("No actions planned")
|
| 70 |
+
|
| 71 |
+
# Get next action from plan
|
| 72 |
+
action = self._current_plan[self._plan_step]
|
| 73 |
+
action.plan_step = self._plan_step
|
| 74 |
+
action.agent_id = self.agent_id
|
| 75 |
+
self._plan_step += 1
|
| 76 |
+
|
| 77 |
+
return action
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
return self._create_error_action(f"Planning error: {e}")
|
| 81 |
+
|
| 82 |
+
async def plan(self, observation: Observation) -> list[Action]:
|
| 83 |
+
"""
|
| 84 |
+
Create a plan of actions to achieve the task goals.
|
| 85 |
+
|
| 86 |
+
Analyzes the observation to determine:
|
| 87 |
+
- What fields still need to be extracted
|
| 88 |
+
- What navigation may be required
|
| 89 |
+
- What verification steps are needed
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
observation: The current state observation.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
A list of planned actions in execution order.
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
actions: list[Action] = []
|
| 99 |
+
task_context = observation.task_context
|
| 100 |
+
|
| 101 |
+
if not task_context:
|
| 102 |
+
return [self._create_done_action("No task context provided")]
|
| 103 |
+
|
| 104 |
+
# Determine remaining fields to extract
|
| 105 |
+
remaining_fields = observation.fields_remaining
|
| 106 |
+
extracted_fields = [f.field_name for f in observation.extracted_so_far]
|
| 107 |
+
|
| 108 |
+
# If no URL loaded, plan navigation first
|
| 109 |
+
if not observation.current_url:
|
| 110 |
+
search_action = self._plan_initial_navigation(task_context)
|
| 111 |
+
if search_action:
|
| 112 |
+
actions.append(search_action)
|
| 113 |
+
|
| 114 |
+
# Plan extraction for remaining fields
|
| 115 |
+
for field in remaining_fields:
|
| 116 |
+
extraction_action = self._plan_field_extraction(
|
| 117 |
+
field,
|
| 118 |
+
observation,
|
| 119 |
+
)
|
| 120 |
+
actions.append(extraction_action)
|
| 121 |
+
|
| 122 |
+
# Plan verification if fields have been extracted
|
| 123 |
+
if extracted_fields:
|
| 124 |
+
verify_action = self._plan_verification(extracted_fields)
|
| 125 |
+
actions.append(verify_action)
|
| 126 |
+
|
| 127 |
+
# Add completion action
|
| 128 |
+
actions.append(
|
| 129 |
+
Action(
|
| 130 |
+
action_type=ActionType.DONE,
|
| 131 |
+
parameters={"success": True, "message": "Plan completed"},
|
| 132 |
+
reasoning="All planned steps completed",
|
| 133 |
+
confidence=0.9,
|
| 134 |
+
agent_id=self.agent_id,
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return actions
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
return [self._create_error_action(f"Plan creation failed: {e}")]
|
| 142 |
+
|
| 143 |
+
def _plan_initial_navigation(self, task_context: Any) -> Action | None:
|
| 144 |
+
"""Plan initial navigation based on task context."""
|
| 145 |
+
if task_context.hints:
|
| 146 |
+
# Use hints for navigation
|
| 147 |
+
for hint in task_context.hints:
|
| 148 |
+
if hint.startswith("http"):
|
| 149 |
+
return Action(
|
| 150 |
+
action_type=ActionType.NAVIGATE,
|
| 151 |
+
parameters={"url": hint},
|
| 152 |
+
reasoning=f"Navigating to hinted URL: {hint}",
|
| 153 |
+
confidence=0.85,
|
| 154 |
+
agent_id=self.agent_id,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Default to search
|
| 158 |
+
search_query = f"{task_context.task_name} site information"
|
| 159 |
+
return Action(
|
| 160 |
+
action_type=ActionType.SEARCH_ENGINE,
|
| 161 |
+
parameters={"query": search_query, "engine": "google"},
|
| 162 |
+
reasoning=f"Searching for: {search_query}",
|
| 163 |
+
confidence=0.7,
|
| 164 |
+
agent_id=self.agent_id,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def _plan_field_extraction(
|
| 168 |
+
self,
|
| 169 |
+
field_name: str,
|
| 170 |
+
observation: Observation,
|
| 171 |
+
) -> Action:
|
| 172 |
+
"""Plan extraction for a specific field."""
|
| 173 |
+
# Check if we have page elements that might contain the field
|
| 174 |
+
selector = None
|
| 175 |
+
confidence = 0.6
|
| 176 |
+
|
| 177 |
+
for element in observation.page_elements:
|
| 178 |
+
element_text = (element.text or "").lower()
|
| 179 |
+
if field_name.lower() in element_text:
|
| 180 |
+
selector = element.selector
|
| 181 |
+
confidence = 0.8
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
return Action(
|
| 185 |
+
action_type=ActionType.EXTRACT_FIELD,
|
| 186 |
+
parameters={
|
| 187 |
+
"field_name": field_name,
|
| 188 |
+
"selector": selector,
|
| 189 |
+
"extraction_method": "text",
|
| 190 |
+
},
|
| 191 |
+
reasoning=f"Extracting field: {field_name}",
|
| 192 |
+
confidence=confidence,
|
| 193 |
+
agent_id=self.agent_id,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def _plan_verification(self, fields: list[str]) -> Action:
|
| 197 |
+
"""Plan verification for extracted fields."""
|
| 198 |
+
return Action(
|
| 199 |
+
action_type=ActionType.VERIFY_FIELD,
|
| 200 |
+
parameters={
|
| 201 |
+
"field_name": fields[0] if fields else "unknown",
|
| 202 |
+
"validation_rules": ["not_empty", "format_check"],
|
| 203 |
+
},
|
| 204 |
+
reasoning=f"Verifying extracted fields: {fields}",
|
| 205 |
+
confidence=0.75,
|
| 206 |
+
agent_id=self.agent_id,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def _create_done_action(self, message: str) -> Action:
|
| 210 |
+
"""Create a done action."""
|
| 211 |
+
return Action(
|
| 212 |
+
action_type=ActionType.DONE,
|
| 213 |
+
parameters={"success": True, "message": message},
|
| 214 |
+
reasoning=message,
|
| 215 |
+
confidence=1.0,
|
| 216 |
+
agent_id=self.agent_id,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def _create_error_action(self, error: str) -> Action:
|
| 220 |
+
"""Create a fail action for errors."""
|
| 221 |
+
return Action(
|
| 222 |
+
action_type=ActionType.FAIL,
|
| 223 |
+
parameters={"success": False, "message": error},
|
| 224 |
+
reasoning=error,
|
| 225 |
+
confidence=1.0,
|
| 226 |
+
agent_id=self.agent_id,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def get_current_plan(self) -> list[Action] | None:
|
| 230 |
+
"""Get the current plan."""
|
| 231 |
+
return self._current_plan
|
| 232 |
+
|
| 233 |
+
def get_plan_progress(self) -> tuple[int, int]:
|
| 234 |
+
"""Get current plan progress as (current_step, total_steps)."""
|
| 235 |
+
total = len(self._current_plan) if self._current_plan else 0
|
| 236 |
+
return (self._plan_step, total)
|
| 237 |
+
|
| 238 |
+
def reset(self) -> None:
|
| 239 |
+
"""Reset the planner state."""
|
| 240 |
+
super().reset()
|
| 241 |
+
self._current_plan = None
|
| 242 |
+
self._plan_step = 0
|
backend/app/agents/verifier.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Verifier agent for cross-source verification."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from app.core.action import Action, ActionType
|
| 7 |
+
from app.core.observation import ExtractedField, Observation
|
| 8 |
+
|
| 9 |
+
from .base import BaseAgent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VerificationResult:
|
| 13 |
+
"""Result of a verification check."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
field_name: str,
|
| 18 |
+
is_valid: bool,
|
| 19 |
+
confidence: float,
|
| 20 |
+
issues: list[str] | None = None,
|
| 21 |
+
sources_checked: int = 0,
|
| 22 |
+
):
|
| 23 |
+
"""Initialize verification result."""
|
| 24 |
+
self.field_name = field_name
|
| 25 |
+
self.is_valid = is_valid
|
| 26 |
+
self.confidence = confidence
|
| 27 |
+
self.issues = issues or []
|
| 28 |
+
self.sources_checked = sources_checked
|
| 29 |
+
|
| 30 |
+
def to_dict(self) -> dict[str, Any]:
|
| 31 |
+
"""Convert to dictionary."""
|
| 32 |
+
return {
|
| 33 |
+
"field_name": self.field_name,
|
| 34 |
+
"is_valid": self.is_valid,
|
| 35 |
+
"confidence": self.confidence,
|
| 36 |
+
"issues": self.issues,
|
| 37 |
+
"sources_checked": self.sources_checked,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class VerifierAgent(BaseAgent):
|
| 42 |
+
"""
|
| 43 |
+
Agent responsible for verifying extracted data.
|
| 44 |
+
|
| 45 |
+
The VerifierAgent handles:
|
| 46 |
+
- Format validation (emails, URLs, dates, etc.)
|
| 47 |
+
- Cross-source verification
|
| 48 |
+
- Consistency checks across fields
|
| 49 |
+
- Confidence scoring for verified data
|
| 50 |
+
- Flagging suspicious or inconsistent data
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
agent_id: str = "verifier",
|
| 56 |
+
config: dict[str, Any] | None = None,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Initialize the VerifierAgent.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
agent_id: Unique identifier for this agent.
|
| 63 |
+
config: Optional configuration with keys:
|
| 64 |
+
- min_confidence: Minimum confidence to accept (default: 0.7)
|
| 65 |
+
- require_cross_validation: Require multiple sources (default: False)
|
| 66 |
+
- strict_mode: Apply stricter validation rules (default: False)
|
| 67 |
+
"""
|
| 68 |
+
super().__init__(agent_id, config)
|
| 69 |
+
self.min_confidence = self.config.get("min_confidence", 0.7)
|
| 70 |
+
self.require_cross_validation = self.config.get("require_cross_validation", False)
|
| 71 |
+
self.strict_mode = self.config.get("strict_mode", False)
|
| 72 |
+
self._validation_rules = self._init_validation_rules()
|
| 73 |
+
self._verification_history: list[VerificationResult] = []
|
| 74 |
+
|
| 75 |
+
def _init_validation_rules(self) -> dict[str, list[dict[str, Any]]]:
|
| 76 |
+
"""Initialize validation rules for common field types."""
|
| 77 |
+
return {
|
| 78 |
+
"email": [
|
| 79 |
+
{
|
| 80 |
+
"type": "regex",
|
| 81 |
+
"pattern": r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
|
| 82 |
+
"error": "Invalid email format",
|
| 83 |
+
},
|
| 84 |
+
],
|
| 85 |
+
"url": [
|
| 86 |
+
{
|
| 87 |
+
"type": "regex",
|
| 88 |
+
"pattern": r"^https?://[^\s]+$",
|
| 89 |
+
"error": "Invalid URL format",
|
| 90 |
+
},
|
| 91 |
+
],
|
| 92 |
+
"phone": [
|
| 93 |
+
{
|
| 94 |
+
"type": "regex",
|
| 95 |
+
"pattern": r"[\d\s\-\(\)\+]{7,}",
|
| 96 |
+
"error": "Invalid phone format",
|
| 97 |
+
},
|
| 98 |
+
],
|
| 99 |
+
"price": [
|
| 100 |
+
{
|
| 101 |
+
"type": "range",
|
| 102 |
+
"min": 0,
|
| 103 |
+
"max": 1000000,
|
| 104 |
+
"error": "Price out of reasonable range",
|
| 105 |
+
},
|
| 106 |
+
],
|
| 107 |
+
"date": [
|
| 108 |
+
{
|
| 109 |
+
"type": "regex",
|
| 110 |
+
"pattern": r"\d{1,4}[-/]\d{1,2}[-/]\d{1,4}",
|
| 111 |
+
"error": "Invalid date format",
|
| 112 |
+
},
|
| 113 |
+
],
|
| 114 |
+
"rating": [
|
| 115 |
+
{
|
| 116 |
+
"type": "range",
|
| 117 |
+
"min": 0,
|
| 118 |
+
"max": 5,
|
| 119 |
+
"error": "Rating out of range",
|
| 120 |
+
},
|
| 121 |
+
],
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
async def act(self, observation: Observation) -> Action:
|
| 125 |
+
"""
|
| 126 |
+
Select the best verification action based on observation.
|
| 127 |
+
|
| 128 |
+
Determines which extracted fields need verification and
|
| 129 |
+
selects the appropriate verification method.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
observation: The current state observation.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
The verification action to execute.
|
| 136 |
+
"""
|
| 137 |
+
try:
|
| 138 |
+
# Find unverified fields
|
| 139 |
+
unverified = [
|
| 140 |
+
f for f in observation.extracted_so_far
|
| 141 |
+
if not f.verified
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
if not unverified:
|
| 145 |
+
return Action(
|
| 146 |
+
action_type=ActionType.DONE,
|
| 147 |
+
parameters={"success": True, "message": "All fields verified"},
|
| 148 |
+
reasoning="No unverified fields remaining",
|
| 149 |
+
confidence=1.0,
|
| 150 |
+
agent_id=self.agent_id,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Verify the first unverified field
|
| 154 |
+
field = unverified[0]
|
| 155 |
+
result = await self._verify_field(field, observation)
|
| 156 |
+
|
| 157 |
+
if result.is_valid and result.confidence >= self.min_confidence:
|
| 158 |
+
return Action(
|
| 159 |
+
action_type=ActionType.VERIFY_FIELD,
|
| 160 |
+
parameters={
|
| 161 |
+
"field_name": field.field_name,
|
| 162 |
+
"verified": True,
|
| 163 |
+
"confidence": result.confidence,
|
| 164 |
+
"issues": result.issues,
|
| 165 |
+
},
|
| 166 |
+
reasoning=f"Field {field.field_name} verified with confidence {result.confidence:.2f}",
|
| 167 |
+
confidence=result.confidence,
|
| 168 |
+
agent_id=self.agent_id,
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
# Verification failed - may need re-extraction
|
| 172 |
+
return self._create_reverify_action(field, result)
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
return Action(
|
| 176 |
+
action_type=ActionType.FAIL,
|
| 177 |
+
parameters={"success": False, "message": str(e)},
|
| 178 |
+
reasoning=f"Verification error: {e}",
|
| 179 |
+
confidence=1.0,
|
| 180 |
+
agent_id=self.agent_id,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
async def plan(self, observation: Observation) -> list[Action]:
|
| 184 |
+
"""
|
| 185 |
+
Create a verification plan for all extracted fields.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
observation: The current state observation.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
A list of planned verification actions.
|
| 192 |
+
"""
|
| 193 |
+
try:
|
| 194 |
+
actions: list[Action] = []
|
| 195 |
+
|
| 196 |
+
# Plan verification for each unverified field
|
| 197 |
+
for field in observation.extracted_so_far:
|
| 198 |
+
if field.verified:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
# Basic format verification
|
| 202 |
+
actions.append(
|
| 203 |
+
Action(
|
| 204 |
+
action_type=ActionType.VERIFY_FIELD,
|
| 205 |
+
parameters={
|
| 206 |
+
"field_name": field.field_name,
|
| 207 |
+
"expected_type": self._infer_field_type(field.field_name),
|
| 208 |
+
},
|
| 209 |
+
reasoning=f"Verify format of {field.field_name}",
|
| 210 |
+
confidence=0.8,
|
| 211 |
+
agent_id=self.agent_id,
|
| 212 |
+
)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Cross-source verification if required
|
| 216 |
+
if self.require_cross_validation:
|
| 217 |
+
actions.append(
|
| 218 |
+
Action(
|
| 219 |
+
action_type=ActionType.VERIFY_FACT,
|
| 220 |
+
parameters={
|
| 221 |
+
"claim": f"{field.field_name}: {field.value}",
|
| 222 |
+
"confidence_threshold": self.min_confidence,
|
| 223 |
+
},
|
| 224 |
+
reasoning=f"Cross-validate {field.field_name} with other sources",
|
| 225 |
+
confidence=0.7,
|
| 226 |
+
agent_id=self.agent_id,
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return actions
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
return [
|
| 234 |
+
Action(
|
| 235 |
+
action_type=ActionType.FAIL,
|
| 236 |
+
parameters={"message": f"Verification planning failed: {e}"},
|
| 237 |
+
reasoning=str(e),
|
| 238 |
+
confidence=1.0,
|
| 239 |
+
agent_id=self.agent_id,
|
| 240 |
+
)
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
async def _verify_field(
|
| 244 |
+
self,
|
| 245 |
+
field: ExtractedField,
|
| 246 |
+
observation: Observation,
|
| 247 |
+
) -> VerificationResult:
|
| 248 |
+
"""
|
| 249 |
+
Verify a single field.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
field: The field to verify.
|
| 253 |
+
observation: Current observation context.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Verification result.
|
| 257 |
+
"""
|
| 258 |
+
issues: list[str] = []
|
| 259 |
+
confidence = field.confidence
|
| 260 |
+
sources_checked = 1
|
| 261 |
+
|
| 262 |
+
# Apply validation rules
|
| 263 |
+
field_type = self._infer_field_type(field.field_name)
|
| 264 |
+
format_valid, format_issues = self._validate_format(
|
| 265 |
+
field.value,
|
| 266 |
+
field_type,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if not format_valid:
|
| 270 |
+
issues.extend(format_issues)
|
| 271 |
+
confidence *= 0.5
|
| 272 |
+
|
| 273 |
+
# Check for empty or null values
|
| 274 |
+
if field.value is None or (
|
| 275 |
+
isinstance(field.value, str) and not field.value.strip()
|
| 276 |
+
):
|
| 277 |
+
issues.append("Empty value")
|
| 278 |
+
confidence = 0.0
|
| 279 |
+
|
| 280 |
+
# Check against memory context for consistency
|
| 281 |
+
consistency_issues = self._check_consistency(field, observation)
|
| 282 |
+
if consistency_issues:
|
| 283 |
+
issues.extend(consistency_issues)
|
| 284 |
+
confidence *= 0.8
|
| 285 |
+
|
| 286 |
+
# Create result
|
| 287 |
+
result = VerificationResult(
|
| 288 |
+
field_name=field.field_name,
|
| 289 |
+
is_valid=len(issues) == 0,
|
| 290 |
+
confidence=confidence,
|
| 291 |
+
issues=issues,
|
| 292 |
+
sources_checked=sources_checked,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self._verification_history.append(result)
|
| 296 |
+
return result
|
| 297 |
+
|
| 298 |
+
def _validate_format(
|
| 299 |
+
self,
|
| 300 |
+
value: Any,
|
| 301 |
+
field_type: str,
|
| 302 |
+
) -> tuple[bool, list[str]]:
|
| 303 |
+
"""
|
| 304 |
+
Validate value format against rules.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
value: The value to validate.
|
| 308 |
+
field_type: The expected field type.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
Tuple of (is_valid, list of issues).
|
| 312 |
+
"""
|
| 313 |
+
if value is None:
|
| 314 |
+
return False, ["Value is None"]
|
| 315 |
+
|
| 316 |
+
issues: list[str] = []
|
| 317 |
+
rules = self._validation_rules.get(field_type, [])
|
| 318 |
+
|
| 319 |
+
value_str = str(value)
|
| 320 |
+
|
| 321 |
+
for rule in rules:
|
| 322 |
+
rule_type = rule.get("type")
|
| 323 |
+
|
| 324 |
+
if rule_type == "regex":
|
| 325 |
+
pattern = rule.get("pattern", "")
|
| 326 |
+
if not re.match(pattern, value_str):
|
| 327 |
+
issues.append(rule.get("error", "Format validation failed"))
|
| 328 |
+
|
| 329 |
+
elif rule_type == "range":
|
| 330 |
+
try:
|
| 331 |
+
num_value = float(value_str.replace(",", "").replace("$", ""))
|
| 332 |
+
min_val = rule.get("min", float("-inf"))
|
| 333 |
+
max_val = rule.get("max", float("inf"))
|
| 334 |
+
if not (min_val <= num_value <= max_val):
|
| 335 |
+
issues.append(rule.get("error", "Value out of range"))
|
| 336 |
+
except ValueError:
|
| 337 |
+
issues.append("Cannot convert to number for range check")
|
| 338 |
+
|
| 339 |
+
elif rule_type == "length":
|
| 340 |
+
min_len = rule.get("min", 0)
|
| 341 |
+
max_len = rule.get("max", float("inf"))
|
| 342 |
+
if not (min_len <= len(value_str) <= max_len):
|
| 343 |
+
issues.append(rule.get("error", "Length validation failed"))
|
| 344 |
+
|
| 345 |
+
return len(issues) == 0, issues
|
| 346 |
+
|
| 347 |
+
def _check_consistency(
|
| 348 |
+
self,
|
| 349 |
+
field: ExtractedField,
|
| 350 |
+
observation: Observation,
|
| 351 |
+
) -> list[str]:
|
| 352 |
+
"""
|
| 353 |
+
Check field consistency with other data.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
field: The field to check.
|
| 357 |
+
observation: Current observation.
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
List of consistency issues.
|
| 361 |
+
"""
|
| 362 |
+
issues: list[str] = []
|
| 363 |
+
|
| 364 |
+
# Check against other extracted fields
|
| 365 |
+
for other in observation.extracted_so_far:
|
| 366 |
+
if other.field_name == field.field_name:
|
| 367 |
+
continue
|
| 368 |
+
|
| 369 |
+
# Example: price should be less than total_price
|
| 370 |
+
if field.field_name == "price" and other.field_name == "total_price":
|
| 371 |
+
try:
|
| 372 |
+
price = float(str(field.value).replace("$", "").replace(",", ""))
|
| 373 |
+
total = float(str(other.value).replace("$", "").replace(",", ""))
|
| 374 |
+
if price > total:
|
| 375 |
+
issues.append("Price exceeds total_price")
|
| 376 |
+
except (ValueError, TypeError):
|
| 377 |
+
pass
|
| 378 |
+
|
| 379 |
+
# Check against memory for historical consistency
|
| 380 |
+
memory = observation.memory_context
|
| 381 |
+
if memory.long_term_relevant:
|
| 382 |
+
for mem in memory.long_term_relevant:
|
| 383 |
+
if mem.get("field") == field.field_name:
|
| 384 |
+
historical_value = mem.get("value")
|
| 385 |
+
if historical_value and historical_value != field.value:
|
| 386 |
+
# Different from historical - flag for review
|
| 387 |
+
issues.append(
|
| 388 |
+
f"Value differs from historical: {historical_value}"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
return issues
|
| 392 |
+
|
| 393 |
+
def _infer_field_type(self, field_name: str) -> str:
|
| 394 |
+
"""Infer the field type from its name."""
|
| 395 |
+
field_lower = field_name.lower()
|
| 396 |
+
|
| 397 |
+
type_keywords = {
|
| 398 |
+
"email": ["email", "mail"],
|
| 399 |
+
"url": ["url", "link", "href", "website"],
|
| 400 |
+
"phone": ["phone", "tel", "mobile", "fax"],
|
| 401 |
+
"price": ["price", "cost", "amount", "total", "fee"],
|
| 402 |
+
"date": ["date", "time", "created", "updated", "published"],
|
| 403 |
+
"rating": ["rating", "score", "stars"],
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
for field_type, keywords in type_keywords.items():
|
| 407 |
+
for keyword in keywords:
|
| 408 |
+
if keyword in field_lower:
|
| 409 |
+
return field_type
|
| 410 |
+
|
| 411 |
+
return "text"
|
| 412 |
+
|
| 413 |
+
def _create_reverify_action(
|
| 414 |
+
self,
|
| 415 |
+
field: ExtractedField,
|
| 416 |
+
result: VerificationResult,
|
| 417 |
+
) -> Action:
|
| 418 |
+
"""Create an action to handle failed verification."""
|
| 419 |
+
if result.confidence < 0.3:
|
| 420 |
+
# Very low confidence - suggest re-extraction
|
| 421 |
+
return Action(
|
| 422 |
+
action_type=ActionType.EXTRACT_FIELD,
|
| 423 |
+
parameters={
|
| 424 |
+
"field_name": field.field_name,
|
| 425 |
+
"reason": "Re-extracting due to verification failure",
|
| 426 |
+
},
|
| 427 |
+
reasoning=f"Verification failed with issues: {result.issues}",
|
| 428 |
+
confidence=0.6,
|
| 429 |
+
agent_id=self.agent_id,
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
# Moderate confidence - try cross-validation
|
| 433 |
+
return Action(
|
| 434 |
+
action_type=ActionType.VERIFY_FACT,
|
| 435 |
+
parameters={
|
| 436 |
+
"claim": f"{field.field_name}: {field.value}",
|
| 437 |
+
"sources": None,
|
| 438 |
+
"confidence_threshold": self.min_confidence,
|
| 439 |
+
},
|
| 440 |
+
reasoning=f"Attempting cross-validation for {field.field_name}",
|
| 441 |
+
confidence=0.5,
|
| 442 |
+
agent_id=self.agent_id,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def add_validation_rule(
|
| 446 |
+
self,
|
| 447 |
+
field_type: str,
|
| 448 |
+
rule: dict[str, Any],
|
| 449 |
+
) -> None:
|
| 450 |
+
"""
|
| 451 |
+
Add a custom validation rule.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
field_type: The field type this rule applies to.
|
| 455 |
+
rule: The validation rule dictionary.
|
| 456 |
+
"""
|
| 457 |
+
if field_type not in self._validation_rules:
|
| 458 |
+
self._validation_rules[field_type] = []
|
| 459 |
+
self._validation_rules[field_type].append(rule)
|
| 460 |
+
|
| 461 |
+
def get_verification_history(self) -> list[dict[str, Any]]:
|
| 462 |
+
"""Get verification history as dictionaries."""
|
| 463 |
+
return [r.to_dict() for r in self._verification_history]
|
| 464 |
+
|
| 465 |
+
def reset(self) -> None:
|
| 466 |
+
"""Reset the verifier state."""
|
| 467 |
+
super().reset()
|
| 468 |
+
self._verification_history.clear()
|