XiaoBai1221 commited on
Commit
2473698
·
1 Parent(s): fe476cf

Update agent bridge, directions tool, and add tests

Browse files
features/mcp/agent_bridge.py CHANGED
@@ -140,6 +140,141 @@ class MCPAgentBridge:
140
 
141
  return enriched
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def get_current_time_data(self) -> Dict[str, Any]:
144
  """
145
  獲取當前時間數據,用於生成個性化歡迎詞
@@ -598,6 +733,9 @@ class MCPAgentBridge:
598
  return f"⚠️ 工具 {tool_name} 尚未實作,請稍後再試"
599
 
600
  arguments = await self._enrich_arguments_with_env(tool_name, arguments, user_id)
 
 
 
601
 
602
  logger.info(f"🔧 調用 MCP 工具: {tool_name}")
603
  logger.debug("📋 調用參數: %s", _safe_json(arguments))
@@ -639,6 +777,14 @@ class MCPAgentBridge:
639
 
640
  logger.debug(f"📊 提取的 tool_data: {type(tool_data)} = {tool_data if tool_data is None or isinstance(tool_data, (str, int, bool)) else '<dict/list>'}")
641
 
 
 
 
 
 
 
 
 
642
  if self._should_reformat(tool_name, content):
643
  logger.info(f"🎨 啟用 AI 格式化: {tool_name}")
644
  try:
 
140
 
141
  return enriched
142
 
143
+ async def _resolve_coordinate_label(self, lat: Any, lon: Any) -> Optional[str]:
144
+ """透過 reverse_geocode 將座標轉換為可朗讀的地點名稱。"""
145
+ try:
146
+ lat_f = float(lat)
147
+ lon_f = float(lon)
148
+ except (TypeError, ValueError):
149
+ return None
150
+
151
+ reverse_tool = self.mcp_server.tools.get("reverse_geocode")
152
+ if not reverse_tool or not reverse_tool.handler:
153
+ return None
154
+
155
+ try:
156
+ res = await reverse_tool.handler({"lat": lat_f, "lon": lon_f})
157
+ except Exception as ge:
158
+ logger.debug(f"reverse_geocode 失敗: {ge}")
159
+ return None
160
+
161
+ if not isinstance(res, dict):
162
+ return None
163
+ if not res.get("success"):
164
+ return None
165
+
166
+ payload = res.get("data") or res
167
+ label = (
168
+ payload.get("label")
169
+ or payload.get("display_name")
170
+ or ", ".join(
171
+ part for part in [payload.get("city"), payload.get("admin")] if part
172
+ )
173
+ )
174
+ return label.strip() if label else None
175
+
176
+ async def _prepare_route_arguments(self, arguments: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, str]]:
177
+ """為 directions 工具補齊可讀地點名稱並正規化座標。"""
178
+ prepared = dict(arguments or {})
179
+ labels: Dict[str, str] = {}
180
+
181
+ def _normalize_coord(value: Any) -> Optional[float]:
182
+ try:
183
+ if value is None:
184
+ return None
185
+ return float(value)
186
+ except (TypeError, ValueError):
187
+ return None
188
+
189
+ for prefix, default_label in (("origin", "起點"), ("dest", "目的地")):
190
+ lat_key = f"{prefix}_lat"
191
+ lon_key = f"{prefix}_lon"
192
+ label_key = f"{prefix}_label"
193
+
194
+ lat_val = _normalize_coord(prepared.get(lat_key))
195
+ lon_val = _normalize_coord(prepared.get(lon_key))
196
+ if lat_val is not None:
197
+ prepared[lat_key] = lat_val
198
+ if lon_val is not None:
199
+ prepared[lon_key] = lon_val
200
+
201
+ label_val = str(prepared.get(label_key) or "").strip()
202
+ if not label_val and lat_val is not None and lon_val is not None:
203
+ label_val = await self._resolve_coordinate_label(lat_val, lon_val) or ""
204
+
205
+ if not label_val:
206
+ label_val = default_label
207
+
208
+ prepared[label_key] = label_val
209
+ labels[label_key] = label_val
210
+
211
+ return prepared, labels
212
+
213
+ @staticmethod
214
+ def _format_distance(distance_m: Optional[float]) -> str:
215
+ """將距離換算為人類可讀格式。"""
216
+ if distance_m is None:
217
+ return "未知距離"
218
+ try:
219
+ distance = float(distance_m)
220
+ except (TypeError, ValueError):
221
+ return "未知距離"
222
+
223
+ if distance >= 1000:
224
+ return f"{distance / 1000:.1f} 公里"
225
+ return f"{round(distance)} 公尺"
226
+
227
+ @staticmethod
228
+ def _format_duration(duration_s: Optional[float]) -> str:
229
+ """將秒數換算為人類可讀格式。"""
230
+ if duration_s is None:
231
+ return "未知時間"
232
+ try:
233
+ total_seconds = int(round(float(duration_s)))
234
+ except (TypeError, ValueError):
235
+ return "未知時間"
236
+
237
+ minutes = total_seconds // 60
238
+ if minutes < 1:
239
+ return "不到 1 分鐘"
240
+
241
+ hours = minutes // 60
242
+ remaining_minutes = minutes % 60
243
+
244
+ if hours and remaining_minutes:
245
+ return f"{hours} 小時 {remaining_minutes} 分"
246
+ if hours:
247
+ return f"{hours} 小時"
248
+ return f"{minutes} 分鐘"
249
+
250
+ def _build_directions_message(
251
+ self,
252
+ tool_data: Dict[str, Any],
253
+ labels: Dict[str, str],
254
+ ) -> Tuple[str, Dict[str, Any]]:
255
+ """依據 directions 工具回傳資料,產出友善訊息與乾淨的 tool_data。"""
256
+ origin_label = labels.get("origin_label") or tool_data.get("origin_label") or "起點"
257
+ dest_label = labels.get("dest_label") or tool_data.get("dest_label") or "目的地"
258
+
259
+ distance_m = tool_data.get("distance_m")
260
+ duration_s = tool_data.get("duration_s")
261
+
262
+ distance_str = self._format_distance(distance_m)
263
+ duration_str = self._format_duration(duration_s)
264
+
265
+ polite_message = (
266
+ f"從 {origin_label} 前往 {dest_label} 大約需要 {duration_str},"
267
+ f"總距離約 {distance_str}。"
268
+ )
269
+
270
+ sanitized_tool_data = dict(tool_data or {})
271
+ sanitized_tool_data["origin_label"] = origin_label
272
+ sanitized_tool_data["dest_label"] = dest_label
273
+ sanitized_tool_data["distance_readable"] = distance_str
274
+ sanitized_tool_data["duration_readable"] = duration_str
275
+
276
+ return polite_message, sanitized_tool_data
277
+
278
  def get_current_time_data(self) -> Dict[str, Any]:
279
  """
