| """ |
| Robust Hugging Face MCP Client - Optimized for HF Spaces |
| |
| This module provides a robust client for interacting with Hugging Face's MCP endpoint |
| with better error handling, TaskGroup avoidance, and compatibility for Hugging Face Spaces. |
| """ |
|
|
| import asyncio |
| import json |
| import logging |
| import os |
| from typing import Any, Dict, List, Optional, Union |
| from datetime import timedelta |
| from contextlib import asynccontextmanager |
|
|
| from mcp.shared.message import SessionMessage |
| from mcp.types import ( |
| JSONRPCMessage, |
| JSONRPCRequest, |
| JSONRPCNotification, |
| JSONRPCResponse, |
| JSONRPCError, |
| ) |
| from mcp.client.streamable_http import streamablehttp_client |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class RobustHFMCPClient: |
| """Robust client for interacting with Hugging Face MCP endpoint optimized for Spaces.""" |
| |
| def __init__(self, hf_token: str, timeout: int = 120): |
| """ |
| Initialize the Robust Hugging Face MCP client. |
| |
| Args: |
| hf_token: Hugging Face API token |
| timeout: Timeout in seconds for HTTP requests |
| """ |
| self.hf_token = hf_token |
| self.url = "https://huggingface.co/mcp" |
| self.headers = { |
| "Authorization": f"Bearer {hf_token}", |
| "User-Agent": "robust-hf-mcp-client/2.0.0", |
| "Accept": "application/json, text/event-stream", |
| "Content-Type": "application/json" |
| } |
| self.timeout = timedelta(seconds=timeout) |
| self.sse_read_timeout = timedelta(seconds=timeout * 2) |
| self.request_id_counter = 0 |
| |
| def _get_next_request_id(self) -> int: |
| """Get the next request ID.""" |
| self.request_id_counter += 1 |
| return self.request_id_counter |
|
|
| async def _execute_single_request_session( |
| self, |
| method: str, |
| params: Optional[Dict[str, Any]] = None |
| ) -> Any: |
| """ |
| Execute a complete MCP session for a single request. |
| This avoids TaskGroup issues by handling everything in sequence. |
| """ |
| request_id = self._get_next_request_id() |
| |
| |
| main_request = JSONRPCRequest( |
| jsonrpc="2.0", |
| id=request_id, |
| method=method, |
| params=params |
| ) |
| |
| async with streamablehttp_client( |
| url=self.url, |
| headers=self.headers, |
| timeout=self.timeout, |
| sse_read_timeout=self.sse_read_timeout, |
| terminate_on_close=False |
| ) as (read_stream, write_stream, get_session_id): |
| |
| |
| logger.info("Starting MCP session initialization...") |
| await self._initialize_session(read_stream, write_stream) |
| |
| |
| logger.info(f"Sending main request: {method}") |
| main_message = JSONRPCMessage(main_request) |
| main_session_message = SessionMessage(main_message) |
| await write_stream.send(main_session_message) |
| |
| |
| logger.info("Waiting for main request response...") |
| response = await self._wait_for_response(read_stream, request_id, timeout=90) |
| |
| return response |
|
|
| async def _initialize_session(self, read_stream, write_stream) -> None: |
| """Initialize the MCP session with proper handshake.""" |
| init_request_id = self._get_next_request_id() |
| |
| |
| init_request = JSONRPCRequest( |
| jsonrpc="2.0", |
| id=init_request_id, |
| method="initialize", |
| params={ |
| "protocolVersion": "2024-11-05", |
| "capabilities": { |
| "tools": {}, |
| "resources": {}, |
| "prompts": {} |
| }, |
| "clientInfo": { |
| "name": "robust-hf-mcp-client", |
| "version": "2.0.0" |
| } |
| } |
| ) |
| |
| init_message = JSONRPCMessage(init_request) |
| init_session_message = SessionMessage(init_message) |
| |
| await write_stream.send(init_session_message) |
| |
| |
| init_response = await self._wait_for_response(read_stream, init_request_id, timeout=60) |
| logger.info("MCP session initialized successfully") |
| |
| |
| initialized_notification = JSONRPCNotification( |
| jsonrpc="2.0", |
| method="notifications/initialized" |
| ) |
| |
| init_notif_message = JSONRPCMessage(initialized_notification) |
| init_notif_session_message = SessionMessage(init_notif_message) |
| |
| await write_stream.send(init_notif_session_message) |
| |
| |
| await asyncio.sleep(1.0) |
|
|
| async def _wait_for_response( |
| self, |
| read_stream, |
| expected_id: int, |
| timeout: int = 60 |
| ) -> Any: |
| """ |
| Wait for a specific response by ID with timeout handling. |
| """ |
| start_time = asyncio.get_event_loop().time() |
| |
| while True: |
| current_time = asyncio.get_event_loop().time() |
| if current_time - start_time > timeout: |
| raise asyncio.TimeoutError(f"Timeout waiting for response to request {expected_id}") |
| |
| try: |
| |
| response = await asyncio.wait_for( |
| read_stream.receive(), |
| timeout=10.0 |
| ) |
| |
| if isinstance(response, Exception): |
| logger.error(f"Received exception in stream: {response}") |
| raise response |
| |
| if isinstance(response, SessionMessage): |
| msg_root = response.message.root |
| |
| if isinstance(msg_root, JSONRPCResponse) and msg_root.id == expected_id: |
| logger.info(f"Received successful response for request {expected_id}") |
| return msg_root.result |
| |
| elif isinstance(msg_root, JSONRPCError) and msg_root.id == expected_id: |
| error_msg = f"Server error for request {expected_id}: {msg_root.error}" |
| logger.error(error_msg) |
| raise Exception(error_msg) |
| |
| else: |
| |
| logger.debug(f"Received unexpected message type: {type(msg_root)} with ID: {getattr(msg_root, 'id', 'N/A')}") |
| continue |
| |
| except asyncio.TimeoutError: |
| |
| logger.debug("Receive timeout, continuing to wait...") |
| continue |
| except Exception as e: |
| if "ClosedResourceError" in str(type(e)) or "StreamClosed" in str(e): |
| raise Exception("Connection closed while waiting for response") |
| logger.error(f"Error while waiting for response: {e}") |
| raise |
|
|
| async def get_all_tools(self) -> List[Dict[str, Any]]: |
| """ |
| Get all available tools from the Hugging Face MCP endpoint. |
| |
| Returns: |
| List of tool definitions |
| """ |
| try: |
| logger.info("Fetching all available tools from Hugging Face MCP") |
| result = await self._execute_single_request_session("tools/list") |
| |
| if isinstance(result, dict) and "tools" in result: |
| tools = result["tools"] |
| logger.info(f"Successfully fetched {len(tools)} tools") |
| return tools |
| else: |
| logger.warning(f"Unexpected response format for tools/list: {result}") |
| return [] |
| |
| except Exception as e: |
| logger.error(f"Failed to get tools: {e}") |
| raise |
|
|
| async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: |
| """ |
| Call a specific tool with the given arguments. |
| |
| Args: |
| tool_name: Name of the tool to call |
| args: Arguments to pass to the tool |
| |
| Returns: |
| The tool's response |
| """ |
| try: |
| logger.info(f"Calling tool '{tool_name}' with args: {args}") |
| |
| params = { |
| "name": tool_name, |
| "arguments": args |
| } |
| |
| result = await self._execute_single_request_session("tools/call", params) |
| logger.info(f"Tool '{tool_name}' executed successfully") |
| return result |
| |
| except Exception as e: |
| logger.error(f"Failed to call tool '{tool_name}': {e}") |
| raise |
|
|
|
|
| class SimplifiedHFMCPClient: |
| """Ultra-simplified client that avoids all TaskGroup usage.""" |
| |
| def __init__(self, hf_token: str, timeout: int = 90): |
| self.hf_token = hf_token |
| self.timeout = timeout |
| self.headers = { |
| "Authorization": f"Bearer {hf_token}", |
| "User-Agent": "simplified-hf-mcp-client/1.0.0" |
| } |
| self.request_counter = 0 |
| |
| def _next_id(self) -> int: |
| self.request_counter += 1 |
| return self.request_counter |
|
|
| async def _simple_mcp_call(self, method: str, params: Optional[Dict[str, Any]] = None) -> Any: |
| """Make a simple MCP call without complex async patterns.""" |
| |
| async with streamablehttp_client( |
| url="https://huggingface.co/mcp", |
| headers=self.headers, |
| timeout=timedelta(seconds=self.timeout), |
| sse_read_timeout=timedelta(seconds=self.timeout * 2), |
| terminate_on_close=False |
| ) as (read_stream, write_stream, get_session_id): |
| |
| responses = {} |
| |
| |
| async def collect_responses(): |
| try: |
| async for message in read_stream: |
| if isinstance(message, Exception): |
| responses['error'] = message |
| break |
| elif isinstance(message, SessionMessage): |
| msg_root = message.message.root |
| if hasattr(msg_root, 'id') and msg_root.id is not None: |
| responses[msg_root.id] = msg_root |
| except Exception as e: |
| responses['error'] = e |
| |
| |
| collector_task = asyncio.create_task(collect_responses()) |
| |
| try: |
| |
| init_id = self._next_id() |
| init_req = JSONRPCRequest( |
| jsonrpc="2.0", |
| id=init_id, |
| method="initialize", |
| params={ |
| "protocolVersion": "2024-11-05", |
| "capabilities": {"tools": {}}, |
| "clientInfo": {"name": "simple-hf-mcp", "version": "1.0.0"} |
| } |
| ) |
| |
| await write_stream.send(SessionMessage(JSONRPCMessage(init_req))) |
| |
| |
| for _ in range(300): |
| if init_id in responses: |
| break |
| if 'error' in responses: |
| raise responses['error'] |
| await asyncio.sleep(0.1) |
| |
| if init_id not in responses: |
| raise Exception("Initialization timeout") |
| |
| |
| notif = JSONRPCNotification( |
| jsonrpc="2.0", |
| method="notifications/initialized" |
| ) |
| await write_stream.send(SessionMessage(JSONRPCMessage(notif))) |
| await asyncio.sleep(0.5) |
| |
| |
| main_id = self._next_id() |
| main_req = JSONRPCRequest( |
| jsonrpc="2.0", |
| id=main_id, |
| method=method, |
| params=params |
| ) |
| |
| await write_stream.send(SessionMessage(JSONRPCMessage(main_req))) |
| |
| |
| for _ in range(600): |
| if main_id in responses: |
| break |
| if 'error' in responses: |
| raise responses['error'] |
| await asyncio.sleep(0.1) |
| |
| if main_id not in responses: |
| raise Exception("Main request timeout") |
| |
| result = responses[main_id] |
| if isinstance(result, JSONRPCResponse): |
| return result.result |
| elif isinstance(result, JSONRPCError): |
| raise Exception(f"Server error: {result.error}") |
| else: |
| raise Exception(f"Unexpected response type: {type(result)}") |
| |
| finally: |
| collector_task.cancel() |
| try: |
| await collector_task |
| except asyncio.CancelledError: |
| pass |
|
|
| async def get_tools(self) -> List[Dict[str, Any]]: |
| """Get all available tools.""" |
| result = await self._simple_mcp_call("tools/list") |
| if isinstance(result, dict) and "tools" in result: |
| return result["tools"] |
| return [] |
|
|
| async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: |
| """Call a specific tool.""" |
| params = { |
| "name": tool_name, |
| "arguments": args |
| } |
| return await self._simple_mcp_call("tools/call", params) |
|
|
|
|
| |
| async def get_hf_tools_robust(hf_token: str, max_retries: int = 3) -> List[Dict[str, Any]]: |
| """ |
| Get all available tools with multiple fallback strategies. |
| |
| Args: |
| hf_token: Hugging Face API token |
| max_retries: Maximum retry attempts per method |
| |
| Returns: |
| List of tool definitions |
| """ |
| last_error = None |
| |
| |
| for attempt in range(max_retries): |
| try: |
| logger.info(f"Trying robust client (attempt {attempt + 1})") |
| client = RobustHFMCPClient(hf_token, timeout=90) |
| tools = await client.get_all_tools() |
| logger.info(f"Robust client succeeded with {len(tools)} tools") |
| return tools |
| except Exception as e: |
| last_error = e |
| logger.warning(f"Robust client attempt {attempt + 1} failed: {e}") |
| if attempt < max_retries - 1: |
| await asyncio.sleep(2 ** attempt) |
| |
| |
| for attempt in range(max_retries): |
| try: |
| logger.info(f"Trying simplified client (attempt {attempt + 1})") |
| client = SimplifiedHFMCPClient(hf_token, timeout=120) |
| tools = await client.get_tools() |
| logger.info(f"Simplified client succeeded with {len(tools)} tools") |
| return tools |
| except Exception as e: |
| last_error = e |
| logger.warning(f"Simplified client attempt {attempt + 1} failed: {e}") |
| if attempt < max_retries - 1: |
| await asyncio.sleep(2 ** attempt) |
| |
| |
| raise Exception(f"All connection strategies failed. Last error: {last_error}") |
|
|
|
|
| async def call_hf_tool_robust( |
| hf_token: str, |
| tool_name: str, |
| args: Dict[str, Any], |
| max_retries: int = 3 |
| ) -> Any: |
| """ |
| Call a specific Hugging Face MCP tool with multiple fallback strategies. |
| |
| Args: |
| hf_token: Hugging Face API token |
| tool_name: Name of the tool to call |
| args: Arguments to pass to the tool |
| max_retries: Maximum retry attempts per method |
| |
| Returns: |
| The tool's response |
| """ |
| last_error = None |
| |
| |
| for attempt in range(max_retries): |
| try: |
| logger.info(f"Trying robust client for tool call (attempt {attempt + 1})") |
| client = RobustHFMCPClient(hf_token, timeout=120) |
| result = await client.call_tool(tool_name, args) |
| logger.info(f"Robust client tool call succeeded") |
| return result |
| except Exception as e: |
| last_error = e |
| logger.warning(f"Robust client tool call attempt {attempt + 1} failed: {e}") |
| if attempt < max_retries - 1: |
| await asyncio.sleep(2 ** attempt) |
| |
| |
| for attempt in range(max_retries): |
| try: |
| logger.info(f"Trying simplified client for tool call (attempt {attempt + 1})") |
| client = SimplifiedHFMCPClient(hf_token, timeout=150) |
| result = await client.call_tool(tool_name, args) |
| logger.info(f"Simplified client tool call succeeded") |
| return result |
| except Exception as e: |
| last_error = e |
| logger.warning(f"Simplified client tool call attempt {attempt + 1} failed: {e}") |
| if attempt < max_retries - 1: |
| await asyncio.sleep(2 ** attempt) |
| |
| |
| raise Exception(f"All tool call strategies failed. Last error: {last_error}") |
|
|
|
|
| |
| async def get_hf_tools(hf_token: str) -> List[Dict[str, Any]]: |
| """Legacy function - now uses robust implementation.""" |
| return await get_hf_tools_robust(hf_token) |
|
|
|
|
| async def call_hf_tool(hf_token: str, tool_name: str, args: Dict[str, Any]) -> Any: |
| """Legacy function - now uses robust implementation.""" |
| return await call_hf_tool_robust(hf_token, tool_name, args) |
|
|
|
|
| |
| async def diagnose_connection_advanced(hf_token: str) -> Dict[str, Any]: |
| """ |
| Advanced connection diagnostics with multiple test scenarios. |
| |
| Args: |
| hf_token: Hugging Face API token |
| |
| Returns: |
| Comprehensive diagnostic information |
| """ |
| diagnostics = { |
| "environment": "huggingface_spaces" if os.getenv("SPACE_ID") else "local", |
| "space_id": os.getenv("SPACE_ID"), |
| "python_version": os.sys.version, |
| "token_length": len(hf_token) if hf_token else 0, |
| "has_token": bool(hf_token), |
| "tests": { |
| "basic_connection": False, |
| "robust_client": False, |
| "simplified_client": False, |
| "tools_fetch": False, |
| "tool_call_test": False |
| }, |
| "errors": {}, |
| "tool_count": 0, |
| "sample_tools": [] |
| } |
| |
| |
| try: |
| async with streamablehttp_client( |
| url="https://huggingface.co/mcp", |
| headers={"Authorization": f"Bearer {hf_token}"}, |
| timeout=timedelta(seconds=10), |
| terminate_on_close=False |
| ) as (read_stream, write_stream, get_session_id): |
| diagnostics["tests"]["basic_connection"] = True |
| logger.info("Basic connection test passed") |
| except Exception as e: |
| diagnostics["errors"]["basic_connection"] = str(e) |
| logger.error(f"Basic connection test failed: {e}") |
| |
| |
| if diagnostics["tests"]["basic_connection"]: |
| try: |
| client = RobustHFMCPClient(hf_token, timeout=60) |
| tools = await client.get_all_tools() |
| diagnostics["tests"]["robust_client"] = True |
| diagnostics["tests"]["tools_fetch"] = True |
| diagnostics["tool_count"] = len(tools) |
| diagnostics["sample_tools"] = [ |
| {"name": tool.get("name"), "description": tool.get("description", "")[:100]} |
| for tool in tools[:3] |
| ] |
| logger.info(f"Robust client test passed - {len(tools)} tools") |
| except Exception as e: |
| diagnostics["errors"]["robust_client"] = str(e) |
| logger.error(f"Robust client test failed: {e}") |
| |
| |
| if not diagnostics["tests"]["robust_client"]: |
| try: |
| client = SimplifiedHFMCPClient(hf_token, timeout=90) |
| tools = await client.get_tools() |
| diagnostics["tests"]["simplified_client"] = True |
| if not diagnostics["tests"]["tools_fetch"]: |
| diagnostics["tests"]["tools_fetch"] = True |
| diagnostics["tool_count"] = len(tools) |
| diagnostics["sample_tools"] = [ |
| {"name": tool.get("name"), "description": tool.get("description", "")[:100]} |
| for tool in tools[:3] |
| ] |
| logger.info(f"Simplified client test passed - {len(tools)} tools") |
| except Exception as e: |
| diagnostics["errors"]["simplified_client"] = str(e) |
| logger.error(f"Simplified client test failed: {e}") |
| |
| |
| if diagnostics["tests"]["tools_fetch"] and diagnostics["sample_tools"]: |
| try: |
| |
| sample_tool_name = diagnostics["sample_tools"][0]["name"] |
| if sample_tool_name: |
| |
| if diagnostics["tests"]["robust_client"]: |
| client = RobustHFMCPClient(hf_token, timeout=60) |
| else: |
| client = SimplifiedHFMCPClient(hf_token, timeout=90) |
| |
| |
| try: |
| result = await client.call_tool(sample_tool_name, {}) |
| diagnostics["tests"]["tool_call_test"] = True |
| logger.info(f"Tool call test passed with {sample_tool_name}") |
| except Exception as tool_error: |
| |
| diagnostics["errors"]["tool_call_test"] = f"Tool call failed (might need args): {str(tool_error)}" |
| logger.warning(f"Tool call test failed: {tool_error}") |
| |
| except Exception as e: |
| diagnostics["errors"]["tool_call_test"] = str(e) |
| logger.error(f"Tool call test setup failed: {e}") |
| |
| return diagnostics |