Spaces:
Running
Running
| """ | |
| 工具註冊中心 | |
| 統一管理 MCP 工具的 OpenAI Function Calling Schema | |
| 2025 最佳實踐:讓 GPT 原生選擇工具,不需要自定義意圖檢測 Prompt | |
| 重構版本:整合 Pydantic Schema 自動生成 | |
| """ | |
| from typing import Dict, List, Any, Optional, Callable, Type | |
| from dataclasses import dataclass, field | |
| from core.logging import get_logger | |
| from core.tool_schema import ( | |
| ToolSchema, | |
| ToolMetadata, | |
| ToolSchemaRegistry, | |
| tool_schema_registry, | |
| extract_schema_from_mcp_tool, | |
| ) | |
| logger = get_logger("core.tool_registry") | |
| class ToolDefinition: | |
| """工具定義(向後兼容)""" | |
| name: str | |
| description: str | |
| parameters: Dict[str, Any] | |
| handler: Optional[Callable] = None | |
| category: str = "general" | |
| requires_auth: bool = False | |
| requires_location: bool = False | |
| keywords: List[str] = field(default_factory=list) | |
| examples: List[str] = field(default_factory=list) | |
| class ToolRegistry: | |
| """ | |
| 工具註冊中心(重構版) | |
| 功能: | |
| 1. 統一註冊所有 MCP 工具 | |
| 2. 自動從 MCPTool 類別生成 OpenAI Function Calling Schema | |
| 3. 支援工具分類和過濾 | |
| 4. 動態啟用/停用工具 | |
| 5. 整合 ToolSchemaRegistry 提供 Pydantic 支援 | |
| """ | |
| def __init__(self): | |
| self._tools: Dict[str, ToolDefinition] = {} | |
| self._disabled_tools: set = set() | |
| # 整合新的 Schema Registry | |
| self._schema_registry = tool_schema_registry | |
| def register( | |
| self, | |
| name: str, | |
| description: str, | |
| parameters: Dict[str, Any], | |
| handler: Optional[Callable] = None, | |
| category: str = "general", | |
| requires_auth: bool = False, | |
| requires_location: bool = False, | |
| keywords: Optional[List[str]] = None, | |
| examples: Optional[List[str]] = None, | |
| ) -> None: | |
| """註冊工具(向後兼容 + 自動同步到 Schema Registry)""" | |
| self._tools[name] = ToolDefinition( | |
| name=name, | |
| description=description, | |
| parameters=parameters, | |
| handler=handler, | |
| category=category, | |
| requires_auth=requires_auth, | |
| requires_location=requires_location, | |
| keywords=keywords or [], | |
| examples=examples or [], | |
| ) | |
| # 同步到 Schema Registry | |
| schema = ToolSchema( | |
| metadata=ToolMetadata( | |
| name=name, | |
| description=description, | |
| category=category, | |
| keywords=keywords or [], | |
| examples=examples or [], | |
| requires_location=requires_location, | |
| requires_auth=requires_auth, | |
| ), | |
| input_schema=parameters, | |
| handler=handler, | |
| ) | |
| self._schema_registry.register(schema) | |
| logger.debug(f"註冊工具: {name}") | |
| def register_mcp_tool(self, tool_class: Type) -> bool: | |
| """ | |
| 從 MCPTool 類別自動註冊工具 | |
| Args: | |
| tool_class: MCPTool 子類別 | |
| Returns: | |
| 是否註冊成功 | |
| """ | |
| schema = extract_schema_from_mcp_tool(tool_class) | |
| if not schema: | |
| return False | |
| # 註冊到 Schema Registry | |
| self._schema_registry.register(schema) | |
| # 同步到舊的 _tools(向後兼容) | |
| self._tools[schema.metadata.name] = ToolDefinition( | |
| name=schema.metadata.name, | |
| description=schema.metadata.description, | |
| parameters=schema.input_schema, | |
| handler=schema.handler, | |
| category=schema.metadata.category, | |
| requires_auth=schema.metadata.requires_auth, | |
| requires_location=schema.metadata.requires_location, | |
| keywords=schema.metadata.keywords, | |
| examples=schema.metadata.examples, | |
| ) | |
| logger.debug(f"從 MCPTool 註冊工具: {schema.metadata.name}") | |
| return True | |
| def unregister(self, name: str) -> bool: | |
| """取消註冊工具""" | |
| if name in self._tools: | |
| del self._tools[name] | |
| self._schema_registry.unregister(name) | |
| return True | |
| return False | |
| def disable(self, name: str) -> None: | |
| """停用工具""" | |
| self._disabled_tools.add(name) | |
| self._schema_registry.disable(name) | |
| def enable(self, name: str) -> None: | |
| """啟用工具""" | |
| self._disabled_tools.discard(name) | |
| self._schema_registry.enable(name) | |
| def get_tool(self, name: str) -> Optional[ToolDefinition]: | |
| """取得工具定義""" | |
| if name in self._disabled_tools: | |
| return None | |
| return self._tools.get(name) | |
| def get_openai_tools( | |
| self, | |
| categories: Optional[List[str]] = None, | |
| include_location_tools: bool = True, | |
| strict: bool = True, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| 生成 OpenAI Function Calling 格式的工具列表 | |
| Args: | |
| categories: 只包含指定分類的工具 | |
| include_location_tools: 是否包含需要位置的工具 | |
| strict: 是否啟用 strict mode(確保輸出符合 schema) | |
| Returns: | |
| OpenAI tools 格式的列表 | |
| """ | |
| # 優先使用 Schema Registry(支援 strict mode) | |
| return self._schema_registry.get_openai_tools( | |
| categories=categories, | |
| include_location_tools=include_location_tools, | |
| strict=strict, | |
| ) | |
| def get_openai_tools_legacy( | |
| self, | |
| categories: Optional[List[str]] = None, | |
| include_location_tools: bool = True, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| 生成 OpenAI Function Calling 格式的工具列表(舊版,不支援 strict mode) | |
| """ | |
| tools = [] | |
| for name, tool in self._tools.items(): | |
| # 跳過停用的工具 | |
| if name in self._disabled_tools: | |
| continue | |
| # 分類過濾 | |
| if categories and tool.category not in categories: | |
| continue | |
| # 位置過濾 | |
| if not include_location_tools and tool.requires_location: | |
| continue | |
| tools.append({ | |
| "type": "function", | |
| "function": { | |
| "name": tool.name, | |
| "description": tool.description, | |
| "parameters": tool.parameters, | |
| } | |
| }) | |
| return tools | |
| def get_tool_names(self) -> List[str]: | |
| """取得所有已註冊的工具名稱""" | |
| return [ | |
| name for name in self._tools.keys() | |
| if name not in self._disabled_tools | |
| ] | |
| def get_stats(self) -> Dict[str, Any]: | |
| """取得統計資訊""" | |
| return self._schema_registry.get_stats() | |
| def get_summaries(self) -> List[Dict[str, Any]]: | |
| """取得所有工具摘要(用於快速意圖匹配)""" | |
| return self._schema_registry.get_summaries() | |
| # 全域單例 | |
| tool_registry = ToolRegistry() | |
| def register_mcp_tools_to_registry(mcp_server) -> int: | |
| """ | |
| 從 MCP Server 自動註冊工具到 Registry | |
| 2025 重構版:優先使用 MCPTool 類別自動提取 Schema | |
| Args: | |
| mcp_server: MCPServer 實例 | |
| Returns: | |
| 註冊的工具數量 | |
| """ | |
| count = 0 | |
| for tool_name, tool in mcp_server.tools.items(): | |
| # 優先嘗試從 MCPTool 類別提取完整 Schema | |
| if hasattr(tool, 'handler') and hasattr(tool.handler, '__self__'): | |
| tool_class = tool.handler.__self__ | |
| if tool_registry.register_mcp_tool(type(tool_class)): | |
| count += 1 | |
| continue | |
| # 降級:使用舊方法註冊 | |
| description = getattr(tool, 'description', f'{tool_name} 工具') | |
| parameters = {"type": "object", "properties": {}, "required": []} | |
| if hasattr(tool, 'handler') and hasattr(tool.handler, '__self__'): | |
| tool_class = tool.handler.__self__ | |
| if hasattr(tool_class, 'get_input_schema'): | |
| try: | |
| parameters = tool_class.get_input_schema() | |
| except Exception as e: | |
| logger.warning(f"取得 {tool_name} schema 失敗: {e}") | |
| # 提取關鍵字和範例 | |
| keywords = [] | |
| examples = [] | |
| if hasattr(tool, 'handler') and hasattr(tool.handler, '__self__'): | |
| tool_class = tool.handler.__self__ | |
| keywords = getattr(tool_class, 'KEYWORDS', []) | |
| examples = getattr(tool_class, 'USAGE_TIPS', []) | |
| # 判斷分類 | |
| category = _infer_category(tool_name) | |
| # 判斷是否需要位置 | |
| requires_location = _requires_location(tool_name, parameters) | |
| tool_registry.register( | |
| name=tool_name, | |
| description=description, | |
| parameters=parameters, | |
| handler=getattr(tool, 'handler', None), | |
| category=category, | |
| requires_location=requires_location, | |
| keywords=keywords, | |
| examples=examples, | |
| ) | |
| count += 1 | |
| logger.info(f"從 MCP Server 註冊了 {count} 個工具") | |
| return count | |
| def _infer_category(tool_name: str) -> str: | |
| """推斷工具分類""" | |
| name_lower = tool_name.lower() | |
| if any(k in name_lower for k in ['weather', 'forecast']): | |
| return "weather" | |
| if any(k in name_lower for k in ['bus', 'train', 'metro', 'thsr', 'youbike', 'parking']): | |
| return "transportation" | |
| if any(k in name_lower for k in ['geocode', 'directions', 'location']): | |
| return "location" | |
| if any(k in name_lower for k in ['news']): | |
| return "information" | |
| if any(k in name_lower for k in ['exchange', 'currency']): | |
| return "finance" | |
| if any(k in name_lower for k in ['health', 'heart', 'sleep', 'step']): | |
| return "health" | |
| return "general" | |
| def _requires_location(tool_name: str, parameters: Dict) -> bool: | |
| """判斷工具是否需要位置資訊""" | |
| # 檢查參數中是否有 lat/lon | |
| props = parameters.get("properties", {}) | |
| if "lat" in props or "lon" in props or "latitude" in props or "longitude" in props: | |
| return True | |
| # 檢查工具名稱 | |
| location_tools = [ | |
| 'reverse_geocode', 'directions', 'tdx_bus_arrival', | |
| 'tdx_youbike', 'tdx_metro', 'tdx_parking', 'tdx_train', 'tdx_thsr' | |
| ] | |
| return tool_name in location_tools | |