test_ui / src /open_llm_vtuber /mcpp /mcp_client.py
britto224's picture
Upload 130 files
5669b22 verified
"""MCP Client for Open-LLM-Vtuber."""
from contextlib import AsyncExitStack
from typing import Dict, Any, List, Callable
from loguru import logger
from datetime import timedelta
from mcp import ClientSession, StdioServerParameters
from mcp.types import Tool
from mcp.client.stdio import stdio_client
from .server_registry import ServerRegistry
DEFAULT_TIMEOUT = timedelta(seconds=30)
class MCPClient:
"""MCP Client for Open-LLM-Vtuber.
Manages persistent connections to multiple MCP servers.
"""
def __init__(
self,
server_registery: ServerRegistry,
send_text: Callable = None,
client_uid: str = None,
) -> None:
"""Initialize the MCP Client."""
self.exit_stack: AsyncExitStack = AsyncExitStack()
self.active_sessions: Dict[str, ClientSession] = {}
self._list_tools_cache: Dict[str, List[Tool]] = {} # Cache for list_tools
self._send_text: Callable = send_text
self._client_uid: str = client_uid
if isinstance(server_registery, ServerRegistry):
self.server_registery = server_registery
else:
raise TypeError(
"MCPC: Invalid server manager. Must be an instance of ServerRegistry."
)
logger.info("MCPC: Initialized MCPClient instance.")
async def _ensure_server_running_and_get_session(
self, server_name: str
) -> ClientSession:
"""Gets the existing session or creates a new one."""
if server_name in self.active_sessions:
return self.active_sessions[server_name]
logger.info(f"MCPC: Starting and connecting to server '{server_name}'...")
server = self.server_registery.get_server(server_name)
if not server:
raise ValueError(
f"MCPC: Server '{server_name}' not found in available servers."
)
timeout = server.timeout if server.timeout else DEFAULT_TIMEOUT
server_params = StdioServerParameters(
command=server.command, args=server.args, env=server.env, cwd=server.cwd
)
try:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = stdio_transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write, read_timeout_seconds=timeout)
)
await session.initialize()
self.active_sessions[server_name] = session
logger.info(f"MCPC: Successfully connected to server '{server_name}'.")
return session
except Exception as e:
logger.exception(f"MCPC: Failed to connect to server '{server_name}': {e}")
raise RuntimeError(
f"MCPC: Failed to connect to server '{server_name}'."
) from e
async def list_tools(self, server_name: str) -> List[Tool]:
"""List all available tools on the specified server."""
# Check cache first
if server_name in self._list_tools_cache:
logger.debug(f"MCPC: Cache hit for list_tools on server '{server_name}'.")
return self._list_tools_cache[server_name]
logger.debug(
f"MCPC: Cache miss for list_tools on server '{server_name}'. Fetching..."
)
session = await self._ensure_server_running_and_get_session(server_name)
response = await session.list_tools()
# Store in cache before returning
self._list_tools_cache[server_name] = response.tools
logger.debug(f"MCPC: Cached list_tools result for server '{server_name}'.")
return response.tools
async def call_tool(
self, server_name: str, tool_name: str, tool_args: Dict[str, Any]
) -> Dict[str, Any]:
"""Call a tool on the specified server.
Returns:
Dict containing the metadata and content_items from the tool response.
"""
session = await self._ensure_server_running_and_get_session(server_name)
logger.info(f"MCPC: Calling tool '{tool_name}' on server '{server_name}'...")
response = await session.call_tool(tool_name, tool_args)
if response.isError:
error_text = (
response.content[0].text
if response.content and hasattr(response.content[0], "text")
else "Unknown server error"
)
logger.error(f"MCPC: Error calling tool '{tool_name}': {error_text}")
# Return error information within the standard structure
return {
"metadata": getattr(response, "metadata", {}),
"content_items": [{"type": "error", "text": error_text}],
}
content_items = []
if response.content:
for item in response.content:
item_dict = {"type": getattr(item, "type", "text")}
# Extract available attributes from content item
for attr in [
"text",
"data",
"mimeType",
"url",
"altText",
]: # Added url and altText
if (
hasattr(item, attr) and getattr(item, attr) is not None
): # Check for None
item_dict[attr] = getattr(item, attr)
content_items.append(item_dict)
else:
logger.warning(
f"MCPC: Tool '{tool_name}' returned no content. Returning empty content_items."
)
content_items.append(
{"type": "text", "text": ""}
) # Ensure content_items is not empty
result = {
"metadata": getattr(response, "metadata", {}),
"content_items": content_items,
}
return result
async def aclose(self) -> None:
"""Closes all active server connections."""
logger.info(
f"MCPC: Closing client instance and {len(self.active_sessions)} active connections..."
)
await self.exit_stack.aclose()
self.active_sessions.clear()
self._list_tools_cache.clear() # Clear cache on close
self.exit_stack = AsyncExitStack()
logger.info("MCPC: Client instance closed.")
async def __aenter__(self) -> "MCPClient":
"""Enter the async context manager."""
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the async context manager."""
await self.aclose()
if exc_type:
logger.error(f"MCPC: Exception in async context: {exc_value}")
# if __name__ == "__main__":
# # Test the MCPClient.
# async def main():
# server_registery = ServerRegistry()
# async with MCPClient(server_registery) as client:
# # Assuming 'example' server and 'example_tool' exist
# # The old call used: await client.call_tool("example_tool", {"arg1": "value1"})
# # The new call needs server name:
# try:
# result = await client.call_tool("example", "example_tool", {"arg1": "value1"})
# print(f"Tool result: {result}")
# # Test error handling by calling a non-existent tool
# await client.call_tool("example", "non_existent_tool", {})
# except ValueError as e:
# print(f"Caught expected error: {e}")
# except Exception as e:
# print(f"Caught unexpected error: {e}")
# asyncio.run(main())