"""Agent class for executing multi-step reasoning with tools.""" from dataclasses import dataclass, field from typing import List, Optional, Type, Callable, Literal from pydantic import BaseModel from .tools import tool import inspect import json from .models import ( ExecutionContext, Event, Message, ToolCall, ToolResult, PendingToolCall, ToolConfirmation, BaseSessionManager, InMemorySessionManager ) from .tools import BaseTool from .llm import LlmClient, LlmRequest, LlmResponse @dataclass class AgentResult: """Result of an agent execution.""" output: str | BaseModel context: ExecutionContext status: Literal["complete", "pending", "error"] = "complete" pending_tool_calls: list[PendingToolCall] = field(default_factory=list) class Agent: """Agent that can reason and use tools to solve tasks.""" def __init__( self, model: LlmClient, tools: List[BaseTool] = None, instructions: str = "", max_steps: int = 5, name: str = "agent", output_type: Optional[Type[BaseModel]] = None, before_tool_callbacks: List[Callable] = None, after_tool_callbacks: List[Callable] = None, session_manager: BaseSessionManager | None = None ): self.model = model self.instructions = instructions self.max_steps = max_steps self.name = name self.output_type = output_type self.output_tool_name = None self.tools = self._setup_tools(tools or []) # Initialize callback lists self.before_tool_callbacks = before_tool_callbacks or [] self.after_tool_callbacks = after_tool_callbacks or [] # Session manager self.session_manager = session_manager or InMemorySessionManager() def _setup_tools(self, tools: List[BaseTool]) -> List[BaseTool]: if self.output_type is not None: @tool( name="final_answer", description="Return the final structured answer matching the required schema." ) def final_answer(output: self.output_type) -> self.output_type: return output tools = list(tools) # Create a copy to avoid modifying the original tools.append(final_answer) self.output_tool_name = "final_answer" return tools async def run( self, user_input: str, context: ExecutionContext = None, session_id: Optional[str] = None, tool_confirmations: Optional[List[ToolConfirmation]] = None ) -> AgentResult: """Execute the agent with optional session support. Args: user_input: User's input message context: Optional execution context (creates new if None) session_id: Optional session ID for persistent conversations tool_confirmations: Optional list of tool confirmations for pending calls """ # Load or create session if session_id is provided session = None if session_id and self.session_manager: session = await self.session_manager.get_or_create(session_id) # Load session data into context if context is new if context is None: context = ExecutionContext() # Restore events and state from session context.events = session.events.copy() context.state = session.state.copy() context.execution_id = session.session_id context.session_id = session_id if tool_confirmations: if context is None: context = ExecutionContext() context.state["tool_confirmations"] = [ c.model_dump() for c in tool_confirmations ] # Create or reuse context if context is None: context = ExecutionContext() # Add user input as the first event user_event = Event( execution_id=context.execution_id, author="user", content=[Message(role="user", content=user_input)] ) context.add_event(user_event) # Execute steps until completion or max steps reached while not context.final_result and context.current_step < self.max_steps: await self.step(context) # Check for pending confirmations after each step if context.state.get("pending_tool_calls"): pending_calls = [ PendingToolCall.model_validate(p) for p in context.state["pending_tool_calls"] ] # Save session state before returning if session: session.events = context.events session.state = context.state await self.session_manager.save(session) return AgentResult( status="pending", context=context, pending_tool_calls=pending_calls, ) # Check if the last event is a final response last_event = context.events[-1] if self._is_final_response(last_event): context.final_result = self._extract_final_result(last_event) # Save session after execution completes if session: session.events = context.events session.state = context.state await self.session_manager.save(session) return AgentResult(output=context.final_result, context=context) def _is_final_response(self, event: Event) -> bool: """Check if this event contains a final response.""" if self.output_tool_name: # For structured output: check if final_answer tool succeeded for item in event.content: if (isinstance(item, ToolResult) and item.name == self.output_tool_name and item.status == "success"): return True return False has_tool_calls = any(isinstance(c, ToolCall) for c in event.content) has_tool_results = any(isinstance(c, ToolResult) for c in event.content) return not has_tool_calls and not has_tool_results def _extract_final_result(self, event: Event) -> str: if self.output_tool_name: # Extract structured output from final_answer tool result for item in event.content: if (isinstance(item, ToolResult) and item.name == self.output_tool_name and item.status == "success" and item.content): return item.content[0] for item in event.content: if isinstance(item, Message) and item.role == "assistant": return item.content return None async def step(self, context: ExecutionContext): """Execute one step of the agent loop.""" # Process pending confirmations if both are present (before preparing request) if ("pending_tool_calls" in context.state and "tool_confirmations" in context.state): confirmation_results = await self._process_confirmations(context) # Add results as an event so they appear in contents if confirmation_results: confirmation_event = Event( execution_id=context.execution_id, author=self.name, content=confirmation_results, ) context.add_event(confirmation_event) # Clear processed state del context.state["pending_tool_calls"] del context.state["tool_confirmations"] llm_request = self._prepare_llm_request(context) # Get LLM's decision llm_response = await self.think(llm_request) # Handle LLM errors - surface them instead of silently failing if llm_response.error_message: error_content = [Message( role="assistant", content=f"Error from LLM: {llm_response.error_message}" )] error_event = Event( execution_id=context.execution_id, author=self.name, content=error_content, ) context.add_event(error_event) context.final_result = error_content[0].content return # Record LLM response as an event response_event = Event( execution_id=context.execution_id, author=self.name, content=llm_response.content, ) context.add_event(response_event) # Execute tools if the LLM requested any tool_calls = [c for c in llm_response.content if isinstance(c, ToolCall)] if tool_calls: tool_results = await self.act(context, tool_calls) tool_event = Event( execution_id=context.execution_id, author=self.name, content=tool_results, ) context.add_event(tool_event) context.increment_step() def _prepare_llm_request(self, context: ExecutionContext) -> LlmRequest: """Convert execution context to LLM request. Args: context: Execution context with conversation history enforce_output_type: If True, enforce structured output format. Only set to True when expecting final answer. """ # Flatten events into content items flat_contents = [] for event in context.events: flat_contents.extend(event.content) # Determine tool choice strategy if self.output_tool_name: tool_choice = "required" # Force tool usage for structured output elif self.tools: tool_choice = "auto" else: tool_choice = None return LlmRequest( instructions=[self.instructions] if self.instructions else [], contents=flat_contents, tools=self.tools, tool_choice = tool_choice ) async def think(self, llm_request: LlmRequest) -> LlmResponse: """Get LLM's response/decision.""" return await self.model.generate(llm_request) async def act( self, context: ExecutionContext, tool_calls: List[ToolCall] ) -> List[ToolResult]: tools_dict = {tool.name: tool for tool in self.tools} results = [] pending_calls = [] # ADD THIS for tool_call in tool_calls: if tool_call.name not in tools_dict: raise ValueError(f"Tool '{tool_call.name}' not found") tool = tools_dict[tool_call.name] tool_response = None status = "success" # Stage 1: Execute before_tool_callbacks for callback in self.before_tool_callbacks: result = callback(context, tool_call) if inspect.isawaitable(result): result = await result if result is not None: tool_response = result break # Check if confirmation is required if tool.requires_confirmation: pending = PendingToolCall( tool_call=tool_call, confirmation_message=tool.get_confirmation_message( tool_call.arguments ) ) pending_calls.append(pending) continue # Stage 2: Execute actual tool only if callback didn't provide a result if tool_response is None: try: tool_response = await tool(context, **tool_call.arguments) except Exception as e: tool_response = str(e) status = "error" tool_result = ToolResult( tool_call_id=tool_call.tool_call_id, name=tool_call.name, status=status, content=[tool_response], ) # Stage 3: Execute after_tool_callbacks for callback in self.after_tool_callbacks: result = callback(context, tool_result) if inspect.isawaitable(result): result = await result if result is not None: tool_result = result break results.append(tool_result) if pending_calls: context.state["pending_tool_calls"] = [p.model_dump() for p in pending_calls] return results async def _process_confirmations( self, context: ExecutionContext ) -> List[ToolResult]: tools_dict = {tool.name: tool for tool in self.tools} results = [] # Restore pending tool calls from state pending_map = { p["tool_call"]["tool_call_id"]: PendingToolCall.model_validate(p) for p in context.state["pending_tool_calls"] } # Build confirmation lookup by tool_call_id confirmation_map = { c["tool_call_id"]: ToolConfirmation.model_validate(c) for c in context.state["tool_confirmations"] } # Process ALL pending tool calls for tool_call_id, pending in pending_map.items(): tool = tools_dict.get(pending.tool_call.name) confirmation = confirmation_map.get(tool_call_id) if confirmation and confirmation.approved: # Merge original arguments with modifications arguments = { **pending.tool_call.arguments, **(confirmation.modified_arguments or {}) } # Execute the approved tool try: output = await tool(context, **arguments) results.append(ToolResult( tool_call_id=tool_call_id, name=pending.tool_call.name, status="success", content=[output], )) except Exception as e: results.append(ToolResult( tool_call_id=tool_call_id, name=pending.tool_call.name, status="error", content=[str(e)], )) else: # Rejected: either explicitly or not in confirmation list if confirmation: reason = confirmation.reason or "Tool execution was rejected by user." else: reason = "Tool execution was not approved." results.append(ToolResult( tool_call_id=tool_call_id, name=pending.tool_call.name, status="error", content=[reason], )) return results # List of dangerous tools requiring approval DANGEROUS_TOOLS = ["delete_file", "send_email", "execute_sql"] def approval_callback(context: ExecutionContext, tool_call: ToolCall): """Requests user approval before executing dangerous tools.""" # Execute immediately if not a dangerous tool if tool_call.name not in DANGEROUS_TOOLS: return None print(f"\n Dangerous tool execution requested") print(f"Tool: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") response = input("Do you want to execute? (y/n): ").lower().strip() if response == 'y': print(" Approved. Executing...\n") return None # Proceed with actual tool execution else: print(" Denied. Skipping execution.\n") return f"User denied execution of {tool_call.name}"