File size: 8,498 Bytes
7d7bc87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921a78a
 
 
 
 
 
7d7bc87
 
 
921a78a
7d7bc87
 
 
 
921a78a
9746ac7
 
 
 
 
 
 
921a78a
 
 
 
 
 
 
 
7d7bc87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import asyncio
import logging
from typing import Any, Awaitable, Callable, Dict, Optional

from .tool_models import ToolMetadata, ToolResult

logger = logging.getLogger(__name__)

EnvProvider = Callable[[Optional[str]], Awaitable[Dict[str, Any]]]
ResultFormatter = Callable[[str, str, Dict[str, Any], str], Awaitable[str]]
ToolHandler = Callable[[Dict[str, Any]], Awaitable[Any]]


class ToolCoordinator:
    """
    統一管理 MCP 工具調用:
    - 依 ToolMetadata 注入環境/預設值
    - 處理特殊流程(導航)
    - 統一結果格式
    """

    def __init__(
        self,
        *,
        env_provider: EnvProvider,
        tool_lookup: Callable[[str], Optional[ToolHandler]],
        formatter: ResultFormatter,
        failure_handlers: Optional[Dict[str, Callable[[Dict[str, Any], Exception], ToolResult]]] = None,
    ) -> None:
        self._env_provider = env_provider
        self._tool_lookup = tool_lookup
        self._formatter = formatter
        self._metadata: Dict[str, ToolMetadata] = {}
        self._failure_handlers = failure_handlers or {}

    # ------------------------------------------------------------------ #
    def register(self, metadata: ToolMetadata) -> None:
        self._metadata[metadata.name] = metadata

    def get_metadata(self, name: str) -> Optional[ToolMetadata]:
        return self._metadata.get(name)

    # ------------------------------------------------------------------ #
    async def invoke(
        self,
        tool_name: str,
        arguments: Dict[str, Any],
        *,
        user_id: Optional[str],
        original_message: str,
    ) -> ToolResult:
        metadata = self._metadata.get(tool_name, ToolMetadata(name=tool_name))

        if metadata.flow == "navigation":
            return await self._handle_navigation(arguments, user_id, original_message, metadata)

        prepared_args = await self._prepare_arguments(arguments, metadata, user_id)
        raw_result = await self._execute(tool_name, prepared_args)
        return await self._format_result(tool_name, raw_result, metadata, original_message)

    async def _prepare_arguments(
        self,
        arguments: Dict[str, Any],
        metadata: ToolMetadata,
        user_id: Optional[str],
    ) -> Dict[str, Any]:
        merged = dict(metadata.defaults)
        merged.update(arguments or {})
        
        # 注入 user_id 到參數中,讓工具可以從 arguments 中讀取
        if user_id:
            merged["_user_id"] = user_id

        logger.info(f"📦 [Coordinator] 準備參數: tool={metadata.name}, user_id={user_id}, requires_env={metadata.requires_env}")

        if metadata.requires_env and user_id:
            env_ctx = await self._env_provider(user_id)
            logger.info(f"📦 [Coordinator] 環境資訊: {env_ctx}")
            if env_ctx:
                for field in metadata.requires_env:
                    if merged.get(field) is not None:
                        continue
                    env_value = env_ctx.get(field)
                    # 主欄位為 None 時,嘗試 fallback 欄位
                    if env_value is None and metadata.env_fallbacks.get(field):
                        for fallback_key in metadata.env_fallbacks[field]:
                            env_value = env_ctx.get(fallback_key)
                            if env_value is not None:
                                logger.info(f"📦 [Coordinator] 使用 fallback 注入: {field}{fallback_key}={env_value}")
                                break
                    # 只注入非 None 的值,避免覆蓋工具的預設值或觸發 schema 驗證錯誤
                    if env_value is not None:
                        merged[field] = env_value
                        logger.info(f"📦 [Coordinator] 注入環境變數: {field}={env_value}")
        elif not user_id:
            logger.warning(f"⚠️ [Coordinator] user_id 為 None,無法注入環境變數")

        logger.info(f"📦 [Coordinator] 最終參數: {merged}")
        return merged

    async def _execute(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
        handler = self._tool_lookup(tool_name)
        if not handler:
            raise RuntimeError(f"工具 {tool_name} 無可用 handler")

        retry_delays = [1, 2, 5]
        last_exc: Optional[BaseException] = None
        for attempt, delay in enumerate(retry_delays, start=1):
            try:
                result = await asyncio.wait_for(handler(arguments), timeout=30.0)
                if isinstance(result, dict):
                    return result
                return {"success": True, "content": str(result)}
            except Exception as exc:  # noqa: BLE001
                last_exc = exc
                logger.warning("工具 %s 執行失敗 (attempt=%s): %s", tool_name, attempt, exc)
                await asyncio.sleep(delay)
        handler = self._failure_handlers.get(tool_name)
        if handler and last_exc:
            return handler(arguments, last_exc)  # type: ignore[arg-type]
        raise RuntimeError(f"工具 {tool_name} 執行失敗:{last_exc}")  # type: ignore[arg-type]

    async def _format_result(
        self,
        tool_name: str,
        result: Dict[str, Any],
        metadata: ToolMetadata,
        original_message: str,
    ) -> ToolResult:
        if isinstance(result, ToolResult):
            return result

        if result.get("success") and result.get("content"):
            message = str(result.get("content"))
        elif result.get("success"):
            message = "操作完成,但無額外內容。"
        else:
            raise RuntimeError(result.get("error") or f"{tool_name} 執行失敗")

        payload = {k: v for k, v in result.items() if k not in {"success", "content", "error"}}

        if metadata.enable_reformat:
            try:
                message = await self._formatter(tool_name, message, payload, original_message)
            except Exception as exc:  # noqa: BLE001
                logger.warning("AI 格式化失敗,改用原訊息:%s", exc)

        return ToolResult(
            name=tool_name,
            message=message,
            data=payload or None,
            raw=result,
        )

    # ------------------------------------------------------------------ #
    async def _handle_navigation(
        self,
        arguments: Dict[str, Any],
        user_id: Optional[str],
        original_message: str,
        metadata: ToolMetadata,
    ) -> ToolResult:
        geo_result = await self._execute(metadata.name, arguments or {})
        if not geo_result.get("success"):
            raise RuntimeError(geo_result.get("error") or "地點查詢失敗")

        data = geo_result.get("data") or {}
        best_match = data.get("best_match") or {}
        dest_lat = best_match.get("lat")
        dest_lon = best_match.get("lon")
        if dest_lat is None or dest_lon is None:
            return ToolResult(
                name=metadata.name,
                message=str(geo_result.get("content") or "找不到合適的目的地"),
                data=data,
                raw=geo_result,
            )

        env_ctx = await self._env_provider(user_id) if user_id else {}
        origin_lat = env_ctx.get("lat")
        origin_lon = env_ctx.get("lon")
        if origin_lat is None or origin_lon is None:
            return ToolResult(
                name=metadata.name,
                message=str(geo_result.get("content") or "取得目的地座標成功"),
                data=data,
                raw=geo_result,
                metadata={"note": "缺少目前位置,僅返回地點資訊"},
            )

        directions_args = {
            "origin_lat": float(origin_lat),
            "origin_lon": float(origin_lon),
            "dest_lat": float(dest_lat),
            "dest_lon": float(dest_lon),
            "origin_label": env_ctx.get("label") or env_ctx.get("address_display") or "目前位置",
            "dest_label": best_match.get("label") or arguments.get("query"),
            "mode": "foot-walking",
        }

        directions_meta = self._metadata.get("directions", ToolMetadata(name="directions"))
        prepared = await self._prepare_arguments(directions_args, directions_meta, user_id)
        directions_result = await self._execute("directions", prepared)
        return await self._format_result("directions", directions_result, directions_meta, original_message)