280
  獲取當前時間數據,用於生成個性化歡迎詞
 
733
  return f"⚠️ 工具 {tool_name} 尚未實作,請稍後再試"
734
 
735
  arguments = await self._enrich_arguments_with_env(tool_name, arguments, user_id)
736
+ route_labels: Dict[str, str] = {}
737
+ if tool_name == "directions":
738
+ arguments, route_labels = await self._prepare_route_arguments(arguments)
739
 
740
  logger.info(f"🔧 調用 MCP 工具: {tool_name}")
741
  logger.debug("📋 調用參數: %s", _safe_json(arguments))
 
777
 
778
  logger.debug(f"📊 提取的 tool_data: {type(tool_data)} = {tool_data if tool_data is None or isinstance(tool_data, (str, int, bool)) else '<dict/list>'}")
779
 
780
+ if tool_name == "directions":
781
+ message, sanitized_tool_data = self._build_directions_message(
782
+ tool_data if isinstance(tool_data, dict) else {},
783
+ route_labels,
784
+ )
785
+ content = message
786
+ tool_data = sanitized_tool_data
787
+
788
  if self._should_reformat(tool_name, content):
789
  logger.info(f"🎨 啟用 AI 格式化: {tool_name}")
790
  try:
features/mcp/tools/directions_tool.py CHANGED
@@ -40,7 +40,9 @@ class DirectionsTool(MCPTool):
40
  "origin_lon": {"type": "number"},
