Spaces:
Sleeping
Sleeping
| '''Classes for handling MCP server connection and operations.''' | |
| import asyncio | |
| import logging | |
| from typing import Any, Dict, List, Optional | |
| from urllib.parse import urlparse | |
| from dataclasses import dataclass | |
| from mcp import ClientSession | |
| from mcp.client.sse import sse_client | |
| class ToolParameter: | |
| '''Represents a parameter for a tool. | |
| Attributes: | |
| name: Parameter name | |
| parameter_type: Parameter type (e.g., 'string', 'number') | |
| description: Parameter description | |
| required: Whether the parameter is required | |
| default: Default value for the parameter | |
| ''' | |
| name: str | |
| parameter_type: str | |
| description: str | |
| required: bool = False | |
| default: Any = None | |
| class ToolDef: | |
| '''Represents a tool definition. | |
| Attributes: | |
| name: Tool name | |
| description: Tool description | |
| parameters: List of ToolParameter objects | |
| metadata: Optional dictionary of additional metadata | |
| identifier: Tool identifier (defaults to name) | |
| ''' | |
| name: str | |
| description: str | |
| parameters: List[ToolParameter] | |
| metadata: Optional[Dict[str, Any]] = None | |
| identifier: str = '' | |
| class ToolInvocationResult: | |
| '''Represents the result of a tool invocation. | |
| Attributes: | |
| content: Result content as a string | |
| error_code: Error code (0 for success, 1 for error) | |
| ''' | |
| content: str | |
| error_code: int | |
| class MCPConnectionError(Exception): | |
| '''Exception raised when MCP connection fails''' | |
| pass | |
| class MCPTimeoutError(Exception): | |
| '''Exception raised when MCP operation times out''' | |
| pass | |
| class MCPClientWrapper: | |
| '''Main client wrapper class for interacting with Model Context Protocol (MCP) endpoints''' | |
| def __init__(self, endpoint: str, timeout: float = 360.0, max_retries: int = 3): | |
| '''Initialize MCP client with endpoint URL | |
| Args: | |
| endpoint: The MCP endpoint URL (must be http or https) | |
| timeout: Connection timeout in seconds | |
| max_retries: Maximum number of retry attempts | |
| ''' | |
| self.endpoint = endpoint | |
| self.timeout = timeout | |
| self.max_retries = max_retries | |
| async def _execute_with_retry(self, operation_name: str, operation_func): | |
| '''Execute an operation with retry logic and proper error handling | |
| Args: | |
| operation_name: Name of the operation for logging | |
| operation_func: Async function to execute | |
| Returns: | |
| Result of the operation | |
| Raises: | |
| MCPConnectionError: If connection fails after all retries | |
| MCPTimeoutError: If operation times out | |
| ''' | |
| logger = logging.getLogger(__name__ + '_execute_with_retry') | |
| last_exception = None | |
| for attempt in range(self.max_retries): | |
| try: | |
| logger.debug( | |
| 'Attempting %s (attempt %s/%s)', | |
| operation_name, | |
| attempt + 1, | |
| self.max_retries | |
| ) | |
| # Execute with timeout | |
| result = await asyncio.wait_for(operation_func(), timeout=self.timeout) | |
| logger.debug('%s completed successfully', operation_name) | |
| return result | |
| except asyncio.TimeoutError as e: | |
| last_exception = MCPTimeoutError( | |
| f'{operation_name} timed out after {self.timeout} seconds' | |
| ) | |
| logger.warning('%s timed out on attempt %s: %s', operation_name, attempt + 1, e) | |
| except Exception as e: # pylint: disable=broad-exception-caught | |
| last_exception = e | |
| logger.warning('%s failed on attempt %s: %s', operation_name, attempt + 1, str(e)) | |
| # Don't retry on certain types of errors | |
| if isinstance(e, (ValueError, TypeError)): | |
| break | |
| # Wait before retry (exponential backoff) | |
| if attempt < self.max_retries - 1: | |
| wait_time = 2 ** attempt | |
| logger.debug('Waiting %s seconds before retry', wait_time) | |
| await asyncio.sleep(wait_time) | |
| # All retries failed | |
| if isinstance(last_exception, MCPTimeoutError): | |
| raise last_exception | |
| else: | |
| raise MCPConnectionError( | |
| f'{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}' | |
| ) | |
| async def _safe_sse_operation(self, operation_func): | |
| '''Safely execute an SSE operation with proper task cleanup | |
| Args: | |
| operation_func: Function that takes (streams, session) as arguments | |
| Returns: | |
| Result of the operation | |
| ''' | |
| logger = logging.getLogger(__name__ + '_safe_sse_operation') | |
| streams = None | |
| session = None | |
| try: | |
| # Create SSE client with proper error handling | |
| streams = sse_client(self.endpoint) | |
| async with streams as stream_context: | |
| # Create session with proper cleanup | |
| session = ClientSession(*stream_context) | |
| async with session as session_context: | |
| await session_context.initialize() | |
| return await operation_func(session_context) | |
| except Exception as e: | |
| logger.error('SSE operation failed: %s', str(e)) | |
| # Ensure proper cleanup of any remaining tasks | |
| if session: | |
| try: | |
| # Cancel any pending tasks in the session | |
| tasks = [task for task in asyncio.all_tasks() if not task.done()] | |
| if tasks: | |
| logger.debug('Cancelling %s pending tasks', len(tasks)) | |
| for task in tasks: | |
| task.cancel() | |
| # Wait for tasks to be cancelled | |
| await asyncio.gather(*tasks, return_exceptions=True) | |
| except Exception as cleanup_error: # pylint: disable=broad-exception-caught | |
| logger.warning('Error during task cleanup: %s', cleanup_error) | |
| raise | |
| async def list_tools(self) -> List[ToolDef]: | |
| '''List available tools from the MCP endpoint | |
| Returns: | |
| List of ToolDef objects describing available tools | |
| Raises: | |
| MCPConnectionError: If connection fails | |
| MCPTimeoutError: If operation times out | |
| ''' | |
| async def _list_tools_operation(): | |
| async def _operation(session): | |
| tools_result = await session.list_tools() | |
| tools = [] | |
| for tool in tools_result.tools: | |
| parameters = [] | |
| required_params = tool.inputSchema.get('required', []) | |
| for param_name, param_schema in tool.inputSchema.get('properties', {}).items(): | |
| parameters.append( | |
| ToolParameter( | |
| name=param_name, | |
| parameter_type=param_schema.get('type', 'string'), | |
| description=param_schema.get('description', ''), | |
| required=param_name in required_params, | |
| default=param_schema.get('default'), | |
| ) | |
| ) | |
| tools.append( | |
| ToolDef( | |
| name=tool.name, | |
| description=tool.description, | |
| parameters=parameters, | |
| metadata={'endpoint': self.endpoint}, | |
| identifier=tool.name # Using name as identifier | |
| ) | |
| ) | |
| self.tools = tools | |
| return tools | |
| return await self._safe_sse_operation(_operation) | |
| return await self._execute_with_retry('list_tools', _list_tools_operation) | |
| async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: | |
| '''Invoke a specific tool with parameters | |
| Args: | |
| tool_name: Name of the tool to invoke | |
| kwargs: Dictionary of parameters to pass to the tool | |
| Returns: | |
| ToolInvocationResult containing the tool's response | |
| Raises: | |
| MCPConnectionError: If connection fails | |
| MCPTimeoutError: If operation times out | |
| ''' | |
| async def _invoke_tool_operation(): | |
| async def _operation(session): | |
| result = await session.call_tool(tool_name, kwargs) | |
| return ToolInvocationResult( | |
| content='\n'.join([result.model_dump_json() for result in result.content]), | |
| error_code=1 if result.isError else 0, | |
| ) | |
| return await self._safe_sse_operation(_operation) | |
| return await self._execute_with_retry(f'invoke_tool({tool_name})', _invoke_tool_operation) | |
| async def check_connection(self) -> bool: | |
| '''Check if the MCP endpoint is reachable | |
| Returns: | |
| True if connection is successful, False otherwise | |
| ''' | |
| logger = logging.getLogger(__name__ + '_check_connection') | |
| try: | |
| await self.list_tools() | |
| return True | |
| except Exception as e: # pylint: disable=broad-exception-caught | |
| logger.debug('Connection check failed: %s', str(e)) | |
| return False | |
| def get_endpoint_info(self) -> Dict[str, Any]: | |
| '''Get information about the configured endpoint | |
| Returns: | |
| Dictionary with endpoint information | |
| ''' | |
| parsed = urlparse(self.endpoint) | |
| return { | |
| 'endpoint': self.endpoint, | |
| 'scheme': parsed.scheme, | |
| 'hostname': parsed.hostname, | |
| 'port': parsed.port, | |
| 'path': parsed.path, | |
| 'timeout': self.timeout, | |
| 'max_retries': self.max_retries | |
| } | |