Spaces:
Sleeping
Sleeping
File size: 8,498 Bytes
7d7bc87 921a78a 7d7bc87 921a78a 7d7bc87 921a78a 9746ac7 921a78a 7d7bc87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | import asyncio
import logging
from typing import Any, Awaitable, Callable, Dict, Optional
from .tool_models import ToolMetadata, ToolResult
logger = logging.getLogger(__name__)
EnvProvider = Callable[[Optional[str]], Awaitable[Dict[str, Any]]]
ResultFormatter = Callable[[str, str, Dict[str, Any], str], Awaitable[str]]
ToolHandler = Callable[[Dict[str, Any]], Awaitable[Any]]
class ToolCoordinator:
"""
統一管理 MCP 工具調用:
- 依 ToolMetadata 注入環境/預設值
- 處理特殊流程(導航)
- 統一結果格式
"""
def __init__(
self,
*,
env_provider: EnvProvider,
tool_lookup: Callable[[str], Optional[ToolHandler]],
formatter: ResultFormatter,
failure_handlers: Optional[Dict[str, Callable[[Dict[str, Any], Exception], ToolResult]]] = None,
) -> None:
self._env_provider = env_provider
self._tool_lookup = tool_lookup
self._formatter = formatter
self._metadata: Dict[str, ToolMetadata] = {}
self._failure_handlers = failure_handlers or {}
# ------------------------------------------------------------------ #
def register(self, metadata: ToolMetadata) -> None:
self._metadata[metadata.name] = metadata
def get_metadata(self, name: str) -> Optional[ToolMetadata]:
return self._metadata.get(name)
# ------------------------------------------------------------------ #
async def invoke(
self,
tool_name: str,
arguments: Dict[str, Any],
*,
user_id: Optional[str],
original_message: str,
) -> ToolResult:
metadata = self._metadata.get(tool_name, ToolMetadata(name=tool_name))
if metadata.flow == "navigation":
return await self._handle_navigation(arguments, user_id, original_message, metadata)
prepared_args = await self._prepare_arguments(arguments, metadata, user_id)
raw_result = await self._execute(tool_name, prepared_args)
return await self._format_result(tool_name, raw_result, metadata, original_message)
async def _prepare_arguments(
self,
arguments: Dict[str, Any],
metadata: ToolMetadata,
user_id: Optional[str],
) -> Dict[str, Any]:
merged = dict(metadata.defaults)
merged.update(arguments or {})
# 注入 user_id 到參數中,讓工具可以從 arguments 中讀取
if user_id:
merged["_user_id"] = user_id
logger.info(f"📦 [Coordinator] 準備參數: tool={metadata.name}, user_id={user_id}, requires_env={metadata.requires_env}")
if metadata.requires_env and user_id:
env_ctx = await self._env_provider(user_id)
logger.info(f"📦 [Coordinator] 環境資訊: {env_ctx}")
if env_ctx:
for field in metadata.requires_env:
if merged.get(field) is not None:
continue
env_value = env_ctx.get(field)
# 主欄位為 None 時,嘗試 fallback 欄位
if env_value is None and metadata.env_fallbacks.get(field):
for fallback_key in metadata.env_fallbacks[field]:
env_value = env_ctx.get(fallback_key)
if env_value is not None:
logger.info(f"📦 [Coordinator] 使用 fallback 注入: {field} ← {fallback_key}={env_value}")
break
# 只注入非 None 的值,避免覆蓋工具的預設值或觸發 schema 驗證錯誤
if env_value is not None:
merged[field] = env_value
logger.info(f"📦 [Coordinator] 注入環境變數: {field}={env_value}")
elif not user_id:
logger.warning(f"⚠️ [Coordinator] user_id 為 None,無法注入環境變數")
logger.info(f"📦 [Coordinator] 最終參數: {merged}")
return merged
async def _execute(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
handler = self._tool_lookup(tool_name)
if not handler:
raise RuntimeError(f"工具 {tool_name} 無可用 handler")
retry_delays = [1, 2, 5]
last_exc: Optional[BaseException] = None
for attempt, delay in enumerate(retry_delays, start=1):
try:
result = await asyncio.wait_for(handler(arguments), timeout=30.0)
if isinstance(result, dict):
return result
return {"success": True, "content": str(result)}
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.warning("工具 %s 執行失敗 (attempt=%s): %s", tool_name, attempt, exc)
await asyncio.sleep(delay)
handler = self._failure_handlers.get(tool_name)
if handler and last_exc:
return handler(arguments, last_exc) # type: ignore[arg-type]
raise RuntimeError(f"工具 {tool_name} 執行失敗:{last_exc}") # type: ignore[arg-type]
async def _format_result(
self,
tool_name: str,
result: Dict[str, Any],
metadata: ToolMetadata,
original_message: str,
) -> ToolResult:
if isinstance(result, ToolResult):
return result
if result.get("success") and result.get("content"):
message = str(result.get("content"))
elif result.get("success"):
message = "操作完成,但無額外內容。"
else:
raise RuntimeError(result.get("error") or f"{tool_name} 執行失敗")
payload = {k: v for k, v in result.items() if k not in {"success", "content", "error"}}
if metadata.enable_reformat:
try:
message = await self._formatter(tool_name, message, payload, original_message)
except Exception as exc: # noqa: BLE001
logger.warning("AI 格式化失敗,改用原訊息:%s", exc)
return ToolResult(
name=tool_name,
message=message,
data=payload or None,
raw=result,
)
# ------------------------------------------------------------------ #
async def _handle_navigation(
self,
arguments: Dict[str, Any],
user_id: Optional[str],
original_message: str,
metadata: ToolMetadata,
) -> ToolResult:
geo_result = await self._execute(metadata.name, arguments or {})
if not geo_result.get("success"):
raise RuntimeError(geo_result.get("error") or "地點查詢失敗")
data = geo_result.get("data") or {}
best_match = data.get("best_match") or {}
dest_lat = best_match.get("lat")
dest_lon = best_match.get("lon")
if dest_lat is None or dest_lon is None:
return ToolResult(
name=metadata.name,
message=str(geo_result.get("content") or "找不到合適的目的地"),
data=data,
raw=geo_result,
)
env_ctx = await self._env_provider(user_id) if user_id else {}
origin_lat = env_ctx.get("lat")
origin_lon = env_ctx.get("lon")
if origin_lat is None or origin_lon is None:
return ToolResult(
name=metadata.name,
message=str(geo_result.get("content") or "取得目的地座標成功"),
data=data,
raw=geo_result,
metadata={"note": "缺少目前位置,僅返回地點資訊"},
)
directions_args = {
"origin_lat": float(origin_lat),
"origin_lon": float(origin_lon),
"dest_lat": float(dest_lat),
"dest_lon": float(dest_lon),
"origin_label": env_ctx.get("label") or env_ctx.get("address_display") or "目前位置",
"dest_label": best_match.get("label") or arguments.get("query"),
"mode": "foot-walking",
}
directions_meta = self._metadata.get("directions", ToolMetadata(name="directions"))
prepared = await self._prepare_arguments(directions_args, directions_meta, user_id)
directions_result = await self._execute("directions", prepared)
return await self._format_result("directions", directions_result, directions_meta, original_message)
|