Spaces:
Running
Running
| """Agent controller - orchestrates the agent execution flow.""" | |
| from __future__ import annotations | |
| import json | |
| from typing import Any | |
| from src.agent.models import ( | |
| AgentState, | |
| ExecutionPlan, | |
| Intent, | |
| IntentType, | |
| PlanStep, | |
| ThoughtStep, | |
| WorkflowStrategy, | |
| ) | |
| from src.llm import LLMClient, Message, MessageRole | |
| from src.llm.prompts import format_prompt, get_system_prompt, PromptNames | |
| from src.tools.base import ToolRegistry | |
| from src.utils.config import settings | |
| from src.utils.exceptions import MaxIterationsError, PlanningError | |
| from src.utils.logging import get_logger, log_agent_step | |
| logger = get_logger(__name__) | |
| class AgentController: | |
| """Controller that orchestrates agent activities.""" | |
| def __init__(self, llm_client: LLMClient, tool_registry: ToolRegistry): | |
| """Initialize the agent controller. | |
| Args: | |
| llm_client: LLM client for reasoning | |
| tool_registry: Registry of available tools | |
| """ | |
| self.llm = llm_client | |
| self.tools = tool_registry | |
| self.system_prompt = get_system_prompt() | |
| async def parse_intent(self, query: str) -> Intent: | |
| """Parse user query to determine intent. | |
| Args: | |
| query: User's query string | |
| Returns: | |
| Parsed Intent object | |
| """ | |
| prompt = format_prompt(PromptNames.INTENT_PARSER, user_query=query) | |
| messages = [ | |
| Message(role=MessageRole.SYSTEM, content=self.system_prompt), | |
| Message(role=MessageRole.USER, content=prompt), | |
| ] | |
| response = await self.llm.chat(messages, temperature=0.3) | |
| try: | |
| # Parse JSON response | |
| content = response.content or "{}" | |
| # Extract JSON from markdown code block if present | |
| if "```json" in content: | |
| content = content.split("```json")[1].split("```")[0] | |
| elif "```" in content: | |
| content = content.split("```")[1].split("```")[0] | |
| data = json.loads(content) | |
| return Intent( | |
| intent_type=IntentType(data.get("intent", "factual_query").lower()), | |
| confidence=data.get("confidence", 0.5), | |
| secondary_intents=[ | |
| IntentType(i.lower()) for i in data.get("secondary_intents", []) | |
| ], | |
| entities=data.get("entities", {}), | |
| requires_web_search=data.get("requires_web_search", True), | |
| complexity=data.get("complexity", "simple"), | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to parse intent, using defaults: {e}") | |
| return Intent( | |
| intent_type=IntentType.FACTUAL_QUERY, | |
| confidence=0.5, | |
| requires_web_search=True, | |
| ) | |
| async def plan_workflow(self, query: str, intent: Intent) -> ExecutionPlan: | |
| """Create an execution plan based on intent. | |
| Args: | |
| query: User's query | |
| intent: Parsed intent | |
| Returns: | |
| ExecutionPlan for the query | |
| """ | |
| prompt = format_prompt( | |
| PromptNames.WORKFLOW_PLANNER, | |
| user_query=query, | |
| intent_analysis=json.dumps({ | |
| "intent": intent.intent_type.value, | |
| "confidence": intent.confidence, | |
| "requires_web_search": intent.requires_web_search, | |
| "complexity": intent.complexity, | |
| "entities": intent.entities, | |
| }), | |
| context="", | |
| ) | |
| messages = [ | |
| Message(role=MessageRole.SYSTEM, content=self.system_prompt), | |
| Message(role=MessageRole.USER, content=prompt), | |
| ] | |
| response = await self.llm.chat(messages, temperature=0.3) | |
| try: | |
| content = response.content or "{}" | |
| if "```json" in content: | |
| content = content.split("```json")[1].split("```")[0] | |
| elif "```" in content: | |
| content = content.split("```")[1].split("```")[0] | |
| data = json.loads(content) | |
| steps = [] | |
| for step_data in data.get("plan", []): | |
| steps.append( | |
| PlanStep( | |
| step_number=step_data.get("step", len(steps) + 1), | |
| action=step_data.get("action", ""), | |
| tool=step_data.get("tool"), | |
| parameters=step_data.get("parameters", {}), | |
| purpose=step_data.get("purpose", ""), | |
| depends_on=step_data.get("depends_on", []), | |
| ) | |
| ) | |
| return ExecutionPlan( | |
| strategy=WorkflowStrategy(data.get("strategy", "single_search").lower()), | |
| reasoning=data.get("reasoning", ""), | |
| steps=steps, | |
| max_iterations=data.get("max_iterations", settings.max_iterations), | |
| fallback_strategy=( | |
| WorkflowStrategy(data["fallback_strategy"].lower()) | |
| if data.get("fallback_strategy") | |
| else None | |
| ), | |
| success_criteria=data.get("success_criteria", ""), | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to parse plan, using default: {e}") | |
| # Return a simple default plan | |
| return self._create_default_plan(query, intent) | |
| def _create_default_plan(self, query: str, intent: Intent) -> ExecutionPlan: | |
| """Create a default execution plan. | |
| Args: | |
| query: User's query | |
| intent: Parsed intent | |
| Returns: | |
| Default ExecutionPlan | |
| """ | |
| if intent.requires_web_search: | |
| return ExecutionPlan( | |
| strategy=WorkflowStrategy.SINGLE_SEARCH, | |
| reasoning="Default plan with web search", | |
| steps=[ | |
| PlanStep( | |
| step_number=1, | |
| action="search", | |
| tool="web_search", | |
| parameters={"query": query, "num_results": 5}, | |
| purpose="Search for relevant information", | |
| ), | |
| PlanStep( | |
| step_number=2, | |
| action="synthesize", | |
| tool=None, | |
| parameters={}, | |
| purpose="Synthesize search results into answer", | |
| depends_on=[1], | |
| ), | |
| ], | |
| ) | |
| else: | |
| return ExecutionPlan( | |
| strategy=WorkflowStrategy.DIRECT_ANSWER, | |
| reasoning="Query can be answered directly", | |
| steps=[ | |
| PlanStep( | |
| step_number=1, | |
| action="respond", | |
| tool=None, | |
| parameters={}, | |
| purpose="Generate direct response", | |
| ), | |
| ], | |
| ) | |
| async def execute_step( | |
| self, state: AgentState, step: PlanStep | |
| ) -> dict[str, Any]: | |
| """Execute a single step in the plan. | |
| Args: | |
| state: Current agent state | |
| step: Step to execute | |
| Returns: | |
| Step result | |
| """ | |
| log_agent_step(logger, "executing", {"step": step.step_number, "action": step.action}) | |
| if step.tool: | |
| # Execute tool | |
| result = await self.tools.execute(step.tool, **step.parameters) | |
| return { | |
| "step": step.step_number, | |
| "tool": step.tool, | |
| "success": result.success, | |
| "data": result.data, | |
| "error": result.error, | |
| } | |
| else: | |
| # Non-tool action (synthesize, respond, etc.) | |
| return { | |
| "step": step.step_number, | |
| "action": step.action, | |
| "success": True, | |
| "data": None, | |
| } | |
| async def run_react_loop( | |
| self, state: AgentState | |
| ) -> tuple[str, list[ThoughtStep]]: | |
| """Run the ReACT reasoning loop. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Tuple of (final_answer, thought_history) | |
| """ | |
| thought_history: list[ThoughtStep] = [] | |
| iteration = 0 | |
| # Build tool descriptions for the prompt | |
| tool_schemas = self.tools.get_schemas() | |
| while iteration < settings.max_iterations: | |
| iteration += 1 | |
| log_agent_step(logger, "react_iteration", {"iteration": iteration}) | |
| # Build context from previous steps | |
| context = self._build_react_context(state, thought_history) | |
| prompt = format_prompt( | |
| PromptNames.REACT_REASONING, | |
| user_query=state.query, | |
| iteration_number=iteration, | |
| max_iterations=settings.max_iterations, | |
| previous_steps=context, | |
| working_memory=json.dumps(state.working_memory), | |
| ) | |
| messages = [ | |
| Message(role=MessageRole.SYSTEM, content=self.system_prompt), | |
| Message(role=MessageRole.USER, content=prompt), | |
| ] | |
| # Get LLM response with tools | |
| response = await self.llm.chat(messages, tools=tool_schemas, temperature=0.5) | |
| # Parse thought and action from response | |
| thought, action, action_input = self._parse_react_response(response) | |
| log_agent_step( | |
| logger, | |
| "thought", | |
| {"thought": thought, "action": action}, | |
| iteration=iteration, | |
| ) | |
| # Check for finish action | |
| if action.lower() == "finish": | |
| thought_step = ThoughtStep( | |
| iteration=iteration, | |
| thought=thought, | |
| action="finish", | |
| action_input=action_input, | |
| observation=action_input.get("answer", ""), | |
| ) | |
| thought_history.append(thought_step) | |
| return action_input.get("answer", response.content or ""), thought_history | |
| # Execute tool action | |
| if response.has_tool_calls: | |
| tool_call = response.tool_calls[0] | |
| result = await self.tools.execute(tool_call.name, **tool_call.arguments) | |
| observation = json.dumps(result.data) if result.success else f"Error: {result.error}" | |
| else: | |
| # Manual tool call parsing from response | |
| if action and action != "finish": | |
| result = await self.tools.execute(action, **action_input) | |
| observation = json.dumps(result.data) if result.success else f"Error: {result.error}" | |
| else: | |
| observation = "No action taken" | |
| # Record step | |
| thought_step = ThoughtStep( | |
| iteration=iteration, | |
| thought=thought, | |
| action=action, | |
| action_input=action_input, | |
| observation=observation, | |
| ) | |
| thought_history.append(thought_step) | |
| # Update working memory | |
| state.working_memory[f"step_{iteration}"] = { | |
| "action": action, | |
| "result": observation, | |
| } | |
| raise MaxIterationsError(f"Reached maximum iterations ({settings.max_iterations})") | |
| def _build_react_context( | |
| self, state: AgentState, thought_history: list[ThoughtStep] | |
| ) -> str: | |
| """Build context string from thought history. | |
| Args: | |
| state: Current state | |
| thought_history: List of thought steps | |
| Returns: | |
| Formatted context string | |
| """ | |
| if not thought_history: | |
| return "No previous steps." | |
| context_parts = [] | |
| for step in thought_history: | |
| context_parts.append( | |
| f"**THOUGHT {step.iteration}:** {step.thought}\n" | |
| f"**ACTION {step.iteration}:** {step.action}[{json.dumps(step.action_input)}]\n" | |
| f"**OBSERVATION {step.iteration}:** {step.observation}" | |
| ) | |
| return "\n\n".join(context_parts) | |
| def _parse_react_response( | |
| self, response: Any | |
| ) -> tuple[str, str, dict[str, Any]]: | |
| """Parse thought and action from LLM response. | |
| Args: | |
| response: LLM response | |
| Returns: | |
| Tuple of (thought, action, action_input) | |
| """ | |
| content = response.content or "" | |
| # Handle tool calls from LLM | |
| if response.has_tool_calls: | |
| tool_call = response.tool_calls[0] | |
| # Extract thought from content before tool call | |
| thought = content.split("**ACTION")[0].replace("**THOUGHT", "").strip() | |
| thought = thought.strip("*: \n") | |
| return thought, tool_call.name, tool_call.arguments | |
| # Parse manual format | |
| thought = "" | |
| action = "" | |
| action_input: dict[str, Any] = {} | |
| # Extract thought | |
| if "**THOUGHT" in content or "THOUGHT" in content: | |
| thought_match = content.split("THOUGHT")[1] if "THOUGHT" in content else "" | |
| thought = thought_match.split("**ACTION")[0].strip("*: \n") | |
| # Extract action | |
| if "**ACTION" in content or "ACTION" in content: | |
| action_part = content.split("ACTION")[1] if "ACTION" in content else "" | |
| action_part = action_part.strip("*: \n") | |
| # Parse action[input] format | |
| if "[" in action_part and "]" in action_part: | |
| action = action_part.split("[")[0].strip() | |
| input_str = action_part[action_part.find("[") + 1:action_part.rfind("]")] | |
| try: | |
| action_input = json.loads(input_str) if input_str.startswith("{") else {"answer": input_str} | |
| except json.JSONDecodeError: | |
| action_input = {"answer": input_str} | |
| else: | |
| action = action_part.split("\n")[0].strip() | |
| # Check for finish | |
| if "finish" in action.lower(): | |
| action = "finish" | |
| return thought, action, action_input | |