Spaces:
Paused
Paused
| from abc import ABC, abstractmethod | |
| import re | |
| from typing import ( | |
| List, | |
| Dict, | |
| Optional, | |
| Any, | |
| TextIO, | |
| Union, | |
| Literal, | |
| Annotated, | |
| ClassVar, | |
| cast, | |
| Callable, | |
| Awaitable, | |
| TypeVar, | |
| ) | |
| import threading | |
| import asyncio | |
| from contextlib import AsyncExitStack | |
| from shutil import which | |
| from datetime import timedelta | |
| import json | |
| from python.helpers import errors | |
| from python.helpers import settings | |
| import httpx | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp.client.stdio import stdio_client | |
| from mcp.client.sse import sse_client | |
| from mcp.client.streamable_http import streamablehttp_client | |
| from mcp.shared.message import SessionMessage | |
| from mcp.types import CallToolResult, ListToolsResult | |
| from anyio.streams.memory import ( | |
| MemoryObjectReceiveStream, | |
| MemoryObjectSendStream, | |
| ) | |
| from pydantic import BaseModel, Field, Discriminator, Tag, PrivateAttr | |
| from python.helpers import dirty_json | |
| from python.helpers.print_style import PrintStyle | |
| from python.helpers.tool import Tool, Response | |
| def normalize_name(name: str) -> str: | |
| # Lowercase and strip whitespace | |
| name = name.strip().lower() | |
| # Replace all non-alphanumeric (unicode) chars with underscore | |
| # \W matches non-alphanumeric, but also matches underscore, so use [^\w] with re.UNICODE | |
| # To also replace underscores from non-latin chars, use [^a-zA-Z0-9] with re.UNICODE | |
| name = re.sub(r"[^\w]", "_", name, flags=re.UNICODE) | |
| return name | |
| def _determine_server_type(config_dict: dict) -> str: | |
| """Determine the server type based on configuration, with backward compatibility.""" | |
| # First check if type is explicitly specified | |
| if "type" in config_dict: | |
| server_type = config_dict["type"].lower() | |
| if server_type in ["sse", "http-stream", "streaming-http", "streamable-http", "http-streaming"]: | |
| return "MCPServerRemote" | |
| elif server_type == "stdio": | |
| return "MCPServerLocal" | |
| # For future types, we could add more cases here | |
| else: | |
| # For unknown types, fall back to URL-based detection | |
| # This allows for graceful handling of new types | |
| pass | |
| # Backward compatibility: if no type specified, use URL-based detection | |
| if "url" in config_dict or "serverUrl" in config_dict: | |
| return "MCPServerRemote" | |
| else: | |
| return "MCPServerLocal" | |
| def _is_streaming_http_type(server_type: str) -> bool: | |
| """Check if the server type is a streaming HTTP variant.""" | |
| return server_type.lower() in ["http-stream", "streaming-http", "streamable-http", "http-streaming"] | |
| def initialize_mcp(mcp_servers_config: str): | |
| if not MCPConfig.get_instance().is_initialized(): | |
| try: | |
| MCPConfig.update(mcp_servers_config) | |
| except Exception as e: | |
| from agent import AgentContext | |
| AgentContext.log_to_all( | |
| type="warning", | |
| content=f"Failed to update MCP settings: {e}", | |
| temp=False, | |
| ) | |
| PrintStyle( | |
| background_color="black", font_color="red", padding=True | |
| ).print(f"Failed to update MCP settings: {e}") | |
| class MCPTool(Tool): | |
| """MCP Tool wrapper""" | |
| async def execute(self, **kwargs: Any): | |
| error = "" | |
| try: | |
| response: CallToolResult = await MCPConfig.get_instance().call_tool( | |
| self.name, kwargs | |
| ) | |
| message = "\n\n".join( | |
| [item.text for item in response.content if item.type == "text"] | |
| ) | |
| if response.isError: | |
| error = message | |
| except Exception as e: | |
| error = f"MCP Tool Exception: {str(e)}" | |
| message = f"ERROR: {str(e)}" | |
| if error: | |
| PrintStyle( | |
| background_color="#CC34C3", font_color="white", bold=True, padding=True | |
| ).print(f"MCPTool::Failed to call mcp tool {self.name}:") | |
| PrintStyle( | |
| background_color="#AA4455", font_color="white", padding=False | |
| ).print(error) | |
| self.agent.context.log.log( | |
| type="warning", | |
| content=f"{self.name}: {error}", | |
| ) | |
| return Response(message=message, break_loop=False) | |
| async def before_execution(self, **kwargs: Any): | |
| ( | |
| PrintStyle( | |
| font_color="#1B4F72", padding=True, background_color="white", bold=True | |
| ).print(f"{self.agent.agent_name}: Using tool '{self.name}'") | |
| ) | |
| self.log = self.get_log_object() | |
| for key, value in self.args.items(): | |
| PrintStyle(font_color="#85C1E9", bold=True).stream( | |
| self.nice_key(key) + ": " | |
| ) | |
| PrintStyle( | |
| font_color="#85C1E9", padding=isinstance(value, str) and "\n" in value | |
| ).stream(value) | |
| PrintStyle().print() | |
| async def after_execution(self, response: Response, **kwargs: Any): | |
| raw_tool_response = response.message.strip() if response.message else "" | |
| if not raw_tool_response: | |
| PrintStyle(font_color="red").print( | |
| f"Warning: Tool '{self.name}' returned an empty message." | |
| ) | |
| # Even if empty, we might still want to provide context for the agent | |
| raw_tool_response = "[Tool returned no textual content]" | |
| # Prepare user message context | |
| # user_message_text = ( | |
| # "No specific user message context available for this exact step." | |
| # ) | |
| # if ( | |
| # self.agent | |
| # and self.agent.last_user_message | |
| # and self.agent.last_user_message.content | |
| # ): | |
| # content = self.agent.last_user_message.content | |
| # if isinstance(content, dict): | |
| # # Attempt to get a 'message' field, otherwise stringify the dict | |
| # user_message_text = str(content.get( | |
| # "message", json.dumps(content, indent=2) | |
| # )) | |
| # elif isinstance(content, str): | |
| # user_message_text = content | |
| # else: | |
| # # Fallback for any other types (e.g. list, if that were possible for content) | |
| # user_message_text = str(content) | |
| # # Ensure user_message_text is a string before length check and slicing | |
| # user_message_text = str(user_message_text) | |
| # # Truncate user message context if it's too long to avoid overwhelming the prompt | |
| # max_user_context_len = 500 # characters | |
| # if len(user_message_text) > max_user_context_len: | |
| # user_message_text = ( | |
| # user_message_text[:max_user_context_len] + "... (truncated)" | |
| # ) | |
| final_text_for_agent = raw_tool_response | |
| self.agent.hist_add_tool_result(self.name, final_text_for_agent) | |
| ( | |
| PrintStyle( | |
| font_color="#1B4F72", background_color="white", padding=True, bold=True | |
| ).print( | |
| f"{self.agent.agent_name}: Response from tool '{self.name}' (plus context added)" | |
| ) | |
| ) | |
| # Print only the raw response to console for brevity, agent gets the full context. | |
| PrintStyle(font_color="#85C1E9").print( | |
| raw_tool_response | |
| if raw_tool_response | |
| else "[No direct textual output from tool]" | |
| ) | |
| if self.log: | |
| self.log.update( | |
| content=final_text_for_agent | |
| ) # Log includes the full context | |
| class MCPServerRemote(BaseModel): | |
| name: str = Field(default_factory=str) | |
| description: Optional[str] = Field(default="Remote SSE Server") | |
| type: str = Field(default="sse", description="Server connection type") | |
| url: str = Field(default_factory=str) | |
| headers: dict[str, Any] | None = Field(default_factory=dict[str, Any]) | |
| init_timeout: int = Field(default=0) | |
| tool_timeout: int = Field(default=0) | |
| verify: bool = Field(default=True, description="Verify SSL certificates") | |
| disabled: bool = Field(default=False) | |
| __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) | |
| __client: Optional["MCPClientRemote"] = PrivateAttr(default=None) | |
| def __init__(self, config: dict[str, Any]): | |
| super().__init__() | |
| self.__client = MCPClientRemote(self) | |
| self.update(config) | |
| def get_error(self) -> str: | |
| with self.__lock: | |
| return self.__client.error # type: ignore | |
| def get_log(self) -> str: | |
| with self.__lock: | |
| return self.__client.get_log() # type: ignore | |
| def get_tools(self) -> List[dict[str, Any]]: | |
| """Get all tools from the server""" | |
| with self.__lock: | |
| return self.__client.tools # type: ignore | |
| def has_tool(self, tool_name: str) -> bool: | |
| """Check if a tool is available""" | |
| with self.__lock: | |
| return self.__client.has_tool(tool_name) # type: ignore | |
| async def call_tool( | |
| self, tool_name: str, input_data: Dict[str, Any] | |
| ) -> CallToolResult: | |
| """Call a tool with the given input data""" | |
| with self.__lock: | |
| # We already run in an event loop, dont believe Pylance | |
| return await self.__client.call_tool(tool_name, input_data) # type: ignore | |
| def update(self, config: dict[str, Any]) -> "MCPServerRemote": | |
| with self.__lock: | |
| for key, value in config.items(): | |
| if key in [ | |
| "name", | |
| "description", | |
| "type", | |
| "url", | |
| "serverUrl", | |
| "headers", | |
| "init_timeout", | |
| "tool_timeout", | |
| "disabled", | |
| "verify", | |
| ]: | |
| if key == "name": | |
| value = normalize_name(value) | |
| if key == "serverUrl": | |
| key = "url" # remap serverUrl to url | |
| setattr(self, key, value) | |
| # We already run in an event loop, dont believe Pylance | |
| return asyncio.run(self.__on_update()) | |
| async def __on_update(self) -> "MCPServerRemote": | |
| await self.__client.update_tools() # type: ignore | |
| return self | |
| class MCPServerLocal(BaseModel): | |
| name: str = Field(default_factory=str) | |
| description: Optional[str] = Field(default="Local StdIO Server") | |
| type: str = Field(default="stdio", description="Server connection type") | |
| command: str = Field(default_factory=str) | |
| args: list[str] = Field(default_factory=list) | |
| env: dict[str, str] | None = Field(default_factory=dict[str, str]) | |
| encoding: str = Field(default="utf-8") | |
| encoding_error_handler: Literal["strict", "ignore", "replace"] = Field( | |
| default="strict" | |
| ) | |
| init_timeout: int = Field(default=0) | |
| tool_timeout: int = Field(default=0) | |
| verify: bool = Field(default=True, description="Verify SSL certificates") | |
| disabled: bool = Field(default=False) | |
| __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) | |
| __client: Optional["MCPClientLocal"] = PrivateAttr(default=None) | |
| def __init__(self, config: dict[str, Any]): | |
| super().__init__() | |
| self.__client = MCPClientLocal(self) | |
| self.update(config) | |
| def get_error(self) -> str: | |
| with self.__lock: | |
| return self.__client.error # type: ignore | |
| def get_log(self) -> str: | |
| with self.__lock: | |
| return self.__client.get_log() # type: ignore | |
| def get_tools(self) -> List[dict[str, Any]]: | |
| """Get all tools from the server""" | |
| with self.__lock: | |
| return self.__client.tools # type: ignore | |
| def has_tool(self, tool_name: str) -> bool: | |
| """Check if a tool is available""" | |
| with self.__lock: | |
| return self.__client.has_tool(tool_name) # type: ignore | |
| async def call_tool( | |
| self, tool_name: str, input_data: Dict[str, Any] | |
| ) -> CallToolResult: | |
| """Call a tool with the given input data""" | |
| with self.__lock: | |
| # We already run in an event loop, dont believe Pylance | |
| return await self.__client.call_tool(tool_name, input_data) # type: ignore | |
| def update(self, config: dict[str, Any]) -> "MCPServerLocal": | |
| with self.__lock: | |
| for key, value in config.items(): | |
| if key in [ | |
| "name", | |
| "description", | |
| "type", | |
| "command", | |
| "args", | |
| "env", | |
| "encoding", | |
| "encoding_error_handler", | |
| "init_timeout", | |
| "tool_timeout", | |
| "disabled", | |
| ]: | |
| if key == "name": | |
| value = normalize_name(value) | |
| setattr(self, key, value) | |
| # We already run in an event loop, dont believe Pylance | |
| return asyncio.run(self.__on_update()) | |
| async def __on_update(self) -> "MCPServerLocal": | |
| await self.__client.update_tools() # type: ignore | |
| return self | |
| MCPServer = Annotated[ | |
| Union[ | |
| Annotated[MCPServerRemote, Tag("MCPServerRemote")], | |
| Annotated[MCPServerLocal, Tag("MCPServerLocal")], | |
| ], | |
| Discriminator(_determine_server_type), | |
| ] | |
| class MCPConfig(BaseModel): | |
| servers: list[MCPServer] = Field(default_factory=list) | |
| disconnected_servers: list[dict[str, Any]] = Field(default_factory=list) | |
| __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) | |
| __instance: ClassVar[Any] = PrivateAttr(default=None) | |
| __initialized: ClassVar[bool] = PrivateAttr(default=False) | |
| def get_instance(cls) -> "MCPConfig": | |
| # with cls.__lock: | |
| if cls.__instance is None: | |
| cls.__instance = cls(servers_list=[]) | |
| return cls.__instance | |
| def wait_for_lock(cls): | |
| with cls.__lock: | |
| return | |
| def update(cls, config_str: str) -> Any: | |
| with cls.__lock: | |
| servers_data: List[Dict[str, Any]] = [] # Default to empty list | |
| if ( | |
| config_str and config_str.strip() | |
| ): # Only parse if non-empty and not just whitespace | |
| try: | |
| # Try with standard json.loads first, as it should handle escaped strings correctly | |
| parsed_value = dirty_json.try_parse(config_str) | |
| normalized = cls.normalize_config(parsed_value) | |
| if isinstance(normalized, list): | |
| valid_servers = [] | |
| for item in normalized: | |
| if isinstance(item, dict): | |
| valid_servers.append(item) | |
| else: | |
| PrintStyle( | |
| background_color="yellow", | |
| font_color="black", | |
| padding=True, | |
| ).print( | |
| f"Warning: MCP config item (from json.loads) was not a dictionary and was ignored: {item}" | |
| ) | |
| servers_data = valid_servers | |
| else: | |
| PrintStyle( | |
| background_color="red", font_color="white", padding=True | |
| ).print( | |
| f"Error: Parsed MCP config (from json.loads) top-level structure is not a list. Config string was: '{config_str}'" | |
| ) | |
| # servers_data remains empty | |
| except ( | |
| Exception | |
| ) as e_json: # Catch json.JSONDecodeError specifically if possible, or general Exception | |
| PrintStyle.error( | |
| f"Error parsing MCP config string: {e_json}. Config string was: '{config_str}'" | |
| ) | |
| # # Fallback to DirtyJson or log error if standard json.loads fails | |
| # PrintStyle(background_color="orange", font_color="black", padding=True).print( | |
| # f"Standard json.loads failed for MCP config: {e_json}. Attempting DirtyJson as fallback." | |
| # ) | |
| # try: | |
| # parsed_value = DirtyJson.parse_string(config_str) | |
| # if isinstance(parsed_value, list): | |
| # valid_servers = [] | |
| # for item in parsed_value: | |
| # if isinstance(item, dict): | |
| # valid_servers.append(item) | |
| # else: | |
| # PrintStyle(background_color="yellow", font_color="black", padding=True).print( | |
| # f"Warning: MCP config item (from DirtyJson) was not a dictionary and was ignored: {item}" | |
| # ) | |
| # servers_data = valid_servers | |
| # else: | |
| # PrintStyle(background_color="red", font_color="white", padding=True).print( | |
| # f"Error: Parsed MCP config (from DirtyJson) top-level structure is not a list. Config string was: '{config_str}'" | |
| # ) | |
| # # servers_data remains empty | |
| # except Exception as e_dirty: | |
| # PrintStyle(background_color="red", font_color="white", padding=True).print( | |
| # f"Error parsing MCP config string with DirtyJson as well: {e_dirty}. Config string was: '{config_str}'" | |
| # ) | |
| # # servers_data remains empty, allowing graceful degradation | |
| # Initialize/update the singleton instance with the (potentially empty) list of server data | |
| instance = cls.get_instance() | |
| # Directly update the servers attribute of the existing instance or re-initialize carefully | |
| # For simplicity and to ensure __init__ logic runs if needed for setup: | |
| new_instance_data = { | |
| "servers": servers_data | |
| } # Prepare data for re-initialization or update | |
| # Option 1: Re-initialize the existing instance (if __init__ is idempotent for other fields) | |
| instance.__init__(servers_list=servers_data) | |
| # Option 2: Or, if __init__ has side effects we don't want to repeat, | |
| # and 'servers' is the primary thing 'update' changes: | |
| # instance.servers = [] # Clear existing servers first | |
| # for server_item_data in servers_data: | |
| # try: | |
| # if server_item_data.get("url", None): | |
| # instance.servers.append(MCPServerRemote(server_item_data)) | |
| # else: | |
| # instance.servers.append(MCPServerLocal(server_item_data)) | |
| # except Exception as e_init: | |
| # PrintStyle(background_color="grey", font_color="red", padding=True).print( | |
| # f"MCPConfig.update: Failed to create MCPServer from item '{server_item_data.get('name', 'Unknown')}': {e_init}" | |
| # ) | |
| cls.__initialized = True | |
| return instance | |
| def normalize_config(cls, servers: Any): | |
| normalized = [] | |
| if isinstance(servers, list): | |
| for server in servers: | |
| if isinstance(server, dict): | |
| normalized.append(server) | |
| elif isinstance(servers, dict): | |
| if "mcpServers" in servers: | |
| if isinstance(servers["mcpServers"], dict): | |
| for key, value in servers["mcpServers"].items(): | |
| if isinstance(value, dict): | |
| value["name"] = key | |
| normalized.append(value) | |
| elif isinstance(servers["mcpServers"], list): | |
| for server in servers["mcpServers"]: | |
| if isinstance(server, dict): | |
| normalized.append(server) | |
| else: | |
| normalized.append(servers) # single server? | |
| return normalized | |
| def __init__(self, servers_list: List[Dict[str, Any]]): | |
| from collections.abc import Mapping, Iterable | |
| # # DEBUG: Print the received servers_list | |
| # if servers_list: | |
| # PrintStyle(background_color="blue", font_color="white", padding=True).print( | |
| # f"MCPConfig.__init__ received servers_list: {servers_list}" | |
| # ) | |
| # This empties the servers list if MCPConfig is a Pydantic model and servers is a field. | |
| # If servers is a field like `servers: List[MCPServer] = Field(default_factory=list)`, | |
| # then super().__init__() might try to initialize it. | |
| # We are re-assigning self.servers later in this __init__. | |
| super().__init__() | |
| # Clear any servers potentially initialized by super().__init__() before we populate based on servers_list | |
| self.servers = [] | |
| # initialize failed servers list | |
| self.disconnected_servers = [] | |
| if not isinstance(servers_list, Iterable): | |
| ( | |
| PrintStyle( | |
| background_color="grey", font_color="red", padding=True | |
| ).print("MCPConfig::__init__::servers_list must be a list") | |
| ) | |
| return | |
| for server_item in servers_list: | |
| if not isinstance(server_item, Mapping): | |
| # log the error | |
| error_msg = "server_item must be a mapping" | |
| ( | |
| PrintStyle( | |
| background_color="grey", font_color="red", padding=True | |
| ).print(f"MCPConfig::__init__::{error_msg}") | |
| ) | |
| # add to failed servers with generic name | |
| self.disconnected_servers.append( | |
| { | |
| "config": ( | |
| server_item | |
| if isinstance(server_item, dict) | |
| else {"raw": str(server_item)} | |
| ), | |
| "error": error_msg, | |
| "name": "invalid_server_config", | |
| } | |
| ) | |
| continue | |
| if server_item.get("disabled", False): | |
| # get server name if available | |
| server_name = server_item.get("name", "unnamed_server") | |
| # normalize server name if it exists | |
| if server_name != "unnamed_server": | |
| server_name = normalize_name(server_name) | |
| # add to failed servers | |
| self.disconnected_servers.append( | |
| { | |
| "config": server_item, | |
| "error": "Disabled in config", | |
| "name": server_name, | |
| } | |
| ) | |
| continue | |
| server_name = server_item.get("name", "__not__found__") | |
| if server_name == "__not__found__": | |
| # log the error | |
| error_msg = "server_name is required" | |
| ( | |
| PrintStyle( | |
| background_color="grey", font_color="red", padding=True | |
| ).print(f"MCPConfig::__init__::{error_msg}") | |
| ) | |
| # add to failed servers | |
| self.disconnected_servers.append( | |
| { | |
| "config": server_item, | |
| "error": error_msg, | |
| "name": "unnamed_server", | |
| } | |
| ) | |
| continue | |
| try: | |
| # not generic MCPServer because: "Annotated can not be instatioated" | |
| if server_item.get("url", None) or server_item.get("serverUrl", None): | |
| self.servers.append(MCPServerRemote(server_item)) | |
| else: | |
| self.servers.append(MCPServerLocal(server_item)) | |
| except Exception as e: | |
| # log the error | |
| error_msg = str(e) | |
| ( | |
| PrintStyle( | |
| background_color="grey", font_color="red", padding=True | |
| ).print( | |
| f"MCPConfig::__init__: Failed to create MCPServer '{server_name}': {error_msg}" | |
| ) | |
| ) | |
| # add to failed servers | |
| self.disconnected_servers.append( | |
| {"config": server_item, "error": error_msg, "name": server_name} | |
| ) | |
| def get_server_log(self, server_name: str) -> str: | |
| with self.__lock: | |
| for server in self.servers: | |
| if server.name == server_name: | |
| return server.get_log() # type: ignore | |
| return "" | |
| def get_servers_status(self) -> list[dict[str, Any]]: | |
| """Get status of all servers""" | |
| result = [] | |
| with self.__lock: | |
| # add connected/working servers | |
| for server in self.servers: | |
| # get server name | |
| name = server.name | |
| # get tool count | |
| tool_count = len(server.get_tools()) | |
| # check if server is connected | |
| connected = True # tool_count > 0 | |
| # get error message if any | |
| error = server.get_error() | |
| # get log bool | |
| has_log = server.get_log() != "" | |
| # add server status to result | |
| result.append( | |
| { | |
| "name": name, | |
| "connected": connected, | |
| "error": error, | |
| "tool_count": tool_count, | |
| "has_log": has_log, | |
| } | |
| ) | |
| # add failed servers | |
| for disconnected in self.disconnected_servers: | |
| result.append( | |
| { | |
| "name": disconnected["name"], | |
| "connected": False, | |
| "error": disconnected["error"], | |
| "tool_count": 0, | |
| "has_log": False, | |
| } | |
| ) | |
| return result | |
| def get_server_detail(self, server_name: str) -> dict[str, Any]: | |
| with self.__lock: | |
| for server in self.servers: | |
| if server.name == server_name: | |
| try: | |
| tools = server.get_tools() | |
| except Exception: | |
| tools = [] | |
| return { | |
| "name": server.name, | |
| "description": server.description, | |
| "tools": tools, | |
| } | |
| return {} | |
| def is_initialized(self) -> bool: | |
| """Check if the client is initialized""" | |
| with self.__lock: | |
| return self.__initialized | |
| def get_tools(self) -> List[dict[str, dict[str, Any]]]: | |
| """Get all tools from all servers""" | |
| with self.__lock: | |
| tools = [] | |
| for server in self.servers: | |
| for tool in server.get_tools(): | |
| tool_copy = tool.copy() | |
| tool_copy["server"] = server.name | |
| tools.append({f"{server.name}.{tool['name']}": tool_copy}) | |
| return tools | |
| def get_tools_prompt(self, server_name: str = "") -> str: | |
| """Get a prompt for all tools""" | |
| # just to wait for pending initialization | |
| with self.__lock: | |
| pass | |
| prompt = '## "Remote (MCP Server) Agent Tools" available:\n\n' | |
| server_names = [] | |
| for server in self.servers: | |
| if not server_name or server.name == server_name: | |
| server_names.append(server.name) | |
| if server_name and server_name not in server_names: | |
| raise ValueError(f"Server {server_name} not found") | |
| for server in self.servers: | |
| if server.name in server_names: | |
| server_name = server.name | |
| prompt += f"### {server_name}\n" | |
| prompt += f"{server.description}\n" | |
| tools = server.get_tools() | |
| for tool in tools: | |
| prompt += ( | |
| f"\n### {server_name}.{tool['name']}:\n" | |
| f"{tool['description']}\n\n" | |
| # f"#### Categories:\n" | |
| # f"* kind: MCP Server Tool\n" | |
| # f'* server: "{server_name}" ({server.description})\n\n' | |
| # f"#### Arguments:\n" | |
| ) | |
| input_schema = ( | |
| json.dumps(tool["input_schema"]) if tool["input_schema"] else "" | |
| ) | |
| prompt += f"#### Input schema for tool_args:\n{input_schema}\n" | |
| prompt += "\n" | |
| prompt += ( | |
| f"#### Usage:\n" | |
| f"{{\n" | |
| # f' "observations": ["..."],\n' # TODO: this should be a prompt file with placeholders | |
| f' "thoughts": ["..."],\n' | |
| # f' "reflection": ["..."],\n' # TODO: this should be a prompt file with placeholders | |
| f" \"tool_name\": \"{server_name}.{tool['name']}\",\n" | |
| f' "tool_args": !follow schema above\n' | |
| f"}}\n" | |
| ) | |
| return prompt | |
| def has_tool(self, tool_name: str) -> bool: | |
| """Check if a tool is available""" | |
| if "." not in tool_name: | |
| return False | |
| server_name_part, tool_name_part = tool_name.split(".") | |
| with self.__lock: | |
| for server in self.servers: | |
| if server.name == server_name_part: | |
| return server.has_tool(tool_name_part) | |
| return False | |
| def get_tool(self, agent: Any, tool_name: str) -> MCPTool | None: | |
| if not self.has_tool(tool_name): | |
| return None | |
| return MCPTool(agent=agent, name=tool_name, method=None, args={}, message="", loop_data=None) | |
| async def call_tool( | |
| self, tool_name: str, input_data: Dict[str, Any] | |
| ) -> CallToolResult: | |
| """Call a tool with the given input data""" | |
| if "." not in tool_name: | |
| raise ValueError(f"Tool {tool_name} not found") | |
| server_name_part, tool_name_part = tool_name.split(".") | |
| with self.__lock: | |
| for server in self.servers: | |
| if server.name == server_name_part and server.has_tool(tool_name_part): | |
| return await server.call_tool(tool_name_part, input_data) | |
| raise ValueError(f"Tool {tool_name} not found") | |
| T = TypeVar("T") | |
| class MCPClientBase(ABC): | |
| # server: Union[MCPServerLocal, MCPServerRemote] # Defined in __init__ | |
| # tools: List[dict[str, Any]] # Defined in __init__ | |
| # No self.session, self.exit_stack, self.stdio, self.write as persistent instance fields | |
| __lock: ClassVar[threading.Lock] = threading.Lock() | |
| def __init__(self, server: Union[MCPServerLocal, MCPServerRemote]): | |
| self.server = server | |
| self.tools: List[dict[str, Any]] = [] # Tools are cached on the client instance | |
| self.error: str = "" | |
| self.log: List[str] = [] | |
| self.log_file: Optional[TextIO] = None | |
| # Protected method | |
| async def _create_stdio_transport( | |
| self, current_exit_stack: AsyncExitStack | |
| ) -> tuple[ | |
| MemoryObjectReceiveStream[SessionMessage | Exception], | |
| MemoryObjectSendStream[SessionMessage], | |
| ]: | |
| """Create stdio/write streams using the provided exit_stack.""" | |
| ... | |
| async def _execute_with_session( | |
| self, | |
| coro_func: Callable[[ClientSession], Awaitable[T]], | |
| read_timeout_seconds=60, | |
| ) -> T: | |
| """ | |
| Manages the lifecycle of an MCP session for a single operation. | |
| Creates a temporary session, executes coro_func with it, and ensures cleanup. | |
| """ | |
| operation_name = coro_func.__name__ # For logging | |
| # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name}): Creating new session for operation '{operation_name}'...") | |
| # Store the original exception outside the async block | |
| original_exception = None | |
| try: | |
| async with AsyncExitStack() as temp_stack: | |
| try: | |
| stdio, write = await self._create_stdio_transport(temp_stack) | |
| # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name} - {operation_name}): Transport created. Initializing session...") | |
| session = await temp_stack.enter_async_context( | |
| ClientSession( | |
| stdio, # type: ignore | |
| write, # type: ignore | |
| read_timeout_seconds=timedelta( | |
| seconds=read_timeout_seconds | |
| ), | |
| ) | |
| ) | |
| await session.initialize() | |
| result = await coro_func(session) | |
| return result | |
| except Exception as e: | |
| # Store the original exception and raise a dummy exception | |
| excs = getattr(e, "exceptions", None) # Python 3.11+ ExceptionGroup | |
| if excs: | |
| original_exception = excs[0] | |
| else: | |
| original_exception = e | |
| # Create a dummy exception to break out of the async block | |
| raise RuntimeError("Dummy exception to break out of async block") | |
| except Exception as e: | |
| # Check if this is our dummy exception | |
| if original_exception is not None: | |
| e = original_exception | |
| # We have the original exception stored | |
| PrintStyle( | |
| background_color="#AA4455", font_color="white", padding=False | |
| ).print( | |
| f"MCPClientBase ({self.server.name} - {operation_name}): Error during operation: {type(e).__name__}: {e}" | |
| ) | |
| raise e # Re-raise the original exception | |
| # finally: | |
| # PrintStyle(font_color="cyan").print( | |
| # f"MCPClientBase ({self.server.name} - {operation_name}): Session and transport will be closed by AsyncExitStack." | |
| # ) | |
| # This line should ideally be unreachable if the try/except/finally logic within the 'async with' is exhaustive. | |
| # Adding it to satisfy linters that might not fully trace the raise/return paths through async context managers. | |
| raise RuntimeError( | |
| f"MCPClientBase ({self.server.name} - {operation_name}): _execute_with_session exited 'async with' block unexpectedly." | |
| ) | |
| async def update_tools(self) -> "MCPClientBase": | |
| # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name}): Starting 'update_tools' operation...") | |
| async def list_tools_op(current_session: ClientSession): | |
| response: ListToolsResult = await current_session.list_tools() | |
| with self.__lock: | |
| self.tools = [ | |
| { | |
| "name": tool.name, | |
| "description": tool.description, | |
| "input_schema": tool.inputSchema, | |
| } | |
| for tool in response.tools | |
| ] | |
| PrintStyle(font_color="green").print( | |
| f"MCPClientBase ({self.server.name}): Tools updated. Found {len(self.tools)} tools." | |
| ) | |
| try: | |
| set = settings.get_settings() | |
| await self._execute_with_session( | |
| list_tools_op, | |
| read_timeout_seconds=self.server.init_timeout | |
| or set["mcp_client_init_timeout"], | |
| ) | |
| except Exception as e: | |
| # e = eg.exceptions[0] | |
| error_text = errors.format_error(e, 0, 0) | |
| # Error already logged by _execute_with_session, this is for specific handling if needed | |
| PrintStyle( | |
| background_color="#CC34C3", font_color="white", bold=True, padding=True | |
| ).print( | |
| f"MCPClientBase ({self.server.name}): 'update_tools' operation failed: {error_text}" | |
| ) | |
| with self.__lock: | |
| self.tools = [] # Ensure tools are cleared on failure | |
| self.error = f"Failed to initialize. {error_text[:200]}{'...' if len(error_text) > 200 else ''}" # store error from tools fetch | |
| return self | |
| def has_tool(self, tool_name: str) -> bool: | |
| """Check if a tool is available (uses cached tools)""" | |
| with self.__lock: | |
| for tool in self.tools: | |
| if tool["name"] == tool_name: | |
| return True | |
| return False | |
| def get_tools(self) -> List[dict[str, Any]]: | |
| """Get all tools from the server (uses cached tools)""" | |
| with self.__lock: | |
| return self.tools | |
| async def call_tool( | |
| self, tool_name: str, input_data: Dict[str, Any] | |
| ) -> CallToolResult: | |
| # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name}): Preparing for 'call_tool' operation for tool '{tool_name}'.") | |
| if not self.has_tool(tool_name): | |
| PrintStyle(font_color="orange").print( | |
| f"MCPClientBase ({self.server.name}): Tool '{tool_name}' not in cache for 'call_tool', refreshing tools..." | |
| ) | |
| await self.update_tools() # This will use its own properly managed session | |
| if not self.has_tool(tool_name): | |
| PrintStyle(font_color="red").print( | |
| f"MCPClientBase ({self.server.name}): Tool '{tool_name}' not found after refresh. Raising ValueError." | |
| ) | |
| raise ValueError( | |
| f"Tool {tool_name} not found after refreshing tool list for server {self.server.name}." | |
| ) | |
| PrintStyle(font_color="green").print( | |
| f"MCPClientBase ({self.server.name}): Tool '{tool_name}' found after updating tools." | |
| ) | |
| async def call_tool_op(current_session: ClientSession): | |
| set = settings.get_settings() | |
| # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name}): Executing 'call_tool' for '{tool_name}' via MCP session...") | |
| response: CallToolResult = await current_session.call_tool( | |
| tool_name, | |
| input_data, | |
| read_timeout_seconds=timedelta(seconds=set["mcp_client_tool_timeout"]), | |
| ) | |
| # PrintStyle(font_color="green").print(f"MCPClientBase ({self.server.name}): Tool '{tool_name}' call successful via session.") | |
| return response | |
| try: | |
| return await self._execute_with_session(call_tool_op) | |
| except Exception as e: | |
| # Error logged by _execute_with_session. Re-raise a specific error for the caller. | |
| PrintStyle( | |
| background_color="#AA4455", font_color="white", padding=True | |
| ).print( | |
| f"MCPClientBase ({self.server.name}): 'call_tool' operation for '{tool_name}' failed: {type(e).__name__}: {e}" | |
| ) | |
| raise ConnectionError( | |
| f"MCPClientBase::Failed to call tool '{tool_name}' on server '{self.server.name}'. Original error: {type(e).__name__}: {e}" | |
| ) | |
| def get_log(self): | |
| # read and return lines from self.log_file, do not close it | |
| if not hasattr(self, "log_file") or self.log_file is None: | |
| return "" | |
| self.log_file.seek(0) | |
| try: | |
| log = self.log_file.read() | |
| except Exception: | |
| log = "" | |
| return log | |
| class MCPClientLocal(MCPClientBase): | |
| def __del__(self): | |
| # close the log file if it exists | |
| if hasattr(self, "log_file") and self.log_file is not None: | |
| try: | |
| self.log_file.close() | |
| except Exception: | |
| pass | |
| self.log_file = None | |
| async def _create_stdio_transport( | |
| self, current_exit_stack: AsyncExitStack | |
| ) -> tuple[ | |
| MemoryObjectReceiveStream[SessionMessage | Exception], | |
| MemoryObjectSendStream[SessionMessage], | |
| ]: | |
| """Connect to an MCP server, init client and save stdio/write streams""" | |
| server: MCPServerLocal = cast(MCPServerLocal, self.server) | |
| if not server.command: | |
| raise ValueError("Command not specified") | |
| if not which(server.command): | |
| raise ValueError(f"Command '{server.command}' not found") | |
| server_params = StdioServerParameters( | |
| command=server.command, | |
| args=server.args, | |
| env=server.env, | |
| encoding=server.encoding, | |
| encoding_error_handler=server.encoding_error_handler, | |
| ) | |
| # create a custom error log handler that will capture error output | |
| import tempfile | |
| # use a temporary file for error logging (text mode) if not already present | |
| if not hasattr(self, "log_file") or self.log_file is None: | |
| self.log_file = tempfile.TemporaryFile(mode="w+", encoding="utf-8") | |
| # use the stdio_client with our error log file | |
| stdio_transport = await current_exit_stack.enter_async_context( | |
| stdio_client(server_params, errlog=self.log_file) | |
| ) | |
| # do not read or close the file here, as stdio is async | |
| return stdio_transport | |
| class CustomHTTPClientFactory(ABC): | |
| def __init__(self, verify: bool = True): | |
| self.verify = verify | |
| def __call__( | |
| self, | |
| headers: dict[str, str] | None = None, | |
| timeout: httpx.Timeout | None = None, | |
| auth: httpx.Auth | None = None, | |
| ) -> httpx.AsyncClient: | |
| # Set MCP defaults | |
| kwargs: dict[str, Any] = { | |
| "follow_redirects": True, | |
| } | |
| # Handle timeout | |
| if timeout is None: | |
| kwargs["timeout"] = httpx.Timeout(30.0) | |
| else: | |
| kwargs["timeout"] = timeout | |
| # Handle headers | |
| if headers is not None: | |
| kwargs["headers"] = headers | |
| # Handle authentication | |
| if auth is not None: | |
| kwargs["auth"] = auth | |
| return httpx.AsyncClient(**kwargs, verify=self.verify) | |
| class MCPClientRemote(MCPClientBase): | |
| def __init__(self, server: Union[MCPServerLocal, MCPServerRemote]): | |
| super().__init__(server) | |
| self.session_id: Optional[str] = None # Track session ID for streaming HTTP clients | |
| self.session_id_callback: Optional[Callable[[], Optional[str]]] = None | |
| async def _create_stdio_transport( | |
| self, current_exit_stack: AsyncExitStack | |
| ) -> tuple[ | |
| MemoryObjectReceiveStream[SessionMessage | Exception], | |
| MemoryObjectSendStream[SessionMessage], | |
| ]: | |
| """Connect to an MCP server, init client and save stdio/write streams""" | |
| server: MCPServerRemote = cast(MCPServerRemote, self.server) | |
| set = settings.get_settings() | |
| # Use lower timeouts for faster failure detection | |
| init_timeout = min(server.init_timeout or set["mcp_client_init_timeout"], 5) | |
| tool_timeout = min(server.tool_timeout or set["mcp_client_tool_timeout"], 10) | |
| client_factory = CustomHTTPClientFactory(verify=server.verify) | |
| # Check if this is a streaming HTTP type | |
| if _is_streaming_http_type(server.type): | |
| # Use streamable HTTP client | |
| transport_result = await current_exit_stack.enter_async_context( | |
| streamablehttp_client( | |
| url=server.url, | |
| headers=server.headers, | |
| timeout=timedelta(seconds=init_timeout), | |
| sse_read_timeout=timedelta(seconds=tool_timeout), | |
| httpx_client_factory=client_factory, | |
| ) | |
| ) | |
| # streamablehttp_client returns (read_stream, write_stream, get_session_id_callback) | |
| read_stream, write_stream, get_session_id_callback = transport_result | |
| # Store session ID callback for potential future use | |
| self.session_id_callback = get_session_id_callback | |
| return read_stream, write_stream | |
| else: | |
| # Use traditional SSE client (default behavior) | |
| stdio_transport = await current_exit_stack.enter_async_context( | |
| sse_client( | |
| url=server.url, | |
| headers=server.headers, | |
| timeout=init_timeout, | |
| sse_read_timeout=tool_timeout, | |
| httpx_client_factory=client_factory, | |
| ) | |
| ) | |
| return stdio_transport | |
| def get_session_id(self) -> Optional[str]: | |
| """Get the current session ID if available (for streaming HTTP clients).""" | |
| if self.session_id_callback is not None: | |
| return self.session_id_callback() | |
| return None | |