| from abc import ABC, abstractmethod |
| import re |
| from typing import ( |
| List, |
| Dict, |
| Optional, |
| Any, |
| TextIO, |
| Union, |
| Literal, |
| Annotated, |
| ClassVar, |
| cast, |
| Callable, |
| Awaitable, |
| TypeVar, |
| ) |
| import threading |
| import asyncio |
| from contextlib import AsyncExitStack |
| from shutil import which |
| from datetime import timedelta |
| import json |
| from helpers import errors |
| from helpers import settings |
| from helpers.log import LogItem |
|
|
| import httpx |
|
|
| from mcp import ClientSession, StdioServerParameters |
| from mcp.client.stdio import stdio_client |
| from mcp.client.sse import sse_client |
| from mcp.client.streamable_http import streamablehttp_client |
| from mcp.shared.message import SessionMessage |
| from mcp.types import CallToolResult, ListToolsResult |
| from anyio.streams.memory import ( |
| MemoryObjectReceiveStream, |
| MemoryObjectSendStream, |
| ) |
|
|
| from pydantic import BaseModel, Field, Discriminator, Tag, PrivateAttr |
| from helpers import dirty_json |
| from helpers.print_style import PrintStyle |
| from helpers.tool import Tool, Response |
|
|
|
|
| def normalize_name(name: str) -> str: |
| |
| name = name.strip().lower() |
| |
| |
| |
| name = re.sub(r"[^\w]", "_", name, flags=re.UNICODE) |
| return name |
|
|
|
|
| def _determine_server_type(config_dict: dict) -> str: |
| """Determine the server type based on configuration, with backward compatibility.""" |
| |
| if "type" in config_dict: |
| server_type = config_dict["type"].lower() |
| if server_type in ["sse", "http-stream", "streaming-http", "streamable-http", "http-streaming"]: |
| return "MCPServerRemote" |
| elif server_type == "stdio": |
| return "MCPServerLocal" |
| |
| else: |
| |
| |
| pass |
|
|
| |
| if "url" in config_dict or "serverUrl" in config_dict: |
| return "MCPServerRemote" |
| else: |
| return "MCPServerLocal" |
|
|
|
|
| def _is_streaming_http_type(server_type: str) -> bool: |
| """Check if the server type is a streaming HTTP variant.""" |
| return server_type.lower() in ["http-stream", "streaming-http", "streamable-http", "http-streaming"] |
|
|
|
|
| def initialize_mcp(mcp_servers_config: str): |
| if not MCPConfig.get_instance().is_initialized(): |
| try: |
| MCPConfig.update(mcp_servers_config) |
| except Exception as e: |
| from agent import AgentContext |
|
|
| AgentContext.log_to_all( |
| type="warning", |
| content=f"Failed to update MCP settings: {e}", |
| ) |
|
|
| PrintStyle( |
| background_color="black", font_color="red", padding=True |
| ).print(f"Failed to update MCP settings: {e}") |
|
|
|
|
| class MCPTool(Tool): |
| """MCP Tool wrapper""" |
|
|
| def get_log_object(self) -> LogItem: |
| import uuid |
| return self.agent.context.log.log( |
| type="mcp", |
| heading=f"icon://extension {self.agent.agent_name}: Using MCP tool '{self.name}'", |
| content="", |
| kvps={"tool_name": self.name, **self.args}, |
| id=str(uuid.uuid4()), |
| ) |
|
|
| async def execute(self, **kwargs: Any): |
| error = "" |
| try: |
| response: CallToolResult = await MCPConfig.get_instance().call_tool( |
| self.name, kwargs |
| ) |
| message = "\n\n".join( |
| [item.text for item in response.content if item.type == "text"] |
| ) |
| if response.isError: |
| error = message |
| except Exception as e: |
| error = f"MCP Tool Exception: {str(e)}" |
| message = f"ERROR: {str(e)}" |
|
|
| if error: |
| PrintStyle( |
| background_color="#CC34C3", font_color="white", bold=True, padding=True |
| ).print(f"MCPTool::Failed to call mcp tool {self.name}:") |
| PrintStyle( |
| background_color="#AA4455", font_color="white", padding=False |
| ).print(error) |
|
|
| self.agent.context.log.log( |
| type="warning", |
| content=f"{self.name}: {error}", |
| ) |
|
|
| return Response(message=message, break_loop=False) |
|
|
| async def before_execution(self, **kwargs: Any): |
| ( |
| PrintStyle( |
| font_color="#1B4F72", padding=True, background_color="white", bold=True |
| ).print(f"{self.agent.agent_name}: Using tool '{self.name}'") |
| ) |
| self.log = self.get_log_object() |
|
|
| for key, value in self.args.items(): |
| PrintStyle(font_color="#85C1E9", bold=True).stream( |
| self.nice_key(key) + ": " |
| ) |
| PrintStyle( |
| font_color="#85C1E9", padding=isinstance(value, str) and "\n" in value |
| ).stream(value) |
| PrintStyle().print() |
|
|
| async def after_execution(self, response: Response, **kwargs: Any): |
| raw_tool_response = response.message.strip() if response.message else "" |
| if not raw_tool_response: |
| PrintStyle(font_color="red").print( |
| f"Warning: Tool '{self.name}' returned an empty message." |
| ) |
| |
| raw_tool_response = "[Tool returned no textual content]" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| final_text_for_agent = raw_tool_response |
|
|
| self.agent.hist_add_tool_result(self.name, final_text_for_agent, id=self.log.id if self.log else "") |
| ( |
| PrintStyle( |
| font_color="#1B4F72", background_color="white", padding=True, bold=True |
| ).print( |
| f"{self.agent.agent_name}: Response from tool '{self.name}' (plus context added)" |
| ) |
| ) |
| |
| PrintStyle(font_color="#85C1E9").print( |
| raw_tool_response |
| if raw_tool_response |
| else "[No direct textual output from tool]" |
| ) |
| if self.log: |
| self.log.update( |
| content=final_text_for_agent |
| ) |
|
|
|
|
| class MCPServerRemote(BaseModel): |
| name: str = Field(default_factory=str) |
| description: Optional[str] = Field(default="Remote SSE Server") |
| type: str = Field(default="sse", description="Server connection type") |
| url: str = Field(default_factory=str) |
| headers: dict[str, Any] | None = Field(default_factory=dict[str, Any]) |
| init_timeout: int = Field(default=0) |
| tool_timeout: int = Field(default=0) |
| verify: bool = Field(default=True, description="Verify SSL certificates") |
| disabled: bool = Field(default=False) |
|
|
| __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) |
| __client: Optional["MCPClientRemote"] = PrivateAttr(default=None) |
|
|
| def __init__(self, config: dict[str, Any]): |
| super().__init__() |
| self.__client = MCPClientRemote(self) |
| self.update(config) |
|
|
| def get_error(self) -> str: |
| with self.__lock: |
| return self.__client.error |
|
|
| def get_log(self) -> str: |
| with self.__lock: |
| return self.__client.get_log() |
|
|
| def get_tools(self) -> List[dict[str, Any]]: |
| """Get all tools from the server""" |
| with self.__lock: |
| return self.__client.tools |
|
|
| def has_tool(self, tool_name: str) -> bool: |
| """Check if a tool is available""" |
| with self.__lock: |
| return self.__client.has_tool(tool_name) |
|
|
| async def call_tool( |
| self, tool_name: str, input_data: Dict[str, Any] |
| ) -> CallToolResult: |
| """Call a tool with the given input data""" |
| with self.__lock: |
| |
| return await self.__client.call_tool(tool_name, input_data) |
|
|
| def update(self, config: dict[str, Any]) -> "MCPServerRemote": |
| with self.__lock: |
| for key, value in config.items(): |
| if key in [ |
| "name", |
| "description", |
| "type", |
| "url", |
| "serverUrl", |
| "headers", |
| "init_timeout", |
| "tool_timeout", |
| "disabled", |
| "verify", |
| ]: |
| if key == "name": |
| value = normalize_name(value) |
| if key == "serverUrl": |
| key = "url" |
|
|
| setattr(self, key, value) |
| return self |
|
|
| async def initialize(self) -> "MCPServerRemote": |
| await self.__client.update_tools() |
| return self |
|
|
|
|
| class MCPServerLocal(BaseModel): |
| name: str = Field(default_factory=str) |
| description: Optional[str] = Field(default="Local StdIO Server") |
| type: str = Field(default="stdio", description="Server connection type") |
| command: str = Field(default_factory=str) |
| args: list[str] = Field(default_factory=list) |
| env: dict[str, str] | None = Field(default_factory=dict[str, str]) |
| encoding: str = Field(default="utf-8") |
| encoding_error_handler: Literal["strict", "ignore", "replace"] = Field( |
| default="strict" |
| ) |
| init_timeout: int = Field(default=0) |
| tool_timeout: int = Field(default=0) |
| verify: bool = Field(default=True, description="Verify SSL certificates") |
| disabled: bool = Field(default=False) |
|
|
| __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) |
| __client: Optional["MCPClientLocal"] = PrivateAttr(default=None) |
|
|
| def __init__(self, config: dict[str, Any]): |
| super().__init__() |
| self.__client = MCPClientLocal(self) |
| self.update(config) |
|
|
| def get_error(self) -> str: |
| with self.__lock: |
| return self.__client.error |
|
|
| def get_log(self) -> str: |
| with self.__lock: |
| return self.__client.get_log() |
|
|
| def get_tools(self) -> List[dict[str, Any]]: |
| """Get all tools from the server""" |
| with self.__lock: |
| return self.__client.tools |
|
|
| def has_tool(self, tool_name: str) -> bool: |
| """Check if a tool is available""" |
| with self.__lock: |
| return self.__client.has_tool(tool_name) |
|
|
| async def call_tool( |
| self, tool_name: str, input_data: Dict[str, Any] |
| ) -> CallToolResult: |
| """Call a tool with the given input data""" |
| with self.__lock: |
| |
| return await self.__client.call_tool(tool_name, input_data) |
|
|
| def update(self, config: dict[str, Any]) -> "MCPServerLocal": |
| with self.__lock: |
| for key, value in config.items(): |
| if key in [ |
| "name", |
| "description", |
| "type", |
| "command", |
| "args", |
| "env", |
| "encoding", |
| "encoding_error_handler", |
| "init_timeout", |
| "tool_timeout", |
| "disabled", |
| ]: |
| if key == "name": |
| value = normalize_name(value) |
| setattr(self, key, value) |
| return self |
|
|
| async def initialize(self) -> "MCPServerLocal": |
| await self.__client.update_tools() |
| return self |
|
|
|
|
| MCPServer = Annotated[ |
| Union[ |
| Annotated[MCPServerRemote, Tag("MCPServerRemote")], |
| Annotated[MCPServerLocal, Tag("MCPServerLocal")], |
| ], |
| Discriminator(_determine_server_type), |
| ] |
|
|
|
|
| class MCPConfig(BaseModel): |
| servers: list[MCPServer] = Field(default_factory=list) |
| disconnected_servers: list[dict[str, Any]] = Field(default_factory=list) |
| __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) |
| __instance: ClassVar[Any] = PrivateAttr(default=None) |
| __initialized: ClassVar[bool] = PrivateAttr(default=False) |
|
|
| @classmethod |
| def get_instance(cls) -> "MCPConfig": |
| |
| if cls.__instance is None: |
| cls.__instance = cls(servers_list=[]) |
| return cls.__instance |
|
|
| @classmethod |
| def wait_for_lock(cls): |
| with cls.__lock: |
| return |
|
|
| @classmethod |
| def update(cls, config_str: str) -> Any: |
| with cls.__lock: |
| servers_data: List[Dict[str, Any]] = [] |
|
|
| if ( |
| config_str and config_str.strip() |
| ): |
| try: |
| |
| parsed_value = dirty_json.try_parse(config_str) |
| normalized = cls.normalize_config(parsed_value) |
|
|
| if isinstance(normalized, list): |
| valid_servers = [] |
| for item in normalized: |
| if isinstance(item, dict): |
| valid_servers.append(item) |
| else: |
| PrintStyle( |
| background_color="yellow", |
| font_color="black", |
| padding=True, |
| ).print( |
| f"Warning: MCP config item (from json.loads) was not a dictionary and was ignored: {item}" |
| ) |
| servers_data = valid_servers |
| else: |
| PrintStyle( |
| background_color="red", font_color="white", padding=True |
| ).print( |
| f"Error: Parsed MCP config (from json.loads) top-level structure is not a list. Config string was: '{config_str}'" |
| ) |
| |
| except ( |
| Exception |
| ) as e_json: |
| PrintStyle.error( |
| f"Error parsing MCP config string: {e_json}. Config string was: '{config_str}'" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| instance = cls.get_instance() |
| |
| |
| new_instance_data = { |
| "servers": servers_data |
| } |
|
|
| |
| instance.__init__(servers_list=servers_data) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| cls.__initialized = True |
| return instance |
|
|
| @classmethod |
| def normalize_config(cls, servers: Any): |
| normalized = [] |
| if isinstance(servers, list): |
| for server in servers: |
| if isinstance(server, dict): |
| normalized.append(server) |
| elif isinstance(servers, dict): |
| if "mcpServers" in servers: |
| if isinstance(servers["mcpServers"], dict): |
| for key, value in servers["mcpServers"].items(): |
| if isinstance(value, dict): |
| value["name"] = key |
| normalized.append(value) |
| elif isinstance(servers["mcpServers"], list): |
| for server in servers["mcpServers"]: |
| if isinstance(server, dict): |
| normalized.append(server) |
| else: |
| normalized.append(servers) |
| return normalized |
|
|
| def __init__(self, servers_list: List[Dict[str, Any]]): |
| from collections.abc import Mapping, Iterable |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| super().__init__() |
|
|
| |
| self.servers = [] |
| |
| self.disconnected_servers = [] |
|
|
| if not isinstance(servers_list, Iterable): |
| ( |
| PrintStyle( |
| background_color="grey", font_color="red", padding=True |
| ).print("MCPConfig::__init__::servers_list must be a list") |
| ) |
| return |
|
|
| for server_item in servers_list: |
| if not isinstance(server_item, Mapping): |
| |
| error_msg = "server_item must be a mapping" |
| ( |
| PrintStyle( |
| background_color="grey", font_color="red", padding=True |
| ).print(f"MCPConfig::__init__::{error_msg}") |
| ) |
| |
| self.disconnected_servers.append( |
| { |
| "config": ( |
| server_item |
| if isinstance(server_item, dict) |
| else {"raw": str(server_item)} |
| ), |
| "error": error_msg, |
| "name": "invalid_server_config", |
| } |
| ) |
| continue |
|
|
| if server_item.get("disabled", False): |
| |
| server_name = server_item.get("name", "unnamed_server") |
| |
| if server_name != "unnamed_server": |
| server_name = normalize_name(server_name) |
|
|
| |
| self.disconnected_servers.append( |
| { |
| "config": server_item, |
| "error": "Disabled in config", |
| "name": server_name, |
| } |
| ) |
| continue |
|
|
| server_name = server_item.get("name", "__not__found__") |
| if server_name == "__not__found__": |
| |
| error_msg = "server_name is required" |
| ( |
| PrintStyle( |
| background_color="grey", font_color="red", padding=True |
| ).print(f"MCPConfig::__init__::{error_msg}") |
| ) |
| |
| self.disconnected_servers.append( |
| { |
| "config": server_item, |
| "error": error_msg, |
| "name": "unnamed_server", |
| } |
| ) |
| continue |
|
|
| try: |
| |
| if server_item.get("url", None) or server_item.get("serverUrl", None): |
| self.servers.append(MCPServerRemote(server_item)) |
| else: |
| self.servers.append(MCPServerLocal(server_item)) |
| except Exception as e: |
| |
| error_msg = str(e) |
| ( |
| PrintStyle( |
| background_color="grey", font_color="red", padding=True |
| ).print( |
| f"MCPConfig::__init__: Failed to create MCPServer '{server_name}': {error_msg}" |
| ) |
| ) |
| |
| self.disconnected_servers.append( |
| {"config": server_item, "error": error_msg, "name": server_name} |
| ) |
|
|
| |
| if self.servers: |
| async def _init_server(server): |
| try: |
| await server.initialize() |
| except Exception as e: |
| error_msg = str(e) |
| PrintStyle( |
| background_color="grey", font_color="red", padding=True |
| ).print( |
| f"MCPConfig::__init__: Failed to initialize MCPServer '{server.name}': {error_msg}" |
| ) |
|
|
| async def _init_all(): |
| await asyncio.gather(*[_init_server(s) for s in self.servers]) |
|
|
| asyncio.run(_init_all()) |
|
|
| def get_server_log(self, server_name: str) -> str: |
| with self.__lock: |
| for server in self.servers: |
| if server.name == server_name: |
| return server.get_log() |
| return "" |
|
|
| def get_servers_status(self) -> list[dict[str, Any]]: |
| """Get status of all servers""" |
| result = [] |
| with self.__lock: |
| |
| for server in self.servers: |
| |
| name = server.name |
| |
| tool_count = len(server.get_tools()) |
| |
| connected = True |
| |
| error = server.get_error() |
| |
| has_log = server.get_log() != "" |
|
|
| |
| result.append( |
| { |
| "name": name, |
| "connected": connected, |
| "error": error, |
| "tool_count": tool_count, |
| "has_log": has_log, |
| } |
| ) |
|
|
| |
| for disconnected in self.disconnected_servers: |
| result.append( |
| { |
| "name": disconnected["name"], |
| "connected": False, |
| "error": disconnected["error"], |
| "tool_count": 0, |
| "has_log": False, |
| } |
| ) |
|
|
| return result |
|
|
| def get_server_detail(self, server_name: str) -> dict[str, Any]: |
| with self.__lock: |
| for server in self.servers: |
| if server.name == server_name: |
| try: |
| tools = server.get_tools() |
| except Exception: |
| tools = [] |
| return { |
| "name": server.name, |
| "description": server.description, |
| "tools": tools, |
| } |
| return {} |
|
|
| def is_initialized(self) -> bool: |
| """Check if the client is initialized""" |
| with self.__lock: |
| return self.__initialized |
|
|
| def get_tools(self) -> List[dict[str, dict[str, Any]]]: |
| """Get all tools from all servers""" |
| with self.__lock: |
| tools = [] |
| for server in self.servers: |
| for tool in server.get_tools(): |
| tool_copy = tool.copy() |
| tool_copy["server"] = server.name |
| tools.append({f"{server.name}.{tool['name']}": tool_copy}) |
| return tools |
|
|
| def get_tools_prompt(self, server_name: str = "") -> str: |
| """Get a prompt for all tools""" |
|
|
| |
| with self.__lock: |
| pass |
|
|
| prompt = '## "Remote (MCP Server) Agent Tools" available:\n\n' |
| server_names = [] |
| for server in self.servers: |
| if not server_name or server.name == server_name: |
| server_names.append(server.name) |
|
|
| if server_name and server_name not in server_names: |
| raise ValueError(f"Server {server_name} not found") |
|
|
| for server in self.servers: |
| if server.name in server_names: |
| server_name = server.name |
| prompt += f"### {server_name}\n" |
| prompt += f"{server.description}\n" |
| tools = server.get_tools() |
|
|
| for tool in tools: |
| prompt += ( |
| f"\n### {server_name}.{tool['name']}:\n" |
| f"{tool['description']}\n\n" |
| |
| |
| |
| |
| ) |
|
|
| input_schema = ( |
| json.dumps(tool["input_schema"]) if tool["input_schema"] else "" |
| ) |
|
|
| prompt += f"#### Input schema for tool_args:\n{input_schema}\n" |
|
|
| prompt += "\n" |
|
|
| prompt += ( |
| f"#### Usage:\n" |
| f"{{\n" |
| |
| f' "thoughts": ["..."],\n' |
| |
| f" \"tool_name\": \"{server_name}.{tool['name']}\",\n" |
| f' "tool_args": !follow schema above\n' |
| f"}}\n" |
| ) |
|
|
| return prompt |
|
|
| def has_tool(self, tool_name: str) -> bool: |
| """Check if a tool is available""" |
| if "." not in tool_name: |
| return False |
| server_name_part, tool_name_part = tool_name.split(".") |
| with self.__lock: |
| for server in self.servers: |
| if server.name == server_name_part: |
| return server.has_tool(tool_name_part) |
| return False |
|
|
| def get_tool(self, agent: Any, tool_name: str) -> MCPTool | None: |
| if not self.has_tool(tool_name): |
| return None |
| return MCPTool(agent=agent, name=tool_name, method=None, args={}, message="", loop_data=None) |
|
|
| async def call_tool( |
| self, tool_name: str, input_data: Dict[str, Any] |
| ) -> CallToolResult: |
| """Call a tool with the given input data""" |
| if "." not in tool_name: |
| raise ValueError(f"Tool {tool_name} not found") |
| server_name_part, tool_name_part = tool_name.split(".") |
| with self.__lock: |
| for server in self.servers: |
| if server.name == server_name_part and server.has_tool(tool_name_part): |
| return await server.call_tool(tool_name_part, input_data) |
| raise ValueError(f"Tool {tool_name} not found") |
|
|
|
|
| T = TypeVar("T") |
|
|
|
|
| class MCPClientBase(ABC): |
| |
| |
| |
|
|
| __lock: ClassVar[threading.Lock] = threading.Lock() |
|
|
| def __init__(self, server: Union[MCPServerLocal, MCPServerRemote]): |
| self.server = server |
| self.tools: List[dict[str, Any]] = [] |
| self.error: str = "" |
| self.log: List[str] = [] |
| self.log_file: Optional[TextIO] = None |
|
|
| |
| @abstractmethod |
| async def _create_stdio_transport( |
| self, current_exit_stack: AsyncExitStack |
| ) -> tuple[ |
| MemoryObjectReceiveStream[SessionMessage | Exception], |
| MemoryObjectSendStream[SessionMessage], |
| ]: |
| """Create stdio/write streams using the provided exit_stack.""" |
| ... |
|
|
| async def _execute_with_session( |
| self, |
| coro_func: Callable[[ClientSession], Awaitable[T]], |
| read_timeout_seconds=60, |
| ) -> T: |
| """ |
| Manages the lifecycle of an MCP session for a single operation. |
| Creates a temporary session, executes coro_func with it, and ensures cleanup. |
| """ |
| operation_name = coro_func.__name__ |
| |
| |
| original_exception = None |
| try: |
| async with AsyncExitStack() as temp_stack: |
| try: |
|
|
| stdio, write = await self._create_stdio_transport(temp_stack) |
| |
| session = await temp_stack.enter_async_context( |
| ClientSession( |
| stdio, |
| write, |
| read_timeout_seconds=timedelta( |
| seconds=read_timeout_seconds |
| ), |
| ) |
| ) |
| await session.initialize() |
|
|
| result = await coro_func(session) |
|
|
| return result |
| except Exception as e: |
| |
| excs = getattr(e, "exceptions", None) |
| if excs: |
| original_exception = excs[0] |
| else: |
| original_exception = e |
| |
| raise RuntimeError("Dummy exception to break out of async block") |
| except Exception as e: |
| |
| if original_exception is not None: |
| e = original_exception |
| |
| PrintStyle( |
| background_color="#AA4455", font_color="white", padding=False |
| ).print( |
| f"MCPClientBase ({self.server.name} - {operation_name}): Error during operation: {type(e).__name__}: {e}" |
| ) |
| raise e |
| |
| |
| |
| |
| |
| |
| raise RuntimeError( |
| f"MCPClientBase ({self.server.name} - {operation_name}): _execute_with_session exited 'async with' block unexpectedly." |
| ) |
|
|
| async def update_tools(self) -> "MCPClientBase": |
| |
|
|
| async def list_tools_op(current_session: ClientSession): |
| response: ListToolsResult = await current_session.list_tools() |
| with self.__lock: |
| self.tools = [ |
| { |
| "name": tool.name, |
| "description": tool.description, |
| "input_schema": tool.inputSchema, |
| } |
| for tool in response.tools |
| ] |
| PrintStyle(font_color="green").print( |
| f"MCPClientBase ({self.server.name}): Tools updated. Found {len(self.tools)} tools." |
| ) |
|
|
| try: |
| set = settings.get_settings() |
| await self._execute_with_session( |
| list_tools_op, |
| read_timeout_seconds=self.server.init_timeout |
| or set["mcp_client_init_timeout"], |
| ) |
| except Exception as e: |
| |
| error_text = errors.format_error(e, 0, 0) |
| |
| PrintStyle( |
| background_color="#CC34C3", font_color="white", bold=True, padding=True |
| ).print( |
| f"MCPClientBase ({self.server.name}): 'update_tools' operation failed: {error_text}" |
| ) |
| with self.__lock: |
| self.tools = [] |
| self.error = f"Failed to initialize. {error_text[:200]}{'...' if len(error_text) > 200 else ''}" |
| return self |
|
|
| def has_tool(self, tool_name: str) -> bool: |
| """Check if a tool is available (uses cached tools)""" |
| with self.__lock: |
| for tool in self.tools: |
| if tool["name"] == tool_name: |
| return True |
| return False |
|
|
| def get_tools(self) -> List[dict[str, Any]]: |
| """Get all tools from the server (uses cached tools)""" |
| with self.__lock: |
| return self.tools |
|
|
| async def call_tool( |
| self, tool_name: str, input_data: Dict[str, Any] |
| ) -> CallToolResult: |
| |
| if not self.has_tool(tool_name): |
| PrintStyle(font_color="orange").print( |
| f"MCPClientBase ({self.server.name}): Tool '{tool_name}' not in cache for 'call_tool', refreshing tools..." |
| ) |
| await self.update_tools() |
| if not self.has_tool(tool_name): |
| PrintStyle(font_color="red").print( |
| f"MCPClientBase ({self.server.name}): Tool '{tool_name}' not found after refresh. Raising ValueError." |
| ) |
| raise ValueError( |
| f"Tool {tool_name} not found after refreshing tool list for server {self.server.name}." |
| ) |
| PrintStyle(font_color="green").print( |
| f"MCPClientBase ({self.server.name}): Tool '{tool_name}' found after updating tools." |
| ) |
|
|
| async def call_tool_op(current_session: ClientSession): |
| set = settings.get_settings() |
| |
| response: CallToolResult = await current_session.call_tool( |
| tool_name, |
| input_data, |
| read_timeout_seconds=timedelta(seconds=set["mcp_client_tool_timeout"]), |
| ) |
| |
| return response |
|
|
| try: |
| return await self._execute_with_session(call_tool_op) |
| except Exception as e: |
| |
| PrintStyle( |
| background_color="#AA4455", font_color="white", padding=True |
| ).print( |
| f"MCPClientBase ({self.server.name}): 'call_tool' operation for '{tool_name}' failed: {type(e).__name__}: {e}" |
| ) |
| raise ConnectionError( |
| f"MCPClientBase::Failed to call tool '{tool_name}' on server '{self.server.name}'. Original error: {type(e).__name__}: {e}" |
| ) |
|
|
| def get_log(self): |
| |
| if not hasattr(self, "log_file") or self.log_file is None: |
| return "" |
| self.log_file.seek(0) |
| try: |
| log = self.log_file.read() |
| except Exception: |
| log = "" |
| return log |
|
|
|
|
| class MCPClientLocal(MCPClientBase): |
| def __del__(self): |
| |
| if hasattr(self, "log_file") and self.log_file is not None: |
| try: |
| self.log_file.close() |
| except Exception: |
| pass |
| self.log_file = None |
|
|
| async def _create_stdio_transport( |
| self, current_exit_stack: AsyncExitStack |
| ) -> tuple[ |
| MemoryObjectReceiveStream[SessionMessage | Exception], |
| MemoryObjectSendStream[SessionMessage], |
| ]: |
| """Connect to an MCP server, init client and save stdio/write streams""" |
| server: MCPServerLocal = cast(MCPServerLocal, self.server) |
|
|
| if not server.command: |
| raise ValueError("Command not specified") |
| if not which(server.command): |
| raise ValueError(f"Command '{server.command}' not found") |
|
|
| server_params = StdioServerParameters( |
| command=server.command, |
| args=server.args, |
| env=server.env, |
| encoding=server.encoding, |
| encoding_error_handler=server.encoding_error_handler, |
| ) |
| |
| import tempfile |
|
|
| |
| if not hasattr(self, "log_file") or self.log_file is None: |
| self.log_file = tempfile.TemporaryFile(mode="w+", encoding="utf-8") |
|
|
| |
| stdio_transport = await current_exit_stack.enter_async_context( |
| stdio_client(server_params, errlog=self.log_file) |
| ) |
| |
| return stdio_transport |
|
|
| class CustomHTTPClientFactory(ABC): |
| def __init__(self, verify: bool = True): |
| self.verify = verify |
|
|
| def __call__( |
| self, |
| headers: dict[str, str] | None = None, |
| timeout: httpx.Timeout | None = None, |
| auth: httpx.Auth | None = None, |
| ) -> httpx.AsyncClient: |
| |
| kwargs: dict[str, Any] = { |
| "follow_redirects": True, |
| } |
|
|
| |
| if timeout is None: |
| kwargs["timeout"] = httpx.Timeout(30.0) |
| else: |
| kwargs["timeout"] = timeout |
|
|
| |
| if headers is not None: |
| kwargs["headers"] = headers |
|
|
| |
| if auth is not None: |
| kwargs["auth"] = auth |
|
|
| return httpx.AsyncClient(**kwargs, verify=self.verify) |
|
|
| class MCPClientRemote(MCPClientBase): |
|
|
| def __init__(self, server: Union[MCPServerLocal, MCPServerRemote]): |
| super().__init__(server) |
| self.session_id: Optional[str] = None |
| self.session_id_callback: Optional[Callable[[], Optional[str]]] = None |
|
|
| async def _create_stdio_transport( |
| self, current_exit_stack: AsyncExitStack |
| ) -> tuple[ |
| MemoryObjectReceiveStream[SessionMessage | Exception], |
| MemoryObjectSendStream[SessionMessage], |
| ]: |
| """Connect to an MCP server, init client and save stdio/write streams""" |
| server: MCPServerRemote = cast(MCPServerRemote, self.server) |
| set = settings.get_settings() |
|
|
| |
| init_timeout = server.init_timeout or set["mcp_client_init_timeout"] or 5 |
| tool_timeout = server.tool_timeout or set["mcp_client_tool_timeout"] or 10 |
|
|
| client_factory = CustomHTTPClientFactory(verify=server.verify) |
| |
| if _is_streaming_http_type(server.type): |
| |
| transport_result = await current_exit_stack.enter_async_context( |
| streamablehttp_client( |
| url=server.url, |
| headers=server.headers, |
| timeout=timedelta(seconds=init_timeout), |
| sse_read_timeout=timedelta(seconds=tool_timeout), |
| httpx_client_factory=client_factory, |
| ) |
| ) |
| |
| read_stream, write_stream, get_session_id_callback = transport_result |
|
|
| |
| self.session_id_callback = get_session_id_callback |
|
|
| return read_stream, write_stream |
| else: |
| |
| stdio_transport = await current_exit_stack.enter_async_context( |
| sse_client( |
| url=server.url, |
| headers=server.headers, |
| timeout=init_timeout, |
| sse_read_timeout=tool_timeout, |
| httpx_client_factory=client_factory, |
| ) |
| ) |
| return stdio_transport |
|
|
| def get_session_id(self) -> Optional[str]: |
| """Get the current session ID if available (for streaming HTTP clients).""" |
| if self.session_id_callback is not None: |
| return self.session_id_callback() |
| return None |
|
|