Spaces:
Sleeping
Sleeping
File size: 4,128 Bytes
676582c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
MCP Tool Registry
Manages registration and execution of MCP tools with user context injection.
Security: user_id is injected by the backend, never trusted from LLM output.
"""
from typing import Dict, List, Any, Callable, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class ToolDefinition:
"""Definition of an MCP tool for LLM function calling."""
name: str
description: str
parameters: Dict[str, Any]
@dataclass
class ToolExecutionResult:
"""Result of executing an MCP tool."""
success: bool
data: Optional[Dict[str, Any]] = None
message: Optional[str] = None
error: Optional[str] = None
class MCPToolRegistry:
"""
Registry for MCP tools with user context injection.
This class manages tool registration and execution, ensuring that
user_id is always injected by the backend for security.
"""
def __init__(self):
self._tools: Dict[str, Callable] = {}
self._tool_definitions: Dict[str, ToolDefinition] = {}
def register_tool(
self,
name: str,
description: str,
parameters: Dict[str, Any],
handler: Callable
) -> None:
"""
Register an MCP tool with its handler function.
Args:
name: Tool name (e.g., "add_task")
description: Tool description for LLM
parameters: JSON schema for tool parameters
handler: Async function that executes the tool
"""
self._tools[name] = handler
self._tool_definitions[name] = ToolDefinition(
name=name,
description=description,
parameters=parameters
)
logger.info(f"Registered MCP tool: {name}")
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""
Get tool definitions in format suitable for LLM function calling.
Returns:
List of tool definitions with name, description, and parameters
"""
return [
{
"name": tool_def.name,
"description": tool_def.description,
"parameters": tool_def.parameters
}
for tool_def in self._tool_definitions.values()
]
async def execute_tool(
self,
tool_name: str,
arguments: Dict[str, Any],
user_id: int
) -> ToolExecutionResult:
"""
Execute an MCP tool with user context injection.
SECURITY: user_id is injected by the backend, never from LLM output.
Args:
tool_name: Name of the tool to execute
arguments: Tool arguments from LLM
user_id: User ID (injected by backend, not from LLM)
Returns:
ToolExecutionResult with success status and data/error
"""
if tool_name not in self._tools:
logger.error(f"Tool not found: {tool_name}")
return ToolExecutionResult(
success=False,
error=f"Tool '{tool_name}' not found"
)
try:
# Inject user_id into arguments for security
arguments_with_context = {**arguments, "user_id": user_id}
logger.info(f"Executing tool: {tool_name} for user: {user_id}")
# Execute the tool handler
handler = self._tools[tool_name]
result = await handler(**arguments_with_context)
logger.info(f"Tool execution successful: {tool_name}")
return result
except Exception as e:
logger.error(f"Tool execution failed: {tool_name} - {str(e)}")
return ToolExecutionResult(
success=False,
error=f"Tool execution failed: {str(e)}"
)
def list_tools(self) -> List[str]:
"""Get list of registered tool names."""
return list(self._tools.keys())
def has_tool(self, tool_name: str) -> bool:
"""Check if a tool is registered."""
return tool_name in self._tools
# Global registry instance
tool_registry = MCPToolRegistry()
|