Spaces:
Build error
Build error
File size: 6,837 Bytes
87a665c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import asyncio
import logging
from typing import Optional
from contextlib import AsyncExitStack
log = logging.getLogger(__name__)
import anyio
from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
import httpx
from open_webui.env import AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER
def _build_httpx_client(headers=None, timeout=None, auth=None, verify=True):
"""Create an httpx AsyncClient for MCP transport.
Falls back to AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER when the caller
(i.e. the MCP SDK) does not supply an explicit timeout.
Note: verify must be passed at construction time because httpx
configures the SSL context during __init__. Setting client.verify = False
after construction does not affect the underlying transport's SSL context.
"""
kwargs = {
'follow_redirects': True,
'verify': verify,
}
if timeout is not None:
kwargs['timeout'] = timeout
elif AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER is not None:
kwargs['timeout'] = float(AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER)
if headers is not None:
kwargs['headers'] = headers
if auth is not None:
kwargs['auth'] = auth
return httpx.AsyncClient(**kwargs)
def create_httpx_client(headers=None, timeout=None, auth=None):
return _build_httpx_client(headers=headers, timeout=timeout, auth=auth, verify=True)
def create_insecure_httpx_client(headers=None, timeout=None, auth=None):
return _build_httpx_client(headers=headers, timeout=timeout, auth=auth, verify=False)
class MCPClient:
def __init__(self):
self.session: Optional[ClientSession] = None
self.exit_stack = None
async def connect(self, url: str, headers: Optional[dict] = None):
async with AsyncExitStack() as exit_stack:
try:
self._streams_context = streamablehttp_client(
url,
headers=headers,
httpx_client_factory=create_httpx_client
if AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
else create_insecure_httpx_client,
)
transport = await exit_stack.enter_async_context(self._streams_context)
read_stream, write_stream, _ = transport
self._session_context = ClientSession(read_stream, write_stream) # pylint: disable=W0201
self.session = await exit_stack.enter_async_context(self._session_context)
with anyio.fail_after(10):
await self.session.initialize()
self.exit_stack = exit_stack.pop_all()
except Exception as e:
await asyncio.shield(self.disconnect())
raise e
async def list_tool_specs(self) -> Optional[dict]:
if not self.session:
raise RuntimeError('MCP client is not connected.')
result = await self.session.list_tools()
tools = result.tools
tool_specs = []
for tool in tools:
name = tool.name
description = tool.description
inputSchema = tool.inputSchema
# TODO: handle outputSchema if needed
outputSchema = getattr(tool, 'outputSchema', None)
tool_specs.append({'name': name, 'description': description, 'parameters': inputSchema})
return tool_specs
async def call_tool(self, function_name: str, function_args: dict) -> Optional[dict]:
if not self.session:
raise RuntimeError('MCP client is not connected.')
result = await self.session.call_tool(function_name, function_args)
if not result:
raise Exception('No result returned from MCP tool call.')
result_dict = result.model_dump(mode='json')
result_content = result_dict.get('content', {})
if result.isError:
raise Exception(result_content)
else:
return result_content
async def list_resources(self, cursor: Optional[str] = None) -> Optional[dict]:
if not self.session:
raise RuntimeError('MCP client is not connected.')
result = await self.session.list_resources(cursor=cursor)
if not result:
raise Exception('No result returned from MCP list_resources call.')
result_dict = result.model_dump()
resources = result_dict.get('resources', [])
return resources
async def read_resource(self, uri: str) -> Optional[dict]:
if not self.session:
raise RuntimeError('MCP client is not connected.')
result = await self.session.read_resource(uri)
if not result:
raise Exception('No result returned from MCP read_resource call.')
result_dict = result.model_dump()
return result_dict
async def disconnect(self):
"""Clean up and close the session.
This method is idempotent — calling it multiple times or on a
client that was never connected is safe. It shields the close
operation from CancelledError and adds a timeout so a hung MCP
server cannot block the event loop indefinitely.
"""
exit_stack = self.exit_stack
if exit_stack is None:
return
# Prevent double-close from concurrent callers
self.exit_stack = None
self.session = None
try:
# IMPORTANT: Do NOT use asyncio.shield() or asyncio.wait_for()
# because they create a new asyncio task, which violates the MCP SDK's
# requirement that its TaskGroup be exited in the exact same task.
# ALSO do NOT use anyio.CancelScope(shield=True) or anyio.fail_after(),
# because they push a new cancel scope onto the task, violating LIFO
# order when aclose() attempts to exit the inner TaskGroup.
# We simply call aclose() directly. If the task is cancelled, the
# sockets will eventually be cleaned up by garbage collection.
await exit_stack.aclose()
except TimeoutError:
log.warning('MCPClient.disconnect() timed out after 5 s')
except RuntimeError as exc:
log.debug('MCPClient.disconnect() suppressed RuntimeError: %s', exc)
except Exception as exc:
log.debug('MCPClient.disconnect() error: %s', exc)
async def __aenter__(self):
await self.exit_stack.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.exit_stack.__aexit__(exc_type, exc_value, traceback)
await self.disconnect()
|