41
  "dest_lat": {"type": "number"},
42
  "dest_lon": {"type": "number"},
43
- "mode": {"type": "string", "enum": ["driving-car", "foot-walking", "cycling-regular"], "default": "foot-walking"}
 
 
44
  }, required=["origin_lat", "origin_lon", "dest_lat", "dest_lon"])
45
 
46
  @classmethod
@@ -49,7 +51,9 @@ class DirectionsTool(MCPTool):
49
  schema["properties"].update({
50
  "distance_m": {"type": "number"},
51
  "duration_s": {"type": "number"},
52
- "polyline": {"type": "string"}
 
 
53
  })
54
  return schema
55
 
@@ -109,6 +113,16 @@ class DirectionsTool(MCPTool):
109
  d_lat = float(arguments.get("dest_lat"))
110
  d_lon = float(arguments.get("dest_lon"))
111
  mode = arguments.get("mode", "foot-walking")
 
 
 
 
 
 
 
 
 
 
112
 
113
  # 快取鍵(geohash 簡化)
114
  try:
@@ -119,12 +133,20 @@ class DirectionsTool(MCPTool):
119
 
120
  cached = await db_cache.get_route_cached(key)
121
  if cached:
122
- return cls.create_success_response(content=f"距離 {int(cached['distance_m'])}m,約 {int(cached['duration_s']/60)} 分鐘", data=cached)
 
 
 
 
123
 
124
  db_cached = await get_route_cache(key)
125
  if db_cached:
126
  await db_cache.set_route_cache(key, db_cached)
127
- return cls.create_success_response(content=f"距離 {int(db_cached['distance_m'])}m,約 {int(db_cached['duration_s']/60)} 分鐘", data=db_cached)
 
 
 
 
128
 
129
  # 呼叫 ORS Directions
130
  url = f"https://api.openrouteservice.org/v2/directions/{mode}"
@@ -148,12 +170,19 @@ class DirectionsTool(MCPTool):
148
  distance_m = float(summary["distance"]) # meters
149
  duration_s = float(summary["duration"]) # seconds
150
  polyline = feat["geometry"]["coordinates"] # LineString 座標
151
- payload = {"distance_m": distance_m, "duration_s": duration_s, "polyline": json.dumps(polyline)}
 
 
 
 
152
  except Exception as e:
153
  raise ExecutionError(f"解析 ORS 回應失敗: {e}")
154
 
155
  # 回寫快取
156
- await db_cache.set_route_cache(key, payload)
157
- await set_route_cache(key, payload)
158
 
159
- return cls.create_success_response(content=f"距離 {int(distance_m)}m,約 {int(duration_s/60)} 分鐘", data=payload)
 
 
 
 
40
  "origin_lon": {"type": "number"},
41
  "dest_lat": {"type": "number"},
42
  "dest_lon": {"type": "number"},
43
+ "mode": {"type": "string", "enum": ["driving-car", "foot-walking", "cycling-regular"], "default": "foot-walking"},
44
+ "origin_label": {"type": "string"},
45
+ "dest_label": {"type": "string"},
46
  }, required=["origin_lat", "origin_lon", "dest_lat", "dest_lon"])
47
 
48
  @classmethod
 
51
  schema["properties"].update({
52
  "distance_m": {"type": "number"},
53
  "duration_s": {"type": "number"},
54
+ "polyline": {"type": "string"},
55
+ "origin_label": {"type": "string"},
56
+ "dest_label": {"type": "string"},
57
  })
58
  return schema
59
 
 
113
  d_lat = float(arguments.get("dest_lat"))
114
  d_lon = float(arguments.get("dest_lon"))
115
  mode = arguments.get("mode", "foot-walking")
