Spaces:
Paused
Paused
| """ | |
| LiteLLM MCP Server Routes | |
| """ | |
| import asyncio | |
| from typing import Any, Dict, List, Optional, Union | |
| from anyio import BrokenResourceError | |
| from fastapi import APIRouter, Depends, HTTPException, Request | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import ConfigDict, ValidationError | |
| from litellm._logging import verbose_logger | |
| from litellm.constants import MCP_TOOL_NAME_PREFIX | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.proxy.auth.user_api_key_auth import user_api_key_auth | |
| from litellm.types.mcp_server.mcp_server_manager import MCPInfo | |
| from litellm.types.utils import StandardLoggingMCPToolCall | |
| from litellm.utils import client | |
| # Check if MCP is available | |
| # "mcp" requires python 3.10 or higher, but several litellm users use python 3.8 | |
| # We're making this conditional import to avoid breaking users who use python 3.8. | |
| try: | |
| from mcp.server import Server | |
| MCP_AVAILABLE = True | |
| except ImportError as e: | |
| verbose_logger.debug(f"MCP module not found: {e}") | |
| MCP_AVAILABLE = False | |
| router = APIRouter( | |
| prefix="/mcp", | |
| tags=["mcp"], | |
| ) | |
| if MCP_AVAILABLE: | |
| from mcp.server import NotificationOptions, Server | |
| from mcp.server.models import InitializationOptions | |
| from mcp.types import EmbeddedResource as MCPEmbeddedResource | |
| from mcp.types import ImageContent as MCPImageContent | |
| from mcp.types import TextContent as MCPTextContent | |
| from mcp.types import Tool as MCPTool | |
| from .mcp_server_manager import global_mcp_server_manager | |
| from .sse_transport import SseServerTransport | |
| from .tool_registry import global_mcp_tool_registry | |
| ###################################################### | |
| ############ MCP Tools List REST API Response Object # | |
| # Defined here because we don't want to add `mcp` as a | |
| # required dependency for `litellm` pip package | |
| ###################################################### | |
| class ListMCPToolsRestAPIResponseObject(MCPTool): | |
| """ | |
| Object returned by the /tools/list REST API route. | |
| """ | |
| mcp_info: Optional[MCPInfo] = None | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| ######################################################## | |
| ############ Initialize the MCP Server ################# | |
| ######################################################## | |
| router = APIRouter( | |
| prefix="/mcp", | |
| tags=["mcp"], | |
| ) | |
| server: Server = Server("litellm-mcp-server") | |
| sse: SseServerTransport = SseServerTransport("/mcp/sse/messages") | |
| ######################################################## | |
| ############### MCP Server Routes ####################### | |
| ######################################################## | |
| async def list_tools() -> list[MCPTool]: | |
| """ | |
| List all available tools | |
| """ | |
| return await _list_mcp_tools() | |
| async def _list_mcp_tools() -> List[MCPTool]: | |
| """ | |
| List all available tools | |
| """ | |
| tools = [] | |
| for tool in global_mcp_tool_registry.list_tools(): | |
| tools.append( | |
| MCPTool( | |
| name=tool.name, | |
| description=tool.description, | |
| inputSchema=tool.input_schema, | |
| ) | |
| ) | |
| verbose_logger.debug( | |
| "GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools() | |
| ) | |
| sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools() | |
| verbose_logger.debug("SSE TOOLS: %s", sse_tools) | |
| if sse_tools is not None: | |
| tools.extend(sse_tools) | |
| return tools | |
| async def mcp_server_tool_call( | |
| name: str, arguments: Dict[str, Any] | None | |
| ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
| """ | |
| Call a specific tool with the provided arguments | |
| Args: | |
| name (str): Name of the tool to call | |
| arguments (Dict[str, Any] | None): Arguments to pass to the tool | |
| Returns: | |
| List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results | |
| Raises: | |
| HTTPException: If tool not found or arguments missing | |
| """ | |
| # Validate arguments | |
| response = await call_mcp_tool( | |
| name=name, | |
| arguments=arguments, | |
| ) | |
| return response | |
| async def call_mcp_tool( | |
| name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any | |
| ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
| """ | |
| Call a specific tool with the provided arguments | |
| """ | |
| if arguments is None: | |
| raise HTTPException( | |
| status_code=400, detail="Request arguments are required" | |
| ) | |
| standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = ( | |
| _get_standard_logging_mcp_tool_call( | |
| name=name, | |
| arguments=arguments, | |
| ) | |
| ) | |
| litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get( | |
| "litellm_logging_obj", None | |
| ) | |
| if litellm_logging_obj: | |
| litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( | |
| standard_logging_mcp_tool_call | |
| ) | |
| litellm_logging_obj.model_call_details["model"] = ( | |
| f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}" | |
| ) | |
| litellm_logging_obj.model_call_details["custom_llm_provider"] = ( | |
| standard_logging_mcp_tool_call.get("mcp_server_name") | |
| ) | |
| # Try managed server tool first | |
| if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping: | |
| return await _handle_managed_mcp_tool(name, arguments) | |
| # Fall back to local tool registry | |
| return await _handle_local_mcp_tool(name, arguments) | |
| def _get_standard_logging_mcp_tool_call( | |
| name: str, | |
| arguments: Dict[str, Any], | |
| ) -> StandardLoggingMCPToolCall: | |
| mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name) | |
| if mcp_server: | |
| mcp_info = mcp_server.mcp_info or {} | |
| return StandardLoggingMCPToolCall( | |
| name=name, | |
| arguments=arguments, | |
| mcp_server_name=mcp_info.get("server_name"), | |
| mcp_server_logo_url=mcp_info.get("logo_url"), | |
| ) | |
| else: | |
| return StandardLoggingMCPToolCall( | |
| name=name, | |
| arguments=arguments, | |
| ) | |
| async def _handle_managed_mcp_tool( | |
| name: str, arguments: Dict[str, Any] | |
| ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
| """Handle tool execution for managed server tools""" | |
| call_tool_result = await global_mcp_server_manager.call_tool( | |
| name=name, | |
| arguments=arguments, | |
| ) | |
| verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result) | |
| return call_tool_result.content | |
| async def _handle_local_mcp_tool( | |
| name: str, arguments: Dict[str, Any] | |
| ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
| """Handle tool execution for local registry tools""" | |
| tool = global_mcp_tool_registry.get_tool(name) | |
| if not tool: | |
| raise HTTPException(status_code=404, detail=f"Tool '{name}' not found") | |
| try: | |
| result = tool.handler(**arguments) | |
| return [MCPTextContent(text=str(result), type="text")] | |
| except Exception as e: | |
| return [MCPTextContent(text=f"Error: {str(e)}", type="text")] | |
| async def handle_sse(request: Request): | |
| verbose_logger.info("new incoming SSE connection established") | |
| async with sse.connect_sse(request) as streams: | |
| try: | |
| await server.run(streams[0], streams[1], options) | |
| except BrokenResourceError: | |
| pass | |
| except asyncio.CancelledError: | |
| pass | |
| except ValidationError: | |
| pass | |
| except Exception: | |
| raise | |
| await request.close() | |
| async def handle_messages(request: Request): | |
| verbose_logger.info("incoming SSE message received") | |
| await sse.handle_post_message(request.scope, request.receive, request._send) | |
| await request.close() | |
| ######################################################## | |
| ############ MCP Server REST API Routes ################# | |
| ######################################################## | |
| async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]: | |
| """ | |
| List all available tools with information about the server they belong to. | |
| Example response: | |
| Tools: | |
| [ | |
| { | |
| "name": "create_zap", | |
| "description": "Create a new zap", | |
| "inputSchema": "tool_input_schema", | |
| "mcp_info": { | |
| "server_name": "zapier", | |
| "logo_url": "https://www.zapier.com/logo.png", | |
| } | |
| }, | |
| { | |
| "name": "fetch_data", | |
| "description": "Fetch data from a URL", | |
| "inputSchema": "tool_input_schema", | |
| "mcp_info": { | |
| "server_name": "fetch", | |
| "logo_url": "https://www.fetch.com/logo.png", | |
| } | |
| } | |
| ] | |
| """ | |
| list_tools_result: List[ListMCPToolsRestAPIResponseObject] = [] | |
| for server in global_mcp_server_manager.mcp_servers: | |
| try: | |
| tools = await global_mcp_server_manager._get_tools_from_server(server) | |
| for tool in tools: | |
| list_tools_result.append( | |
| ListMCPToolsRestAPIResponseObject( | |
| name=tool.name, | |
| description=tool.description, | |
| inputSchema=tool.inputSchema, | |
| mcp_info=server.mcp_info, | |
| ) | |
| ) | |
| except Exception as e: | |
| verbose_logger.exception(f"Error getting tools from {server.name}: {e}") | |
| continue | |
| return list_tools_result | |
| async def call_tool_rest_api( | |
| request: Request, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| REST API to call a specific MCP tool with the provided arguments | |
| """ | |
| from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config | |
| data = await request.json() | |
| data = await add_litellm_data_to_request( | |
| data=data, | |
| request=request, | |
| user_api_key_dict=user_api_key_dict, | |
| proxy_config=proxy_config, | |
| ) | |
| return await call_mcp_tool(**data) | |
| options = InitializationOptions( | |
| server_name="litellm-mcp-server", | |
| server_version="0.1.0", | |
| capabilities=server.get_capabilities( | |
| notification_options=NotificationOptions(), | |
| experimental_capabilities={}, | |
| ), | |
| ) | |