NeerajCodz commited on
Commit
3bfb250
·
1 Parent(s): ab65628

feat: implement multi-agent system with coordinator

Browse files
backend/app/agents/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agents module for ScrapeRL.
3
+
4
+ This module contains specialized agents for web scraping with RL:
5
+ - BaseAgent: Abstract base class for all agents
6
+ - PlannerAgent: Goal decomposition and task planning
7
+ - NavigatorAgent: URL prioritization and page navigation
8
+ - ExtractorAgent: Data extraction with selectors
9
+ - VerifierAgent: Cross-source verification
10
+ - MemoryAgent: Memory operations and knowledge management
11
+ - AgentCoordinator: Orchestrates multiple agents with message passing
12
+ """
13
+
14
+ from .base import BaseAgent
15
+ from .coordinator import AgentCoordinator, AgentRole, Message
16
+ from .extractor import ExtractorAgent
17
+ from .memory_agent import MemoryAgent, MemoryEntry
18
+ from .navigator import NavigatorAgent
19
+ from .planner import PlannerAgent
20
+ from .verifier import VerificationResult, VerifierAgent
21
+
22
+ __all__ = [
23
+ # Base
24
+ "BaseAgent",
25
+ # Agents
26
+ "PlannerAgent",
27
+ "NavigatorAgent",
28
+ "ExtractorAgent",
29
+ "VerifierAgent",
30
+ "MemoryAgent",
31
+ # Coordinator
32
+ "AgentCoordinator",
33
+ "AgentRole",
34
+ "Message",
35
+ # Data classes
36
+ "VerificationResult",
37
+ "MemoryEntry",
38
+ ]
backend/app/agents/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (1.16 kB). View file
 
backend/app/agents/__pycache__/base.cpython-314.pyc ADDED
Binary file (6.75 kB). View file
 
backend/app/agents/__pycache__/coordinator.cpython-314.pyc ADDED
Binary file (19.6 kB). View file
 
backend/app/agents/__pycache__/extractor.cpython-314.pyc ADDED
Binary file (18 kB). View file
 
backend/app/agents/__pycache__/memory_agent.cpython-314.pyc ADDED
Binary file (20.6 kB). View file
 
backend/app/agents/__pycache__/navigator.cpython-314.pyc ADDED
Binary file (16.8 kB). View file
 
backend/app/agents/__pycache__/planner.cpython-314.pyc ADDED
Binary file (11.5 kB). View file
 
backend/app/agents/__pycache__/verifier.cpython-314.pyc ADDED
Binary file (19.2 kB). View file
 
backend/app/agents/base.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base agent abstract class for ScrapeRL agents."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ from app.core.action import Action
7
+ from app.core.observation import Observation
8
+
9
+
10
+ class BaseAgent(ABC):
11
+ """
12
+ Abstract base class for all agents in the ScrapeRL system.
13
+
14
+ Each agent specializes in a specific aspect of the scraping workflow:
15
+ - Planning and goal decomposition
16
+ - Navigation and URL prioritization
17
+ - Data extraction
18
+ - Verification and validation
19
+ - Memory operations
20
+
21
+ Agents communicate through message passing and coordinate via
22
+ the AgentCoordinator.
23
+ """
24
+
25
+ def __init__(self, agent_id: str, config: dict[str, Any] | None = None):
26
+ """
27
+ Initialize the agent.
28
+
29
+ Args:
30
+ agent_id: Unique identifier for this agent instance.
31
+ config: Optional configuration dictionary for the agent.
32
+ """
33
+ self.agent_id = agent_id
34
+ self.config = config or {}
35
+ self._message_queue: list[dict[str, Any]] = []
36
+ self._action_history: list[Action] = []
37
+
38
+ @abstractmethod
39
+ async def act(self, observation: Observation) -> Action:
40
+ """
41
+ Select an action based on the current observation.
42
+
43
+ This is the main decision-making method. The agent analyzes
44
+ the observation and returns the best action to take.
45
+
46
+ Args:
47
+ observation: The current state observation from the environment.
48
+
49
+ Returns:
50
+ The action to execute.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ async def plan(self, observation: Observation) -> list[Action]:
56
+ """
57
+ Create a plan of actions based on the current observation.
58
+
59
+ Unlike act() which returns a single action, plan() creates
60
+ a sequence of actions to achieve a goal.
61
+
62
+ Args:
63
+ observation: The current state observation from the environment.
64
+
65
+ Returns:
66
+ A list of planned actions in execution order.
67
+ """
68
+ pass
69
+
70
+ async def explain(self, action: Action) -> str:
71
+ """
72
+ Explain why this action was chosen.
73
+
74
+ Args:
75
+ action: The action to explain.
76
+
77
+ Returns:
78
+ A human-readable explanation of the action choice.
79
+ """
80
+ return action.reasoning or "No explanation provided"
81
+
82
+ def receive_message(self, message: dict[str, Any]) -> None:
83
+ """
84
+ Receive a message from another agent.
85
+
86
+ Args:
87
+ message: The message dictionary containing sender, type, and content.
88
+ """
89
+ self._message_queue.append(message)
90
+
91
+ def get_pending_messages(self) -> list[dict[str, Any]]:
92
+ """
93
+ Get all pending messages and clear the queue.
94
+
95
+ Returns:
96
+ List of pending messages.
97
+ """
98
+ messages = self._message_queue.copy()
99
+ self._message_queue.clear()
100
+ return messages
101
+
102
+ def record_action(self, action: Action) -> None:
103
+ """
104
+ Record an action in the agent's history.
105
+
106
+ Args:
107
+ action: The action that was executed.
108
+ """
109
+ self._action_history.append(action)
110
+
111
+ def get_action_history(self) -> list[Action]:
112
+ """
113
+ Get the history of actions taken by this agent.
114
+
115
+ Returns:
116
+ List of past actions.
117
+ """
118
+ return self._action_history.copy()
119
+
120
+ def reset(self) -> None:
121
+ """Reset the agent state for a new episode."""
122
+ self._message_queue.clear()
123
+ self._action_history.clear()
124
+
125
+ def __repr__(self) -> str:
126
+ """String representation of the agent."""
127
+ return f"{self.__class__.__name__}(agent_id={self.agent_id!r})"
backend/app/agents/coordinator.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent coordinator for orchestrating multiple agents with message passing."""
2
+
3
+ import asyncio
4
+ from datetime import datetime
5
+ from enum import Enum
6
+ from typing import Any
7
+
8
+ from app.core.action import Action, ActionType
9
+ from app.core.observation import Observation
10
+
11
+ from .base import BaseAgent
12
+ from .extractor import ExtractorAgent
13
+ from .memory_agent import MemoryAgent
14
+ from .navigator import NavigatorAgent
15
+ from .planner import PlannerAgent
16
+ from .verifier import VerifierAgent
17
+
18
+
19
+ class AgentRole(str, Enum):
20
+ """Roles that agents can fulfill."""
21
+
22
+ PLANNER = "planner"
23
+ NAVIGATOR = "navigator"
24
+ EXTRACTOR = "extractor"
25
+ VERIFIER = "verifier"
26
+ MEMORY = "memory"
27
+
28
+
29
+ class Message:
30
+ """A message between agents."""
31
+
32
+ def __init__(
33
+ self,
34
+ sender: str,
35
+ recipient: str,
36
+ message_type: str,
37
+ content: dict[str, Any],
38
+ priority: int = 0,
39
+ ):
40
+ """Initialize a message."""
41
+ self.sender = sender
42
+ self.recipient = recipient
43
+ self.message_type = message_type
44
+ self.content = content
45
+ self.priority = priority
46
+ self.timestamp = datetime.utcnow()
47
+
48
+ def to_dict(self) -> dict[str, Any]:
49
+ """Convert to dictionary."""
50
+ return {
51
+ "sender": self.sender,
52
+ "recipient": self.recipient,
53
+ "message_type": self.message_type,
54
+ "content": self.content,
55
+ "priority": self.priority,
56
+ "timestamp": self.timestamp.isoformat(),
57
+ }
58
+
59
+
60
+ class AgentCoordinator:
61
+ """
62
+ Orchestrator for multiple specialized agents.
63
+
64
+ The AgentCoordinator manages:
65
+ - Agent lifecycle and initialization
66
+ - Message passing between agents
67
+ - Action selection and routing
68
+ - Coordination of multi-agent workflows
69
+ - Error handling and recovery
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ config: dict[str, Any] | None = None,
75
+ ):
76
+ """
77
+ Initialize the AgentCoordinator.
78
+
79
+ Args:
80
+ config: Optional configuration with keys:
81
+ - enable_parallel: Allow parallel agent execution (default: False)
82
+ - max_messages_per_step: Max messages per step (default: 10)
83
+ - default_timeout: Default timeout for agent actions (default: 30)
84
+ """
85
+ self.config = config or {}
86
+ self.enable_parallel = self.config.get("enable_parallel", False)
87
+ self.max_messages_per_step = self.config.get("max_messages_per_step", 10)
88
+ self.default_timeout = self.config.get("default_timeout", 30)
89
+
90
+ # Initialize agents
91
+ self._agents: dict[str, BaseAgent] = {}
92
+ self._message_queue: list[Message] = []
93
+ self._action_history: list[tuple[str, Action]] = []
94
+ self._current_lead: str | None = None
95
+
96
+ # Initialize default agents
97
+ self._initialize_default_agents()
98
+
99
+ def _initialize_default_agents(self) -> None:
100
+ """Initialize the default set of agents."""
101
+ self._agents = {
102
+ AgentRole.PLANNER: PlannerAgent(
103
+ agent_id="planner",
104
+ config=self.config.get("planner_config"),
105
+ ),
106
+ AgentRole.NAVIGATOR: NavigatorAgent(
107
+ agent_id="navigator",
108
+ config=self.config.get("navigator_config"),
109
+ ),
110
+ AgentRole.EXTRACTOR: ExtractorAgent(
111
+ agent_id="extractor",
112
+ config=self.config.get("extractor_config"),
113
+ ),
114
+ AgentRole.VERIFIER: VerifierAgent(
115
+ agent_id="verifier",
116
+ config=self.config.get("verifier_config"),
117
+ ),
118
+ AgentRole.MEMORY: MemoryAgent(
119
+ agent_id="memory",
120
+ config=self.config.get("memory_config"),
121
+ ),
122
+ }
123
+
124
+ def register_agent(self, role: str, agent: BaseAgent) -> None:
125
+ """
126
+ Register an agent for a specific role.
127
+
128
+ Args:
129
+ role: The role this agent fulfills.
130
+ agent: The agent instance.
131
+ """
132
+ self._agents[role] = agent
133
+
134
+ def get_agent(self, role: str) -> BaseAgent | None:
135
+ """
136
+ Get an agent by role.
137
+
138
+ Args:
139
+ role: The role to look up.
140
+
141
+ Returns:
142
+ The agent if found, None otherwise.
143
+ """
144
+ return self._agents.get(role)
145
+
146
+ async def step(self, observation: Observation) -> Action:
147
+ """
148
+ Perform one coordination step.
149
+
150
+ Determines which agent should act, processes messages,
151
+ and returns the selected action.
152
+
153
+ Args:
154
+ observation: The current state observation.
155
+
156
+ Returns:
157
+ The action to execute.
158
+ """
159
+ try:
160
+ # Process pending messages
161
+ await self._process_messages()
162
+
163
+ # Determine lead agent based on state
164
+ lead_role = self._determine_lead_agent(observation)
165
+ self._current_lead = lead_role
166
+
167
+ # Get action from lead agent
168
+ lead_agent = self._agents.get(lead_role)
169
+ if not lead_agent:
170
+ return self._create_error_action(f"No agent for role: {lead_role}")
171
+
172
+ # Get action from the lead agent
173
+ action = await lead_agent.act(observation)
174
+ action.agent_id = lead_agent.agent_id
175
+
176
+ # Record action
177
+ self._action_history.append((lead_role, action))
178
+ lead_agent.record_action(action)
179
+
180
+ # Handle inter-agent communication actions
181
+ if action.action_type == ActionType.SEND_MESSAGE:
182
+ self._handle_send_message(action)
183
+
184
+ return action
185
+
186
+ except Exception as e:
187
+ return self._create_error_action(f"Coordination error: {e}")
188
+
189
+ async def plan(self, observation: Observation) -> list[Action]:
190
+ """
191
+ Create a coordinated plan using multiple agents.
192
+
193
+ The planner agent creates the high-level plan, which is then
194
+ refined by other agents.
195
+
196
+ Args:
197
+ observation: The current state observation.
198
+
199
+ Returns:
200
+ A coordinated list of actions.
201
+ """
202
+ try:
203
+ # Get plan from planner
204
+ planner = self._agents.get(AgentRole.PLANNER)
205
+ if not planner:
206
+ return []
207
+
208
+ plan = await planner.plan(observation)
209
+
210
+ # Refine with navigator for navigation steps
211
+ navigator = self._agents.get(AgentRole.NAVIGATOR)
212
+ if navigator:
213
+ nav_plan = await navigator.plan(observation)
214
+ # Insert navigation at the beginning if needed
215
+ if nav_plan and not observation.current_url:
216
+ plan = nav_plan + plan
217
+
218
+ return plan
219
+
220
+ except Exception as e:
221
+ return [self._create_error_action(f"Planning error: {e}")]
222
+
223
+ def send_message(
224
+ self,
225
+ sender: str,
226
+ recipient: str,
227
+ message_type: str,
228
+ content: dict[str, Any],
229
+ priority: int = 0,
230
+ ) -> None:
231
+ """
232
+ Send a message between agents.
233
+
234
+ Args:
235
+ sender: ID of the sending agent.
236
+ recipient: ID of the receiving agent.
237
+ message_type: Type of the message.
238
+ content: Message content.
239
+ priority: Message priority (higher = more urgent).
240
+ """
241
+ message = Message(
242
+ sender=sender,
243
+ recipient=recipient,
244
+ message_type=message_type,
245
+ content=content,
246
+ priority=priority,
247
+ )
248
+ self._message_queue.append(message)
249
+
250
+ async def _process_messages(self) -> None:
251
+ """Process queued messages and deliver to agents."""
252
+ # Sort by priority (highest first)
253
+ self._message_queue.sort(key=lambda m: -m.priority)
254
+
255
+ # Process up to max messages
256
+ messages_processed = 0
257
+ while self._message_queue and messages_processed < self.max_messages_per_step:
258
+ message = self._message_queue.pop(0)
259
+
260
+ # Find recipient agent
261
+ recipient = None
262
+ for role, agent in self._agents.items():
263
+ if agent.agent_id == message.recipient or role == message.recipient:
264
+ recipient = agent
265
+ break
266
+
267
+ if recipient:
268
+ recipient.receive_message(message.to_dict())
269
+ messages_processed += 1
270
+
271
+ def _determine_lead_agent(self, observation: Observation) -> str:
272
+ """
273
+ Determine which agent should lead based on state.
274
+
275
+ Args:
276
+ observation: Current observation.
277
+
278
+ Returns:
279
+ The role of the agent that should lead.
280
+ """
281
+ # If no URL, navigator should lead
282
+ if not observation.current_url:
283
+ return AgentRole.NAVIGATOR
284
+
285
+ # If there are unverified fields, verifier should lead
286
+ unverified = [f for f in observation.extracted_so_far if not f.verified]
287
+ if unverified and observation.extraction_progress > 0.5:
288
+ return AgentRole.VERIFIER
289
+
290
+ # If there are remaining fields to extract, extractor should lead
291
+ if observation.fields_remaining:
292
+ return AgentRole.EXTRACTOR
293
+
294
+ # If we have errors, planner should re-plan
295
+ if observation.consecutive_errors > 0:
296
+ return AgentRole.PLANNER
297
+
298
+ # Default to planner
299
+ return AgentRole.PLANNER
300
+
301
+ def _handle_send_message(self, action: Action) -> None:
302
+ """Handle a send_message action from an agent."""
303
+ params = action.parameters
304
+ self.send_message(
305
+ sender=action.agent_id or "unknown",
306
+ recipient=params.get("target_agent", ""),
307
+ message_type=params.get("message_type", "generic"),
308
+ content=params.get("content", {}),
309
+ )
310
+
311
+ def _create_error_action(self, error: str) -> Action:
312
+ """Create a fail action for errors."""
313
+ return Action(
314
+ action_type=ActionType.FAIL,
315
+ parameters={"success": False, "message": error},
316
+ reasoning=error,
317
+ confidence=1.0,
318
+ agent_id="coordinator",
319
+ )
320
+
321
+ async def run_parallel_agents(
322
+ self,
323
+ observation: Observation,
324
+ roles: list[str],
325
+ ) -> dict[str, Action]:
326
+ """
327
+ Run multiple agents in parallel.
328
+
329
+ Args:
330
+ observation: Current observation.
331
+ roles: List of agent roles to run.
332
+
333
+ Returns:
334
+ Dictionary mapping role to action.
335
+ """
336
+ if not self.enable_parallel:
337
+ # Fallback to sequential
338
+ results = {}
339
+ for role in roles:
340
+ agent = self._agents.get(role)
341
+ if agent:
342
+ results[role] = await agent.act(observation)
343
+ return results
344
+
345
+ # Run agents in parallel
346
+ async def run_agent(role: str) -> tuple[str, Action]:
347
+ agent = self._agents.get(role)
348
+ if agent:
349
+ action = await agent.act(observation)
350
+ return (role, action)
351
+ return (role, self._create_error_action(f"No agent for role: {role}"))
352
+
353
+ tasks = [run_agent(role) for role in roles]
354
+ results = await asyncio.gather(*tasks)
355
+
356
+ return dict(results)
357
+
358
+ def get_action_history(self) -> list[tuple[str, Action]]:
359
+ """Get the history of actions with their agent roles."""
360
+ return self._action_history.copy()
361
+
362
+ def get_current_lead(self) -> str | None:
363
+ """Get the current lead agent role."""
364
+ return self._current_lead
365
+
366
+ def get_message_queue_length(self) -> int:
367
+ """Get the number of pending messages."""
368
+ return len(self._message_queue)
369
+
370
+ def reset(self) -> None:
371
+ """Reset all agents and coordinator state."""
372
+ for agent in self._agents.values():
373
+ agent.reset()
374
+
375
+ self._message_queue.clear()
376
+ self._action_history.clear()
377
+ self._current_lead = None
378
+
379
+ def get_stats(self) -> dict[str, Any]:
380
+ """Get coordinator statistics."""
381
+ return {
382
+ "agents": list(self._agents.keys()),
383
+ "current_lead": self._current_lead,
384
+ "pending_messages": len(self._message_queue),
385
+ "action_count": len(self._action_history),
386
+ "enable_parallel": self.enable_parallel,
387
+ }
backend/app/agents/extractor.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extractor agent for data extraction with selectors."""
2
+
3
+ import re
4
+ from typing import Any
5
+
6
+ from app.core.action import Action, ActionType
7
+ from app.core.observation import Observation, PageElement
8
+
9
+ from .base import BaseAgent
10
+
11
+
12
+ class ExtractorAgent(BaseAgent):
13
+ """
14
+ Agent responsible for extracting structured data from pages.
15
+
16
+ The ExtractorAgent handles:
17
+ - Identifying data elements using CSS/XPath selectors
18
+ - Extracting text, attributes, and structured content
19
+ - Handling tables and lists
20
+ - Post-processing extracted values
21
+ - Confidence scoring for extractions
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ agent_id: str = "extractor",
27
+ config: dict[str, Any] | None = None,
28
+ ):
29
+ """
30
+ Initialize the ExtractorAgent.
31
+
32
+ Args:
33
+ agent_id: Unique identifier for this agent.
34
+ config: Optional configuration with keys:
35
+ - min_confidence: Minimum confidence to accept extraction
36
+ - extraction_timeout: Timeout for extraction operations
37
+ - enable_fuzzy_matching: Enable fuzzy text matching
38
+ """
39
+ super().__init__(agent_id, config)
40
+ self.min_confidence = self.config.get("min_confidence", 0.5)
41
+ self.extraction_timeout = self.config.get("extraction_timeout", 5000)
42
+ self.enable_fuzzy_matching = self.config.get("enable_fuzzy_matching", True)
43
+ self._extraction_cache: dict[str, Any] = {}
44
+ self._selector_patterns: dict[str, list[str]] = self._init_selector_patterns()
45
+
46
+ def _init_selector_patterns(self) -> dict[str, list[str]]:
47
+ """Initialize common selector patterns for different field types."""
48
+ return {
49
+ "price": [
50
+ "[class*='price']",
51
+ "[id*='price']",
52
+ "[itemprop='price']",
53
+ ".product-price",
54
+ ".item-price",
55
+ "span[data-price]",
56
+ ],
57
+ "title": [
58
+ "h1",
59
+ "[class*='title']",
60
+ "[itemprop='name']",
61
+ ".product-title",
62
+ ".item-title",
63
+ ],
64
+ "description": [
65
+ "[class*='description']",
66
+ "[itemprop='description']",
67
+ ".product-description",
68
+ "article p",
69
+ ".content p",
70
+ ],
71
+ "image": [
72
+ "[class*='product-image'] img",
73
+ "[itemprop='image']",
74
+ ".main-image img",
75
+ "figure img",
76
+ ],
77
+ "date": [
78
+ "time",
79
+ "[datetime]",
80
+ "[class*='date']",
81
+ "[itemprop='datePublished']",
82
+ ],
83
+ "author": [
84
+ "[class*='author']",
85
+ "[itemprop='author']",
86
+ "[rel='author']",
87
+ ".byline",
88
+ ],
89
+ }
90
+
91
+ async def act(self, observation: Observation) -> Action:
92
+ """
93
+ Select the best extraction action based on observation.
94
+
95
+ Analyzes the page and decides what data to extract next.
96
+
97
+ Args:
98
+ observation: The current state observation.
99
+
100
+ Returns:
101
+ The extraction action to execute.
102
+ """
103
+ try:
104
+ # Get remaining fields to extract
105
+ remaining_fields = observation.fields_remaining
106
+
107
+ if not remaining_fields:
108
+ return Action(
109
+ action_type=ActionType.DONE,
110
+ parameters={"success": True, "message": "All fields extracted"},
111
+ reasoning="No more fields to extract",
112
+ confidence=1.0,
113
+ agent_id=self.agent_id,
114
+ )
115
+
116
+ # Pick the next field to extract
117
+ field_name = remaining_fields[0]
118
+
119
+ # Find best selector for the field
120
+ selector, confidence = await self._find_selector_for_field(
121
+ field_name,
122
+ observation,
123
+ )
124
+
125
+ if selector and confidence >= self.min_confidence:
126
+ return self._create_extraction_action(
127
+ field_name,
128
+ selector,
129
+ confidence,
130
+ )
131
+
132
+ # Try alternative extraction methods
133
+ alt_action = await self._try_alternative_extraction(
134
+ field_name,
135
+ observation,
136
+ )
137
+ if alt_action:
138
+ return alt_action
139
+
140
+ # Cannot extract this field
141
+ return Action(
142
+ action_type=ActionType.EXTRACT_FIELD,
143
+ parameters={
144
+ "field_name": field_name,
145
+ "selector": None,
146
+ "extraction_method": "llm",
147
+ },
148
+ reasoning=f"No selector found, using LLM extraction for {field_name}",
149
+ confidence=0.4,
150
+ agent_id=self.agent_id,
151
+ )
152
+
153
+ except Exception as e:
154
+ return Action(
155
+ action_type=ActionType.FAIL,
156
+ parameters={"success": False, "message": str(e)},
157
+ reasoning=f"Extraction error: {e}",
158
+ confidence=1.0,
159
+ agent_id=self.agent_id,
160
+ )
161
+
162
+ async def plan(self, observation: Observation) -> list[Action]:
163
+ """
164
+ Create an extraction plan for all remaining fields.
165
+
166
+ Analyzes the page structure and plans the optimal
167
+ extraction sequence.
168
+
169
+ Args:
170
+ observation: The current state observation.
171
+
172
+ Returns:
173
+ A list of planned extraction actions.
174
+ """
175
+ try:
176
+ actions: list[Action] = []
177
+ remaining_fields = observation.fields_remaining
178
+
179
+ for field_name in remaining_fields:
180
+ selector, confidence = await self._find_selector_for_field(
181
+ field_name,
182
+ observation,
183
+ )
184
+
185
+ if selector:
186
+ actions.append(
187
+ self._create_extraction_action(
188
+ field_name,
189
+ selector,
190
+ confidence,
191
+ )
192
+ )
193
+ else:
194
+ # Plan LLM-based extraction as fallback
195
+ actions.append(
196
+ Action(
197
+ action_type=ActionType.EXTRACT_FIELD,
198
+ parameters={
199
+ "field_name": field_name,
200
+ "extraction_method": "llm",
201
+ },
202
+ reasoning=f"Planning LLM extraction for {field_name}",
203
+ confidence=0.5,
204
+ agent_id=self.agent_id,
205
+ )
206
+ )
207
+
208
+ return actions
209
+
210
+ except Exception as e:
211
+ return [
212
+ Action(
213
+ action_type=ActionType.FAIL,
214
+ parameters={"message": f"Extraction planning failed: {e}"},
215
+ reasoning=str(e),
216
+ confidence=1.0,
217
+ agent_id=self.agent_id,
218
+ )
219
+ ]
220
+
221
+ async def _find_selector_for_field(
222
+ self,
223
+ field_name: str,
224
+ observation: Observation,
225
+ ) -> tuple[str | None, float]:
226
+ """
227
+ Find the best selector for a field.
228
+
229
+ Args:
230
+ field_name: Name of the field to extract.
231
+ observation: Current observation.
232
+
233
+ Returns:
234
+ Tuple of (selector, confidence).
235
+ """
236
+ best_selector: str | None = None
237
+ best_confidence = 0.0
238
+
239
+ # Check predefined patterns first
240
+ patterns = self._get_patterns_for_field(field_name)
241
+ for pattern in patterns:
242
+ element = self._find_element_by_selector(
243
+ pattern,
244
+ observation.page_elements,
245
+ )
246
+ if element:
247
+ confidence = self._calculate_confidence(element, field_name)
248
+ if confidence > best_confidence:
249
+ best_selector = element.selector
250
+ best_confidence = confidence
251
+
252
+ # Search by text content if fuzzy matching enabled
253
+ if self.enable_fuzzy_matching and best_confidence < 0.7:
254
+ element, confidence = self._find_element_by_text(
255
+ field_name,
256
+ observation.page_elements,
257
+ )
258
+ if element and confidence > best_confidence:
259
+ best_selector = element.selector
260
+ best_confidence = confidence
261
+
262
+ return best_selector, best_confidence
263
+
264
+ def _get_patterns_for_field(self, field_name: str) -> list[str]:
265
+ """Get selector patterns for a field type."""
266
+ field_lower = field_name.lower()
267
+
268
+ # Direct match
269
+ if field_lower in self._selector_patterns:
270
+ return self._selector_patterns[field_lower]
271
+
272
+ # Partial match
273
+ for key, patterns in self._selector_patterns.items():
274
+ if key in field_lower or field_lower in key:
275
+ return patterns
276
+
277
+ # Generate generic patterns
278
+ return [
279
+ f"[class*='{field_lower}']",
280
+ f"[id*='{field_lower}']",
281
+ f"[data-{field_lower}]",
282
+ f".{field_lower}",
283
+ f"#{field_lower}",
284
+ ]
285
+
286
+ def _find_element_by_selector(
287
+ self,
288
+ selector: str,
289
+ elements: list[PageElement],
290
+ ) -> PageElement | None:
291
+ """Find an element matching a selector pattern."""
292
+ selector_lower = selector.lower()
293
+
294
+ for element in elements:
295
+ element_selector = element.selector.lower()
296
+ if selector_lower in element_selector:
297
+ return element
298
+
299
+ # Check class and id attributes
300
+ classes = element.attributes.get("class", "").lower()
301
+ element_id = element.attributes.get("id", "").lower()
302
+
303
+ if selector_lower.strip(".[#]") in classes:
304
+ return element
305
+ if selector_lower.strip(".[#]") in element_id:
306
+ return element
307
+
308
+ return None
309
+
310
+ def _find_element_by_text(
311
+ self,
312
+ field_name: str,
313
+ elements: list[PageElement],
314
+ ) -> tuple[PageElement | None, float]:
315
+ """Find an element by text content matching."""
316
+ field_lower = field_name.lower().replace("_", " ")
317
+ best_element: PageElement | None = None
318
+ best_score = 0.0
319
+
320
+ for element in elements:
321
+ if not element.text:
322
+ continue
323
+
324
+ text_lower = element.text.lower()
325
+
326
+ # Check for label-like patterns
327
+ if f"{field_lower}:" in text_lower or f"{field_lower} :" in text_lower:
328
+ score = 0.9
329
+ elif field_lower in text_lower:
330
+ # Calculate similarity score
331
+ score = len(field_lower) / max(len(text_lower), 1) * 0.8
332
+ else:
333
+ continue
334
+
335
+ if score > best_score:
336
+ best_element = element
337
+ best_score = score
338
+
339
+ return best_element, best_score
340
+
341
+ def _calculate_confidence(self, element: PageElement, field_name: str) -> float:
342
+ """Calculate extraction confidence for an element."""
343
+ confidence = 0.5
344
+
345
+ # Boost for visible elements
346
+ if element.is_visible:
347
+ confidence += 0.1
348
+
349
+ # Boost for semantic attributes
350
+ if element.attributes.get("itemprop"):
351
+ confidence += 0.2
352
+ if element.attributes.get("data-field"):
353
+ confidence += 0.15
354
+
355
+ # Boost if text contains field name
356
+ if element.text and field_name.lower() in element.text.lower():
357
+ confidence += 0.1
358
+
359
+ # Penalty for very long text (likely not a single field)
360
+ if element.text and len(element.text) > 500:
361
+ confidence -= 0.2
362
+
363
+ return min(1.0, max(0.0, confidence))
364
+
365
+ async def _try_alternative_extraction(
366
+ self,
367
+ field_name: str,
368
+ observation: Observation,
369
+ ) -> Action | None:
370
+ """Try alternative extraction methods."""
371
+ # Check for table data
372
+ for element in observation.page_elements:
373
+ if element.tag in ("table", "tbody"):
374
+ return Action(
375
+ action_type=ActionType.EXTRACT_TABLE,
376
+ parameters={
377
+ "table_selector": element.selector,
378
+ "target_field": field_name,
379
+ },
380
+ reasoning=f"Extracting {field_name} from table",
381
+ confidence=0.6,
382
+ agent_id=self.agent_id,
383
+ )
384
+
385
+ # Check for list data
386
+ for element in observation.page_elements:
387
+ if element.tag in ("ul", "ol", "dl"):
388
+ return Action(
389
+ action_type=ActionType.EXTRACT_LIST,
390
+ parameters={
391
+ "container_selector": element.selector,
392
+ "item_selector": "li",
393
+ "field_selectors": {field_name: "text"},
394
+ },
395
+ reasoning=f"Extracting {field_name} from list",
396
+ confidence=0.55,
397
+ agent_id=self.agent_id,
398
+ )
399
+
400
+ return None
401
+
402
+ def _create_extraction_action(
403
+ self,
404
+ field_name: str,
405
+ selector: str,
406
+ confidence: float,
407
+ ) -> Action:
408
+ """Create an extraction action."""
409
+ return Action(
410
+ action_type=ActionType.EXTRACT_FIELD,
411
+ parameters={
412
+ "field_name": field_name,
413
+ "selector": selector,
414
+ "extraction_method": "text",
415
+ },
416
+ reasoning=f"Extracting {field_name} using selector: {selector}",
417
+ confidence=confidence,
418
+ agent_id=self.agent_id,
419
+ )
420
+
421
+ def extract_with_regex(
422
+ self,
423
+ text: str,
424
+ pattern: str,
425
+ group: int = 0,
426
+ ) -> str | None:
427
+ """
428
+ Extract text using a regex pattern.
429
+
430
+ Args:
431
+ text: The text to search in.
432
+ pattern: Regex pattern.
433
+ group: Capture group to return.
434
+
435
+ Returns:
436
+ Extracted text or None.
437
+ """
438
+ try:
439
+ match = re.search(pattern, text)
440
+ if match:
441
+ return match.group(group)
442
+ return None
443
+ except re.error:
444
+ return None
445
+
446
+ def post_process_value(
447
+ self,
448
+ value: Any,
449
+ field_name: str,
450
+ ) -> Any:
451
+ """
452
+ Post-process an extracted value based on field type.
453
+
454
+ Args:
455
+ value: The raw extracted value.
456
+ field_name: Name of the field (used to infer type).
457
+
458
+ Returns:
459
+ Processed value.
460
+ """
461
+ if value is None:
462
+ return None
463
+
464
+ value_str = str(value).strip()
465
+ field_lower = field_name.lower()
466
+
467
+ # Price processing
468
+ if "price" in field_lower:
469
+ # Remove currency symbols but keep numbers and decimal
470
+ price_match = re.search(r"[\d,]+\.?\d*", value_str.replace(",", ""))
471
+ if price_match:
472
+ return float(price_match.group().replace(",", ""))
473
+
474
+ # Date processing
475
+ if "date" in field_lower:
476
+ return value_str # Return as-is, let caller parse
477
+
478
+ # Number processing
479
+ if any(x in field_lower for x in ["count", "quantity", "number"]):
480
+ num_match = re.search(r"\d+", value_str)
481
+ if num_match:
482
+ return int(num_match.group())
483
+
484
+ return value_str
485
+
486
+ def reset(self) -> None:
487
+ """Reset the extractor state."""
488
+ super().reset()
489
+ self._extraction_cache.clear()
backend/app/agents/memory_agent.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Memory agent for memory operations and knowledge management."""
2
+
3
+ from datetime import datetime
4
+ from typing import Any
5
+
6
+ from app.core.action import Action, ActionType
7
+ from app.core.observation import Observation
8
+
9
+ from .base import BaseAgent
10
+
11
+
12
+ class MemoryEntry:
13
+ """A single memory entry."""
14
+
15
+ def __init__(
16
+ self,
17
+ key: str,
18
+ value: Any,
19
+ memory_type: str = "working",
20
+ ttl_seconds: int | None = None,
21
+ metadata: dict[str, Any] | None = None,
22
+ ):
23
+ """Initialize memory entry."""
24
+ self.key = key
25
+ self.value = value
26
+ self.memory_type = memory_type
27
+ self.ttl_seconds = ttl_seconds
28
+ self.metadata = metadata or {}
29
+ self.created_at = datetime.utcnow()
30
+ self.accessed_at = datetime.utcnow()
31
+ self.access_count = 0
32
+
33
+ def is_expired(self) -> bool:
34
+ """Check if the memory entry has expired."""
35
+ if self.ttl_seconds is None:
36
+ return False
37
+ elapsed = (datetime.utcnow() - self.created_at).total_seconds()
38
+ return elapsed > self.ttl_seconds
39
+
40
+ def access(self) -> Any:
41
+ """Access the memory and update metadata."""
42
+ self.accessed_at = datetime.utcnow()
43
+ self.access_count += 1
44
+ return self.value
45
+
46
+ def to_dict(self) -> dict[str, Any]:
47
+ """Convert to dictionary."""
48
+ return {
49
+ "key": self.key,
50
+ "value": self.value,
51
+ "memory_type": self.memory_type,
52
+ "ttl_seconds": self.ttl_seconds,
53
+ "metadata": self.metadata,
54
+ "created_at": self.created_at.isoformat(),
55
+ "accessed_at": self.accessed_at.isoformat(),
56
+ "access_count": self.access_count,
57
+ }
58
+
59
+
60
+ class MemoryAgent(BaseAgent):
61
+ """
62
+ Agent responsible for memory operations and knowledge management.
63
+
64
+ The MemoryAgent handles:
65
+ - Storing and retrieving memories across different layers
66
+ - Managing short-term, working, and long-term memory
67
+ - Memory consolidation and cleanup
68
+ - Relevance-based memory retrieval
69
+ - Sharing knowledge between episodes
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ agent_id: str = "memory",
75
+ config: dict[str, Any] | None = None,
76
+ ):
77
+ """
78
+ Initialize the MemoryAgent.
79
+
80
+ Args:
81
+ agent_id: Unique identifier for this agent.
82
+ config: Optional configuration with keys:
83
+ - max_short_term: Max short-term memory entries (default: 100)
84
+ - max_working: Max working memory entries (default: 50)
85
+ - consolidation_threshold: Accesses before long-term (default: 3)
86
+ - enable_auto_cleanup: Auto cleanup expired entries (default: True)
87
+ """
88
+ super().__init__(agent_id, config)
89
+ self.max_short_term = self.config.get("max_short_term", 100)
90
+ self.max_working = self.config.get("max_working", 50)
91
+ self.consolidation_threshold = self.config.get("consolidation_threshold", 3)
92
+ self.enable_auto_cleanup = self.config.get("enable_auto_cleanup", True)
93
+
94
+ # Memory stores
95
+ self._short_term: dict[str, MemoryEntry] = {}
96
+ self._working: dict[str, MemoryEntry] = {}
97
+ self._pending_operations: list[dict[str, Any]] = []
98
+
99
+ async def act(self, observation: Observation) -> Action:
100
+ """
101
+ Select the best memory action based on observation.
102
+
103
+ Analyzes the current state and determines if any memory
104
+ operations are needed.
105
+
106
+ Args:
107
+ observation: The current state observation.
108
+
109
+ Returns:
110
+ The memory action to execute.
111
+ """
112
+ try:
113
+ # Process any pending messages requesting memory operations
114
+ messages = self.get_pending_messages()
115
+ for msg in messages:
116
+ if msg.get("message_type") == "memory_request":
117
+ return self._process_memory_request(msg)
118
+
119
+ # Auto cleanup if enabled
120
+ if self.enable_auto_cleanup:
121
+ self._cleanup_expired()
122
+
123
+ # Check if we should store new information
124
+ store_action = self._check_for_storage(observation)
125
+ if store_action:
126
+ return store_action
127
+
128
+ # Check if any memories need consolidation
129
+ consolidation_action = self._check_for_consolidation()
130
+ if consolidation_action:
131
+ return consolidation_action
132
+
133
+ # No memory operations needed
134
+ return Action(
135
+ action_type=ActionType.WAIT,
136
+ parameters={"duration_ms": 100},
137
+ reasoning="No memory operations required",
138
+ confidence=1.0,
139
+ agent_id=self.agent_id,
140
+ )
141
+
142
+ except Exception as e:
143
+ return Action(
144
+ action_type=ActionType.FAIL,
145
+ parameters={"success": False, "message": str(e)},
146
+ reasoning=f"Memory operation error: {e}",
147
+ confidence=1.0,
148
+ agent_id=self.agent_id,
149
+ )
150
+
151
+ async def plan(self, observation: Observation) -> list[Action]:
152
+ """
153
+ Create a plan of memory operations.
154
+
155
+ Plans memory operations needed based on the current state
156
+ and extracted data.
157
+
158
+ Args:
159
+ observation: The current state observation.
160
+
161
+ Returns:
162
+ A list of planned memory actions.
163
+ """
164
+ try:
165
+ actions: list[Action] = []
166
+
167
+ # Plan to store extracted fields
168
+ for field in observation.extracted_so_far:
169
+ if field.verified and field.confidence > 0.8:
170
+ actions.append(
171
+ Action(
172
+ action_type=ActionType.STORE_MEMORY,
173
+ parameters={
174
+ "key": f"extracted:{field.field_name}",
175
+ "value": field.value,
176
+ "memory_type": "working",
177
+ "metadata": {
178
+ "source": observation.current_url,
179
+ "confidence": field.confidence,
180
+ },
181
+ },
182
+ reasoning=f"Storing verified field: {field.field_name}",
183
+ confidence=0.9,
184
+ agent_id=self.agent_id,
185
+ )
186
+ )
187
+
188
+ # Plan to recall relevant memories for current task
189
+ if observation.task_context:
190
+ for target in observation.task_context.target_fields:
191
+ actions.append(
192
+ Action(
193
+ action_type=ActionType.RECALL_MEMORY,
194
+ parameters={
195
+ "key": f"pattern:{target}",
196
+ "memory_type": "long_term",
197
+ },
198
+ reasoning=f"Recalling patterns for field: {target}",
199
+ confidence=0.7,
200
+ agent_id=self.agent_id,
201
+ )
202
+ )
203
+
204
+ return actions
205
+
206
+ except Exception as e:
207
+ return [
208
+ Action(
209
+ action_type=ActionType.FAIL,
210
+ parameters={"message": f"Memory planning failed: {e}"},
211
+ reasoning=str(e),
212
+ confidence=1.0,
213
+ agent_id=self.agent_id,
214
+ )
215
+ ]
216
+
217
+ def store(
218
+ self,
219
+ key: str,
220
+ value: Any,
221
+ memory_type: str = "working",
222
+ ttl_seconds: int | None = None,
223
+ metadata: dict[str, Any] | None = None,
224
+ ) -> bool:
225
+ """
226
+ Store a value in memory.
227
+
228
+ Args:
229
+ key: The key to store under.
230
+ value: The value to store.
231
+ memory_type: Type of memory (short_term, working).
232
+ ttl_seconds: Optional time-to-live.
233
+ metadata: Optional metadata.
234
+
235
+ Returns:
236
+ True if stored successfully.
237
+ """
238
+ entry = MemoryEntry(
239
+ key=key,
240
+ value=value,
241
+ memory_type=memory_type,
242
+ ttl_seconds=ttl_seconds,
243
+ metadata=metadata,
244
+ )
245
+
246
+ if memory_type == "short_term":
247
+ self._enforce_limit(self._short_term, self.max_short_term)
248
+ self._short_term[key] = entry
249
+ elif memory_type == "working":
250
+ self._enforce_limit(self._working, self.max_working)
251
+ self._working[key] = entry
252
+ else:
253
+ return False
254
+
255
+ return True
256
+
257
+ def recall(
258
+ self,
259
+ key: str,
260
+ memory_type: str | None = None,
261
+ ) -> Any | None:
262
+ """
263
+ Recall a value from memory.
264
+
265
+ Args:
266
+ key: The key to recall.
267
+ memory_type: Optional specific memory type to search.
268
+
269
+ Returns:
270
+ The value if found, None otherwise.
271
+ """
272
+ # Search in order of specificity
273
+ stores = []
274
+ if memory_type == "working" or memory_type is None:
275
+ stores.append(self._working)
276
+ if memory_type == "short_term" or memory_type is None:
277
+ stores.append(self._short_term)
278
+
279
+ for store in stores:
280
+ if key in store:
281
+ entry = store[key]
282
+ if not entry.is_expired():
283
+ return entry.access()
284
+ else:
285
+ # Clean up expired entry
286
+ del store[key]
287
+
288
+ return None
289
+
290
+ def search(
291
+ self,
292
+ query: str,
293
+ memory_type: str | None = None,
294
+ limit: int = 10,
295
+ ) -> list[dict[str, Any]]:
296
+ """
297
+ Search memories by key prefix or content.
298
+
299
+ Args:
300
+ query: Search query (matches key prefix).
301
+ memory_type: Optional specific memory type.
302
+ limit: Maximum results to return.
303
+
304
+ Returns:
305
+ List of matching memories.
306
+ """
307
+ results: list[dict[str, Any]] = []
308
+ query_lower = query.lower()
309
+
310
+ stores = []
311
+ if memory_type in ("working", None):
312
+ stores.append(("working", self._working))
313
+ if memory_type in ("short_term", None):
314
+ stores.append(("short_term", self._short_term))
315
+
316
+ for store_name, store in stores:
317
+ for key, entry in store.items():
318
+ if entry.is_expired():
319
+ continue
320
+
321
+ # Match by key prefix or value content
322
+ if (
323
+ key.lower().startswith(query_lower)
324
+ or query_lower in str(entry.value).lower()
325
+ ):
326
+ results.append({
327
+ **entry.to_dict(),
328
+ "store": store_name,
329
+ })
330
+
331
+ if len(results) >= limit:
332
+ break
333
+
334
+ return results[:limit]
335
+
336
+ def _process_memory_request(self, message: dict[str, Any]) -> Action:
337
+ """Process a memory request from another agent."""
338
+ content = message.get("content", {})
339
+ operation = content.get("operation", "recall")
340
+ key = content.get("key", "")
341
+
342
+ if operation == "store":
343
+ success = self.store(
344
+ key=key,
345
+ value=content.get("value"),
346
+ memory_type=content.get("memory_type", "working"),
347
+ ttl_seconds=content.get("ttl_seconds"),
348
+ metadata=content.get("metadata"),
349
+ )
350
+ return Action(
351
+ action_type=ActionType.STORE_MEMORY,
352
+ parameters={"key": key, "success": success},
353
+ reasoning=f"Processed store request for key: {key}",
354
+ confidence=1.0 if success else 0.5,
355
+ agent_id=self.agent_id,
356
+ )
357
+
358
+ elif operation == "recall":
359
+ value = self.recall(key, content.get("memory_type"))
360
+ return Action(
361
+ action_type=ActionType.RECALL_MEMORY,
362
+ parameters={"key": key, "value": value, "found": value is not None},
363
+ reasoning=f"Processed recall request for key: {key}",
364
+ confidence=1.0 if value else 0.3,
365
+ agent_id=self.agent_id,
366
+ )
367
+
368
+ else:
369
+ return Action(
370
+ action_type=ActionType.FAIL,
371
+ parameters={"message": f"Unknown memory operation: {operation}"},
372
+ reasoning=f"Invalid memory request",
373
+ confidence=1.0,
374
+ agent_id=self.agent_id,
375
+ )
376
+
377
+ def _check_for_storage(self, observation: Observation) -> Action | None:
378
+ """Check if any new information should be stored."""
379
+ # Store newly extracted, verified fields
380
+ for field in observation.extracted_so_far:
381
+ key = f"field:{field.field_name}"
382
+ if key not in self._working and field.verified:
383
+ return Action(
384
+ action_type=ActionType.STORE_MEMORY,
385
+ parameters={
386
+ "key": key,
387
+ "value": {
388
+ "field_name": field.field_name,
389
+ "value": field.value,
390
+ "confidence": field.confidence,
391
+ "source": observation.current_url,
392
+ },
393
+ "memory_type": "working",
394
+ },
395
+ reasoning=f"Storing verified extraction: {field.field_name}",
396
+ confidence=0.85,
397
+ agent_id=self.agent_id,
398
+ )
399
+
400
+ return None
401
+
402
+ def _check_for_consolidation(self) -> Action | None:
403
+ """Check if any memories should be consolidated to long-term."""
404
+ for key, entry in self._working.items():
405
+ if entry.access_count >= self.consolidation_threshold:
406
+ return Action(
407
+ action_type=ActionType.STORE_MEMORY,
408
+ parameters={
409
+ "key": key,
410
+ "value": entry.value,
411
+ "memory_type": "long_term",
412
+ "metadata": {
413
+ "access_count": entry.access_count,
414
+ "consolidated_from": "working",
415
+ },
416
+ },
417
+ reasoning=f"Consolidating frequently accessed memory: {key}",
418
+ confidence=0.8,
419
+ agent_id=self.agent_id,
420
+ )
421
+
422
+ return None
423
+
424
+ def _cleanup_expired(self) -> int:
425
+ """Clean up expired memory entries."""
426
+ cleaned = 0
427
+
428
+ for store in [self._short_term, self._working]:
429
+ expired_keys = [
430
+ k for k, v in store.items()
431
+ if v.is_expired()
432
+ ]
433
+ for key in expired_keys:
434
+ del store[key]
435
+ cleaned += 1
436
+
437
+ return cleaned
438
+
439
+ def _enforce_limit(
440
+ self,
441
+ store: dict[str, MemoryEntry],
442
+ limit: int,
443
+ ) -> None:
444
+ """Enforce memory limit by removing least accessed entries."""
445
+ if len(store) < limit:
446
+ return
447
+
448
+ # Sort by access count and last access time
449
+ sorted_entries = sorted(
450
+ store.items(),
451
+ key=lambda x: (x[1].access_count, x[1].accessed_at),
452
+ )
453
+
454
+ # Remove oldest/least accessed entries
455
+ to_remove = len(store) - limit + 1
456
+ for key, _ in sorted_entries[:to_remove]:
457
+ del store[key]
458
+
459
+ def get_memory_stats(self) -> dict[str, Any]:
460
+ """Get statistics about memory usage."""
461
+ return {
462
+ "short_term_count": len(self._short_term),
463
+ "short_term_limit": self.max_short_term,
464
+ "working_count": len(self._working),
465
+ "working_limit": self.max_working,
466
+ "total_entries": len(self._short_term) + len(self._working),
467
+ }
468
+
469
+ def reset(self) -> None:
470
+ """Reset the memory agent state."""
471
+ super().reset()
472
+ self._short_term.clear()
473
+ self._working.clear()
474
+ self._pending_operations.clear()
backend/app/agents/navigator.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Navigator agent for URL prioritization and page navigation."""
2
+
3
+ from typing import Any
4
+ from urllib.parse import urljoin, urlparse
5
+
6
+ from app.core.action import Action, ActionType
7
+ from app.core.observation import Observation, PageElement
8
+
9
+ from .base import BaseAgent
10
+
11
+
12
+ class NavigatorAgent(BaseAgent):
13
+ """
14
+ Agent responsible for intelligent page navigation.
15
+
16
+ The NavigatorAgent handles:
17
+ - URL prioritization based on relevance to task
18
+ - Link discovery and scoring
19
+ - Navigation decision making
20
+ - Handling pagination and multi-page content
21
+ - Avoiding irrelevant or harmful URLs
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ agent_id: str = "navigator",
27
+ config: dict[str, Any] | None = None,
28
+ ):
29
+ """
30
+ Initialize the NavigatorAgent.
31
+
32
+ Args:
33
+ agent_id: Unique identifier for this agent.
34
+ config: Optional configuration with keys:
35
+ - max_depth: Maximum navigation depth (default: 5)
36
+ - allowed_domains: List of allowed domains to visit
37
+ - blocked_patterns: URL patterns to avoid
38
+ - prioritize_https: Prefer HTTPS URLs (default: True)
39
+ """
40
+ super().__init__(agent_id, config)
41
+ self.max_depth = self.config.get("max_depth", 5)
42
+ self.allowed_domains = self.config.get("allowed_domains", [])
43
+ self.blocked_patterns = self.config.get("blocked_patterns", [
44
+ "logout", "signout", "delete", "remove", "unsubscribe",
45
+ ])
46
+ self.prioritize_https = self.config.get("prioritize_https", True)
47
+ self._visited_urls: set[str] = set()
48
+ self._url_scores: dict[str, float] = {}
49
+
50
+ async def act(self, observation: Observation) -> Action:
51
+ """
52
+ Select the best navigation action based on observation.
53
+
54
+ Analyzes available links and decides whether to:
55
+ - Navigate to a new page
56
+ - Go back to a previous page
57
+ - Click an element to reveal more content
58
+
59
+ Args:
60
+ observation: The current state observation.
61
+
62
+ Returns:
63
+ The navigation action to execute.
64
+ """
65
+ try:
66
+ # Track current URL
67
+ if observation.current_url:
68
+ self._visited_urls.add(observation.current_url)
69
+
70
+ # Check if we've reached max depth
71
+ nav_depth = len(observation.navigation_history)
72
+ if nav_depth >= self.max_depth:
73
+ return self._create_go_back_action(
74
+ "Reached maximum navigation depth"
75
+ )
76
+
77
+ # Find best link to follow
78
+ best_link = await self._find_best_link(observation)
79
+
80
+ if best_link:
81
+ return self._create_navigate_action(best_link, observation)
82
+
83
+ # Check for pagination
84
+ pagination_action = self._find_pagination(observation)
85
+ if pagination_action:
86
+ return pagination_action
87
+
88
+ # No good links, consider going back
89
+ if observation.can_go_back and nav_depth > 1:
90
+ return self._create_go_back_action(
91
+ "No relevant links found, going back"
92
+ )
93
+
94
+ # Nothing to navigate to
95
+ return Action(
96
+ action_type=ActionType.WAIT,
97
+ parameters={"duration_ms": 500},
98
+ reasoning="No navigation targets found",
99
+ confidence=0.5,
100
+ agent_id=self.agent_id,
101
+ )
102
+
103
+ except Exception as e:
104
+ return Action(
105
+ action_type=ActionType.FAIL,
106
+ parameters={"success": False, "message": str(e)},
107
+ reasoning=f"Navigation error: {e}",
108
+ confidence=1.0,
109
+ agent_id=self.agent_id,
110
+ )
111
+
112
+ async def plan(self, observation: Observation) -> list[Action]:
113
+ """
114
+ Create a navigation plan based on task requirements.
115
+
116
+ Plans a sequence of navigation actions to reach content
117
+ relevant to the task.
118
+
119
+ Args:
120
+ observation: The current state observation.
121
+
122
+ Returns:
123
+ A list of planned navigation actions.
124
+ """
125
+ try:
126
+ actions: list[Action] = []
127
+ task_context = observation.task_context
128
+
129
+ if not task_context:
130
+ return []
131
+
132
+ # Analyze task hints for navigation targets
133
+ target_urls = self._extract_urls_from_hints(task_context.hints)
134
+
135
+ for url in target_urls[:3]: # Limit to top 3 URLs
136
+ if url not in self._visited_urls:
137
+ actions.append(
138
+ Action(
139
+ action_type=ActionType.NAVIGATE,
140
+ parameters={"url": url, "timeout_ms": 30000},
141
+ reasoning=f"Navigating to task-relevant URL: {url}",
142
+ confidence=0.85,
143
+ agent_id=self.agent_id,
144
+ )
145
+ )
146
+
147
+ # If no URLs from hints, plan a search
148
+ if not actions:
149
+ search_query = self._build_search_query(task_context)
150
+ actions.append(
151
+ Action(
152
+ action_type=ActionType.SEARCH_ENGINE,
153
+ parameters={"query": search_query, "engine": "google"},
154
+ reasoning=f"Searching for: {search_query}",
155
+ confidence=0.7,
156
+ agent_id=self.agent_id,
157
+ )
158
+ )
159
+
160
+ return actions
161
+
162
+ except Exception as e:
163
+ return [
164
+ Action(
165
+ action_type=ActionType.FAIL,
166
+ parameters={"message": f"Navigation planning failed: {e}"},
167
+ reasoning=str(e),
168
+ confidence=1.0,
169
+ agent_id=self.agent_id,
170
+ )
171
+ ]
172
+
173
+ async def _find_best_link(self, observation: Observation) -> str | None:
174
+ """Find the best link to follow based on task relevance."""
175
+ if not observation.task_context:
176
+ return None
177
+
178
+ target_fields = observation.task_context.target_fields
179
+ remaining_fields = observation.fields_remaining
180
+
181
+ # Score all links on the page
182
+ link_scores: list[tuple[str, float]] = []
183
+
184
+ for element in observation.page_elements:
185
+ if not element.is_interactive:
186
+ continue
187
+
188
+ href = element.attributes.get("href", "")
189
+ if not href or href.startswith("#") or href.startswith("javascript:"):
190
+ continue
191
+
192
+ # Resolve relative URLs
193
+ full_url = self._resolve_url(href, observation.current_url)
194
+ if not full_url:
195
+ continue
196
+
197
+ # Skip already visited URLs
198
+ if full_url in self._visited_urls:
199
+ continue
200
+
201
+ # Skip blocked patterns
202
+ if self._is_blocked_url(full_url):
203
+ continue
204
+
205
+ # Check domain restrictions
206
+ if not self._is_allowed_domain(full_url):
207
+ continue
208
+
209
+ # Score the link
210
+ score = self._score_link(element, full_url, remaining_fields)
211
+ if score > 0:
212
+ link_scores.append((full_url, score))
213
+
214
+ # Return highest scoring link
215
+ if link_scores:
216
+ link_scores.sort(key=lambda x: x[1], reverse=True)
217
+ return link_scores[0][0]
218
+
219
+ return None
220
+
221
+ def _score_link(
222
+ self,
223
+ element: PageElement,
224
+ url: str,
225
+ target_fields: list[str],
226
+ ) -> float:
227
+ """Score a link based on relevance to task fields."""
228
+ score = 0.0
229
+ text = (element.text or "").lower()
230
+ url_lower = url.lower()
231
+
232
+ # Check if link text contains target field names
233
+ for field in target_fields:
234
+ field_lower = field.lower()
235
+ if field_lower in text:
236
+ score += 0.4
237
+ if field_lower in url_lower:
238
+ score += 0.3
239
+
240
+ # Prefer HTTPS
241
+ if self.prioritize_https and url.startswith("https://"):
242
+ score += 0.1
243
+
244
+ # Boost content-like URLs
245
+ content_indicators = ["detail", "view", "info", "about", "product", "page"]
246
+ for indicator in content_indicators:
247
+ if indicator in url_lower:
248
+ score += 0.2
249
+ break
250
+
251
+ # Penalize non-content URLs
252
+ noise_indicators = ["login", "cart", "checkout", "share", "print"]
253
+ for indicator in noise_indicators:
254
+ if indicator in url_lower:
255
+ score -= 0.3
256
+ break
257
+
258
+ return max(0.0, score)
259
+
260
+ def _resolve_url(self, href: str, base_url: str | None) -> str | None:
261
+ """Resolve a relative URL to an absolute URL."""
262
+ if not href:
263
+ return None
264
+
265
+ if href.startswith(("http://", "https://")):
266
+ return href
267
+
268
+ if not base_url:
269
+ return None
270
+
271
+ try:
272
+ return urljoin(base_url, href)
273
+ except Exception:
274
+ return None
275
+
276
+ def _is_blocked_url(self, url: str) -> bool:
277
+ """Check if URL matches any blocked patterns."""
278
+ url_lower = url.lower()
279
+ for pattern in self.blocked_patterns:
280
+ if pattern.lower() in url_lower:
281
+ return True
282
+ return False
283
+
284
+ def _is_allowed_domain(self, url: str) -> bool:
285
+ """Check if URL domain is allowed."""
286
+ if not self.allowed_domains:
287
+ return True
288
+
289
+ try:
290
+ parsed = urlparse(url)
291
+ domain = parsed.netloc.lower()
292
+ for allowed in self.allowed_domains:
293
+ if domain == allowed.lower() or domain.endswith("." + allowed.lower()):
294
+ return True
295
+ return False
296
+ except Exception:
297
+ return False
298
+
299
+ def _find_pagination(self, observation: Observation) -> Action | None:
300
+ """Find and create action for pagination elements."""
301
+ pagination_selectors = [
302
+ "[aria-label*='next']",
303
+ "[aria-label*='Next']",
304
+ "a.next",
305
+ "button.next",
306
+ "[rel='next']",
307
+ ]
308
+
309
+ for element in observation.page_elements:
310
+ text = (element.text or "").lower()
311
+ if element.is_interactive and ("next" in text or "more" in text):
312
+ return Action(
313
+ action_type=ActionType.CLICK,
314
+ parameters={"selector": element.selector},
315
+ reasoning="Clicking pagination to load more content",
316
+ confidence=0.7,
317
+ agent_id=self.agent_id,
318
+ )
319
+
320
+ return None
321
+
322
+ def _extract_urls_from_hints(self, hints: list[str]) -> list[str]:
323
+ """Extract URLs from task hints."""
324
+ urls = []
325
+ for hint in hints:
326
+ if hint.startswith(("http://", "https://")):
327
+ urls.append(hint)
328
+ elif "://" not in hint and "." in hint:
329
+ # Might be a domain without protocol
330
+ urls.append(f"https://{hint}")
331
+ return urls
332
+
333
+ def _build_search_query(self, task_context: Any) -> str:
334
+ """Build a search query from task context."""
335
+ parts = [task_context.task_name]
336
+ if task_context.target_fields:
337
+ parts.extend(task_context.target_fields[:2])
338
+ return " ".join(parts)
339
+
340
+ def _create_navigate_action(self, url: str, observation: Observation) -> Action:
341
+ """Create a navigate action for the given URL."""
342
+ return Action(
343
+ action_type=ActionType.NAVIGATE,
344
+ parameters={"url": url, "timeout_ms": 30000},
345
+ reasoning=f"Navigating to relevant URL: {url}",
346
+ confidence=0.75,
347
+ agent_id=self.agent_id,
348
+ )
349
+
350
+ def _create_go_back_action(self, reason: str) -> Action:
351
+ """Create a go back action."""
352
+ return Action(
353
+ action_type=ActionType.GO_BACK,
354
+ parameters={},
355
+ reasoning=reason,
356
+ confidence=0.8,
357
+ agent_id=self.agent_id,
358
+ )
359
+
360
+ def get_visited_urls(self) -> set[str]:
361
+ """Get the set of visited URLs."""
362
+ return self._visited_urls.copy()
363
+
364
+ def reset(self) -> None:
365
+ """Reset the navigator state."""
366
+ super().reset()
367
+ self._visited_urls.clear()
368
+ self._url_scores.clear()
backend/app/agents/planner.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Planner agent for goal decomposition and task planning."""
2
+
3
+ from typing import Any
4
+
5
+ from app.core.action import Action, ActionType
6
+ from app.core.observation import Observation
7
+
8
+ from .base import BaseAgent
9
+
10
+
11
+ class PlannerAgent(BaseAgent):
12
+ """
13
+ Agent responsible for high-level planning and goal decomposition.
14
+
15
+ The PlannerAgent analyzes the task requirements and creates
16
+ structured plans that other agents can execute. It handles:
17
+ - Breaking down complex tasks into subtasks
18
+ - Determining the optimal sequence of actions
19
+ - Adapting plans based on execution results
20
+ - Coordinating multi-step extraction workflows
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ agent_id: str = "planner",
26
+ config: dict[str, Any] | None = None,
27
+ ):
28
+ """
29
+ Initialize the PlannerAgent.
30
+
31
+ Args:
32
+ agent_id: Unique identifier for this agent.
33
+ config: Optional configuration with keys:
34
+ - max_plan_depth: Maximum depth of nested plans (default: 5)
35
+ - replan_threshold: Error count before replanning (default: 2)
36
+ - planning_model: LLM model to use for planning
37
+ """
38
+ super().__init__(agent_id, config)
39
+ self.max_plan_depth = self.config.get("max_plan_depth", 5)
40
+ self.replan_threshold = self.config.get("replan_threshold", 2)
41
+ self._current_plan: list[Action] | None = None
42
+ self._plan_step: int = 0
43
+
44
+ async def act(self, observation: Observation) -> Action:
45
+ """
46
+ Select the next action based on the current plan or create a new one.
47
+
48
+ If no plan exists or the current plan has failed, creates a new plan.
49
+ Otherwise, returns the next action in the current plan.
50
+
51
+ Args:
52
+ observation: The current state observation.
53
+
54
+ Returns:
55
+ The next action to execute.
56
+ """
57
+ try:
58
+ # Check if we need to replan due to errors
59
+ if observation.consecutive_errors >= self.replan_threshold:
60
+ self._current_plan = None
61
+ self._plan_step = 0
62
+
63
+ # Create plan if none exists
64
+ if self._current_plan is None or self._plan_step >= len(self._current_plan):
65
+ self._current_plan = await self.plan(observation)
66
+ self._plan_step = 0
67
+
68
+ if not self._current_plan:
69
+ return self._create_done_action("No actions planned")
70
+
71
+ # Get next action from plan
72
+ action = self._current_plan[self._plan_step]
73
+ action.plan_step = self._plan_step
74
+ action.agent_id = self.agent_id
75
+ self._plan_step += 1
76
+
77
+ return action
78
+
79
+ except Exception as e:
80
+ return self._create_error_action(f"Planning error: {e}")
81
+
82
+ async def plan(self, observation: Observation) -> list[Action]:
83
+ """
84
+ Create a plan of actions to achieve the task goals.
85
+
86
+ Analyzes the observation to determine:
87
+ - What fields still need to be extracted
88
+ - What navigation may be required
89
+ - What verification steps are needed
90
+
91
+ Args:
92
+ observation: The current state observation.
93
+
94
+ Returns:
95
+ A list of planned actions in execution order.
96
+ """
97
+ try:
98
+ actions: list[Action] = []
99
+ task_context = observation.task_context
100
+
101
+ if not task_context:
102
+ return [self._create_done_action("No task context provided")]
103
+
104
+ # Determine remaining fields to extract
105
+ remaining_fields = observation.fields_remaining
106
+ extracted_fields = [f.field_name for f in observation.extracted_so_far]
107
+
108
+ # If no URL loaded, plan navigation first
109
+ if not observation.current_url:
110
+ search_action = self._plan_initial_navigation(task_context)
111
+ if search_action:
112
+ actions.append(search_action)
113
+
114
+ # Plan extraction for remaining fields
115
+ for field in remaining_fields:
116
+ extraction_action = self._plan_field_extraction(
117
+ field,
118
+ observation,
119
+ )
120
+ actions.append(extraction_action)
121
+
122
+ # Plan verification if fields have been extracted
123
+ if extracted_fields:
124
+ verify_action = self._plan_verification(extracted_fields)
125
+ actions.append(verify_action)
126
+
127
+ # Add completion action
128
+ actions.append(
129
+ Action(
130
+ action_type=ActionType.DONE,
131
+ parameters={"success": True, "message": "Plan completed"},
132
+ reasoning="All planned steps completed",
133
+ confidence=0.9,
134
+ agent_id=self.agent_id,
135
+ )
136
+ )
137
+
138
+ return actions
139
+
140
+ except Exception as e:
141
+ return [self._create_error_action(f"Plan creation failed: {e}")]
142
+
143
+ def _plan_initial_navigation(self, task_context: Any) -> Action | None:
144
+ """Plan initial navigation based on task context."""
145
+ if task_context.hints:
146
+ # Use hints for navigation
147
+ for hint in task_context.hints:
148
+ if hint.startswith("http"):
149
+ return Action(
150
+ action_type=ActionType.NAVIGATE,
151
+ parameters={"url": hint},
152
+ reasoning=f"Navigating to hinted URL: {hint}",
153
+ confidence=0.85,
154
+ agent_id=self.agent_id,
155
+ )
156
+
157
+ # Default to search
158
+ search_query = f"{task_context.task_name} site information"
159
+ return Action(
160
+ action_type=ActionType.SEARCH_ENGINE,
161
+ parameters={"query": search_query, "engine": "google"},
162
+ reasoning=f"Searching for: {search_query}",
163
+ confidence=0.7,
164
+ agent_id=self.agent_id,
165
+ )
166
+
167
+ def _plan_field_extraction(
168
+ self,
169
+ field_name: str,
170
+ observation: Observation,
171
+ ) -> Action:
172
+ """Plan extraction for a specific field."""
173
+ # Check if we have page elements that might contain the field
174
+ selector = None
175
+ confidence = 0.6
176
+
177
+ for element in observation.page_elements:
178
+ element_text = (element.text or "").lower()
179
+ if field_name.lower() in element_text:
180
+ selector = element.selector
181
+ confidence = 0.8
182
+ break
183
+
184
+ return Action(
185
+ action_type=ActionType.EXTRACT_FIELD,
186
+ parameters={
187
+ "field_name": field_name,
188
+ "selector": selector,
189
+ "extraction_method": "text",
190
+ },
191
+ reasoning=f"Extracting field: {field_name}",
192
+ confidence=confidence,
193
+ agent_id=self.agent_id,
194
+ )
195
+
196
+ def _plan_verification(self, fields: list[str]) -> Action:
197
+ """Plan verification for extracted fields."""
198
+ return Action(
199
+ action_type=ActionType.VERIFY_FIELD,
200
+ parameters={
201
+ "field_name": fields[0] if fields else "unknown",
202
+ "validation_rules": ["not_empty", "format_check"],
203
+ },
204
+ reasoning=f"Verifying extracted fields: {fields}",
205
+ confidence=0.75,
206
+ agent_id=self.agent_id,
207
+ )
208
+
209
+ def _create_done_action(self, message: str) -> Action:
210
+ """Create a done action."""
211
+ return Action(
212
+ action_type=ActionType.DONE,
213
+ parameters={"success": True, "message": message},
214
+ reasoning=message,
215
+ confidence=1.0,
216
+ agent_id=self.agent_id,
217
+ )
218
+
219
+ def _create_error_action(self, error: str) -> Action:
220
+ """Create a fail action for errors."""
221
+ return Action(
222
+ action_type=ActionType.FAIL,
223
+ parameters={"success": False, "message": error},
224
+ reasoning=error,
225
+ confidence=1.0,
226
+ agent_id=self.agent_id,
227
+ )
228
+
229
+ def get_current_plan(self) -> list[Action] | None:
230
+ """Get the current plan."""
231
+ return self._current_plan
232
+
233
+ def get_plan_progress(self) -> tuple[int, int]:
234
+ """Get current plan progress as (current_step, total_steps)."""
235
+ total = len(self._current_plan) if self._current_plan else 0
236
+ return (self._plan_step, total)
237
+
238
+ def reset(self) -> None:
239
+ """Reset the planner state."""
240
+ super().reset()
241
+ self._current_plan = None
242
+ self._plan_step = 0
backend/app/agents/verifier.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Verifier agent for cross-source verification."""
2
+
3
+ import re
4
+ from typing import Any
5
+
6
+ from app.core.action import Action, ActionType
7
+ from app.core.observation import ExtractedField, Observation
8
+
9
+ from .base import BaseAgent
10
+
11
+
12
+ class VerificationResult:
13
+ """Result of a verification check."""
14
+
15
+ def __init__(
16
+ self,
17
+ field_name: str,
18
+ is_valid: bool,
19
+ confidence: float,
20
+ issues: list[str] | None = None,
21
+ sources_checked: int = 0,
22
+ ):
23
+ """Initialize verification result."""
24
+ self.field_name = field_name
25
+ self.is_valid = is_valid
26
+ self.confidence = confidence
27
+ self.issues = issues or []
28
+ self.sources_checked = sources_checked
29
+
30
+ def to_dict(self) -> dict[str, Any]:
31
+ """Convert to dictionary."""
32
+ return {
33
+ "field_name": self.field_name,
34
+ "is_valid": self.is_valid,
35
+ "confidence": self.confidence,
36
+ "issues": self.issues,
37
+ "sources_checked": self.sources_checked,
38
+ }
39
+
40
+
41
+ class VerifierAgent(BaseAgent):
42
+ """
43
+ Agent responsible for verifying extracted data.
44
+
45
+ The VerifierAgent handles:
46
+ - Format validation (emails, URLs, dates, etc.)
47
+ - Cross-source verification
48
+ - Consistency checks across fields
49
+ - Confidence scoring for verified data
50
+ - Flagging suspicious or inconsistent data
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ agent_id: str = "verifier",
56
+ config: dict[str, Any] | None = None,
57
+ ):
58
+ """
59
+ Initialize the VerifierAgent.
60
+
61
+ Args:
62
+ agent_id: Unique identifier for this agent.
63
+ config: Optional configuration with keys:
64
+ - min_confidence: Minimum confidence to accept (default: 0.7)
65
+ - require_cross_validation: Require multiple sources (default: False)
66
+ - strict_mode: Apply stricter validation rules (default: False)
67
+ """
68
+ super().__init__(agent_id, config)
69
+ self.min_confidence = self.config.get("min_confidence", 0.7)
70
+ self.require_cross_validation = self.config.get("require_cross_validation", False)
71
+ self.strict_mode = self.config.get("strict_mode", False)
72
+ self._validation_rules = self._init_validation_rules()
73
+ self._verification_history: list[VerificationResult] = []
74
+
75
+ def _init_validation_rules(self) -> dict[str, list[dict[str, Any]]]:
76
+ """Initialize validation rules for common field types."""
77
+ return {
78
+ "email": [
79
+ {
80
+ "type": "regex",
81
+ "pattern": r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
82
+ "error": "Invalid email format",
83
+ },
84
+ ],
85
+ "url": [
86
+ {
87
+ "type": "regex",
88
+ "pattern": r"^https?://[^\s]+$",
89
+ "error": "Invalid URL format",
90
+ },
91
+ ],
92
+ "phone": [
93
+ {
94
+ "type": "regex",
95
+ "pattern": r"[\d\s\-\(\)\+]{7,}",
96
+ "error": "Invalid phone format",
97
+ },
98
+ ],
99
+ "price": [
100
+ {
101
+ "type": "range",
102
+ "min": 0,
103
+ "max": 1000000,
104
+ "error": "Price out of reasonable range",
105
+ },
106
+ ],
107
+ "date": [
108
+ {
109
+ "type": "regex",
110
+ "pattern": r"\d{1,4}[-/]\d{1,2}[-/]\d{1,4}",
111
+ "error": "Invalid date format",
112
+ },
113
+ ],
114
+ "rating": [
115
+ {
116
+ "type": "range",
117
+ "min": 0,
118
+ "max": 5,
119
+ "error": "Rating out of range",
120
+ },
121
+ ],
122
+ }
123
+
124
+ async def act(self, observation: Observation) -> Action:
125
+ """
126
+ Select the best verification action based on observation.
127
+
128
+ Determines which extracted fields need verification and
129
+ selects the appropriate verification method.
130
+
131
+ Args:
132
+ observation: The current state observation.
133
+
134
+ Returns:
135
+ The verification action to execute.
136
+ """
137
+ try:
138
+ # Find unverified fields
139
+ unverified = [
140
+ f for f in observation.extracted_so_far
141
+ if not f.verified
142
+ ]
143
+
144
+ if not unverified:
145
+ return Action(
146
+ action_type=ActionType.DONE,
147
+ parameters={"success": True, "message": "All fields verified"},
148
+ reasoning="No unverified fields remaining",
149
+ confidence=1.0,
150
+ agent_id=self.agent_id,
151
+ )
152
+
153
+ # Verify the first unverified field
154
+ field = unverified[0]
155
+ result = await self._verify_field(field, observation)
156
+
157
+ if result.is_valid and result.confidence >= self.min_confidence:
158
+ return Action(
159
+ action_type=ActionType.VERIFY_FIELD,
160
+ parameters={
161
+ "field_name": field.field_name,
162
+ "verified": True,
163
+ "confidence": result.confidence,
164
+ "issues": result.issues,
165
+ },
166
+ reasoning=f"Field {field.field_name} verified with confidence {result.confidence:.2f}",
167
+ confidence=result.confidence,
168
+ agent_id=self.agent_id,
169
+ )
170
+ else:
171
+ # Verification failed - may need re-extraction
172
+ return self._create_reverify_action(field, result)
173
+
174
+ except Exception as e:
175
+ return Action(
176
+ action_type=ActionType.FAIL,
177
+ parameters={"success": False, "message": str(e)},
178
+ reasoning=f"Verification error: {e}",
179
+ confidence=1.0,
180
+ agent_id=self.agent_id,
181
+ )
182
+
183
+ async def plan(self, observation: Observation) -> list[Action]:
184
+ """
185
+ Create a verification plan for all extracted fields.
186
+
187
+ Args:
188
+ observation: The current state observation.
189
+
190
+ Returns:
191
+ A list of planned verification actions.
192
+ """
193
+ try:
194
+ actions: list[Action] = []
195
+
196
+ # Plan verification for each unverified field
197
+ for field in observation.extracted_so_far:
198
+ if field.verified:
199
+ continue
200
+
201
+ # Basic format verification
202
+ actions.append(
203
+ Action(
204
+ action_type=ActionType.VERIFY_FIELD,
205
+ parameters={
206
+ "field_name": field.field_name,
207
+ "expected_type": self._infer_field_type(field.field_name),
208
+ },
209
+ reasoning=f"Verify format of {field.field_name}",
210
+ confidence=0.8,
211
+ agent_id=self.agent_id,
212
+ )
213
+ )
214
+
215
+ # Cross-source verification if required
216
+ if self.require_cross_validation:
217
+ actions.append(
218
+ Action(
219
+ action_type=ActionType.VERIFY_FACT,
220
+ parameters={
221
+ "claim": f"{field.field_name}: {field.value}",
222
+ "confidence_threshold": self.min_confidence,
223
+ },
224
+ reasoning=f"Cross-validate {field.field_name} with other sources",
225
+ confidence=0.7,
226
+ agent_id=self.agent_id,
227
+ )
228
+ )
229
+
230
+ return actions
231
+
232
+ except Exception as e:
233
+ return [
234
+ Action(
235
+ action_type=ActionType.FAIL,
236
+ parameters={"message": f"Verification planning failed: {e}"},
237
+ reasoning=str(e),
238
+ confidence=1.0,
239
+ agent_id=self.agent_id,
240
+ )
241
+ ]
242
+
243
+ async def _verify_field(
244
+ self,
245
+ field: ExtractedField,
246
+ observation: Observation,
247
+ ) -> VerificationResult:
248
+ """
249
+ Verify a single field.
250
+
251
+ Args:
252
+ field: The field to verify.
253
+ observation: Current observation context.
254
+
255
+ Returns:
256
+ Verification result.
257
+ """
258
+ issues: list[str] = []
259
+ confidence = field.confidence
260
+ sources_checked = 1
261
+
262
+ # Apply validation rules
263
+ field_type = self._infer_field_type(field.field_name)
264
+ format_valid, format_issues = self._validate_format(
265
+ field.value,
266
+ field_type,
267
+ )
268
+
269
+ if not format_valid:
270
+ issues.extend(format_issues)
271
+ confidence *= 0.5
272
+
273
+ # Check for empty or null values
274
+ if field.value is None or (
275
+ isinstance(field.value, str) and not field.value.strip()
276
+ ):
277
+ issues.append("Empty value")
278
+ confidence = 0.0
279
+
280
+ # Check against memory context for consistency
281
+ consistency_issues = self._check_consistency(field, observation)
282
+ if consistency_issues:
283
+ issues.extend(consistency_issues)
284
+ confidence *= 0.8
285
+
286
+ # Create result
287
+ result = VerificationResult(
288
+ field_name=field.field_name,
289
+ is_valid=len(issues) == 0,
290
+ confidence=confidence,
291
+ issues=issues,
292
+ sources_checked=sources_checked,
293
+ )
294
+
295
+ self._verification_history.append(result)
296
+ return result
297
+
298
+ def _validate_format(
299
+ self,
300
+ value: Any,
301
+ field_type: str,
302
+ ) -> tuple[bool, list[str]]:
303
+ """
304
+ Validate value format against rules.
305
+
306
+ Args:
307
+ value: The value to validate.
308
+ field_type: The expected field type.
309
+
310
+ Returns:
311
+ Tuple of (is_valid, list of issues).
312
+ """
313
+ if value is None:
314
+ return False, ["Value is None"]
315
+
316
+ issues: list[str] = []
317
+ rules = self._validation_rules.get(field_type, [])
318
+
319
+ value_str = str(value)
320
+
321
+ for rule in rules:
322
+ rule_type = rule.get("type")
323
+
324
+ if rule_type == "regex":
325
+ pattern = rule.get("pattern", "")
326
+ if not re.match(pattern, value_str):
327
+ issues.append(rule.get("error", "Format validation failed"))
328
+
329
+ elif rule_type == "range":
330
+ try:
331
+ num_value = float(value_str.replace(",", "").replace("$", ""))
332
+ min_val = rule.get("min", float("-inf"))
333
+ max_val = rule.get("max", float("inf"))
334
+ if not (min_val <= num_value <= max_val):
335
+ issues.append(rule.get("error", "Value out of range"))
336
+ except ValueError:
337
+ issues.append("Cannot convert to number for range check")
338
+
339
+ elif rule_type == "length":
340
+ min_len = rule.get("min", 0)
341
+ max_len = rule.get("max", float("inf"))
342
+ if not (min_len <= len(value_str) <= max_len):
343
+ issues.append(rule.get("error", "Length validation failed"))
344
+
345
+ return len(issues) == 0, issues
346
+
347
+ def _check_consistency(
348
+ self,
349
+ field: ExtractedField,
350
+ observation: Observation,
351
+ ) -> list[str]:
352
+ """
353
+ Check field consistency with other data.
354
+
355
+ Args:
356
+ field: The field to check.
357
+ observation: Current observation.
358
+
359
+ Returns:
360
+ List of consistency issues.
361
+ """
362
+ issues: list[str] = []
363
+
364
+ # Check against other extracted fields
365
+ for other in observation.extracted_so_far:
366
+ if other.field_name == field.field_name:
367
+ continue
368
+
369
+ # Example: price should be less than total_price
370
+ if field.field_name == "price" and other.field_name == "total_price":
371
+ try:
372
+ price = float(str(field.value).replace("$", "").replace(",", ""))
373
+ total = float(str(other.value).replace("$", "").replace(",", ""))
374
+ if price > total:
375
+ issues.append("Price exceeds total_price")
376
+ except (ValueError, TypeError):
377
+ pass
378
+
379
+ # Check against memory for historical consistency
380
+ memory = observation.memory_context
381
+ if memory.long_term_relevant:
382
+ for mem in memory.long_term_relevant:
383
+ if mem.get("field") == field.field_name:
384
+ historical_value = mem.get("value")
385
+ if historical_value and historical_value != field.value:
386
+ # Different from historical - flag for review
387
+ issues.append(
388
+ f"Value differs from historical: {historical_value}"
389
+ )
390
+
391
+ return issues
392
+
393
+ def _infer_field_type(self, field_name: str) -> str:
394
+ """Infer the field type from its name."""
395
+ field_lower = field_name.lower()
396
+
397
+ type_keywords = {
398
+ "email": ["email", "mail"],
399
+ "url": ["url", "link", "href", "website"],
400
+ "phone": ["phone", "tel", "mobile", "fax"],
401
+ "price": ["price", "cost", "amount", "total", "fee"],
402
+ "date": ["date", "time", "created", "updated", "published"],
403
+ "rating": ["rating", "score", "stars"],
404
+ }
405
+
406
+ for field_type, keywords in type_keywords.items():
407
+ for keyword in keywords:
408
+ if keyword in field_lower:
409
+ return field_type
410
+
411
+ return "text"
412
+
413
+ def _create_reverify_action(
414
+ self,
415
+ field: ExtractedField,
416
+ result: VerificationResult,
417
+ ) -> Action:
418
+ """Create an action to handle failed verification."""
419
+ if result.confidence < 0.3:
420
+ # Very low confidence - suggest re-extraction
421
+ return Action(
422
+ action_type=ActionType.EXTRACT_FIELD,
423
+ parameters={
424
+ "field_name": field.field_name,
425
+ "reason": "Re-extracting due to verification failure",
426
+ },
427
+ reasoning=f"Verification failed with issues: {result.issues}",
428
+ confidence=0.6,
429
+ agent_id=self.agent_id,
430
+ )
431
+ else:
432
+ # Moderate confidence - try cross-validation
433
+ return Action(
434
+ action_type=ActionType.VERIFY_FACT,
435
+ parameters={
436
+ "claim": f"{field.field_name}: {field.value}",
437
+ "sources": None,
438
+ "confidence_threshold": self.min_confidence,
439
+ },
440
+ reasoning=f"Attempting cross-validation for {field.field_name}",
441
+ confidence=0.5,
442
+ agent_id=self.agent_id,
443
+ )
444
+
445
+ def add_validation_rule(
446
+ self,
447
+ field_type: str,
448
+ rule: dict[str, Any],
449
+ ) -> None:
450
+ """
451
+ Add a custom validation rule.
452
+
453
+ Args:
454
+ field_type: The field type this rule applies to.
455
+ rule: The validation rule dictionary.
456
+ """
457
+ if field_type not in self._validation_rules:
458
+ self._validation_rules[field_type] = []
459
+ self._validation_rules[field_type].append(rule)
460
+
461
+ def get_verification_history(self) -> list[dict[str, Any]]:
462
+ """Get verification history as dictionaries."""
463
+ return [r.to_dict() for r in self._verification_history]
464
+
465
+ def reset(self) -> None:
466
+ """Reset the verifier state."""
467
+ super().reset()
468
+ self._verification_history.clear()