116
+ origin_label = (arguments.get("origin_label") or "").strip()
117
+ dest_label = (arguments.get("dest_label") or "").strip()
118
+
119
+ def _attach_labels(data: Dict[str, Any]) -> Dict[str, Any]:
120
+ enriched = dict(data)
121
+ if origin_label:
122
+ enriched["origin_label"] = origin_label
123
+ if dest_label:
124
+ enriched["dest_label"] = dest_label
125
+ return enriched
126
 
127
  # 快取鍵(geohash 簡化)
128
  try:
 
133
 
134
  cached = await db_cache.get_route_cached(key)
135
  if cached:
136
+ cached_payload = _attach_labels(cached)
137
+ return cls.create_success_response(
138
+ content=f"距離 {int(cached['distance_m'])}m,約 {int(cached['duration_s']/60)} 分鐘",
139
+ data=cached_payload,
140
+ )
141
 
142
  db_cached = await get_route_cache(key)
143
  if db_cached:
144
  await db_cache.set_route_cache(key, db_cached)
145
+ db_payload = _attach_labels(db_cached)
146
+ return cls.create_success_response(
147
+ content=f"距離 {int(db_cached['distance_m'])}m,約 {int(db_cached['duration_s']/60)} 分鐘",
148
+ data=db_payload,
149
+ )
150
 
151
  # 呼叫 ORS Directions
152
  url = f"https://api.openrouteservice.org/v2/directions/{mode}"
 
170
  distance_m = float(summary["distance"]) # meters
171
  duration_s = float(summary["duration"]) # seconds
172
  polyline = feat["geometry"]["coordinates"] # LineString 座標
173
+ base_payload = {
174
+ "distance_m": distance_m,
175
+ "duration_s": duration_s,
176
+ "polyline": json.dumps(polyline),
177
+ }
178
  except Exception as e:
179
  raise ExecutionError(f"解析 ORS 回應失敗: {e}")
180
 
181
  # 回寫快取
182
+ await db_cache.set_route_cache(key, base_payload)
183
+ await set_route_cache(key, base_payload)
184
 
185
+ return cls.create_success_response(
186
+ content=f"距離 {int(distance_m)}m,約 {int(duration_s/60)} 分鐘",
187
+ data=_attach_labels(base_payload),
188
+ )
tests/features/mcp/test_agent_bridge_route_labels.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from pathlib import Path
3
+ import sys
4
+
5
+ ROOT_DIR = Path(__file__).resolve().parents[4]
6
+ if str(ROOT_DIR) not in sys.path:
7
+ sys.path.append(str(ROOT_DIR))
8
+
9
+ from features.mcp.agent_bridge import MCPAgentBridge # noqa: E402
10
+
11
+
12
+ def test_prepare_route_arguments_injects_labels():
13
+ bridge = MCPAgentBridge.__new__(MCPAgentBridge)
14
+
15
+ async def fake_resolve(_self, _lat, _lon):
16
+ return "測試地點"
17
+
18
+ bridge._resolve_coordinate_label = fake_resolve.__get__(bridge, MCPAgentBridge) # type: ignore[attr-defined]
19
+
20
+ prepared, labels = asyncio.run(
21
+ bridge._prepare_route_arguments(
22
+ {
23
+ "origin_lat": "24.9915",
24
+ "origin_lon": "121.3423",
25
+ "dest_lat": "24.9891",
26
+ "dest_lon": "121.3134",
27
+ # 未提供 label,應自動補上
28
+ }
29
+ )
30
+ )
31
+
32
+ assert prepared["origin_label"] == "測試地點"
33
+ assert prepared["dest_label"] == "測試地點"
34
+ assert isinstance(prepared["origin_lat"], float)
35
+ assert isinstance(prepared["dest_lon"], float)
36
+ assert labels["origin_label"] == "測試地點"
37
+ assert labels["dest_label"] == "測試地點"
38
+
39
+
40
+ def test_build_directions_message_returns_human_friendly_text():
41
+ bridge = MCPAgentBridge.__new__(MCPAgentBridge)
42
+
43
+ message, tool_data = bridge._build_directions_message(
44
+ {"distance_m": 1450.0, "duration_s": 840.0, "polyline": "[]"},
45
+ {"origin_label": "測試起點 A", "dest_label": "測試目的地 B"},
46
+ )
47
+
48
+ assert "測試起點 A" in message
49
+ assert "測試目的地 B" in message
50
+ assert "公里" in tool_data["distance_readable"] or "公尺" in tool_data["distance_readable"]
51
+ assert tool_data["duration_readable"].endswith("分") or tool_data["duration_readable"].endswith("分鐘")
52
+ assert "origin_lat" not in tool_data
53
+ assert "dest_lon" not in tool_data
tests/features/mcp/tools/test_directions_tool.py CHANGED
@@ -1,5 +1,7 @@
 
