| """MCP Client for Open-LLM-Vtuber."""
|
|
|
| from contextlib import AsyncExitStack
|
| from typing import Dict, Any, List, Callable
|
| from loguru import logger
|
| from datetime import timedelta
|
|
|
| from mcp import ClientSession, StdioServerParameters
|
| from mcp.types import Tool
|
| from mcp.client.stdio import stdio_client
|
|
|
| from .server_registry import ServerRegistry
|
|
|
| DEFAULT_TIMEOUT = timedelta(seconds=30)
|
|
|
|
|
| class MCPClient:
|
| """MCP Client for Open-LLM-Vtuber.
|
| Manages persistent connections to multiple MCP servers.
|
| """
|
|
|
| def __init__(
|
| self,
|
| server_registery: ServerRegistry,
|
| send_text: Callable = None,
|
| client_uid: str = None,
|
| ) -> None:
|
| """Initialize the MCP Client."""
|
| self.exit_stack: AsyncExitStack = AsyncExitStack()
|
| self.active_sessions: Dict[str, ClientSession] = {}
|
| self._list_tools_cache: Dict[str, List[Tool]] = {}
|
| self._send_text: Callable = send_text
|
| self._client_uid: str = client_uid
|
|
|
| if isinstance(server_registery, ServerRegistry):
|
| self.server_registery = server_registery
|
| else:
|
| raise TypeError(
|
| "MCPC: Invalid server manager. Must be an instance of ServerRegistry."
|
| )
|
| logger.info("MCPC: Initialized MCPClient instance.")
|
|
|
| async def _ensure_server_running_and_get_session(
|
| self, server_name: str
|
| ) -> ClientSession:
|
| """Gets the existing session or creates a new one."""
|
| if server_name in self.active_sessions:
|
| return self.active_sessions[server_name]
|
|
|
| logger.info(f"MCPC: Starting and connecting to server '{server_name}'...")
|
| server = self.server_registery.get_server(server_name)
|
| if not server:
|
| raise ValueError(
|
| f"MCPC: Server '{server_name}' not found in available servers."
|
| )
|
|
|
| timeout = server.timeout if server.timeout else DEFAULT_TIMEOUT
|
|
|
| server_params = StdioServerParameters(
|
| command=server.command, args=server.args, env=server.env, cwd=server.cwd
|
| )
|
|
|
| try:
|
| stdio_transport = await self.exit_stack.enter_async_context(
|
| stdio_client(server_params)
|
| )
|
| read, write = stdio_transport
|
|
|
| session = await self.exit_stack.enter_async_context(
|
| ClientSession(read, write, read_timeout_seconds=timeout)
|
| )
|
| await session.initialize()
|
|
|
| self.active_sessions[server_name] = session
|
| logger.info(f"MCPC: Successfully connected to server '{server_name}'.")
|
| return session
|
| except Exception as e:
|
| logger.exception(f"MCPC: Failed to connect to server '{server_name}': {e}")
|
| raise RuntimeError(
|
| f"MCPC: Failed to connect to server '{server_name}'."
|
| ) from e
|
|
|
| async def list_tools(self, server_name: str) -> List[Tool]:
|
| """List all available tools on the specified server."""
|
|
|
| if server_name in self._list_tools_cache:
|
| logger.debug(f"MCPC: Cache hit for list_tools on server '{server_name}'.")
|
| return self._list_tools_cache[server_name]
|
|
|
| logger.debug(
|
| f"MCPC: Cache miss for list_tools on server '{server_name}'. Fetching..."
|
| )
|
| session = await self._ensure_server_running_and_get_session(server_name)
|
| response = await session.list_tools()
|
|
|
|
|
| self._list_tools_cache[server_name] = response.tools
|
| logger.debug(f"MCPC: Cached list_tools result for server '{server_name}'.")
|
| return response.tools
|
|
|
| async def call_tool(
|
| self, server_name: str, tool_name: str, tool_args: Dict[str, Any]
|
| ) -> Dict[str, Any]:
|
| """Call a tool on the specified server.
|
|
|
| Returns:
|
| Dict containing the metadata and content_items from the tool response.
|
| """
|
| session = await self._ensure_server_running_and_get_session(server_name)
|
| logger.info(f"MCPC: Calling tool '{tool_name}' on server '{server_name}'...")
|
| response = await session.call_tool(tool_name, tool_args)
|
|
|
| if response.isError:
|
| error_text = (
|
| response.content[0].text
|
| if response.content and hasattr(response.content[0], "text")
|
| else "Unknown server error"
|
| )
|
| logger.error(f"MCPC: Error calling tool '{tool_name}': {error_text}")
|
|
|
| return {
|
| "metadata": getattr(response, "metadata", {}),
|
| "content_items": [{"type": "error", "text": error_text}],
|
| }
|
|
|
| content_items = []
|
| if response.content:
|
| for item in response.content:
|
| item_dict = {"type": getattr(item, "type", "text")}
|
|
|
| for attr in [
|
| "text",
|
| "data",
|
| "mimeType",
|
| "url",
|
| "altText",
|
| ]:
|
| if (
|
| hasattr(item, attr) and getattr(item, attr) is not None
|
| ):
|
| item_dict[attr] = getattr(item, attr)
|
| content_items.append(item_dict)
|
| else:
|
| logger.warning(
|
| f"MCPC: Tool '{tool_name}' returned no content. Returning empty content_items."
|
| )
|
| content_items.append(
|
| {"type": "text", "text": ""}
|
| )
|
|
|
| result = {
|
| "metadata": getattr(response, "metadata", {}),
|
| "content_items": content_items,
|
| }
|
| return result
|
|
|
| async def aclose(self) -> None:
|
| """Closes all active server connections."""
|
| logger.info(
|
| f"MCPC: Closing client instance and {len(self.active_sessions)} active connections..."
|
| )
|
| await self.exit_stack.aclose()
|
| self.active_sessions.clear()
|
| self._list_tools_cache.clear()
|
| self.exit_stack = AsyncExitStack()
|
| logger.info("MCPC: Client instance closed.")
|
|
|
| async def __aenter__(self) -> "MCPClient":
|
| """Enter the async context manager."""
|
| return self
|
|
|
| async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
| """Exit the async context manager."""
|
| await self.aclose()
|
| if exc_type:
|
| logger.error(f"MCPC: Exception in async context: {exc_value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|