Spaces:
Build error
Build error
| import asyncio | |
| import logging | |
| from typing import Optional | |
| from contextlib import AsyncExitStack | |
| log = logging.getLogger(__name__) | |
| import anyio | |
| from mcp import ClientSession | |
| from mcp.client.auth import OAuthClientProvider, TokenStorage | |
| from mcp.client.streamable_http import streamablehttp_client | |
| from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken | |
| import httpx | |
| from open_webui.env import AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER | |
| def _build_httpx_client(headers=None, timeout=None, auth=None, verify=True): | |
| """Create an httpx AsyncClient for MCP transport. | |
| Falls back to AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER when the caller | |
| (i.e. the MCP SDK) does not supply an explicit timeout. | |
| Note: verify must be passed at construction time because httpx | |
| configures the SSL context during __init__. Setting client.verify = False | |
| after construction does not affect the underlying transport's SSL context. | |
| """ | |
| kwargs = { | |
| 'follow_redirects': True, | |
| 'verify': verify, | |
| } | |
| if timeout is not None: | |
| kwargs['timeout'] = timeout | |
| elif AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER is not None: | |
| kwargs['timeout'] = float(AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER) | |
| if headers is not None: | |
| kwargs['headers'] = headers | |
| if auth is not None: | |
| kwargs['auth'] = auth | |
| return httpx.AsyncClient(**kwargs) | |
| def create_httpx_client(headers=None, timeout=None, auth=None): | |
| return _build_httpx_client(headers=headers, timeout=timeout, auth=auth, verify=True) | |
| def create_insecure_httpx_client(headers=None, timeout=None, auth=None): | |
| return _build_httpx_client(headers=headers, timeout=timeout, auth=auth, verify=False) | |
| class MCPClient: | |
| def __init__(self): | |
| self.session: Optional[ClientSession] = None | |
| self.exit_stack = None | |
| async def connect(self, url: str, headers: Optional[dict] = None): | |
| async with AsyncExitStack() as exit_stack: | |
| try: | |
| self._streams_context = streamablehttp_client( | |
| url, | |
| headers=headers, | |
| httpx_client_factory=create_httpx_client | |
| if AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL | |
| else create_insecure_httpx_client, | |
| ) | |
| transport = await exit_stack.enter_async_context(self._streams_context) | |
| read_stream, write_stream, _ = transport | |
| self._session_context = ClientSession(read_stream, write_stream) # pylint: disable=W0201 | |
| self.session = await exit_stack.enter_async_context(self._session_context) | |
| with anyio.fail_after(10): | |
| await self.session.initialize() | |
| self.exit_stack = exit_stack.pop_all() | |
| except Exception as e: | |
| await asyncio.shield(self.disconnect()) | |
| raise e | |
| async def list_tool_specs(self) -> Optional[dict]: | |
| if not self.session: | |
| raise RuntimeError('MCP client is not connected.') | |
| result = await self.session.list_tools() | |
| tools = result.tools | |
| tool_specs = [] | |
| for tool in tools: | |
| name = tool.name | |
| description = tool.description | |
| inputSchema = tool.inputSchema | |
| # TODO: handle outputSchema if needed | |
| outputSchema = getattr(tool, 'outputSchema', None) | |
| tool_specs.append({'name': name, 'description': description, 'parameters': inputSchema}) | |
| return tool_specs | |
| async def call_tool(self, function_name: str, function_args: dict) -> Optional[dict]: | |
| if not self.session: | |
| raise RuntimeError('MCP client is not connected.') | |
| result = await self.session.call_tool(function_name, function_args) | |
| if not result: | |
| raise Exception('No result returned from MCP tool call.') | |
| result_dict = result.model_dump(mode='json') | |
| result_content = result_dict.get('content', {}) | |
| if result.isError: | |
| raise Exception(result_content) | |
| else: | |
| return result_content | |
| async def list_resources(self, cursor: Optional[str] = None) -> Optional[dict]: | |
| if not self.session: | |
| raise RuntimeError('MCP client is not connected.') | |
| result = await self.session.list_resources(cursor=cursor) | |
| if not result: | |
| raise Exception('No result returned from MCP list_resources call.') | |
| result_dict = result.model_dump() | |
| resources = result_dict.get('resources', []) | |
| return resources | |
| async def read_resource(self, uri: str) -> Optional[dict]: | |
| if not self.session: | |
| raise RuntimeError('MCP client is not connected.') | |
| result = await self.session.read_resource(uri) | |
| if not result: | |
| raise Exception('No result returned from MCP read_resource call.') | |
| result_dict = result.model_dump() | |
| return result_dict | |
| async def disconnect(self): | |
| """Clean up and close the session. | |
| This method is idempotent — calling it multiple times or on a | |
| client that was never connected is safe. It shields the close | |
| operation from CancelledError and adds a timeout so a hung MCP | |
| server cannot block the event loop indefinitely. | |
| """ | |
| exit_stack = self.exit_stack | |
| if exit_stack is None: | |
| return | |
| # Prevent double-close from concurrent callers | |
| self.exit_stack = None | |
| self.session = None | |
| try: | |
| # IMPORTANT: Do NOT use asyncio.shield() or asyncio.wait_for() | |
| # because they create a new asyncio task, which violates the MCP SDK's | |
| # requirement that its TaskGroup be exited in the exact same task. | |
| # ALSO do NOT use anyio.CancelScope(shield=True) or anyio.fail_after(), | |
| # because they push a new cancel scope onto the task, violating LIFO | |
| # order when aclose() attempts to exit the inner TaskGroup. | |
| # We simply call aclose() directly. If the task is cancelled, the | |
| # sockets will eventually be cleaned up by garbage collection. | |
| await exit_stack.aclose() | |
| except TimeoutError: | |
| log.warning('MCPClient.disconnect() timed out after 5 s') | |
| except RuntimeError as exc: | |
| log.debug('MCPClient.disconnect() suppressed RuntimeError: %s', exc) | |
| except Exception as exc: | |
| log.debug('MCPClient.disconnect() error: %s', exc) | |
| async def __aenter__(self): | |
| await self.exit_stack.__aenter__() | |
| return self | |
| async def __aexit__(self, exc_type, exc_value, traceback): | |
| await self.exit_stack.__aexit__(exc_type, exc_value, traceback) | |
| await self.disconnect() | |