1
  from pathlib import Path
2
  import sys
 
3
  import pytest
4
 
5
  ROOT_DIR = Path(__file__).resolve().parents[4]
@@ -7,6 +9,7 @@ if str(ROOT_DIR) not in sys.path:
7
  sys.path.append(str(ROOT_DIR))
8
 
9
  from features.mcp.tools.base_tool import ValidationError # noqa: E402
 
10
  from features.mcp.tools.directions_tool import DirectionsTool # noqa: E402
11
 
12
 
@@ -40,3 +43,42 @@ def test_validate_input_missing_dest_lon_gives_clear_error():
40
  message = str(excinfo.value)
41
  assert "dest_lon" in message
42
  assert "經度" in message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  from pathlib import Path
3
  import sys
4
+
5
  import pytest
6
 
7
  ROOT_DIR = Path(__file__).resolve().parents[4]
 
9
  sys.path.append(str(ROOT_DIR))
10
 
11
  from features.mcp.tools.base_tool import ValidationError # noqa: E402
12
+ from features.mcp.tools import directions_tool # noqa: E402
13
  from features.mcp.tools.directions_tool import DirectionsTool # noqa: E402
14
 
15
 
 
43
  message = str(excinfo.value)
44
  assert "dest_lon" in message
45
  assert "經度" in message
46
+
47
+
48
+ def test_execute_returns_labels_when_cache_hit(monkeypatch):
49
+ monkeypatch.setattr(directions_tool, "ORS_API_KEY", "dummy-key", raising=False)
50
+
51
+ async def fake_get_route_cached(_key: str):
52
+ return {"distance_m": 1325.5, "duration_s": 780.0, "polyline": "[]"}
53
+
54
+ async def fake_get_route_cache(_key: str):
55
+ return None
56
+
57
+ async def noop_set_route_cache(*_args, **_kwargs):
58
+ return None
59
+
60
+ monkeypatch.setattr(directions_tool.db_cache, "get_route_cached", fake_get_route_cached)
61
+ monkeypatch.setattr(directions_tool, "get_route_cache", fake_get_route_cache)
62
+ monkeypatch.setattr(directions_tool.db_cache, "set_route_cache", noop_set_route_cache)
63
+ monkeypatch.setattr(directions_tool, "set_route_cache", noop_set_route_cache)
64
+
65
+ args = {
66
+ "origin_lat": 24.9915,
67
+ "origin_lon": 121.3423,
68
+ "dest_lat": 24.9891,
69
+ "dest_lon": 121.3134,
70
+ "origin_label": "測試起點 A",
71
+ "dest_label": "測試目的地 B",
72
+ }
73
+
74
+ result = asyncio.run(DirectionsTool.execute(args))
75
+
76
+ assert result["success"] is True
77
+ assert "距離" in result["content"]
78
+ assert result["origin_label"] == "測試起點 A"
79
+ assert result["dest_label"] == "測試目的地 B"
80
+ assert "distance_m" in result and "duration_s" in result
81
+ # 確保沒有把 label 寫入快取的原始資料
82
+ cached = asyncio.run(directions_tool.db_cache.get_route_cached("ignored")) # type: ignore[arg-type]
83
+ assert cached["distance_m"] == pytest.approx(1325.5)
84
+ assert "origin_label" not in cached