ask-the-web-agent / src /agent /controller.py
debashis2007's picture
Upload folder using huggingface_hub
75bea1c verified
"""Agent controller - orchestrates the agent execution flow."""
from __future__ import annotations
import json
from typing import Any
from src.agent.models import (
AgentState,
ExecutionPlan,
Intent,
IntentType,
PlanStep,
ThoughtStep,
WorkflowStrategy,
)
from src.llm import LLMClient, Message, MessageRole
from src.llm.prompts import format_prompt, get_system_prompt, PromptNames
from src.tools.base import ToolRegistry
from src.utils.config import settings
from src.utils.exceptions import MaxIterationsError, PlanningError
from src.utils.logging import get_logger, log_agent_step
logger = get_logger(__name__)
class AgentController:
"""Controller that orchestrates agent activities."""
def __init__(self, llm_client: LLMClient, tool_registry: ToolRegistry):
"""Initialize the agent controller.
Args:
llm_client: LLM client for reasoning
tool_registry: Registry of available tools
"""
self.llm = llm_client
self.tools = tool_registry
self.system_prompt = get_system_prompt()
async def parse_intent(self, query: str) -> Intent:
"""Parse user query to determine intent.
Args:
query: User's query string
Returns:
Parsed Intent object
"""
prompt = format_prompt(PromptNames.INTENT_PARSER, user_query=query)
messages = [
Message(role=MessageRole.SYSTEM, content=self.system_prompt),
Message(role=MessageRole.USER, content=prompt),
]
response = await self.llm.chat(messages, temperature=0.3)
try:
# Parse JSON response
content = response.content or "{}"
# Extract JSON from markdown code block if present
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
data = json.loads(content)
return Intent(
intent_type=IntentType(data.get("intent", "factual_query").lower()),
confidence=data.get("confidence", 0.5),
secondary_intents=[
IntentType(i.lower()) for i in data.get("secondary_intents", [])
],
entities=data.get("entities", {}),
requires_web_search=data.get("requires_web_search", True),
complexity=data.get("complexity", "simple"),
)
except Exception as e:
logger.warning(f"Failed to parse intent, using defaults: {e}")
return Intent(
intent_type=IntentType.FACTUAL_QUERY,
confidence=0.5,
requires_web_search=True,
)
async def plan_workflow(self, query: str, intent: Intent) -> ExecutionPlan:
"""Create an execution plan based on intent.
Args:
query: User's query
intent: Parsed intent
Returns:
ExecutionPlan for the query
"""
prompt = format_prompt(
PromptNames.WORKFLOW_PLANNER,
user_query=query,
intent_analysis=json.dumps({
"intent": intent.intent_type.value,
"confidence": intent.confidence,
"requires_web_search": intent.requires_web_search,
"complexity": intent.complexity,
"entities": intent.entities,
}),
context="",
)
messages = [
Message(role=MessageRole.SYSTEM, content=self.system_prompt),
Message(role=MessageRole.USER, content=prompt),
]
response = await self.llm.chat(messages, temperature=0.3)
try:
content = response.content or "{}"
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
data = json.loads(content)
steps = []
for step_data in data.get("plan", []):
steps.append(
PlanStep(
step_number=step_data.get("step", len(steps) + 1),
action=step_data.get("action", ""),
tool=step_data.get("tool"),
parameters=step_data.get("parameters", {}),
purpose=step_data.get("purpose", ""),
depends_on=step_data.get("depends_on", []),
)
)
return ExecutionPlan(
strategy=WorkflowStrategy(data.get("strategy", "single_search").lower()),
reasoning=data.get("reasoning", ""),
steps=steps,
max_iterations=data.get("max_iterations", settings.max_iterations),
fallback_strategy=(
WorkflowStrategy(data["fallback_strategy"].lower())
if data.get("fallback_strategy")
else None
),
success_criteria=data.get("success_criteria", ""),
)
except Exception as e:
logger.warning(f"Failed to parse plan, using default: {e}")
# Return a simple default plan
return self._create_default_plan(query, intent)
def _create_default_plan(self, query: str, intent: Intent) -> ExecutionPlan:
"""Create a default execution plan.
Args:
query: User's query
intent: Parsed intent
Returns:
Default ExecutionPlan
"""
if intent.requires_web_search:
return ExecutionPlan(
strategy=WorkflowStrategy.SINGLE_SEARCH,
reasoning="Default plan with web search",
steps=[
PlanStep(
step_number=1,
action="search",
tool="web_search",
parameters={"query": query, "num_results": 5},
purpose="Search for relevant information",
),
PlanStep(
step_number=2,
action="synthesize",
tool=None,
parameters={},
purpose="Synthesize search results into answer",
depends_on=[1],
),
],
)
else:
return ExecutionPlan(
strategy=WorkflowStrategy.DIRECT_ANSWER,
reasoning="Query can be answered directly",
steps=[
PlanStep(
step_number=1,
action="respond",
tool=None,
parameters={},
purpose="Generate direct response",
),
],
)
async def execute_step(
self, state: AgentState, step: PlanStep
) -> dict[str, Any]:
"""Execute a single step in the plan.
Args:
state: Current agent state
step: Step to execute
Returns:
Step result
"""
log_agent_step(logger, "executing", {"step": step.step_number, "action": step.action})
if step.tool:
# Execute tool
result = await self.tools.execute(step.tool, **step.parameters)
return {
"step": step.step_number,
"tool": step.tool,
"success": result.success,
"data": result.data,
"error": result.error,
}
else:
# Non-tool action (synthesize, respond, etc.)
return {
"step": step.step_number,
"action": step.action,
"success": True,
"data": None,
}
async def run_react_loop(
self, state: AgentState
) -> tuple[str, list[ThoughtStep]]:
"""Run the ReACT reasoning loop.
Args:
state: Current agent state
Returns:
Tuple of (final_answer, thought_history)
"""
thought_history: list[ThoughtStep] = []
iteration = 0
# Build tool descriptions for the prompt
tool_schemas = self.tools.get_schemas()
while iteration < settings.max_iterations:
iteration += 1
log_agent_step(logger, "react_iteration", {"iteration": iteration})
# Build context from previous steps
context = self._build_react_context(state, thought_history)
prompt = format_prompt(
PromptNames.REACT_REASONING,
user_query=state.query,
iteration_number=iteration,
max_iterations=settings.max_iterations,
previous_steps=context,
working_memory=json.dumps(state.working_memory),
)
messages = [
Message(role=MessageRole.SYSTEM, content=self.system_prompt),
Message(role=MessageRole.USER, content=prompt),
]
# Get LLM response with tools
response = await self.llm.chat(messages, tools=tool_schemas, temperature=0.5)
# Parse thought and action from response
thought, action, action_input = self._parse_react_response(response)
log_agent_step(
logger,
"thought",
{"thought": thought, "action": action},
iteration=iteration,
)
# Check for finish action
if action.lower() == "finish":
thought_step = ThoughtStep(
iteration=iteration,
thought=thought,
action="finish",
action_input=action_input,
observation=action_input.get("answer", ""),
)
thought_history.append(thought_step)
return action_input.get("answer", response.content or ""), thought_history
# Execute tool action
if response.has_tool_calls:
tool_call = response.tool_calls[0]
result = await self.tools.execute(tool_call.name, **tool_call.arguments)
observation = json.dumps(result.data) if result.success else f"Error: {result.error}"
else:
# Manual tool call parsing from response
if action and action != "finish":
result = await self.tools.execute(action, **action_input)
observation = json.dumps(result.data) if result.success else f"Error: {result.error}"
else:
observation = "No action taken"
# Record step
thought_step = ThoughtStep(
iteration=iteration,
thought=thought,
action=action,
action_input=action_input,
observation=observation,
)
thought_history.append(thought_step)
# Update working memory
state.working_memory[f"step_{iteration}"] = {
"action": action,
"result": observation,
}
raise MaxIterationsError(f"Reached maximum iterations ({settings.max_iterations})")
def _build_react_context(
self, state: AgentState, thought_history: list[ThoughtStep]
) -> str:
"""Build context string from thought history.
Args:
state: Current state
thought_history: List of thought steps
Returns:
Formatted context string
"""
if not thought_history:
return "No previous steps."
context_parts = []
for step in thought_history:
context_parts.append(
f"**THOUGHT {step.iteration}:** {step.thought}\n"
f"**ACTION {step.iteration}:** {step.action}[{json.dumps(step.action_input)}]\n"
f"**OBSERVATION {step.iteration}:** {step.observation}"
)
return "\n\n".join(context_parts)
def _parse_react_response(
self, response: Any
) -> tuple[str, str, dict[str, Any]]:
"""Parse thought and action from LLM response.
Args:
response: LLM response
Returns:
Tuple of (thought, action, action_input)
"""
content = response.content or ""
# Handle tool calls from LLM
if response.has_tool_calls:
tool_call = response.tool_calls[0]
# Extract thought from content before tool call
thought = content.split("**ACTION")[0].replace("**THOUGHT", "").strip()
thought = thought.strip("*: \n")
return thought, tool_call.name, tool_call.arguments
# Parse manual format
thought = ""
action = ""
action_input: dict[str, Any] = {}
# Extract thought
if "**THOUGHT" in content or "THOUGHT" in content:
thought_match = content.split("THOUGHT")[1] if "THOUGHT" in content else ""
thought = thought_match.split("**ACTION")[0].strip("*: \n")
# Extract action
if "**ACTION" in content or "ACTION" in content:
action_part = content.split("ACTION")[1] if "ACTION" in content else ""
action_part = action_part.strip("*: \n")
# Parse action[input] format
if "[" in action_part and "]" in action_part:
action = action_part.split("[")[0].strip()
input_str = action_part[action_part.find("[") + 1:action_part.rfind("]")]
try:
action_input = json.loads(input_str) if input_str.startswith("{") else {"answer": input_str}
except json.JSONDecodeError:
action_input = {"answer": input_str}
else:
action = action_part.split("\n")[0].strip()
# Check for finish
if "finish" in action.lower():
action = "finish"
return thought, action, action_input