import asyncio import logging from typing import Any, Awaitable, Callable, Dict, Optional from .tool_models import ToolMetadata, ToolResult logger = logging.getLogger(__name__) EnvProvider = Callable[[Optional[str]], Awaitable[Dict[str, Any]]] ResultFormatter = Callable[[str, str, Dict[str, Any], str], Awaitable[str]] ToolHandler = Callable[[Dict[str, Any]], Awaitable[Any]] class ToolCoordinator: """ 統一管理 MCP 工具調用: - 依 ToolMetadata 注入環境/預設值 - 處理特殊流程(導航) - 統一結果格式 """ def __init__( self, *, env_provider: EnvProvider, tool_lookup: Callable[[str], Optional[ToolHandler]], formatter: ResultFormatter, failure_handlers: Optional[Dict[str, Callable[[Dict[str, Any], Exception], ToolResult]]] = None, ) -> None: self._env_provider = env_provider self._tool_lookup = tool_lookup self._formatter = formatter self._metadata: Dict[str, ToolMetadata] = {} self._failure_handlers = failure_handlers or {} # ------------------------------------------------------------------ # def register(self, metadata: ToolMetadata) -> None: self._metadata[metadata.name] = metadata def get_metadata(self, name: str) -> Optional[ToolMetadata]: return self._metadata.get(name) # ------------------------------------------------------------------ # async def invoke( self, tool_name: str, arguments: Dict[str, Any], *, user_id: Optional[str], original_message: str, ) -> ToolResult: metadata = self._metadata.get(tool_name, ToolMetadata(name=tool_name)) if metadata.flow == "navigation": return await self._handle_navigation(arguments, user_id, original_message, metadata) prepared_args = await self._prepare_arguments(arguments, metadata, user_id) raw_result = await self._execute(tool_name, prepared_args) return await self._format_result(tool_name, raw_result, metadata, original_message) async def _prepare_arguments( self, arguments: Dict[str, Any], metadata: ToolMetadata, user_id: Optional[str], ) -> Dict[str, Any]: merged = dict(metadata.defaults) merged.update(arguments or {}) # 注入 user_id 到參數中,讓工具可以從 arguments 中讀取 if user_id: merged["_user_id"] = user_id logger.info(f"📦 [Coordinator] 準備參數: tool={metadata.name}, user_id={user_id}, requires_env={metadata.requires_env}") if metadata.requires_env and user_id: env_ctx = await self._env_provider(user_id) logger.info(f"📦 [Coordinator] 環境資訊: {env_ctx}") if env_ctx: for field in metadata.requires_env: if merged.get(field) is not None: continue env_value = env_ctx.get(field) # 只注入非 None 的值,避免覆蓋工具的預設值或觸發 schema 驗證錯誤 if env_value is not None: merged[field] = env_value logger.info(f"📦 [Coordinator] 注入環境變數: {field}={env_value}") elif not user_id: logger.warning(f"⚠️ [Coordinator] user_id 為 None,無法注入環境變數") logger.info(f"📦 [Coordinator] 最終參數: {merged}") return merged async def _execute(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: handler = self._tool_lookup(tool_name) if not handler: raise RuntimeError(f"工具 {tool_name} 無可用 handler") retry_delays = [1, 2, 5] last_exc: Optional[BaseException] = None for attempt, delay in enumerate(retry_delays, start=1): try: result = await asyncio.wait_for(handler(arguments), timeout=30.0) if isinstance(result, dict): return result return {"success": True, "content": str(result)} except Exception as exc: # noqa: BLE001 last_exc = exc logger.warning("工具 %s 執行失敗 (attempt=%s): %s", tool_name, attempt, exc) await asyncio.sleep(delay) handler = self._failure_handlers.get(tool_name) if handler and last_exc: return handler(arguments, last_exc) # type: ignore[arg-type] raise RuntimeError(f"工具 {tool_name} 執行失敗:{last_exc}") # type: ignore[arg-type] async def _format_result( self, tool_name: str, result: Dict[str, Any], metadata: ToolMetadata, original_message: str, ) -> ToolResult: if isinstance(result, ToolResult): return result if result.get("success") and result.get("content"): message = str(result.get("content")) elif result.get("success"): message = "操作完成,但無額外內容。" else: raise RuntimeError(result.get("error") or f"{tool_name} 執行失敗") payload = {k: v for k, v in result.items() if k not in {"success", "content", "error"}} if metadata.enable_reformat: try: message = await self._formatter(tool_name, message, payload, original_message) except Exception as exc: # noqa: BLE001 logger.warning("AI 格式化失敗,改用原訊息:%s", exc) return ToolResult( name=tool_name, message=message, data=payload or None, raw=result, ) # ------------------------------------------------------------------ # async def _handle_navigation( self, arguments: Dict[str, Any], user_id: Optional[str], original_message: str, metadata: ToolMetadata, ) -> ToolResult: geo_result = await self._execute(metadata.name, arguments or {}) if not geo_result.get("success"): raise RuntimeError(geo_result.get("error") or "地點查詢失敗") data = geo_result.get("data") or {} best_match = data.get("best_match") or {} dest_lat = best_match.get("lat") dest_lon = best_match.get("lon") if dest_lat is None or dest_lon is None: return ToolResult( name=metadata.name, message=str(geo_result.get("content") or "找不到合適的目的地"), data=data, raw=geo_result, ) env_ctx = await self._env_provider(user_id) if user_id else {} origin_lat = env_ctx.get("lat") origin_lon = env_ctx.get("lon") if origin_lat is None or origin_lon is None: return ToolResult( name=metadata.name, message=str(geo_result.get("content") or "取得目的地座標成功"), data=data, raw=geo_result, metadata={"note": "缺少目前位置,僅返回地點資訊"}, ) directions_args = { "origin_lat": float(origin_lat), "origin_lon": float(origin_lon), "dest_lat": float(dest_lat), "dest_lon": float(dest_lon), "origin_label": env_ctx.get("label") or env_ctx.get("address_display") or "目前位置", "dest_label": best_match.get("label") or arguments.get("query"), "mode": "foot-walking", } directions_meta = self._metadata.get("directions", ToolMetadata(name="directions")) prepared = await self._prepare_arguments(directions_args, directions_meta, user_id) directions_result = await self._execute("directions", prepared) return await self._format_result("directions", directions_result, directions_meta, original_message)