|
|
import asyncio |
|
|
import json |
|
|
from typing import Any, List, Optional, Union |
|
|
|
|
|
from pydantic import Field |
|
|
|
|
|
from app.agent.react import ReActAgent |
|
|
from app.exceptions import TokenLimitExceeded |
|
|
from app.logger import logger |
|
|
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT |
|
|
from app.schema import TOOL_CHOICE_TYPE, AgentState, Message, ToolCall, ToolChoice |
|
|
from app.tool import CreateChatCompletion, Terminate, ToolCollection |
|
|
|
|
|
|
|
|
TOOL_CALL_REQUIRED = "Tool calls required but none provided" |
|
|
|
|
|
|
|
|
class ToolCallAgent(ReActAgent): |
|
|
"""Base agent class for handling tool/function calls with enhanced abstraction""" |
|
|
|
|
|
name: str = "toolcall" |
|
|
description: str = "an agent that can execute tool calls." |
|
|
|
|
|
system_prompt: str = SYSTEM_PROMPT |
|
|
next_step_prompt: str = NEXT_STEP_PROMPT |
|
|
|
|
|
available_tools: ToolCollection = ToolCollection( |
|
|
CreateChatCompletion(), Terminate() |
|
|
) |
|
|
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO |
|
|
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) |
|
|
|
|
|
tool_calls: List[ToolCall] = Field(default_factory=list) |
|
|
_current_base64_image: Optional[str] = None |
|
|
|
|
|
max_steps: int = 30 |
|
|
max_observe: Optional[Union[int, bool]] = None |
|
|
|
|
|
async def think(self) -> bool: |
|
|
"""Process current state and decide next actions using tools""" |
|
|
if self.next_step_prompt: |
|
|
user_msg = Message.user_message(self.next_step_prompt) |
|
|
self.messages += [user_msg] |
|
|
|
|
|
try: |
|
|
|
|
|
response = await self.llm.ask_tool( |
|
|
messages=self.messages, |
|
|
system_msgs=( |
|
|
[Message.system_message(self.system_prompt)] |
|
|
if self.system_prompt |
|
|
else None |
|
|
), |
|
|
tools=self.available_tools.to_params(), |
|
|
tool_choice=self.tool_choices, |
|
|
) |
|
|
except ValueError: |
|
|
raise |
|
|
except Exception as e: |
|
|
|
|
|
if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): |
|
|
token_limit_error = e.__cause__ |
|
|
logger.error( |
|
|
f"π¨ Token limit error (from RetryError): {token_limit_error}" |
|
|
) |
|
|
self.memory.add_message( |
|
|
Message.assistant_message( |
|
|
f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" |
|
|
) |
|
|
) |
|
|
self.state = AgentState.FINISHED |
|
|
return False |
|
|
raise |
|
|
|
|
|
self.tool_calls = tool_calls = ( |
|
|
response.tool_calls if response and response.tool_calls else [] |
|
|
) |
|
|
content = response.content if response and response.content else "" |
|
|
|
|
|
|
|
|
logger.info(f"β¨ {self.name}'s thoughts: {content}") |
|
|
logger.info( |
|
|
f"π οΈ {self.name} selected {len(tool_calls) if tool_calls else 0} tools to use" |
|
|
) |
|
|
if tool_calls: |
|
|
logger.info( |
|
|
f"π§° Tools being prepared: {[call.function.name for call in tool_calls]}" |
|
|
) |
|
|
logger.info(f"π§ Tool arguments: {tool_calls[0].function.arguments}") |
|
|
|
|
|
try: |
|
|
if response is None: |
|
|
raise RuntimeError("No response received from the LLM") |
|
|
|
|
|
|
|
|
if self.tool_choices == ToolChoice.NONE: |
|
|
if tool_calls: |
|
|
logger.warning( |
|
|
f"π€ Hmm, {self.name} tried to use tools when they weren't available!" |
|
|
) |
|
|
if content: |
|
|
self.memory.add_message(Message.assistant_message(content)) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
assistant_msg = ( |
|
|
Message.from_tool_calls(content=content, tool_calls=self.tool_calls) |
|
|
if self.tool_calls |
|
|
else Message.assistant_message(content) |
|
|
) |
|
|
self.memory.add_message(assistant_msg) |
|
|
|
|
|
if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls: |
|
|
return True |
|
|
|
|
|
|
|
|
if self.tool_choices == ToolChoice.AUTO and not self.tool_calls: |
|
|
return bool(content) |
|
|
|
|
|
return bool(self.tool_calls) |
|
|
except Exception as e: |
|
|
logger.error(f"π¨ Oops! The {self.name}'s thinking process hit a snag: {e}") |
|
|
self.memory.add_message( |
|
|
Message.assistant_message( |
|
|
f"Error encountered while processing: {str(e)}" |
|
|
) |
|
|
) |
|
|
return False |
|
|
|
|
|
async def act(self) -> str: |
|
|
"""Execute tool calls and handle their results""" |
|
|
if not self.tool_calls: |
|
|
if self.tool_choices == ToolChoice.REQUIRED: |
|
|
raise ValueError(TOOL_CALL_REQUIRED) |
|
|
|
|
|
|
|
|
return self.messages[-1].content or "No content or commands to execute" |
|
|
|
|
|
results = [] |
|
|
for command in self.tool_calls: |
|
|
|
|
|
self._current_base64_image = None |
|
|
|
|
|
result = await self.execute_tool(command) |
|
|
|
|
|
if self.max_observe: |
|
|
result = result[: self.max_observe] |
|
|
|
|
|
logger.info( |
|
|
f"π― Tool '{command.function.name}' completed its mission! Result: {result}" |
|
|
) |
|
|
|
|
|
|
|
|
tool_msg = Message.tool_message( |
|
|
content=result, |
|
|
tool_call_id=command.id, |
|
|
name=command.function.name, |
|
|
base64_image=self._current_base64_image, |
|
|
) |
|
|
self.memory.add_message(tool_msg) |
|
|
results.append(result) |
|
|
|
|
|
return "\n\n".join(results) |
|
|
|
|
|
async def execute_tool(self, command: ToolCall) -> str: |
|
|
"""Execute a single tool call with robust error handling""" |
|
|
if not command or not command.function or not command.function.name: |
|
|
return "Error: Invalid command format" |
|
|
|
|
|
name = command.function.name |
|
|
if name not in self.available_tools.tool_map: |
|
|
return f"Error: Unknown tool '{name}'" |
|
|
|
|
|
try: |
|
|
|
|
|
args = json.loads(command.function.arguments or "{}") |
|
|
|
|
|
|
|
|
logger.info(f"π§ Activating tool: '{name}'...") |
|
|
result = await self.available_tools.execute(name=name, tool_input=args) |
|
|
|
|
|
|
|
|
await self._handle_special_tool(name=name, result=result) |
|
|
|
|
|
|
|
|
if hasattr(result, "base64_image") and result.base64_image: |
|
|
|
|
|
self._current_base64_image = result.base64_image |
|
|
|
|
|
|
|
|
observation = ( |
|
|
f"Observed output of cmd `{name}` executed:\n{str(result)}" |
|
|
if result |
|
|
else f"Cmd `{name}` completed with no output" |
|
|
) |
|
|
|
|
|
return observation |
|
|
except json.JSONDecodeError: |
|
|
error_msg = f"Error parsing arguments for {name}: Invalid JSON format" |
|
|
logger.error( |
|
|
f"π Oops! The arguments for '{name}' don't make sense - invalid JSON, arguments:{command.function.arguments}" |
|
|
) |
|
|
return f"Error: {error_msg}" |
|
|
except Exception as e: |
|
|
error_msg = f"β οΈ Tool '{name}' encountered a problem: {str(e)}" |
|
|
logger.exception(error_msg) |
|
|
return f"Error: {error_msg}" |
|
|
|
|
|
async def _handle_special_tool(self, name: str, result: Any, **kwargs): |
|
|
"""Handle special tool execution and state changes""" |
|
|
if not self._is_special_tool(name): |
|
|
return |
|
|
|
|
|
if self._should_finish_execution(name=name, result=result, **kwargs): |
|
|
|
|
|
logger.info(f"π Special tool '{name}' has completed the task!") |
|
|
self.state = AgentState.FINISHED |
|
|
|
|
|
@staticmethod |
|
|
def _should_finish_execution(**kwargs) -> bool: |
|
|
"""Determine if tool execution should finish the agent""" |
|
|
return True |
|
|
|
|
|
def _is_special_tool(self, name: str) -> bool: |
|
|
"""Check if tool name is in special tools list""" |
|
|
return name.lower() in [n.lower() for n in self.special_tool_names] |
|
|
|
|
|
async def cleanup(self): |
|
|
"""Clean up resources used by the agent's tools.""" |
|
|
logger.info(f"π§Ή Cleaning up resources for agent '{self.name}'...") |
|
|
for tool_name, tool_instance in self.available_tools.tool_map.items(): |
|
|
if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction( |
|
|
tool_instance.cleanup |
|
|
): |
|
|
try: |
|
|
logger.debug(f"π§Ό Cleaning up tool: {tool_name}") |
|
|
await tool_instance.cleanup() |
|
|
except Exception as e: |
|
|
logger.error( |
|
|
f"π¨ Error cleaning up tool '{tool_name}': {e}", exc_info=True |
|
|
) |
|
|
logger.info(f"β¨ Cleanup complete for agent '{self.name}'.") |
|
|
|
|
|
async def run(self, request: Optional[str] = None) -> str: |
|
|
"""Run the agent with cleanup when done.""" |
|
|
try: |
|
|
return await super().run(request) |
|
|
finally: |
|
|
await self.cleanup() |
|
|
|