| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | from pydantic import Field |
| |
|
| | from app.agent.toolcall import ToolCallAgent |
| | from app.logger import logger |
| | from app.prompt.mcp import MULTIMEDIA_RESPONSE_PROMPT, NEXT_STEP_PROMPT, SYSTEM_PROMPT |
| | from app.schema import AgentState, Message |
| | from app.tool.base import ToolResult |
| | from app.tool.mcp import MCPClients |
| |
|
| |
|
| | class MCPAgent(ToolCallAgent): |
| | """Agent for interacting with MCP (Model Context Protocol) servers. |
| | |
| | This agent connects to an MCP server using either SSE or stdio transport |
| | and makes the server's tools available through the agent's tool interface. |
| | """ |
| |
|
| | name: str = "mcp_agent" |
| | description: str = "An agent that connects to an MCP server and uses its tools." |
| |
|
| | system_prompt: str = SYSTEM_PROMPT |
| | next_step_prompt: str = NEXT_STEP_PROMPT |
| |
|
| | |
| | mcp_clients: MCPClients = Field(default_factory=MCPClients) |
| | available_tools: MCPClients = None |
| |
|
| | max_steps: int = 20 |
| | connection_type: str = "stdio" |
| |
|
| | |
| | tool_schemas: Dict[str, Dict[str, Any]] = Field(default_factory=dict) |
| | _refresh_tools_interval: int = 5 |
| |
|
| | |
| | special_tool_names: List[str] = Field(default_factory=lambda: ["terminate"]) |
| |
|
| | async def initialize( |
| | self, |
| | connection_type: Optional[str] = None, |
| | server_url: Optional[str] = None, |
| | command: Optional[str] = None, |
| | args: Optional[List[str]] = None, |
| | ) -> None: |
| | """Initialize the MCP connection. |
| | |
| | Args: |
| | connection_type: Type of connection to use ("stdio" or "sse") |
| | server_url: URL of the MCP server (for SSE connection) |
| | command: Command to run (for stdio connection) |
| | args: Arguments for the command (for stdio connection) |
| | """ |
| | if connection_type: |
| | self.connection_type = connection_type |
| |
|
| | |
| | if self.connection_type == "sse": |
| | if not server_url: |
| | raise ValueError("Server URL is required for SSE connection") |
| | await self.mcp_clients.connect_sse(server_url=server_url) |
| | elif self.connection_type == "stdio": |
| | if not command: |
| | raise ValueError("Command is required for stdio connection") |
| | await self.mcp_clients.connect_stdio(command=command, args=args or []) |
| | else: |
| | raise ValueError(f"Unsupported connection type: {self.connection_type}") |
| |
|
| | |
| | self.available_tools = self.mcp_clients |
| |
|
| | |
| | await self._refresh_tools() |
| |
|
| | |
| | tool_names = list(self.mcp_clients.tool_map.keys()) |
| | tools_info = ", ".join(tool_names) |
| |
|
| | |
| | self.memory.add_message( |
| | Message.system_message( |
| | f"{self.system_prompt}\n\nAvailable MCP tools: {tools_info}" |
| | ) |
| | ) |
| |
|
| | async def _refresh_tools(self) -> Tuple[List[str], List[str]]: |
| | """Refresh the list of available tools from the MCP server. |
| | |
| | Returns: |
| | A tuple of (added_tools, removed_tools) |
| | """ |
| | if not self.mcp_clients.sessions: |
| | return [], [] |
| |
|
| | |
| | response = await self.mcp_clients.list_tools() |
| | current_tools = {tool.name: tool.inputSchema for tool in response.tools} |
| |
|
| | |
| | current_names = set(current_tools.keys()) |
| | previous_names = set(self.tool_schemas.keys()) |
| |
|
| | added_tools = list(current_names - previous_names) |
| | removed_tools = list(previous_names - current_names) |
| |
|
| | |
| | changed_tools = [] |
| | for name in current_names.intersection(previous_names): |
| | if current_tools[name] != self.tool_schemas.get(name): |
| | changed_tools.append(name) |
| |
|
| | |
| | self.tool_schemas = current_tools |
| |
|
| | |
| | if added_tools: |
| | logger.info(f"Added MCP tools: {added_tools}") |
| | self.memory.add_message( |
| | Message.system_message(f"New tools available: {', '.join(added_tools)}") |
| | ) |
| | if removed_tools: |
| | logger.info(f"Removed MCP tools: {removed_tools}") |
| | self.memory.add_message( |
| | Message.system_message( |
| | f"Tools no longer available: {', '.join(removed_tools)}" |
| | ) |
| | ) |
| | if changed_tools: |
| | logger.info(f"Changed MCP tools: {changed_tools}") |
| |
|
| | return added_tools, removed_tools |
| |
|
| | async def think(self) -> bool: |
| | """Process current state and decide next action.""" |
| | |
| | if not self.mcp_clients.sessions or not self.mcp_clients.tool_map: |
| | logger.info("MCP service is no longer available, ending interaction") |
| | self.state = AgentState.FINISHED |
| | return False |
| |
|
| | |
| | if self.current_step % self._refresh_tools_interval == 0: |
| | await self._refresh_tools() |
| | |
| | if not self.mcp_clients.tool_map: |
| | logger.info("MCP service has shut down, ending interaction") |
| | self.state = AgentState.FINISHED |
| | return False |
| |
|
| | |
| | return await super().think() |
| |
|
| | async def _handle_special_tool(self, name: str, result: Any, **kwargs) -> None: |
| | """Handle special tool execution and state changes""" |
| | |
| | await super()._handle_special_tool(name, result, **kwargs) |
| |
|
| | |
| | if isinstance(result, ToolResult) and result.base64_image: |
| | self.memory.add_message( |
| | Message.system_message( |
| | MULTIMEDIA_RESPONSE_PROMPT.format(tool_name=name) |
| | ) |
| | ) |
| |
|
| | def _should_finish_execution(self, name: str, **kwargs) -> bool: |
| | """Determine if tool execution should finish the agent""" |
| | |
| | return name.lower() == "terminate" |
| |
|
| | async def cleanup(self) -> None: |
| | """Clean up MCP connection when done.""" |
| | if self.mcp_clients.sessions: |
| | await self.mcp_clients.disconnect() |
| | logger.info("MCP connection closed") |
| |
|
| | async def run(self, request: Optional[str] = None) -> str: |
| | """Run the agent with cleanup when done.""" |
| | try: |
| | result = await super().run(request) |
| | return result |
| | finally: |
| | |
| | await self.cleanup() |
| |
|