diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..798232fb01532db0e314d32d060cbd8b525799d9 --- /dev/null +++ b/.env.example @@ -0,0 +1,83 @@ +# Bloom Ware environment variables +# Copy to .env and fill real values locally. Do not commit real secrets. + +# ===== Application Runtime ===== +ENVIRONMENT=development +HOST=0.0.0.0 +PORT=7860 +CORS_ORIGINS=* + +# ===== Authentication ===== +JWT_SECRET_KEY=replace-with-random-secret +ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# ===== Firebase / Firestore ===== +FIREBASE_PROJECT_ID=your-firebase-project-id +FIREBASE_CREDENTIALS_JSON= +FIREBASE_SERVICE_ACCOUNT_JSON_BASE64= +FIREBASE_SERVICE_ACCOUNT_PATH= + +# ===== Google OAuth ===== +GOOGLE_CLIENT_ID=your-google-oauth-client-id.apps.googleusercontent.com +GOOGLE_CLIENT_SECRET=your-google-oauth-client-secret +GOOGLE_REDIRECT_URI=http://localhost:8080/auth/google/callback + +# ===== Google Cloud 語音(STT/TTS)— 與 Firebase、OAuth 登入「不同專案/不同憑證」===== +# +# 三種 Google 身分請分開設定,勿混用: +# (1) Firebase / Firestore → FIREBASE_PROJECT_ID + FIREBASE_* 服務帳戶 +# (2) Google OAuth(網站登入)→ GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET +# (3) 語音 GCP(Speech + TTS)→ 專案 ID 字串例:supervisor-project;控制台「專案編號」僅供對照,API 請用專案 ID +# +# 已啟用 API:Cloud Speech-to-Text、Cloud Text-to-Speech。 +# - TTS(REST):使用 GOOGLE_TTS_API_KEY 或 GOOGLE_SPEECH_API_KEY(與 GOOGLE_API_KEY 可同一支) +# - STT v2「串流」:官方為 gRPC,必須使用「語音專案」的服務帳戶 OAuth(GOOGLE_SPEECH_*);僅 API Key 無法走現有串流實作 +# +# 語音專案 STT 必備(擇一):GOOGLE_SPEECH_CREDENTIALS_JSON / GOOGLE_SPEECH_SERVICE_ACCOUNT_JSON_BASE64 / GOOGLE_SPEECH_SERVICE_ACCOUNT_PATH +GOOGLE_SPEECH_PROJECT_ID=supervisor-project +GOOGLE_SPEECH_CREDENTIALS_JSON= +GOOGLE_SPEECH_SERVICE_ACCOUNT_JSON_BASE64= +GOOGLE_SPEECH_SERVICE_ACCOUNT_PATH= +# 語音專案預設 ID(未單獨設 GOOGLE_SPEECH_PROJECT_ID 時亦會參考);勿只填專案「編號」 +GOOGLE_CLOUD_PROJECT_ID=supervisor-project +GOOGLE_STT_ACCESS_TOKEN= +GOOGLE_API_KEY=your-google-api-key-for-tts-and-rest +GOOGLE_SPEECH_API_KEY= +GOOGLE_TTS_API_KEY= +GOOGLE_STT_LOCATION=global +GOOGLE_STT_RECOGNIZER_ID=_ +GOOGLE_STT_AUTO_LANGUAGE_CODES=cmn-Hant-TW,en-US,ja-JP +GOOGLE_TTS_LANGUAGE_CODE=cmn-TW +GOOGLE_TTS_DEFAULT_VOICE=cmn-TW-Wavenet-A + +# ===== OpenAI Agent ===== +OPENAI_API_KEY=your-openai-api-key +OPENAI_BASE_URL=https://sub2api.flowatelier.com +OPENAI_MODEL=gpt-5.4 +OPENAI_TIMEOUT=30 +OPENAI_RESPONSES_TIMEOUT=90 +OPENAI_USE_RESPONSES=true +OPENAI_MODEL_CONTEXT_WINDOW=1000000 +OPENAI_MODEL_AUTO_COMPACT_TOKEN_LIMIT=900000 +OPENAI_ENABLE_WEB_SEARCH=true +OPENAI_ENABLE_REMOTE_MCP=true +OPENAI_REMOTE_MCP_SERVERS_JSON=[] +OPENAI_ENABLE_SKILLS=true +USE_GPT_INTENT=true +GPT_INTENT_MODEL=gpt-5.4 + +# ===== External Data APIs ===== +WEATHER_API_KEY=your-weather-api-key +TAVILY_API_KEY=your-newsdata-api-key +EXCHANGE_API_KEY=your-exchange-api-key +OPENROUTESERVICE_API_KEY=your-openrouteservice-api-key +TDX_CLIENT_ID=your-tdx-client-id +TDX_CLIENT_SECRET=your-tdx-client-secret + +# ===== Background Jobs ===== +ENABLE_BACKGROUND_JOBS=true + +# ===== Environment Context ===== +ENV_CONTEXT_DISTANCE_THRESHOLD=100 +ENV_CONTEXT_HEADING_THRESHOLD=25 +ENV_CONTEXT_TTL_SECONDS=300 diff --git a/.gitignore b/.gitignore index 6105c6be8e195faeeb97b1acfce830ace4c626d7..0001e06950dea49f827b613bceeb6577165b35d0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Python __pycache__/ +.pytest_cache/ +.benchmarks/ *.py[cod] *$py.class *.so @@ -33,6 +35,7 @@ MANIFEST # Firebase credentials (JSON files) *-firebase-adminsdk-*.json bloomware-*.json +supervisor-project*.json # IDE .vscode/ @@ -174,4 +177,4 @@ celerybeat.pid dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index 4534d8a190159a185229f246ec99eea137f76cb6..0000000000000000000000000000000000000000 --- a/AGENTS.md +++ /dev/null @@ -1,74 +0,0 @@ -# AGENTS.md - -## 介紹 - -你是 白東衢的狗。主要語言 Python。可用 MCP:context7、feedback-enhanced、filesystem、huggingface、playwright、sequential-thinking。所有思考與回覆皆用繁體中文。先思考再行動。 - -## 全域原則 - -- 所有 MCP 工具逾時一律 60 分鐘。 -- 非框架情境:以最少檔案完成任務,避免過度模組化。 -- 框架情境:依框架慣例放置檔案。 -- 禁止產生非目的文件:說明文件、依賴清單、README、requirements.txt 等。 -- 禁止要求以命令列參數輸入業務值;業務參數以程式內常數或互動式輸入處理。 -- 需新知時主動檢索:優先從網際網路獲取資訊並使用 huggingface(paper_search、model_search、dataset_search、hf_doc_search)與 context7進行輔佐。 -- playwright 僅用於前端樣式檢查、UI 互動測試、E2E 視覺驗證與截圖,禁止做爬蟲或搜尋。 -- 僅在規定節點使用 feedback-enhanced,禁止重複寒暄。 - -## 環境與執行限制 - -- 不得假設環境缺失而自動安裝任何套件或修改系統環境變數。 -- 不得建立或使用任何虛擬環境(venv、conda、poetry 等)。 -- 一律使用當前系統 Python 版本執行與相容(可於程式內讀取 sys.version 僅作紀錄,不觸碰安裝行為)。 -- 不輸出或修改 requirements.txt、pyproject.toml、環境設定檔。 - -## 框架情境的工作區確認 - -偵測到整合型框架或既有專案結構時: - -- 先用 sequential-thinking 規劃「要改哪些模組、檔名、路徑與測試位置」。 -- 接著必須以一次 feedback-enhanced.interactive_feedback 與使用者確認「基準工作區路徑、允許寫入子資料夾與檔名慣例」。 -- 未獲確認前不得寫入專案根目錄。確認後依框架慣例生成或修改檔案。 -- 非框架情境可寫根目錄,但仍以最小檔案集為原則。 - -## 思考與行動流程 - -- sequential-thinking:輸出「目標 → 步驟 → 決策準則 → 風險與驗證」。 -- 取證:context7、huggingface。 -- 產出:filesystem 建立或修改檔案;必要時用 playwright 做 UI 驗證。 -- 回覆內容:只含思考計畫、行動步驟、關鍵程式碼、測試摘要、後續建議。 - -## 測試與驗證政策 - -- 為每個新增或修改模組撰寫單元測試與關鍵整合測試。 -- 測試檔命名 test_*.py;框架專案依其慣例放置。 -- 執行方式以終端 Python3 或 Pytest:python3 -m pytest -q 或 pytest -q。 -- 回報測試摘要:通過數、失敗數、失敗案例與原因、可能回歸點、下一步。 - -## 錯誤處理與互動節奏 - -發生錯誤或接獲失敗回報時: - -- 先以 sequential-thinking 分析根因、解法選項與取捨與影響。 -- 再以一次 feedback-enhanced.interactive_feedback 與使用者對齊解法與影響面。 -- 然後修改程式與測試並重跑驗證。除上述節點外避免重複呼叫 feedback。 - -## 產出規範 - -- 預設單檔或少量檔案即可完成任務;框架專案依其結構放置。 -- 程式需具明確進入點: - ``` - if __name__ == "__main__": - main() - ``` -- 檔案一律透過 filesystem 操作並回報路徑與成功訊息。 -- 不硬編 API 金鑰或密碼;輸出時遮罩敏感資訊。 -- 外部資源不可用時,提出替代方案與自我修正步驟,仍不得觸發安裝或改環境行為。 - -## 禁止事項總表 - -- 自動安裝或升降版本、修改環境變數、建立/使用虛擬環境。 -- 產出說明文件、依賴清單或其他非目的文件。 -- 使用命令列參數傳遞業務值。 -- 用 playwright 做爬蟲或搜尋。 -- 無限制地反覆呼叫 feedback-enhanced。 \ No newline at end of file diff --git a/DEPLOY.md b/DEPLOY.md deleted file mode 100644 index 71b6c6f20f4516318f95d4fb438f1356767ff3fa..0000000000000000000000000000000000000000 --- a/DEPLOY.md +++ /dev/null @@ -1,173 +0,0 @@ -# 🚀 Bloom Ware Render 部署指南 - -## 📋 前置準備 - -### 1. 生成新的 JWT Secret(生產環境專用) -```bash -python3 -c "import secrets; print(secrets.token_urlsafe(32))" -``` -複製輸出的字串,稍後會用到。 - -### 2. 將 Firebase JSON 轉為單行字串 -```bash -cat your-firebase-credentials.json | python3 -m json.tool --compact | pbcopy -``` -(macOS 會自動複製到剪貼簿) - ---- - -## 🔧 Render 部署步驟 - -### 步驟 1:推送程式碼到 GitHub -```bash -git add . -git commit -m "準備 Render 部署:統一配置管理 + Firebase 環境變數化" -git push origin main -``` - -### 步驟 2:在 Render 建立 Web Service -1. 登入 [Render](https://render.com/) -2. 點擊 **New** → **Web Service** -3. 連接 GitHub 倉庫:選擇 `bloom-ware` -4. 設定: - - **Name**: `bloom-ware`(或自訂名稱) - - **Region**: `Singapore` 或 `Oregon` - - **Branch**: `main` - - **Runtime**: `Python 3` - - **Build Command**: `pip install -r requirements.txt` - - **Start Command**: `python app.py` - -### 步驟 3:設定環境變數 -在 Render Dashboard → Environment 頁面,新增以下環境變數: - -#### 必要環境變數(16 項) - -| 變數名 | 值 | 說明 | -|--------|-----|------| -| `ENVIRONMENT` | `production` | 環境識別 | -| `FIREBASE_PROJECT_ID` | `your-firebase-project-id` | Firebase 專案 ID | -| `FIREBASE_CREDENTIALS_JSON` | `{"type":"service_account",...}` | **完整 JSON 字串(單行)** | -| `OPENAI_API_KEY` | `sk-proj-...` | OpenAI API Key | -| `OPENAI_MODEL` | `gpt-5-nano` | 模型名稱 | -| `OPENAI_TIMEOUT` | `30` | 超時秒數 | -| `GOOGLE_CLIENT_ID` | `your-google-client-id.apps.googleusercontent.com` | Google OAuth Client ID | -| `GOOGLE_CLIENT_SECRET` | `GOCSPX-...` | Google OAuth Secret | -| `GOOGLE_REDIRECT_URI` | `https://your-app.onrender.com/auth/google/callback` | **OAuth 回調 URI** | -| `WEATHER_API_KEY` | `your-weather-api-key` | OpenWeatherMap Key | -| `NEWSDATA_API_KEY` | `pub_xxxxx` | NewsData.io Key | -| `EXCHANGE_API_KEY` | `your-exchange-api-key` | ExchangeRate Key | -| `JWT_SECRET_KEY` | `YOUR_NEW_SECRET` | **新生成的 Secret** | -| `ACCESS_TOKEN_EXPIRE_MINUTES` | `30` | Token 有效期 | -| `HOST` | `0.0.0.0` | 監聽主機 | -| `PORT` | `10000` | Render 固定端口 | - -### 步驟 4:部署 -點擊 **Create Web Service**,Render 會自動: -1. 執行 `pip install -r requirements.txt` -2. 啟動 `python app.py` -3. 提供 HTTPS URL(例如:`https://bloom-ware-xxxx.onrender.com`) - ---- - -## 🔗 Google OAuth 回調 URI 更新 - -### 1. 前往 Google Cloud Console -https://console.cloud.google.com/apis/credentials - -### 2. 選擇你的 OAuth 2.0 客戶端 - -### 3. 新增「已授權的重新導向 URI」 -``` -https://bloom-ware-xxxx.onrender.com/auth/google/callback -``` -(替換為你的實際 Render 網址) - -### 4. 儲存變更 - -### 5. 更新 Render 環境變數 -回到 Render Dashboard → Environment,更新: -``` -GOOGLE_REDIRECT_URI=https://bloom-ware-xxxx.onrender.com/auth/google/callback -``` - ---- - -## ✅ 驗證部署 - -### 1. 檢查 Logs -在 Render Dashboard → Logs 查看: -``` -✅ Firebase Firestore連接成功!專案ID:your-project-id -✅ OpenAI 客戶端初始化完成 -🚀 Bloom Ware 後端服務器啟動中... -``` - -### 2. 測試連接 -訪問:`https://your-app.onrender.com` -應該看到前端登入頁面 - -### 3. 測試 Google 登入 -1. 點擊「使用 Google 登入」 -2. 授權後應該成功跳轉並登入 - ---- - -## 🐛 常見問題 - -### 問題 1:Firebase 憑證錯誤 -**錯誤訊息**:`Firebase 憑證載入失敗` - -**解決方式**: -- 確認 `FIREBASE_CREDENTIALS_JSON` 是**單行字串**(無換行符) -- 檢查 JSON 格式是否正確(使用 `python3 -m json.tool` 驗證) - -### 問題 2:Google OAuth 回調失敗 -**錯誤訊息**:`redirect_uri_mismatch` - -**解決方式**: -- 確認 Google Cloud Console 已新增 Render 回調 URI -- 確認 `GOOGLE_REDIRECT_URI` 環境變數正確 - -### 問題 3:應用休眠(免費方案) -**現象**:閒置 15 分鐘後,首次訪問需等待 30 秒 - -**解決方式**: -- 升級到付費方案($7/月) -- 或使用 UptimeRobot 定期 ping(每 14 分鐘) - ---- - -## 📝 部署後清單 - -- [ ] 測試 Google 登入流程 -- [ ] 測試 WebSocket 連接 -- [ ] 測試語音功能(錄音 + TTS) -- [ ] 測試 MCP 工具(天氣、新聞、匯率) -- [ ] 檢查 Firebase Firestore 資料寫入 -- [ ] 監控 Render Logs 是否有錯誤 - ---- - -## 🔄 更新部署 - -每次程式碼更新後: -```bash -git add . -git commit -m "更新功能" -git push origin main -``` - -Render 會自動檢測並重新部署(約 2-3 分鐘)。 - ---- - -## 📞 支援 - -遇到問題?檢查: -1. Render Dashboard → Logs -2. Render Dashboard → Events -3. GitHub Actions(如有設定 CI/CD) - ---- - -**🎉 恭喜!Bloom Ware 已成功部署到 Render!** diff --git a/app.py b/app.py index f71ba2ca8146bcfb484ae16597c3203d5ec4c476..a3d918a4f31a8544aa475b3e742a8f213633c9be 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,4 @@ +# BloomWare Application - Confidence-Driven Agent Loop import os import json import time @@ -6,6 +7,7 @@ import mimetypes import logging import secrets import jwt +import unicodedata from datetime import datetime from typing import List, Dict, Optional, Any @@ -76,6 +78,7 @@ from core.memory_system import memory_manager # 環境 Context 寫入 API from core.database import set_user_env_current, add_user_env_snapshot from core.environment import EnvironmentContextService +from middleware import CSPMiddleware # ----------------------------- @@ -109,6 +112,77 @@ def serialize_for_json(obj: Any) -> Any: except Exception: return None + +def _normalize_bcp47_language_tag(tag: Optional[str]) -> Optional[str]: + raw = str(tag or "").strip() + if not raw: + return None + normalized = raw.replace("_", "-") + parts = [part for part in normalized.split("-") if part] + if not parts: + return None + + language = parts[0].lower() + rest: List[str] = [] + for part in parts[1:]: + if len(part) == 4 and part.isalpha(): + rest.append(part.title()) + elif len(part) in {2, 3} and part.isalpha(): + rest.append(part.upper()) + else: + rest.append(part) + return "-".join([language, *rest]) + + +def _preferred_language_from_text(text: str) -> Optional[str]: + script_counts: Dict[str, int] = {} + for ch in str(text or ""): + if ch.isspace(): + continue + try: + name = unicodedata.name(ch) + except ValueError: + continue + for script in ("HIRAGANA", "KATAKANA", "HANGUL", "CJK UNIFIED IDEOGRAPH", "LATIN", "CYRILLIC", "THAI"): + if script in name: + script_counts[script] = script_counts.get(script, 0) + 1 + break + + if script_counts.get("HIRAGANA", 0) or script_counts.get("KATAKANA", 0): + return "ja-JP" + if script_counts.get("HANGUL", 0): + return "ko-KR" + if script_counts.get("THAI", 0): + return "th-TH" + if script_counts.get("CYRILLIC", 0): + return "ru-RU" + if script_counts.get("LATIN", 0) and not script_counts.get("CJK UNIFIED IDEOGRAPH", 0): + return "en-US" + if script_counts.get("CJK UNIFIED IDEOGRAPH", 0): + return "zh-TW" + return None + + +def _resolve_conversation_language( + user_message: str, + requested_language: Optional[str], + locale_hint: Optional[str] = None, +) -> str: + explicit = _normalize_bcp47_language_tag(requested_language) + if explicit and explicit.lower() != "auto": + return explicit + + inferred = _preferred_language_from_text(user_message) + if inferred: + return inferred + + locale_tag = _normalize_bcp47_language_tag(locale_hint) + if locale_tag and locale_tag.lower() != "auto": + return locale_tag + + return "zh-TW" + + # ----------------------------- # Pydantic 模型(從統一模組導入) # ----------------------------- @@ -322,33 +396,6 @@ app.add_middleware( ) # CSP Middleware(允許內嵌 script 用於語音沉浸式前端) -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request as StarletteRequest -from starlette.responses import Response - -class CSPMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: StarletteRequest, call_next): - response = await call_next(request) - # 對所有靜態檔案路徑添加寬鬆的 CSP header(用於語音沉浸式前端) - if request.url.path.startswith("/static/"): - # 移除可能存在的嚴格 CSP - if "Content-Security-Policy" in response.headers: - del response.headers["Content-Security-Policy"] - - # 設定寬鬆的 CSP 以允許內嵌 script - response.headers["Content-Security-Policy"] = ( - "default-src 'self'; " - "script-src 'self' 'unsafe-inline' 'unsafe-eval' https://accounts.google.com https://www.gstatic.com; " - "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; " - "font-src 'self' https://fonts.gstatic.com data:; " - "connect-src 'self' ws: wss: https://accounts.google.com; " - "img-src 'self' data: https: blob:; " - "media-src 'self' blob: data:; " - "frame-src https://accounts.google.com; " - "base-uri 'self';" - ) - return response - app.add_middleware(CSPMiddleware) # 掛載靜態檔案目錄(語音沉浸式前端) @@ -495,6 +542,18 @@ async def websocket_endpoint_with_jwt( td = app.state.feature_router.get_current_time_data() # 使用語音登入傳遞的情緒(如果有) + + # 如果登入情緒是極端情緒,自動啟動關懷模式 + is_care_active = False + if emotion in ["sad", "angry", "fear"]: + from core.emotion_care_manager import EmotionCareManager + # 使用 force=True 確保從登入情緒直接進入,不需等待連續偵測 + is_care_active = EmotionCareManager.check_and_enter_care_mode( + user_id, emotion, chat_id=current_chat_id, force=True + ) + if is_care_active: + logger.info(f"💙 偵測到登入情緒 [{emotion}],自動啟動關懷模式 (user_id={user_id})") + welcome_msg = compose_welcome( user_name=user_info.get('name'), time_data=td, @@ -506,19 +565,35 @@ async def websocket_endpoint_with_jwt( welcome_msg = f"歡迎回來,{user_info['name']}!" # 發送歡迎訊息,並附帶 chat_id + # 通知前端當前情緒與關懷模式狀態 + await websocket.send_json({ + "type": "emotion_detected", + "emotion": emotion or "neutral", + "care_mode": is_care_active if 'is_care_active' in locals() else False + }) + await websocket.send_json({ "type": "system", "message": welcome_msg, - "chat_id": current_chat_id + "chat_id": current_chat_id, + "care_mode": is_care_active if 'is_care_active' in locals() else False }) while True: - data = await websocket.receive_text() + try: + # 🎯 2026 穩定性優化:加入 WebSocket 心跳 (Ping) 機制,防止長思考工具調用導致連線中斷 + data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0) + except asyncio.TimeoutError: + try: + await websocket.send_json({"type": "ping", "timestamp": time.time()}) + continue + except Exception: + break # 連線已失效 + try: message_data = json.loads(data) message_type_raw = message_data.get("type", "") message_type = (message_type_raw or "").strip().lower() - # 更新最後活動時間 manager.user_sessions[user_id]["last_activity"] = datetime.now() @@ -527,6 +602,7 @@ async def websocket_endpoint_with_jwt( if not user_message: await manager.send_message("收到空消息", user_id, "error") continue + message_language = message_data.get("language") or "auto" chat_id = message_data.get("chat_id", None) @@ -573,9 +649,8 @@ async def websocket_endpoint_with_jwt( "content": ( "你是一個友善、有禮且能夠提供幫助的AI助手。\n\n" "【重要】語言使用規範:\n" - "- 回覆用戶時:必須使用繁體中文,保持簡潔清晰的表達\n" + "- 回覆用戶時:必須使用對應的語言,保持簡潔清晰的表達\n" "- 調用工具時:所有參數必須使用英文(城市名、國家名、貨幣代碼等)\n\n" - "另外,請勿自稱為 GPT-4 或其他版本。若需要自我介紹,請表述為 '基於 gpt-5-nano 模型'。" ), }, {"role": "user", "content": user_message}, @@ -588,7 +663,28 @@ async def websocket_endpoint_with_jwt( try: logger.info(f"🚀 開始處理訊息: user_id={user_id}, chat_id={chat_id}") - async def _on_text_emotion(em: str, cm: bool): + async def _on_text_emotion(em: str, cm: bool, payload: Optional[Dict[str, Any]] = None): + if em == "__bot_delta__" and payload: + await websocket.send_json({ + "type": "bot_delta", + "message_id": payload.get("message_id"), + "delta": payload.get("delta", ""), + "text": payload.get("text", ""), + "temporary": True, + "phase": payload.get("phase", "answering"), + "timestamp": time.time(), + }) + return + if em == "__bot_status__" and payload: + await websocket.send_json({ + "type": "bot_status", + "status": payload.get("status", "processing"), + "message": payload.get("message", "正在處理..."), + "temporary": True, + "phase": payload.get("phase", payload.get("status", "processing")), + "timestamp": time.time(), + }) + return logger.info(f"📤 [即時回調] 發送 text emotion_detected: {em}, care_mode={cm}") await websocket.send_json({ "type": "emotion_detected", @@ -596,7 +692,7 @@ async def websocket_endpoint_with_jwt( "care_mode": cm }) - response = await handle_message(user_message, user_id, chat_id, messages_for_handler, request_id=request_id, emotion_callback=_on_text_emotion) + response = await handle_message(user_message, user_id, chat_id, messages_for_handler, request_id=request_id, language=message_language, emotion_callback=_on_text_emotion) logger.info(f"📥 handle_message 返回: type={type(response)}, response={response}") # 【優化】處理空回應:轉換為帶情緒的 dict 格式 @@ -673,8 +769,7 @@ async def websocket_endpoint_with_jwt( except Exception as e: logger.exception(f"❌ _do_process_and_send 發生異常: {e}") - import asyncio as _asyncio - _asyncio.create_task(_do_process_and_send()) + asyncio.create_task(_do_process_and_send()) elif message_type == "env_snapshot": try: @@ -752,7 +847,16 @@ async def websocket_endpoint_with_jwt( sr = 16000 if mode == "realtime_chat": - # === 即時轉錄模式(使用 OpenAI Realtime API)=== + # 🎯 中斷 Barge-in:如果正在處理上一個回覆,立即取消 + await manager.cancel_user_tasks(user_id) + + # 🎯 2026 穩定性優化:每次開始對話時清除上一次的音頻緩衝與轉錄,防止「語音殘留」污染下一次識別 + client_info = manager.get_client_info(user_id) or {} + client_info["audio_buffer"] = b"" + client_info["realtime_transcript"] = "" + manager.set_client_info(user_id, client_info) + + # === 即時轉錄模式(使用 Google Speech-to-Text)=== try: from services.realtime_stt_service import RealtimeSTTService @@ -785,21 +889,32 @@ async def websocket_endpoint_with_jwt( client_info["realtime_transcript"] = full_text manager.set_client_info(user_id, client_info) - async def on_vad_committed(item_id: str): - """VAD 偵測到語音段結束""" - logger.debug(f"🎤 VAD Committed: {item_id}") + async def on_vad_committed(status: str): + """VAD 偵測到語音狀態變化""" + if status == "error": + await websocket.send_json({ + "type": "error", + "message": "語音識別服務異常 (Stream Timeout),正在自動重置環境..." + }) + else: + await websocket.send_json({ + "type": "stt_status", + "status": status, + "timestamp": time.time() + }) + logger.debug(f"🎤 VAD Status: {status}") # 從前端獲取語言設定(支援:zh, en, id, ja, vi,或 auto 自動檢測) language = message_data.get("language", "auto") logger.info(f"🌐 語言設定: {language}") - # 連線到 OpenAI Realtime API + # 連線到 Google Speech-to-Text 服務 success = await realtime_stt.connect( on_transcript_delta=on_transcript_delta, on_transcript_done=on_transcript_done, on_vad_committed=on_vad_committed, - model="gpt-4o-mini-transcribe", - language=language + language=language, + sample_rate=sr ) if success: @@ -812,11 +927,12 @@ async def websocket_endpoint_with_jwt( await websocket.send_json({ "type": "realtime_stt_status", "status": "connected", - "message": "即時轉錄已啟動" + "message": "即時轉錄已啟動", + "language": language, }) logger.info(f"✅ 用戶 {user_id} 即時轉錄已啟動") else: - raise Exception("無法連接到 OpenAI Realtime API") + raise Exception("無法連接到 Google Speech-to-Text") except Exception as e: logger.error(f"❌ 啟動即時轉錄失敗: {e}") @@ -845,16 +961,22 @@ async def websocket_endpoint_with_jwt( realtime_stt = client_info.get("realtime_stt") if realtime_stt and b64: - # === 即時轉錄模式:轉發到 OpenAI Realtime API === + # === 即時轉錄模式:轉發到 Google Speech-to-Text 緩衝 === try: import base64 audio_bytes = base64.b64decode(b64) await realtime_stt.send_audio_chunk(audio_bytes) - logger.debug(f"🎤 轉發音頻到 OpenAI: {len(audio_bytes)} bytes") + logger.debug(f"🎤 轉發音頻到 Google STT: {len(audio_bytes)} bytes") # 同時儲存到本地緩衝(用於音頻情緒辨識) + # 🎯 效能與記憶體優化:實施滑動窗口(Sliding Window),僅保留最近 15 秒音頻 + # 16000Hz * 2bytes/sample * 15s = 480,000 bytes + MAX_BUFFER_SIZE = 480000 audio_buffer = client_info.get("audio_buffer", b"") audio_buffer += audio_bytes + if len(audio_buffer) > MAX_BUFFER_SIZE: + audio_buffer = audio_buffer[-MAX_BUFFER_SIZE:] + client_info["audio_buffer"] = audio_buffer manager.set_client_info(user_id, client_info) @@ -1000,7 +1122,10 @@ async def websocket_endpoint_with_jwt( td = app.state.feature_router.get_current_time_data() name = user.get("name") or "用戶" emo = result.get("emotion") or {} - emo_label = str(emo.get("label") or "") + emo_label = str(emo.get("label") or "neutral") + # 🎯 使用原始標籤(包含中文)來生成歡迎詞,確保「心情低落」等詞彙能正確匹配 + raw_emo_label = str(emo.get("raw_label") or emo_label) + tz_hint = None try: env_res = await get_user_env_current(user_id) @@ -1011,7 +1136,7 @@ async def websocket_endpoint_with_jwt( welcome = compose_welcome( user_name=name, time_data=td, - emotion_label=emo_label, + emotion_label=raw_emo_label, timezone=tz_hint, ) except Exception: @@ -1030,6 +1155,13 @@ async def websocket_endpoint_with_jwt( logger.error(f"生成 JWT token 失敗: {e}") access_token = None + await websocket.send_json({ + "type": "emotion_detected", + "emotion": emo_label, + "care_mode": False, + "source": "voice_login" + }) + await websocket.send_json({ "type": "voice_login_result", "success": True, @@ -1067,7 +1199,7 @@ async def websocket_endpoint_with_jwt( }) elif mode == "realtime_chat": - # === 即時轉錄模式:關閉 OpenAI Realtime 連線並處理轉錄結果 === + # === 即時轉錄模式:提交 Google STT 並處理轉錄結果 === try: client_info = manager.get_client_info(user_id) or {} realtime_stt = client_info.get("realtime_stt") @@ -1075,6 +1207,24 @@ async def websocket_endpoint_with_jwt( audio_buffer = client_info.get("audio_buffer", b"") if realtime_stt: + await websocket.send_json({ + "type": "stt_status", + "status": "transcribing", + "timestamp": time.time() + }) + logger.info(f"📝 等待即時轉錄 final,用戶 {user_id}") + final_transcript = await realtime_stt.wait_for_final_transcript(timeout=3.5) + if final_transcript and final_transcript != client_info.get("realtime_transcript"): + transcription = final_transcript + client_info["realtime_transcript"] = final_transcript + await websocket.send_json({ + "type": "stt_final", + "text": final_transcript, + "timestamp": time.time() + }) + elif final_transcript: + transcription = final_transcript + logger.info(f"🔌 關閉即時轉錄連線,用戶 {user_id}") await realtime_stt.disconnect() @@ -1099,150 +1249,128 @@ async def websocket_endpoint_with_jwt( # 立即通知前端開始思考,提升即時響應感 await websocket.send_json({"type": "typing", "message": "thinking"}) - # === 方案 B:語音情緒辨識(情緒分佈驗證 + 智能回退)=== - audio_emotion = None - if audio_buffer and len(audio_buffer) >= 16000 * 2: # 至少 1 秒 - try: - logger.info(f"🎭 開始語音情緒辨識,音訊長度: {len(audio_buffer)} bytes") - emotion_result = await predict_emotion_from_audio(audio_buffer, sample_rate=16000) - - if emotion_result.get("success"): - emotion_label = emotion_result.get("emotion", "neutral") - confidence = emotion_result.get("confidence", 0.0) - all_emotions = emotion_result.get("all_emotions", {}) - - # 計算 top-1 與 top-2 的 margin - sorted_emotions = sorted(all_emotions.items(), key=lambda x: x[1], reverse=True) - margin = sorted_emotions[0][1] - sorted_emotions[1][1] if len(sorted_emotions) >= 2 else confidence - - # 方案 B 判斷邏輯 - use_audio_emotion = False - reason = "" - - if emotion_label == "neutral": - # neutral 需要更高置信度,但 margin 可較寬鬆 - if confidence >= 0.55 and margin >= 0.12: - use_audio_emotion = True - reason = f"neutral 高信心 (conf={confidence:.3f}, margin={margin:.3f})" - else: - reason = f"neutral 信心不足 (conf={confidence:.3f}, margin={margin:.3f}) → 回退文字" - else: - # 非 neutral 需要足夠 confidence 與 margin - if confidence >= 0.48 and margin >= 0.18: - use_audio_emotion = True - reason = f"{emotion_label} 高信心 (conf={confidence:.3f}, margin={margin:.3f})" - else: - reason = f"{emotion_label} 信心不足 (conf={confidence:.3f}, margin={margin:.3f}) → 回退文字" - - if use_audio_emotion: - audio_emotion = emotion_result - logger.info(f"✅ 使用語音情緒: {emotion_label}, {reason}") - else: - audio_emotion = None - logger.info(f"📝 {reason}") - else: - logger.warning(f"⚠️ 語音情緒辨識失敗: {emotion_result.get('error')}") - except Exception as e: - logger.error(f"❌ 語音情緒辨識異常: {e}") - audio_emotion = None - - # 清理音頻緩衝 - if audio_buffer: - client_info.pop("audio_buffer", None) - manager.set_client_info(user_id, client_info) - - # 異步處理對話邏輯 - async def _process_realtime_chat(): + # === 優化:並行處理語音情緒辨識與 Agent 邏輯 === + async def _get_audio_emotion(): + if audio_buffer and len(audio_buffer) >= 16000 * 1.5: # 至少 1.5 秒 + try: + logger.info(f"🎭 [Parallel] 開始語音情緒辨識,長度: {len(audio_buffer)} bytes") + res = await predict_emotion_from_audio(audio_buffer, sample_rate=16000) + if res.get("success"): + # 簡單過濾低信心度 + if res.get("confidence", 0) > 0.4: + res["source"] = "realtime_voice" + return res + except Exception as e: + logger.error(f"❌ 語音情緒辨識異常: {e}") + return {"success": False, "source": "realtime_voice"} + + # 定義一個內部函數來處理整個流程,稍後將其作為 Task 執行 + async def _process_full_request(): + # 啟動並行任務 + emotion_task = asyncio.create_task(_get_audio_emotion()) + + # 同步準備對話 ID 等基礎工作 (這些很快,不需額外 Task) chat_id = message_data.get("chat_id") - - # 如果沒有 chat_id,創建新對話 if not chat_id: try: user_chats_result = await get_user_chats(user_id) if user_chats_result["success"] and user_chats_result["chats"]: - latest_chat = user_chats_result["chats"][0] - chat_id = latest_chat["chat_id"] + chat_id = user_chats_result["chats"][0]["chat_id"] else: - chat_title = f"語音對話 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - chat_result = await create_chat(user_id, chat_title) - if chat_result["success"]: - chat_id = chat_result["chat"]["chat_id"] + chat_title = f"語音對話 {datetime.now().strftime('%Y-%m-%d %H:%M')}" + chat_res = await create_chat(user_id, chat_title) + if chat_res["success"]: + chat_id = chat_res["chat"]["chat_id"] except Exception as e: - logger.error(f"創建對話失敗: {e}") - await websocket.send_json({"type": "error", "message": "無法創建對話"}) - return - - # 保存用戶訊息 - await save_message_to_db(user_id, chat_id, "user", transcription) - - # 取得語言設定 + logger.error(f"準備對話時出錯: {e}") + + # 保存用戶訊息(異步) + if chat_id: + asyncio.create_task(save_message_to_db(user_id, chat_id, "user", transcription)) + + # 等待情緒分析結果(如果還沒出的話,這邊會稍微等一下,但通常會與 GPT 意圖偵測並行) + # 為了讓意圖偵測儘早開始,我們甚至可以在 handle_message 內部去 await emotion_task + # 但這裡我們採取簡單策略:先啟動 handle_message,並傳入 task 或稍後合併 + + # 實際上,handle_message 內部會做 GPT 意圖偵測,這耗時最久。 + # 我們讓 handle_message 帶入 emotion_task + language = client_info.get("language", "auto") + + async def _on_emotion_detected(em: str, cm: bool, payload: Optional[Dict[str, Any]] = None): + # ... (原有回調邏輯) + if em.startswith("__bot_"): # 狀態回調 + await websocket.send_json({"type": em.replace("__", ""), **(payload or {})}) + return + await websocket.send_json({"type": "emotion_detected", "emotion": em, "care_mode": cm}) - # 發送即時情緒的回調函數 - async def _on_emotion_detected(em: str, cm: bool): - logger.info(f"📤 [即時回調] 發送 emotion_detected: {em}, care_mode={cm}") - await websocket.send_json({ - "type": "emotion_detected", - "emotion": em, - "care_mode": cm - }) + # 等待情緒完成,或者設定超時 + try: + audio_emotion_res = await asyncio.wait_for(emotion_task, timeout=2.0) + except asyncio.TimeoutError: + logger.warning("⚠️ 語音情緒辨識超時,回退至文字分析") + audio_emotion_res = {"success": False} - # 處理對話(透過 handle_message,自動處理 pipeline) + # 執行 Agent 邏輯 response = await handle_message( transcription, user_id, chat_id, - [], # messages 參數(會自動從數據庫載入) - audio_emotion=audio_emotion, # 傳遞音頻情緒 - language=language, # 傳遞語言設定(新增) + [], + audio_emotion=audio_emotion_res, + language=language, emotion_callback=_on_emotion_detected ) - # 發送回應 - # 從 PipelineResult 提取情緒 - emotion = None - care_mode = False + # 發送結果 if isinstance(response, PipelineResult): - message_text = response.text - if response.meta: - emotion = response.meta.get('emotion') - care_mode = response.meta.get('care_mode', False) - await websocket.send_json({ "type": "bot_message", - "message": message_text, + "message": response.text, "timestamp": time.time(), - "tool_name": None, - "tool_data": None, - "emotion": emotion, - "care_mode": care_mode + "emotion": response.meta.get("emotion") if response.meta else "neutral", + "care_mode": response.meta.get("care_mode", False) if response.meta else False, + "language": language, }) + # 保存 AI 訊息 + if chat_id: + asyncio.create_task(save_message_to_db(user_id, chat_id, "assistant", response.text)) elif isinstance(response, dict): - tool_name = response.get('tool_name') - tool_data = response.get('tool_data') - emotion = response.get('emotion') - message_text = response.get('message', response.get('content', '')) - await websocket.send_json({ "type": "bot_message", - "message": message_text, + "message": response.get("message", response.get("content", "")), "timestamp": time.time(), - "tool_name": tool_name, - "tool_data": tool_data, - "emotion": emotion + "tool_name": response.get("tool_name"), + "tool_data": response.get("tool_data"), + "emotion": response.get("emotion", "neutral"), + "care_mode": response.get("care_mode", False), + "language": response.get("language", language), }) + if chat_id: + asyncio.create_task(save_message_to_db(user_id, chat_id, "assistant", str(response.get("message", "")))) else: # 字串回應 await websocket.send_json({ "type": "bot_message", "message": str(response), - "timestamp": time.time(), - "emotion": None + "timestamp": time.time() }) + if chat_id: + asyncio.create_task(save_message_to_db(user_id, chat_id, "assistant", str(response))) + + # 啟動整體處理任務 + # 🎯 註冊任務,以便支援 Barge-in 中斷 + task = asyncio.create_task(_process_full_request()) + manager.register_task(user_id, task) + + # 清理緩衝區並直接返回循環,讓連線保持暢通 + if audio_buffer: + client_info.pop("audio_buffer", None) + manager.set_client_info(user_id, client_info) - await _process_realtime_chat() else: logger.debug(f"沒有轉錄文字,返回待機狀態") + await websocket.send_json({"type": "stt_status", "status": "idle"}) except Exception as e: logger.error(f"❌ 關閉即時轉錄失敗: {e}") @@ -1270,20 +1398,26 @@ async def websocket_endpoint_with_jwt( # 消息處理與AI # ----------------------------- async def handle_message(user_message, user_id, chat_id, messages, request_id: str = None, audio_emotion: dict = None, language: str = None, emotion_callback=None): + user_message = (user_message or "").strip() logger.info(f"📥 handle_message: 收到訊息='{user_message}', user_id={user_id}, audio_emotion={audio_emotion}, language={language}") + resolved_language = _resolve_conversation_language(user_message, language) + if not user_message: + logger.info(f"🚫 攔截到空請求,中斷交由 Agent 處理。") + return "不好意思,我剛剛沒有聽清楚或是沒收到內容,可以請您再說一次嗎?" + # 指令優先,避免進入管線造成不必要延遲 - if user_message and user_message.startswith("/"): + if user_message.startswith("/"): cmd = await handle_command(user_message, user_id) if cmd: return cmd feature_router:MCPAgentBridge = app.state.feature_router - async def _detect(msg: str): + async def _detect(msg: str, tool_context: str = "", language: str = None, **kwargs): logger.info(f"🎯 Pipeline: 開始意圖偵測,訊息='{msg}'") try: - result = await feature_router.detect_intent(msg) + result = await feature_router.detect_intent(msg, tool_context=tool_context, language=language or resolved_language) logger.info(f"🎯 Pipeline: 意圖偵測結果={result}") return result except Exception as e: @@ -1301,7 +1435,59 @@ async def handle_message(user_message, user_id, chat_id, messages, request_id: s logger.info(f"🔧 Pipeline: 功能處理結果='{result}'") return result - async def _ai(messages_in, cid, model, rid, chat_id, use_care_mode=False, care_emotion=None, emotion_label=None, language=None): + stream_message_id = request_id or f"stream_{int(time.time() * 1000)}" + stream_accumulator = {"text": ""} + + async def _emit_bot_status(status: str, message: str, phase: Optional[str] = None): + if emotion_callback is None: + return + try: + await emotion_callback( + "__bot_status__", + False, + { + "status": status, + "message": message, + "phase": phase or status, + "temporary": True, + }, + ) + except TypeError: + return + + async def _on_ai_chunk(delta: Any): + if not delta or emotion_callback is None: + return + if isinstance(delta, dict): + delta.setdefault("temporary", True) + delta.setdefault("phase", delta.get("status", "processing")) + try: + await emotion_callback("__bot_status__", False, delta) + except TypeError: + return + return + + stream_accumulator["text"] += str(delta) + try: + await emotion_callback( + "__bot_delta__", + False, + { + "message_id": stream_message_id, + "delta": str(delta), + "text": stream_accumulator["text"], + "language": _preferred_language_from_text(stream_accumulator["text"]) or resolved_language, + "phase": "answering", + "temporary": True, + }, + ) + except TypeError: + return + + async def _ai(messages_in, cid, model, rid, chat_id, use_care_mode=False, care_emotion=None, emotion_label=None, language=None, tool_context: str = "", is_first_care: bool = False): + # 【效能優化】立即通知前端進入 AI 生成階段,這會刷新思考超時 + await _emit_bot_status("generating", "正在組織語言回答您...", "thinking") + env_context = {} env_service = getattr(app.state, 'env_service', None) if env_service: @@ -1320,37 +1506,49 @@ async def handle_message(user_message, user_id, chat_id, messages, request_id: s logger.debug(f"無法取得用戶名稱,使用預設值: {e}") # 使用傳入的 language 參數(優先)或閉包捕獲的外部變數 - lang = language if language is not None else globals().get('language', 'zh') + lang = language if language is not None else resolved_language # 兼容:如果傳入字串,視為 user_message;如果傳入 list,視為 messages - if isinstance(messages_in, str): - return await ai_service.generate_response_for_user( - user_message=messages_in, - user_id=cid, - model=model, - request_id=rid, - chat_id=chat_id, - use_care_mode=use_care_mode, - care_emotion=care_emotion, - user_name=user_name, - emotion_label=emotion_label, - env_context=env_context, - language=lang, - ) - else: - return await ai_service.generate_response_for_user( - messages=messages_in, - user_id=cid, - model=model, - request_id=rid, - chat_id=chat_id, - use_care_mode=use_care_mode, - care_emotion=care_emotion, - user_name=user_name, - emotion_label=emotion_label, - env_context=env_context, - language=lang, - ) + try: + if isinstance(messages_in, str): + return await ai_service.generate_response_for_user( + user_message=messages_in, + user_id=cid, + model=model, + request_id=rid, + chat_id=chat_id, + use_care_mode=use_care_mode, + care_emotion=care_emotion, + user_name=user_name, + emotion_label=emotion_label, + env_context=env_context, + language=lang, + stream=bool(emotion_callback), + on_chunk=_on_ai_chunk if emotion_callback else None, + tool_context=tool_context, + is_first_care=is_first_care, + ) + else: + return await ai_service.generate_response_for_user( + messages=messages_in, + user_id=cid, + model=model, + request_id=rid, + chat_id=chat_id, + use_care_mode=use_care_mode, + care_emotion=care_emotion, + user_name=user_name, + emotion_label=emotion_label, + env_context=env_context, + language=lang, + stream=bool(emotion_callback), + on_chunk=_on_ai_chunk if emotion_callback else None, + tool_context=tool_context, + is_first_care=is_first_care, + ) + except Exception as e: + logger.error(f"AI 生成過程出錯: {e}") + raise model = settings.OPENAI_MODEL # 簡化 Pipeline:移除未使用的記憶管理和摘要決策 @@ -1360,12 +1558,13 @@ async def handle_message(user_message, user_id, chat_id, messages, request_id: s _process_feature, _ai, model=model, - detect_timeout=10.0, # 意圖檢測超時 (15 → 10) + detect_timeout=25.0, # 意圖檢測超時:保留 Agent 自主工具判斷空間 feature_timeout=30.0, # 功能處理超時 (15 → 30,新聞摘要生成需要更長時間) - ai_timeout=20.0, # AI回應超時 (30 → 20) + ai_timeout=60.0, # AI回應超時:hosted WebSearch/工具階段可能先無文字 delta ) - logger.info(f"⚙️ 準備調用 ChatPipeline.process,user_message='{user_message}', audio_emotion={audio_emotion}, language={language}") - res: PipelineResult = await pipeline.process(user_message, user_id=user_id, chat_id=chat_id, request_id=request_id, audio_emotion=audio_emotion, language=language, emotion_callback=emotion_callback) + logger.info(f"⚙️ 準備調用 ChatPipeline.process,user_message='{user_message}', audio_emotion={audio_emotion}, language={resolved_language}") + await _emit_bot_status("planning", "已收到,正在規劃處理方式...", "planning") + res: PipelineResult = await pipeline.process(user_message, user_id=user_id, chat_id=chat_id, request_id=request_id, audio_emotion=audio_emotion, language=resolved_language, emotion_callback=emotion_callback) logger.info(f"⚙️ ChatPipeline.process 完成,結果='{res.text}', is_fallback={res.is_fallback}, reason={res.reason}") # 檢查是否有工具元數據 @@ -1408,6 +1607,7 @@ async def handle_message(user_message, user_id, chat_id, messages, request_id: s # 提取情緒與關懷模式資訊(新增) emotion = res.meta.get('emotion') if res.meta else None care_mode = res.meta.get('care_mode', False) if res.meta else False + final_language = _preferred_language_from_text(res.text) or _normalize_bcp47_language_tag(language) or "zh-TW" logger.info(f"🎭 handle_message 情緒: emotion={emotion}, care_mode={care_mode}, meta={res.meta}") @@ -1420,7 +1620,8 @@ async def handle_message(user_message, user_id, chat_id, messages, request_id: s 'tool_name': tool_name, 'tool_data': tool_data, 'emotion': final_emotion, - 'care_mode': care_mode + 'care_mode': care_mode, + 'language': final_language, } @@ -1800,6 +2001,7 @@ async def voice_login(request: VoiceLoginRequest): if not result.get("success"): error_code = result.get("error", "UNKNOWN_ERROR") + quality_warnings = result.get("quality_warnings") or [] error_messages = { "NO_AUDIO": "沒有收到音訊資料", "AUDIO_TOO_SHORT": "音訊太短,請錄製至少 3 秒", @@ -1808,11 +2010,15 @@ async def voice_login(request: VoiceLoginRequest): "THRESHOLD_NOT_MET": "無法確認身份,請重試", "MODEL_ERROR": "辨識系統錯誤,請稍後重試", } - logger.warning(f"🎙️ 語音辨識失敗: {error_code}") + logger.warning(f"🎙️ 語音辨識失敗: {error_code} quality_warnings={quality_warnings}") return JSONResponse(content={ "success": False, "error": error_messages.get(error_code, f"辨識失敗:{error_code}") }) + + quality_warnings = result.get("quality_warnings") or [] + if quality_warnings: + logger.warning(f"🎙️ 語音登入品質警告(未阻擋): {quality_warnings}") # 取得辨識結果 speaker_label = result.get("label") @@ -2146,13 +2352,13 @@ async def analyze_image_with_gpt_vision(filename: str, image_base64: str, mime_t } ] try: - response = ai_service.client.chat.completions.create( - model="gpt-5-nano", + analysis = await ai_service.generate_response_for_user( messages=messages, - max_completion_tokens=1500, - reasoning_effort="medium" # 圖片分析需要較深入理解,使用 medium + user_id="image_analysis", + chat_id=None, + max_tokens=1500, + reasoning_effort="medium", ) - analysis = response.choices[0].message.content return analysis except Exception as e: logger.error(f"GPT Vision分析錯誤: {str(e)}") @@ -2401,6 +2607,12 @@ class TTSRequest(BaseModel): text: str voice: Optional[str] = "nova" speed: Optional[float] = 1.0 + language: Optional[str] = None + persona: Optional[str] = "xiaohua" + mode: Optional[str] = "standard" + speaking_rate: Optional[float] = None + markup: Optional[str] = None + custom_pronunciations: Optional[list[dict]] = None @app.post("/api/tts") @@ -2430,7 +2642,10 @@ async def synthesize_speech( content={"success": False, "error": "文字長度必須在 1-4096 字元之間"} ) - valid_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + valid_voices = [ + "alloy", "echo", "fable", "onyx", "nova", "shimmer", "coral", + "zh-tw", "zh-cn", "en-us", "ja-jp", "ko-kr", "id-id", "vi-vn" + ] if request.voice not in valid_voices: return JSONResponse( status_code=400, @@ -2443,10 +2658,21 @@ async def synthesize_speech( content={"success": False, "error": "語速必須在 0.25 到 4.0 之間"} ) - logger.info(f"🔊 TTS 請求: text={request.text[:50]}..., voice={request.voice}, speed={request.speed}") + logger.info( + f"🔊 TTS 請求: text={request.text[:50]}..., voice={request.voice}, speed={request.speed}, " + f"language={request.language}, persona={request.persona}, speaking_rate={request.speaking_rate}, " + f"has_markup={bool(request.markup)}, custom_pronunciations={len(request.custom_pronunciations or [])}" + ) # 調用 TTS 服務獲取完整音頻 - result = await text_to_speech(request.text, request.voice, request.speed) + result = await text_to_speech( + request.text, + request.voice, + request.speed, + language=request.language, + persona=request.persona, + speaking_rate=request.speaking_rate, + ) if not result.get("success"): return JSONResponse( @@ -2474,6 +2700,93 @@ async def synthesize_speech( ) +@app.websocket("/ws/tts") +async def tts_stream_websocket(websocket: WebSocket): + await websocket.accept() + stream_started_at = time.perf_counter() + total_chunks = 0 + total_bytes = 0 + first_chunk_at = None + try: + payload = await websocket.receive_json() + text = str(payload.get("text") or "").strip() + voice = str(payload.get("voice") or "nova") + speed = float(payload.get("speed") or 1.0) + language = payload.get("language") + persona = payload.get("persona") or "xiaohua" + speaking_rate = payload.get("speaking_rate") + markup = payload.get("markup") + custom_pronunciations = payload.get("custom_pronunciations") + emotion = payload.get("emotion") + care_mode = bool(payload.get("care_mode", False)) + + if not text and not markup: + await websocket.send_json({"type": "tts_error", "error": "文字不可為空"}) + await websocket.close() + return + + from services.tts_service import tts_service + + await websocket.send_json({ + "type": "tts_stream_start", + "sample_rate": 24000, + "encoding": "LINEAR16", + "persona": persona, + "language": language, + }) + + async for chunk in tts_service.streaming_synthesize( + text=text, + voice=voice, + speed=speed, + language=language, + persona=persona, + speaking_rate=speaking_rate, + markup=markup, + custom_pronunciations=custom_pronunciations, + emotion=emotion, + care_mode=care_mode, + ): + total_chunks += 1 + total_bytes += len(chunk) + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + await websocket.send_json({ + "type": "tts_audio_chunk", + "audio_base64": base64.b64encode(chunk).decode("ascii"), + }) + + logger.debug( + "📡 TTS 串流傳送完成: chunks=%d bytes=%d first_chunk_delay=%s total_elapsed=%.2fs", + total_chunks, + total_bytes, + f"{(first_chunk_at - stream_started_at):.2f}s" if first_chunk_at is not None else "none", + time.perf_counter() - stream_started_at, + ) + await websocket.send_json({"type": "tts_stream_end"}) + except WebSocketDisconnect: + logger.debug( + "🔌 TTS 串流連線已由客戶端關閉: chunks=%d bytes=%d first_chunk_delay=%s total_elapsed=%.2fs", + total_chunks, + total_bytes, + f"{(first_chunk_at - stream_started_at):.2f}s" if first_chunk_at is not None else "none", + time.perf_counter() - stream_started_at, + ) + except Exception as e: + error_detail = f"{type(e).__name__}: {str(e) or repr(e)}" + logger.error(f"❌ TTS 串流失敗: {error_detail}") + logger.exception("TTS 串流詳細錯誤堆疊:") + try: + await websocket.send_json({"type": "tts_error", "error": str(e)}) + except Exception: + pass + finally: + try: + await websocket.close() + except Exception: + pass + + # ----------------------------- # MCP Tools API diff --git a/bloom-ware-login/components/login-form.tsx b/bloom-ware-login/components/login-form.tsx index 7c82df4a2d7a4f31ba9b8ae2655325d2434715da..569a6a9b927575c5c31918d8e5f2ac07872819d5 100644 --- a/bloom-ware-login/components/login-form.tsx +++ b/bloom-ware-login/components/login-form.tsx @@ -238,26 +238,66 @@ export function LoginForm() { setVoiceStatus('請求麥克風權限...'); console.log('🎤 開始語音登入...'); + let stream: MediaStream | null = null; + let audioContext: AudioContext | null = null; + let source: MediaStreamAudioSourceNode | null = null; + let processor: AudioWorkletNode | null = null; + + const cleanupAudio = async () => { + if (processor) { + processor.port.onmessage = null; + processor.disconnect(); + processor = null; + } + + if (source) { + source.disconnect(); + source = null; + } + + if (stream) { + stream.getTracks().forEach(track => track.stop()); + stream = null; + } + + if (audioContext) { + await audioContext.close().catch(() => undefined); + audioContext = null; + } + }; + try { // 請求麥克風權限 - const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + stream = await navigator.mediaDevices.getUserMedia({ + audio: { + channelCount: 1, + sampleRate: 16000, + echoCancellation: false, + noiseSuppression: false, + autoGainControl: false, + }, + }); console.log('✅ 麥克風權限已獲取'); // 設定錄音參數 - const audioContext = new AudioContext({ sampleRate: 16000 }); - const source = audioContext.createMediaStreamSource(stream); - const processor = audioContext.createScriptProcessor(4096, 1, 1); + audioContext = new AudioContext({ sampleRate: 16000 }); + await audioContext.audioWorklet.addModule('/audio/pcm-recorder-worklet.js'); + await audioContext.resume(); + source = audioContext.createMediaStreamSource(stream); + processor = new AudioWorkletNode(audioContext, 'pcm-recorder-processor', { + numberOfInputs: 1, + numberOfOutputs: 0, + channelCount: 1, + }); - const audioChunks: Float32Array[] = []; + const audioChunks: Int16Array[] = []; const recordDuration = 4000; // 4 秒(確保足夠長度) - processor.onaudioprocess = (e) => { - const inputData = e.inputBuffer.getChannelData(0); - audioChunks.push(new Float32Array(inputData)); + processor.port.onmessage = (event) => { + audioChunks.push(new Int16Array(event.data)); }; source.connect(processor); - processor.connect(audioContext.destination); setVoiceStatus('🎙️ 錄音中... 請說話 (4秒)'); console.log('🎙️ 開始錄音 4 秒...'); @@ -266,30 +306,20 @@ export function LoginForm() { await new Promise(resolve => setTimeout(resolve, recordDuration)); // 停止錄音 - processor.disconnect(); - source.disconnect(); - stream.getTracks().forEach(track => track.stop()); - await audioContext.close(); + await cleanupAudio(); setVoiceStatus('辨識中...'); console.log('✅ 錄音完成,處理音訊...'); // 合併音訊資料 const totalLength = audioChunks.reduce((acc, chunk) => acc + chunk.length, 0); - const audioData = new Float32Array(totalLength); + const pcm16 = new Int16Array(totalLength); let offset = 0; for (const chunk of audioChunks) { - audioData.set(chunk, offset); + pcm16.set(chunk, offset); offset += chunk.length; } - - // 轉換為 PCM16 - const pcm16 = new Int16Array(audioData.length); - for (let i = 0; i < audioData.length; i++) { - const s = Math.max(-1, Math.min(1, audioData[i])); - pcm16[i] = s < 0 ? s * 0x8000 : s * 0x7FFF; - } - + // 轉換為 base64 const uint8Array = new Uint8Array(pcm16.buffer); let binary = ''; @@ -335,6 +365,7 @@ export function LoginForm() { } setIsLoading(false); setLoadingType(null); + await cleanupAudio(); } } diff --git a/bloom-ware-login/public/audio/pcm-recorder-worklet.js b/bloom-ware-login/public/audio/pcm-recorder-worklet.js new file mode 100644 index 0000000000000000000000000000000000000000..e1876c6a2ba12d3d403416c647d5edb62a4b4534 --- /dev/null +++ b/bloom-ware-login/public/audio/pcm-recorder-worklet.js @@ -0,0 +1,17 @@ +class PCMRecorderProcessor extends AudioWorkletProcessor { + process(inputs) { + const channelData = inputs[0]?.[0]; + if (channelData && channelData.length) { + const pcm16 = new Int16Array(channelData.length); + for (let i = 0; i < channelData.length; i++) { + const sample = Math.max(-1, Math.min(1, channelData[i])); + pcm16[i] = sample < 0 ? sample * 0x8000 : sample * 0x7fff; + } + this.port.postMessage(pcm16.buffer, [pcm16.buffer]); + } + + return true; + } +} + +registerProcessor('pcm-recorder-processor', PCMRecorderProcessor); diff --git a/core/ai_client.py b/core/ai_client.py index 69fcb1286613006d29fe5ad0eec20806b4d9d728..ffa73eb19f03bfe199d621d6dfc4c88454cbef74 100644 --- a/core/ai_client.py +++ b/core/ai_client.py @@ -16,6 +16,17 @@ _openai_client = None _initialized = False +def _normalize_openai_base_url(base_url: Optional[str]) -> Optional[str]: + """Normalize custom OpenAI-compatible base URLs for the Python SDK.""" + if not base_url: + return None + + normalized = base_url.rstrip("/") + if normalized.endswith("/v1"): + return normalized + return f"{normalized}/v1" + + def get_openai_client(): """ 取得 OpenAI 客戶端(單例模式) @@ -37,11 +48,16 @@ def get_openai_client(): _initialized = True return None - _openai_client = OpenAI( - api_key=api_key, - timeout=float(settings.OPENAI_TIMEOUT), - max_retries=3, - ) + client_kwargs = { + "api_key": api_key, + "timeout": float(settings.OPENAI_TIMEOUT), + "max_retries": 3, + } + normalized_base_url = _normalize_openai_base_url(settings.OPENAI_BASE_URL) + if normalized_base_url: + client_kwargs["base_url"] = normalized_base_url + + _openai_client = OpenAI(**client_kwargs) _initialized = True logger.info("✅ OpenAI 客戶端初始化成功") diff --git a/core/config.py b/core/config.py index 7eab28776135f39c787ab29ce95e00248a816a61..d6d524104c8a43091e46606a89037c8682a6d9fd 100644 --- a/core/config.py +++ b/core/config.py @@ -28,6 +28,11 @@ class Settings: _firebase_creds_base64: Optional[str] = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON_BASE64") _firebase_service_account_path: Optional[str] = os.getenv("FIREBASE_SERVICE_ACCOUNT_PATH") + # Google Cloud Speech / TTS 專用服務帳戶(可與 Firebase 不同 GCP 專案) + _google_speech_creds_json: Optional[str] = os.getenv("GOOGLE_SPEECH_CREDENTIALS_JSON") + _google_speech_creds_base64: Optional[str] = os.getenv("GOOGLE_SPEECH_SERVICE_ACCOUNT_JSON_BASE64") + _google_speech_sa_path: Optional[str] = os.getenv("GOOGLE_SPEECH_SERVICE_ACCOUNT_PATH") + @classmethod def get_firebase_credentials(cls) -> Dict[str, Any]: """ @@ -80,22 +85,111 @@ class Settings: "3. FIREBASE_SERVICE_ACCOUNT_PATH(檔案路徑)" ) + @classmethod + def try_get_google_speech_credentials(cls) -> Optional[Dict[str, Any]]: + """ + 載入 STT/TTS 專用 Google 服務帳戶 JSON(與 Firebase 分離)。 + + 若三種來源皆未設定,回傳 None;若已設定但格式錯誤則拋出 ValueError。 + """ + if cls._google_speech_creds_json: + try: + return json.loads(cls._google_speech_creds_json) + except json.JSONDecodeError as e: + raise ValueError(f"GOOGLE_SPEECH_CREDENTIALS_JSON 格式錯誤: {e}") from e + if cls._google_speech_creds_base64: + try: + decoded_bytes = base64.b64decode(cls._google_speech_creds_base64) + return json.loads(decoded_bytes.decode("utf-8")) + except Exception as e: + raise ValueError(f"GOOGLE_SPEECH_SERVICE_ACCOUNT_JSON_BASE64 解碼失敗: {e}") from e + if cls._google_speech_sa_path: + try: + with open(cls._google_speech_sa_path, "r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + raise ValueError(f"GOOGLE_SPEECH_SERVICE_ACCOUNT_PATH 檔案不存在: {cls._google_speech_sa_path}") from None + except json.JSONDecodeError as e: + raise ValueError(f"GOOGLE_SPEECH_SERVICE_ACCOUNT_PATH JSON 格式錯誤: {e}") from e + return None + + @classmethod + def resolve_speech_service_account_info(cls) -> tuple[Optional[Dict[str, Any]], str]: + """ + 解析語音 API 使用的服務帳戶:優先 GOOGLE_SPEECH_*,否則退回 Firebase 憑證(相容舊部署)。 + + Returns: + (credentials_dict | None, "speech" | "firebase" | "none") + """ + speech = cls.try_get_google_speech_credentials() + if speech is not None: + return speech, "speech" + try: + return cls.get_firebase_credentials(), "firebase" + except ValueError: + return None, "none" + + @classmethod + def get_google_speech_project_id(cls, credential_project_id: Optional[str] = None) -> str: + """ + Speech-to-Text recognizer 所屬 GCP 專案 ID。 + + 優先順序:GOOGLE_SPEECH_PROJECT_ID → GOOGLE_CLOUD_PROJECT_ID(若為純數字「專案編號」 + 且憑證 JSON 內有字串型 project_id,則改用憑證內 ID,避免誤用編號)→ + 憑證 JSON 內 project_id → FIREBASE_PROJECT_ID + """ + if cls.GOOGLE_SPEECH_PROJECT_ID.strip(): + return cls.GOOGLE_SPEECH_PROJECT_ID.strip() + cloud = cls.GOOGLE_CLOUD_PROJECT_ID.strip() + cred = (credential_project_id or "").strip() + if cloud.isdigit() and cred and not cred.isdigit(): + return cred + if cloud: + return cloud + if cred: + return cred + return cls.FIREBASE_PROJECT_ID.strip() + # ===== OpenAI 配置 ===== OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") - OPENAI_MODEL: str = os.getenv("OPENAI_MODEL", "gpt-5-nano") + OPENAI_BASE_URL: str = os.getenv("OPENAI_BASE_URL", "") + OPENAI_MODEL: str = os.getenv("OPENAI_MODEL", "gpt-5.4") OPENAI_TIMEOUT: int = int(os.getenv("OPENAI_TIMEOUT", "30")) - - # ===== Google OAuth 配置 ===== + OPENAI_RESPONSES_TIMEOUT: int = int(os.getenv("OPENAI_RESPONSES_TIMEOUT", "90")) + OPENAI_USE_RESPONSES: bool = os.getenv("OPENAI_USE_RESPONSES", "true").lower() == "true" + OPENAI_MODEL_CONTEXT_WINDOW: int = int(os.getenv("OPENAI_MODEL_CONTEXT_WINDOW", "1000000")) + OPENAI_MODEL_AUTO_COMPACT_TOKEN_LIMIT: int = int(os.getenv("OPENAI_MODEL_AUTO_COMPACT_TOKEN_LIMIT", "900000")) + OPENAI_ENABLE_WEB_SEARCH: bool = os.getenv("OPENAI_ENABLE_WEB_SEARCH", "true").lower() == "true" + OPENAI_ENABLE_REMOTE_MCP: bool = os.getenv("OPENAI_ENABLE_REMOTE_MCP", "false").lower() == "true" + OPENAI_REMOTE_MCP_SERVERS_JSON: str = os.getenv("OPENAI_REMOTE_MCP_SERVERS_JSON", "[]") + OPENAI_ENABLE_SKILLS: bool = os.getenv("OPENAI_ENABLE_SKILLS", "false").lower() == "true" + + # ===== Google OAuth(使用者「登入 Bloom Ware」用,非語音 API)===== GOOGLE_CLIENT_ID: str = os.getenv("GOOGLE_CLIENT_ID", "") GOOGLE_CLIENT_SECRET: str = os.getenv("GOOGLE_CLIENT_SECRET", "") GOOGLE_REDIRECT_URI: str = os.getenv( "GOOGLE_REDIRECT_URI", "http://localhost:8080/auth/google/callback" # 開發環境預設值 ) + # ----- Google Cloud「語音」專案(STT/TTS,例:supervisor-project;常與 Firebase 不同)----- + # GOOGLE_CLOUD_PROJECT_ID:語音相關 REST/專案語境之預設專案 ID(請填「專案 ID」字串,勿只填控制台「專案編號」) + GOOGLE_CLOUD_PROJECT_ID: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", os.getenv("FIREBASE_PROJECT_ID", "")) + # GOOGLE_SPEECH_PROJECT_ID:明確指定 STT recognizer 所屬專案;與 Firebase 分離時必須搭配 GOOGLE_SPEECH_* 服務帳戶 + GOOGLE_SPEECH_PROJECT_ID: str = os.getenv("GOOGLE_SPEECH_PROJECT_ID", "") + # STT gRPC 臨時除錯用;正式環境請用服務帳戶 + GOOGLE_STT_ACCESS_TOKEN: str = os.getenv("GOOGLE_STT_ACCESS_TOKEN", "") + # TTS 與部分 REST 用 API Key(屬於語音 GCP;與 STT 串流 gRPC OAuth 分開) + GOOGLE_SPEECH_API_KEY: str = os.getenv("GOOGLE_SPEECH_API_KEY", os.getenv("GOOGLE_API_KEY", "")) + GOOGLE_TTS_API_KEY: str = os.getenv("GOOGLE_TTS_API_KEY", os.getenv("GOOGLE_API_KEY", "")) + GOOGLE_STT_LOCATION: str = os.getenv("GOOGLE_STT_LOCATION", "global") + GOOGLE_STT_RECOGNIZER_ID: str = os.getenv("GOOGLE_STT_RECOGNIZER_ID", "_") + GOOGLE_STT_AUTO_LANGUAGE_CODES: str = os.getenv("GOOGLE_STT_AUTO_LANGUAGE_CODES", "cmn-Hant-TW,en-US,ja-JP") + GOOGLE_TTS_LANGUAGE_CODE: str = os.getenv("GOOGLE_TTS_LANGUAGE_CODE", "cmn-TW") + GOOGLE_TTS_DEFAULT_VOICE: str = os.getenv("GOOGLE_TTS_DEFAULT_VOICE", "cmn-TW-Wavenet-A") # ===== 第三方 API Keys ===== WEATHER_API_KEY: str = os.getenv("WEATHER_API_KEY", "") - NEWSDATA_API_KEY: str = os.getenv("NEWSDATA_API_KEY", "") + TAVILY_API_KEY: str = os.getenv("TAVILY_API_KEY", "") EXCHANGE_API_KEY: str = os.getenv("EXCHANGE_API_KEY", "") # ===== JWT 認證配置 ===== @@ -108,7 +202,7 @@ class Settings: # ===== GPT 意圖檢測配置 ===== USE_GPT_INTENT: bool = os.getenv("USE_GPT_INTENT", "true").lower() == "true" - GPT_INTENT_MODEL: str = os.getenv("GPT_INTENT_MODEL", "gpt-5-nano") + GPT_INTENT_MODEL: str = os.getenv("GPT_INTENT_MODEL", "gpt-5.4") # ===== 背景任務開關 ===== ENABLE_BACKGROUND_JOBS: bool = os.getenv("ENABLE_BACKGROUND_JOBS", "true").lower() == "true" @@ -190,8 +284,8 @@ class Settings: logger.error("請檢查 FIREBASE_CREDENTIALS_JSON 或 FIREBASE_SERVICE_ACCOUNT_PATH") return False - # 驗證 OpenAI API Key 格式(基本檢查) - if not cls.OPENAI_API_KEY.startswith("sk-"): + # 驗證 OpenAI API Key 格式(OpenAI-compatible relay keys may not use sk-*) + if not cls.OPENAI_BASE_URL and not cls.OPENAI_API_KEY.startswith("sk-"): import logging logger = logging.getLogger("core.config") logger.warning("⚠️ OpenAI API Key 格式可能不正確(應以 'sk-' 開頭)") @@ -236,7 +330,10 @@ class Settings: firebase_source = "未設定 ❌" logger.info(f"Firebase 憑證來源: {firebase_source}") logger.info(f"OpenAI 模型: {cls.OPENAI_MODEL}") + logger.info(f"OpenAI Base URL: {cls.OPENAI_BASE_URL or 'default'}") + logger.info(f"OpenAI Responses API: {'enabled' if cls.OPENAI_USE_RESPONSES else 'disabled'}") logger.info(f"OpenAI Timeout: {cls.OPENAI_TIMEOUT}s") + logger.info(f"OpenAI Responses Timeout: {cls.OPENAI_RESPONSES_TIMEOUT}s") logger.info(f"Google OAuth 回調 URI: {cls.GOOGLE_REDIRECT_URI}") logger.info(f"JWT Token 有效期: {cls.ACCESS_TOKEN_EXPIRE_MINUTES} 分鐘") logger.info(f"伺服器監聽: {cls.HOST}:{cls.PORT}") diff --git a/core/environment/__init__.py b/core/environment/__init__.py index 18eda96d2aa15c09d9f90e95ea2e5e56ed5276da..7755b6fdb478c7883753e48bcf3d0121ecdbbc20 100644 --- a/core/environment/__init__.py +++ b/core/environment/__init__.py @@ -1,3 +1,9 @@ +from .context_builder import EnvironmentContextBuilder, EnvironmentInjection from .context_service import EnvironmentContextService, EnvironmentSnapshot -__all__ = ["EnvironmentContextService", "EnvironmentSnapshot"] +__all__ = [ + "EnvironmentContextBuilder", + "EnvironmentContextService", + "EnvironmentInjection", + "EnvironmentSnapshot", +] diff --git a/core/environment/context_builder.py b/core/environment/context_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bba61005600304faa822061d9d74ddf6897a47 --- /dev/null +++ b/core/environment/context_builder.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Dict, Optional + + +@dataclass(frozen=True) +class EnvironmentInjection: + summary_text: str + raw_context: Dict[str, Any] + metadata: Dict[str, Any] + + +class EnvironmentContextBuilder: + """ + 將最新環境快照轉成固定可注入格式。 + + 設計目標: + 1. 每回合固定注入,但內容結構固定,避免 prompt 漂移 + 2. 同時保留 raw context,供工具與 runtime 做更細緻決策 + 3. 缺資料時明確標記 freshness / completeness,而不是靜默忽略 + """ + + def build( + self, + env_context: Optional[Dict[str, Any]], + *, + now: Optional[datetime] = None, + ) -> EnvironmentInjection: + ctx = dict(env_context or {}) + current_time = now or datetime.now(timezone.utc) + + lines = [ + f"snapshot_time_utc: {current_time.isoformat()}", + f"has_location: {self._has_location(ctx)}", + f"has_address: {bool(ctx.get('detailed_address') or ctx.get('address_display') or ctx.get('label'))}", + f"has_timezone: {bool(ctx.get('tz'))}", + f"has_locale: {bool(ctx.get('locale'))}", + ] + + if ctx.get("detailed_address"): + lines.append(f"detailed_address: {ctx['detailed_address']}") + elif ctx.get("address_display"): + lines.append(f"address_display: {ctx['address_display']}") + elif ctx.get("label"): + lines.append(f"label: {ctx['label']}") + + if ctx.get("city"): + lines.append(f"city: {ctx['city']}") + if ctx.get("admin"): + lines.append(f"admin: {ctx['admin']}") + if ctx.get("country_code"): + lines.append(f"country_code: {ctx['country_code']}") + if ctx.get("precision"): + lines.append(f"precision: {ctx['precision']}") + if ctx.get("poi_label"): + lines.append(f"poi_label: {ctx['poi_label']}") + if ctx.get("road"): + lines.append(f"road: {ctx['road']}") + if ctx.get("house_number"): + lines.append(f"house_number: {ctx['house_number']}") + if ctx.get("tz"): + lines.append(f"timezone: {ctx['tz']}") + if ctx.get("locale"): + lines.append(f"locale: {ctx['locale']}") + if ctx.get("heading_cardinal"): + lines.append(f"heading: {ctx['heading_cardinal']}") + elif ctx.get("heading_deg") is not None: + lines.append(f"heading_deg: {ctx['heading_deg']}") + if ctx.get("accuracy_m") is not None: + lines.append(f"accuracy_m: {ctx['accuracy_m']}") + if ctx.get("lat") is not None and ctx.get("lon") is not None: + lines.append(f"coordinates: {ctx['lat']},{ctx['lon']}") + + metadata = { + "source": "environment_context_service", + "freshness": "latest_available" if ctx else "missing", + "has_location": self._has_location(ctx), + "has_timezone": bool(ctx.get("tz")), + "has_locale": bool(ctx.get("locale")), + } + + return EnvironmentInjection( + summary_text="\n".join(lines), + raw_context=ctx, + metadata=metadata, + ) + + @staticmethod + def _has_location(ctx: Dict[str, Any]) -> bool: + return ctx.get("lat") is not None and ctx.get("lon") is not None diff --git a/core/intent_detector.py b/core/intent_detector.py index 200b5cf4a1bcf991a3fe8d5165bdd34301bf92dd..a4d8179e8b8ec2288cf0a03a01881ef4feb54a37 100644 --- a/core/intent_detector.py +++ b/core/intent_detector.py @@ -17,9 +17,13 @@ from typing import Dict, Any, Optional, Tuple, List from core.tool_registry import tool_registry from core.logging import get_logger +from core.config import settings +from core.prompts.tool_calling_policy import get_tool_calling_policy logger = get_logger("core.intent_detector") +MIN_TOOL_CONFIDENCE = 0.90 + class IntentDetector: """ @@ -37,6 +41,14 @@ class IntentDetector: def __init__(self): self._cache: Dict[str, Tuple[bool, Optional[Dict[str, Any]], float]] = {} + + @staticmethod + def _estimate_tool_confidence(tool_name: str, arguments: Dict[str, Any]) -> float: + if not tool_name: + return 0.0 + if not isinstance(arguments, dict): + return 0.0 + return 0.95 if arguments else 0.92 async def detect( self, @@ -143,7 +155,7 @@ class IntentDetector: messages=messages, tools=tools, user_id="intent_detection", - model="gpt-5-nano", + model=settings.GPT_INTENT_MODEL, reasoning_effort=optimal_effort, ) @@ -159,12 +171,14 @@ class IntentDetector: 注意:不再描述每個工具,工具定義由 tools 參數傳遞 """ - return """你是一個多語言智能助手,根據用戶需求選擇合適的工具。支援中文、英文、日文、印尼文、越南文。 + return f"""你是一個多語言智能助手,根據用戶需求選擇合適的工具。支援中文、英文、日文、印尼文、越南文。 + +{get_tool_calling_policy()} 【核心規則】 1. 用戶詢問任何可用工具能解決的需求時,必須選擇對應工具 2. 只有純粹的閒聊、問候、情感表達才不選擇工具 -3. 工具參數盡量從用戶消息中提取,無法確定的使用合理預設值 +3. 工具參數從用戶消息中提取;無法確定的可選參數留空,不自行編造預設值 【多語言意圖識別】 無論用戶使用什麼語言,都要識別以下意圖並選擇對應工具: @@ -251,6 +265,7 @@ class IntentDetector: "tool_name": tool_name, "arguments": arguments, "emotion": emotion, + "confidence": self._estimate_tool_confidence(tool_name, arguments), } # 沒有工具調用,視為一般聊天 diff --git a/core/logging.py b/core/logging.py index 028a1c4c719adf6aff1c3377de618a7d3e5757f1..dba4ef6c6b632c350711edcb80c808ea364805d1 100644 --- a/core/logging.py +++ b/core/logging.py @@ -1,105 +1,121 @@ -""" -統一日誌配置 -集中管理所有模組的日誌設定 - -使用方式: - from core.logging import get_logger - logger = get_logger(__name__) -""" - import os import logging from typing import Optional -# 全域日誌等級(只讀取一次) -_LOG_LEVEL_NAME = os.getenv("BLOOMWARE_LOG_LEVEL", "WARNING").upper() -_LOG_LEVEL = getattr(logging, _LOG_LEVEL_NAME, logging.WARNING) - -# 日誌格式 -_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - +# ANSI 轉義序列 +class LogColor: + DEBUG = "\x1b[38;20m" # 灰色 + INFO = "\x1b[32;20m" # 綠色 + WARNING = "\x1b[33;20m" # 黃色 + ERROR = "\x1b[31;20m" # 紅色 + CRITICAL = "\x1b[31;1m" # 粗體紅 + RESET = "\x1b[0m" + GRAY = "\x1b[90m" # 淺灰色 + CYAN = "\x1b[36;20m" # 青色 + BLUE = "\x1b[34;20m" # 藍色 + MAGENTA = "\x1b[35;20m" # 品紅 + +class ColoredFormatter(logging.Formatter): + """自定義彩色日誌格式化器""" + + LEVEL_COLORS = { + logging.DEBUG: LogColor.GRAY, + logging.INFO: LogColor.INFO, + logging.WARNING: LogColor.WARNING, + logging.ERROR: LogColor.ERROR, + logging.CRITICAL: LogColor.CRITICAL + } + + def format(self, record): + level_color = self.LEVEL_COLORS.get(record.levelno, LogColor.RESET) + + # 格式化時間(淡灰色) + asctime = self.formatTime(record, self.datefmt) + asctime_colored = f"{LogColor.GRAY}{asctime}{LogColor.RESET}" + + # 格式化名稱(青色) + name_colored = f"{LogColor.CYAN}{record.name}{LogColor.RESET}" + + # 格式化等級(根據等級變色) + levelname_colored = f"{level_color}{record.levelname:8}{LogColor.RESET}" + + # 格式化訊息 + message = record.getMessage() + + # 自動截斷過長的訊息(如工具調用的完整原始數據) + if len(message) > 500: + message = message[:500] + f"{LogColor.GRAY}... [已截斷,共 {len(message)} 字元]{LogColor.RESET}" + + if "✅" in message: + message = f"{LogColor.INFO}{message}{LogColor.RESET}" + elif "❌" in message or "⚠️" in message: + message = f"{LogColor.ERROR}{message}{LogColor.RESET}" + elif "🎙️" in message or "🔊" in message: + message = f"{LogColor.BLUE}{message}{LogColor.RESET}" + elif "🌐" in message or "MCP" in message: + message = f"{LogColor.MAGENTA}{message}{LogColor.RESET}" + else: + message = f"{level_color}{message}{LogColor.RESET}" + + return f"{asctime_colored} | {name_colored} | {levelname_colored} | {message}" + +# 全域日誌等級 +_LOG_LEVEL_NAME = os.getenv("BLOOMWARE_LOG_LEVEL", "INFO").upper() +_LOG_LEVEL = getattr(logging, _LOG_LEVEL_NAME, logging.INFO) def get_log_level() -> int: - """獲取日誌等級""" return _LOG_LEVEL - def setup_logging( name: Optional[str] = None, level: Optional[int] = None, ) -> logging.Logger: - """ - 設置日誌配置 - - Args: - name: 日誌名稱(None 表示 root logger) - level: 日誌等級(None 表示使用環境變數) - - Returns: - 配置好的 Logger 實例 - """ if level is None: level = _LOG_LEVEL - # 配置格式 - formatter = logging.Formatter(_LOG_FORMAT) - - # 獲取或創建 logger logger = logging.getLogger(name) logger.setLevel(level) - # 避免重複添加 handler if not logger.handlers: - # 控制台 handler console_handler = logging.StreamHandler() console_handler.setLevel(level) - console_handler.setFormatter(formatter) + console_handler.setFormatter(ColoredFormatter()) logger.addHandler(console_handler) - # 防止日誌重複輸出 logger.propagate = False - return logger - def get_logger(name: str) -> logging.Logger: - """ - 獲取已配置的 Logger(推薦使用) - - Args: - name: 日誌名稱,建議使用 __name__ - - Returns: - Logger 實例 - - Example: - from core.logging import get_logger - logger = get_logger(__name__) - logger.info("Hello") - """ return setup_logging(name) - def get_level_name() -> str: - """獲取當前日誌等級名稱""" return _LOG_LEVEL_NAME - # 預設配置 root logger _root_configured = False def configure_root_logger(): - """配置 root logger(只執行一次)""" global _root_configured if not _root_configured: - level = get_log_level() - logging.basicConfig( - level=level, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] - ) + root = logging.getLogger() + root.setLevel(get_log_level()) + + # 清除現有的 handlers + for handler in root.handlers[:]: + root.removeHandler(handler) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(ColoredFormatter()) + root.addHandler(console_handler) + _root_configured = True - # 自動配置 configure_root_logger() + +# 關閉 Speechbrain 等吵雜的日誌 +logging.getLogger("speechbrain").setLevel(logging.WARNING) +logging.getLogger("werkzeug").setLevel(logging.WARNING) +logging.getLogger("urllib3").setLevel(logging.WARNING) +logging.getLogger("httpcore").setLevel(logging.WARNING) +logging.getLogger("httpx").setLevel(logging.WARNING) diff --git a/core/memory_system.py b/core/memory_system.py index ccc34d08602d8d2a1a65f66fa753fb13f2f9689f..293ec89a363a7c732f229345d41e44a845edff35 100644 --- a/core/memory_system.py +++ b/core/memory_system.py @@ -1,6 +1,8 @@ import json from typing import List, Dict, Any, Optional, Tuple from datetime import datetime +import asyncio +import random # 統一日誌配置 from core.logging import get_logger @@ -8,12 +10,122 @@ logger = get_logger("MemorySystem") # 統一 OpenAI 客戶端 from core.ai_client import get_openai_client +from core.config import settings +from core.responses_runtime import ResponsesAgentRuntime def _get_memory_client(): """取得記憶系統用的 OpenAI 客戶端""" return get_openai_client() +TRANSIENT_MEMORY_ERROR_MARKERS = ( + "502", + "503", + "504", + "bad gateway", + "upstream", + "timeout", + "timed out", + "connection", +) + +TRANSIENT_QUERY_MARKERS = ( + "今天", + "現在", + "目前", + "最新", + "即時", + "收盤", + "開盤", + "股價", + "股票", + "匯率", + "天氣", + "新聞", + "多少", + "查詢", + "search", + "latest", + "today", + "now", + "price", + "stock", + "weather", + "news", +) + +MEMORY_ANALYSIS_SCHEMA = { + "type": "object", + "properties": { + "memories": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["personal_info", "preferences", "goals"], + }, + "content": {"type": "string"}, + "importance": { + "type": "number", + "minimum": 0, + "maximum": 1, + }, + }, + "required": ["type", "content", "importance"], + "additionalProperties": False, + }, + } + }, + "required": ["memories"], + "additionalProperties": False, +} + + +def _is_transient_memory_error(exc: Exception) -> bool: + if isinstance(exc, TimeoutError) or isinstance(exc, asyncio.TimeoutError): + return True + error_text = str(exc).lower() + return any(marker in error_text for marker in TRANSIENT_MEMORY_ERROR_MARKERS) + + +def _should_run_ai_memory_analysis(user_message: str, assistant_response: str = "") -> bool: + text = f"{user_message}\n{assistant_response}".strip().lower() + if not text: + return False + + durable_markers = ( + "我叫", + "我的名字", + "我喜歡", + "我不喜歡", + "我討厭", + "我的偏好", + "我住在", + "我的工作", + "我的目標", + "我想達成", + "我希望", + "請記得", + "記住", + "下次", + "my name", + "i like", + "i dislike", + "remember", + "my goal", + ) + if any(marker in text for marker in durable_markers): + return True + + user_text = (user_message or "").lower() + if any(marker in user_text for marker in TRANSIENT_QUERY_MARKERS): + return False + + return len(user_message.strip()) >= 80 + + # 導入數據庫函數 try: @@ -122,7 +234,44 @@ class MemoryAnalyzer: """記憶分析器:使用AI分析對話內容""" def __init__(self): - pass + self.responses_runtime = ResponsesAgentRuntime() + + @staticmethod + def _memory_model() -> str: + return settings.OPENAI_MODEL or settings.GPT_INTENT_MODEL + + @staticmethod + def _memory_timeout() -> float: + return min(float(getattr(settings, "OPENAI_TIMEOUT", 30)), 20.0) + + @staticmethod + async def _transient_backoff(attempt: int) -> None: + delay = min(0.5 * (2 ** attempt), 3.0) + random.uniform(0, 0.2) + await asyncio.sleep(delay) + + def _create_analysis_response(self, client: Any, messages: List[Dict[str, str]], max_tokens_value: int) -> Any: + model = self._memory_model() + if settings.OPENAI_USE_RESPONSES and model.startswith("gpt-5"): + payload = self.responses_runtime.build_payload_from_messages( + messages=messages, + model=model, + max_output_tokens=max_tokens_value, + text_format={ + "type": "json_schema", + "name": "memory_analysis", + "strict": True, + "schema": MEMORY_ANALYSIS_SCHEMA, + }, + ) + payload["store"] = False + return client.responses.create(**payload) + + return client.chat.completions.create( + model=model, + messages=messages, + max_completion_tokens=max_tokens_value, + response_format={"type": "json_object"}, + ) async def analyze_conversation(self, user_message: str, assistant_response: str = "", conversation_history: List[Dict] = None) -> List[Dict[str, Any]]: @@ -174,39 +323,42 @@ class MemoryAnalyzer: {"role": "user", "content": user_prompt} ] - # 嘗試調用OpenAI API,最多重試2次 - max_retries = 2 - for attempt in range(max_retries + 1): + max_attempts = 3 + for attempt in range(max_attempts): try: - if attempt > 0: - # 如果是重試,增加token限制 - max_tokens_value = 2000 + (attempt * 1000) - logger.info(f"重試AI分析 (嘗試 {attempt + 1}/{max_retries + 1}),增加token限制到 {max_tokens_value}") - else: - max_tokens_value = 2000 + max_tokens_value = 500 - response = client.chat.completions.create( - model="gpt-5-nano", - messages=messages, - max_completion_tokens=max_tokens_value, - reasoning_effort="low" + response = await asyncio.wait_for( + asyncio.to_thread(self._create_analysis_response, client, messages, max_tokens_value), + timeout=self._memory_timeout(), ) break # 成功後跳出重試循環 except Exception as api_error: error_str = str(api_error).lower() - if "max_tokens" in error_str or "token limit" in error_str: - if attempt < max_retries: - logger.warning(f"AI分析遇到token限制錯誤,正在重試 ({attempt + 1}/{max_retries + 1}): {api_error}") + if _is_transient_memory_error(api_error): + if attempt < max_attempts - 1: + logger.warning( + "AI記憶分析遇到暫時性上游錯誤,準備重試 (%s/%s): %s", + attempt + 1, + max_attempts, + api_error, + ) + await self._transient_backoff(attempt) continue - else: - logger.error(f"AI分析在 {max_retries + 1} 次嘗試後仍然遇到token限制錯誤: {api_error}") - return [] # 返回空列表,回退到關鍵字提取 + logger.error("AI記憶分析連續暫時性上游錯誤,已放棄本輪背景分析: %s", api_error) + return [] + if "max_tokens" in error_str or "token limit" in error_str: + logger.error(f"AI分析遇到token限制錯誤: {api_error}") + return [] # 返回空列表,回退到關鍵字提取 else: # 其他類型的錯誤,直接拋出 raise api_error - result_text = response.choices[0].message.content.strip() + if hasattr(response, "choices"): + result_text = response.choices[0].message.content.strip() + else: + result_text = self.responses_runtime.extract_output_text(response) # 解析JSON結果 - 嘗試多種解析方式 try: @@ -241,6 +393,9 @@ class MemoryAnalyzer: return memories except Exception as e: + if _is_transient_memory_error(e): + logger.info("AI記憶分析遇到暫時性上游錯誤,跳過本輪背景分析: %s", e) + return [] logger.error(f"AI記憶分析時發生錯誤: {e}") return [] @@ -268,10 +423,12 @@ class MemoryManager: # 2. 使用AI分析提取記憶(如果可用) ai_memories = [] - if _get_memory_client(): + if _get_memory_client() and _should_run_ai_memory_analysis(user_message, assistant_response): ai_memories = await self.analyzer.analyze_conversation( user_message, assistant_response, conversation_history ) + else: + logger.debug("跳過AI記憶分析:本輪內容不具備長期記憶價值或客戶端不可用") # 3. 合併記憶(去重) all_memories = self._merge_memories(keyword_memories, ai_memories) diff --git a/core/pipeline.py b/core/pipeline.py index 38e549866eacafb807ccb900f6642abf8b1a9dc9..b823e35160c47b81dd488cc0b0f848818c86c134 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -4,9 +4,13 @@ from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional, Dict, Tuple, List from core.emotion_care_manager import EmotionCareManager +from core.config import settings +from core.voice_care_gate import decide_voice_care, is_voice_context logger = logging.getLogger(__name__) +MIN_TOOL_CONFIDENCE = 0.90 + @dataclass class PipelineResult: @@ -35,10 +39,10 @@ class ChatPipeline: intent_detector: Callable[[str], Awaitable[Tuple[bool, dict]]], feature_processor: Callable[[dict, str, str, Optional[str]], Awaitable[Any]], ai_generator: Callable[..., Awaitable[str]], - model: str = "gpt-5-nano", - detect_timeout: float = 5.0, # 2025 最佳實踐:Structured Outputs 通常 2-3秒 - feature_timeout: float = 10.0, # MCP 工具已有內部超時(30秒) - ai_timeout: float = 12.0, # 配合 Streaming(首次回應 0.5-1秒) + model: Optional[str] = None, + detect_timeout: float = 20.0, # 考量到 Function Calling 可能較慢 + feature_timeout: float = 30.0, # MCP 工具內部超時 + ai_timeout: float = 25.0, # 配合 Streaming ) -> None: self._intent_detector = intent_detector self._feature_processor = feature_processor @@ -46,7 +50,7 @@ class ChatPipeline: self._detect_timeout = detect_timeout self._feature_timeout = feature_timeout self._ai_timeout = ai_timeout - self._model = model + self._model = model or settings.OPENAI_MODEL def _is_chinese_message(self, text: str) -> bool: """ @@ -142,11 +146,15 @@ class ChatPipeline: {"role": "user", "content": combined_text} ] - logger.info(f"🌐 呼叫 GPT 翻譯,模型: gpt-4o-mini") + logger.info(f"🌐 呼叫 GPT 翻譯") + # 格式化回應使用環境變數設定的模型 + model = settings.GPT_INTENT_MODEL or settings.OPENAI_MODEL + logger.info(f"🎨 使用配置模型進行格式化: {model}") + translated = await ai_service.generate_response_async( messages=messages, - model="gpt-4o-mini", # 升級到 gpt-4o-mini 以提升翻譯品質 - reasoning_effort=None, # gpt-4o-mini 不支援此參數 + model=model, + reasoning_effort=None, max_tokens=800, ) logger.info(f"🌐 GPT 翻譯完成,結果長度: {len(translated) if translated else 0}") @@ -221,207 +229,182 @@ class ChatPipeline: # language 參數保留以向後兼容,但不使用(GPT 自動判斷語言) - # 0) 先進行意圖偵測以提取情緒(需要在關懷模式檢查前執行) - detect_res = await self._with_timeout( - self._intent_detector(user_message), self._detect_timeout, reason="detect" - ) - if isinstance(detect_res, PipelineResult): - return detect_res - has_feature, intent_data = detect_res - - # 【情緒融合】雙軌制:音頻情緒優先,文字情緒輔助 - text_emotion = intent_data.get("emotion", "neutral") if intent_data else "neutral" - logger.info(f"🎭 [情緒流向-1] 文字情緒: {text_emotion}") - - # 檢查音頻情緒 - if audio_emotion: - logger.info(f"🎭 [情緒流向-2] 音頻情緒資料: success={audio_emotion.get('success')}, emotion={audio_emotion.get('emotion')}, confidence={audio_emotion.get('confidence')}") - - # 情緒融合邏輯 - emotion_confidence = 0.5 # 預設置信度 - if audio_emotion and audio_emotion.get("success"): - audio_emotion_label = audio_emotion.get("emotion", "neutral") - audio_confidence = audio_emotion.get("confidence", 0.0) - - # 【優化】提高門檻到 0.7,避免誤判(太敏感會導致錯誤情緒) - if audio_confidence >= 0.7: - emotion_value = audio_emotion_label - emotion_confidence = audio_confidence - logger.info(f"🎭 [情緒流向-3] ✅ 採用音頻情緒: {emotion_value} (置信度: {audio_confidence:.4f})") - else: - emotion_value = text_emotion - emotion_confidence = 0.5 # 文字情緒預設置信度 - logger.info(f"🎭 [情緒流向-3] ⬇️ 音頻置信度過低 ({audio_confidence:.4f}),改用文字情緒: {emotion_value}") - else: - emotion_value = text_emotion - emotion_confidence = 0.5 # 文字情緒預設置信度 - logger.info(f"🎭 [情緒流向-3] 📝 無音頻情緒,使用文字情緒: {emotion_value}") - - # 【關鍵】記錄最終情緒 - logger.info(f"🎭 [情緒流向-最終] emotion={emotion_value}, confidence={emotion_confidence:.2f}") - - # 1) 檢查是否在關懷模式 - if user_id and EmotionCareManager.is_in_care_mode(user_id, chat_id): - # 檢查是否解除關懷模式(傳入情緒資訊) - if EmotionCareManager.check_release(user_id, user_message, chat_id, emotion=emotion_value): - logger.info(f"✅ 用戶 {user_id} 情緒恢復,解除關懷模式,繼續正常流程") - # 解除後繼續正常流程 - else: - logger.info(f"💙 用戶 {user_id} 在關懷模式中,跳過工具調用,使用關懷 AI") - # 直接用關懷模式 AI 回應(不檢測意圖,不調用工具) - care_emotion = EmotionCareManager.get_care_emotion(user_id, chat_id) - final_emotion = care_emotion or emotion_value + has_feature = False + intent_data = None + tool_context = "" + tool_results_list = [] + emotion_value = "neutral" + care_emotion = None + use_care_mode = False + max_loops = 3 + current_loop = 0 + ai_res_text = "" + + while current_loop < max_loops: + # 0) 先進行意圖偵測與可回答性評估 (Confidence-driven check) + detect_res = await self._with_timeout( + self._intent_detector(user_message, tool_context, language=language), self._detect_timeout, reason="detect" + ) + if isinstance(detect_res, PipelineResult): + return detect_res + has_feature, intent_data = detect_res + + if current_loop == 0: + # 只在第一輪提取情緒與進行關懷模式判斷 + if intent_data and "emotion" in intent_data: + emotion_value = intent_data["emotion"] + else: + emotion_value = "neutral" + + voice_context = is_voice_context(audio_emotion) + voice_care_decision = None + if voice_context: + try: + voice_care_decision = decide_voice_care(text_emotion=emotion_value, audio_emotion=audio_emotion) + if voice_care_decision.emotion: + emotion_value = voice_care_decision.emotion + except Exception as e: + logger.warning(f"Voice care decision failed: {e}") + + emotion_confidence = float(audio_emotion.get("confidence", 0.0)) if isinstance(audio_emotion, dict) else 0.0 + + if EmotionCareManager.is_in_care_mode(user_id): + exit_match = False + if "沒事了" in user_message or "謝謝" in user_message or "好多了" in user_message: + if emotion_value not in ["sad", "angry", "fear"]: + exit_match = True + if voice_context and voice_care_decision and not voice_care_decision.allow: + exit_match = True + + if exit_match: + logger.info(f"💙 使用者情緒平穩 [{emotion_value}],退出關懷模式") + EmotionCareManager.exit_care_mode(user_id) + use_care_mode = False + if emotion_callback: + try: + await emotion_callback(emotion_value, False) + except Exception as e: + logger.warning(f"emotion_callback 錯誤: {e}") + else: + logger.info(f"💙 維持關懷模式,情緒=[{emotion_value}]") + use_care_mode = True + care_emotion = EmotionCareManager._active_care_users.get(user_id, {}).get("emotion") or emotion_value + if emotion_callback: + try: + await emotion_callback(emotion_value, True) + except Exception as e: + logger.warning(f"emotion_callback 錯誤: {e}") + + ai_res = await self._with_timeout( + self._ai_generator( + user_message, + user_id, + self._model, + request_id, + chat_id, + use_care_mode=use_care_mode, + care_emotion=care_emotion, + emotion_label=emotion_value, + is_first_care=False, + ), + self._ai_timeout, + reason="ai-care", + ) + if isinstance(ai_res, PipelineResult): + return ai_res + text = str(ai_res or "").strip() + if not text: + text = "我在這裡陪你,隨時可以聊聊。" + return PipelineResult(text=text, is_fallback=False, meta={"care_mode": True, "emotion": care_emotion}) + + # 檢查是否需要進入關懷模式 + can_enter_care = True + if voice_context and voice_care_decision is not None: + can_enter_care = voice_care_decision.allow + + if can_enter_care and user_id and EmotionCareManager.check_and_enter_care_mode( + user_id, emotion_value, chat_id, confidence=emotion_confidence + ): + logger.warning(f"⚠️ 偵測到極端情緒 [{emotion_value}](置信度: {emotion_confidence:.2f}),進入關懷模式") + + if emotion_callback: + try: + await emotion_callback(emotion_value, True) + except Exception as e: + logger.warning(f"emotion_callback 錯誤: {e}") + + ai_res = await self._with_timeout( + self._ai_generator( + user_message, + user_id, + self._model, + request_id, + chat_id, + use_care_mode=True, + care_emotion=emotion_value, + emotion_label=emotion_value, + is_first_care=True, # 告知 Agent 這是第一次進入,需引導退出 + ), + self._ai_timeout, + reason="ai-care", + ) + if isinstance(ai_res, PipelineResult): + return ai_res + text = str(ai_res or "").strip() + if not text: + text = "我聽到了,我在這裡陪你。" + + return PipelineResult(text=text, is_fallback=False, meta={"care_mode": True, "emotion": emotion_value}) + if emotion_callback: try: - await emotion_callback(final_emotion, True) + await emotion_callback(emotion_value, False) except Exception as e: logger.warning(f"emotion_callback 錯誤: {e}") - ai_res = await self._with_timeout( - self._ai_generator( - user_message, - user_id, - self._model, - request_id, - chat_id, - use_care_mode=True, - care_emotion=care_emotion, - emotion_label=care_emotion, - ), - self._ai_timeout, - reason="ai-care", - ) - if isinstance(ai_res, PipelineResult): - return ai_res - text = str(ai_res or "").strip() - if not text: + if has_feature and intent_data and intent_data.get("type") == "mcp_tool": + confidence = float(intent_data.get("confidence", 0.0) or 0.0) + if confidence < MIN_TOOL_CONFIDENCE: + logger.info("🔒 工具信心度不足 %.2f,禁止調用工具", confidence) return PipelineResult( - text="我在這裡陪你,隨時可以聊聊。", + text=self._build_low_confidence_tool_message(user_message, confidence), is_fallback=True, - reason="ai-care-empty", - meta={"care_mode": True, "emotion": care_emotion or "sad"} + reason="low_confidence", + meta={"confidence": confidence} ) - return PipelineResult(text=text, is_fallback=False, meta={"care_mode": True, "emotion": care_emotion}) - - # 2) 檢查是否需要進入關懷模式(傳遞置信度,用於連續性判斷) - if user_id and EmotionCareManager.check_and_enter_care_mode( - user_id, emotion_value, chat_id, confidence=emotion_confidence - ): - logger.warning(f"⚠️ 偵測到極端情緒 [{emotion_value}](置信度: {emotion_confidence:.2f}),進入關懷模式") - # 立即使用關懷模式 AI 回應 - - if emotion_callback: - try: - await emotion_callback(emotion_value, True) - except Exception as e: - logger.warning(f"emotion_callback 錯誤: {e}") - - ai_res = await self._with_timeout( - self._ai_generator( - user_message, - user_id, - self._model, - request_id, - chat_id, - use_care_mode=True, - care_emotion=emotion_value, - emotion_label=emotion_value, - ), - self._ai_timeout, - reason="ai-care", - ) - if isinstance(ai_res, PipelineResult): - return ai_res - text = str(ai_res or "").strip() - if not text: - text = "我聽到了,我在這裡陪你。" - - # 第一次進入關懷模式時,附加退出提示(新增) - exit_hint = "\n\n💙 關懷模式已啟動。說「我沒事了」可以退出。" - return PipelineResult(text=text + exit_hint, is_fallback=False, meta={"care_mode": True, "emotion": emotion_value}) - - if emotion_callback: - try: - await emotion_callback(emotion_value, False) - except Exception as e: - logger.warning(f"emotion_callback 錯誤: {e}") - - # 3) 有功能 → 功能處理(限時) - if has_feature and intent_data: - feat_res = await self._with_timeout( - self._feature_processor(intent_data, user_id, user_message, chat_id), - self._feature_timeout, - reason="feature", - ) - if isinstance(feat_res, PipelineResult): - return feat_res - # 如果返回 None,表示這是聊天,不應該被當作功能處理 - if feat_res is None: - has_feature = False - intent_data = None - else: - # 檢查是否為字典(包含工具信息) - if isinstance(feat_res, dict): - text = feat_res.get('message', feat_res.get('content', '')).strip() - tool_name = feat_res.get('tool_name') - tool_data = feat_res.get('tool_data') - if not text: - return PipelineResult( - text="抱歉,功能處理沒有產出結果。", - is_fallback=True, - reason="feature-empty", - meta={"emotion": emotion_value, "care_mode": False} - ) - # 簡化翻譯:非中文用戶 → 翻譯工具卡片 - is_chinese = self._is_chinese_message(user_message) - logger.info(f"🌐 語言檢測: user_message='{user_message}', is_chinese={is_chinese}") - if not is_chinese and tool_data: - logger.info(f"🌐 開始翻譯工具卡片: {len(str(tool_data))} chars") - tool_data = await self._translate_tool_data(tool_data, user_message) - logger.info(f"🌐 翻譯完成: {len(str(tool_data))} chars") - elif is_chinese: - logger.info(f"🌐 用戶使用中文,不翻譯工具卡片") - elif not tool_data: - logger.info(f"🌐 無工具資料,跳過翻譯") - - # 返回帶有工具元數據的結果(包含情緒) - meta_dict = { - 'emotion': emotion_value, - 'care_mode': False # 工具調用不是關懷模式 - } - if tool_name: - meta_dict['tool_name'] = tool_name - if tool_data: - meta_dict['tool_data'] = tool_data - - return PipelineResult( - text=text, - is_fallback=False, - meta=meta_dict - ) + if has_feature and intent_data: + feat_res = await self._with_timeout( + self._feature_processor(intent_data, user_id, user_message, chat_id), + self._feature_timeout, + reason="feature", + ) + if isinstance(feat_res, PipelineResult): + tool_context += f"\n[工具執行結果]:\n{feat_res.text}\n" + tool_results_list.append({"text": feat_res.text, "meta": feat_res.meta}) + elif isinstance(feat_res, dict): + t_name = feat_res.get('tool_name', 'unknown') + t_msg = feat_res.get('message', '') + t_data = feat_res.get('tool_data', {}) + tool_context += f"\n[工具 {t_name} 執行結果]:\n{t_msg}\n(Data: {str(t_data)[:2000]})\n" + tool_results_list.append(feat_res) else: - # 正常字串 text = str(feat_res or "").strip() - if not text: - return PipelineResult( - text="抱歉,功能處理沒有產出結果。", - is_fallback=True, - reason="feature-empty", - meta={"emotion": emotion_value, "care_mode": False} - ) - - # 不再翻譯工具回應,讓 GPT 自己處理並用對應語言描述 - - return PipelineResult( - text=text, - is_fallback=False, - meta={"emotion": emotion_value, "care_mode": False}, - ) + tool_context += f"\n[工具執行結果]:\n{text}\n" + tool_results_list.append({"text": text}) + # 【效能優化】短路機制:如果工具調用信心度為 100%,且是簡單工具,則不進入下一輪驗證 + if confidence >= 1.0: + logger.info("⚡ 工具執行信心度高且結果明確,跳過冗餘驗證") + break + + current_loop += 1 + continue + else: + # 如果沒有調用工具,表示 Agent 對目前答案已有 100% 信心,退出循環 + break - # 4) 無功能 → 一般聊天(限時) - # 注意:不傳 messages,改傳 user_message,讓 ai_generator 自動載入歷史對話和記憶 - ai_res = await self._with_timeout( + # 4) 最後 AI 生成回應(結合了所有 tool_context) + ai_res_text = await self._with_timeout( self._ai_generator( user_message, user_id or "default", @@ -430,25 +413,48 @@ class ChatPipeline: chat_id, emotion_label=emotion_value, language=language, + tool_context=tool_context, ), self._ai_timeout, - reason="ai", + reason="ai_gen", + ) + if isinstance(ai_res_text, PipelineResult): + return ai_res_text + + meta = {"emotion": emotion_value, "care_mode": use_care_mode} + if tool_results_list: + executed_tools = [] + for t in tool_results_list: + if isinstance(t, dict) and t.get("tool_name"): + executed_tools.append({ + "tool_name": t.get("tool_name"), + "tool_data": t.get("tool_data") + }) + elif hasattr(t, 'meta') and t.meta and t.meta.get("tool_name"): + executed_tools.append({ + "tool_name": t.meta.get("tool_name"), + "tool_data": t.meta.get("tool_data") + }) + if executed_tools: + meta["executed_tools"] = executed_tools + # 兼容原有邏輯,將最後一個工具設為主卡片 + last_tool = executed_tools[-1] + meta["tool_name"] = last_tool["tool_name"] + meta["tool_data"] = last_tool["tool_data"] + else: + last_tool = tool_results_list[-1] + if isinstance(last_tool, dict): + meta["tool_name"] = last_tool.get("tool_name") + meta["tool_data"] = last_tool.get("tool_data") + elif hasattr(last_tool, 'meta') and last_tool.meta: + meta.update(last_tool.meta) + + return PipelineResult( + text=str(ai_res_text or "").strip(), + is_fallback=False, + meta=meta ) - if isinstance(ai_res, PipelineResult): - return ai_res - text = str(ai_res or "").strip() - if not text: - return PipelineResult( - text="抱歉,我暫時沒有合適的回應。可以換個說法再試試嗎?", - is_fallback=True, - reason="ai-empty", - meta={"emotion": emotion_value, "care_mode": False} - ) - - # 一般聊天也包含情緒資訊 - meta_dict = { - 'emotion': emotion_value, - 'care_mode': False # 一般聊天不是關懷模式 - } - return PipelineResult(text=text, is_fallback=False, meta=meta_dict) + def _build_low_confidence_tool_message(self, user_message: str, confidence: float) -> str: + """建立低信心度工具調用的提示訊息""" + return "抱歉,我不太確定您的意思。您能說得更具體一點嗎?" diff --git a/core/prompts/care_mode.py b/core/prompts/care_mode.py index 4e8c4eba3f54a1e7825ce978c34afa369fb5100d..f13e0cef931a6085100c36347ad6d187b660a8f4 100644 --- a/core/prompts/care_mode.py +++ b/core/prompts/care_mode.py @@ -1,8 +1,16 @@ """ 情緒關懷模式 Prompt -精簡化設計,保持關懷品質 +⚠️ DEPRECATED: 此模組的內容已遷移至 services/ai_service.py 中的 +CARE_MODE_BASE_PROMPT / EMOTION_SPECIFIC_PROMPTS / get_care_mode_prompt()。 +本模組僅保留向後兼容的 re-export,供既有測試使用。 """ +import warnings as _warnings + +# ── 向後兼容 re-export ────────────────────────────────────────── +# 生產環境使用 services.ai_service 中的活躍版本。 +# 此處提供簡化版僅為保持 test_prompts.py 不斷鏈。 + CARE_MODE_PROMPT = """你是 BloomWare 的情緒關懷助手「小花」。你的任務是傾聽、陪伴。 【回應原則】 @@ -25,7 +33,9 @@ CARE_MODE_PROMPT = """你是 BloomWare 的情緒關懷助手「小花」。你 def get_care_prompt(emotion: str = None, user_name: str = None) -> str: """ - 生成關懷模式 Prompt + 生成關懷模式 Prompt(向後兼容) + + ⚠️ DEPRECATED: 生產環境請使用 services.ai_service.get_care_mode_prompt() Args: emotion: 用戶情緒標籤 @@ -34,6 +44,12 @@ def get_care_prompt(emotion: str = None, user_name: str = None) -> str: Returns: 關懷模式 System Prompt """ + _warnings.warn( + "core.prompts.care_mode.get_care_prompt() is deprecated. " + "Use services.ai_service.get_care_mode_prompt() instead.", + DeprecationWarning, + stacklevel=2, + ) prompt = CARE_MODE_PROMPT if emotion: diff --git a/core/prompts/care_mode_skills.py b/core/prompts/care_mode_skills.py new file mode 100644 index 0000000000000000000000000000000000000000..70e708862c58253087d3fd7e1bbcd9a79021be8d --- /dev/null +++ b/core/prompts/care_mode_skills.py @@ -0,0 +1,41 @@ +import os +from pathlib import Path +from typing import List + +CARE_SKILLS_ROOT = Path(__file__).resolve().parents[2] / "features" / "care_mode" / "skills" + +def load_care_mode_skills() -> str: + """ + 載入並格式化情緒關懷模式的對話技巧 (Skills) + """ + if not CARE_SKILLS_ROOT.exists(): + return "" + + skills_content = [] + skills_content.append("\n【情緒關懷對話技巧 (Care Mode Skills)】") + skills_content.append("在關懷模式下,請靈活運用以下專業對話技巧來提升共鳴感:") + + try: + # 獲取所有 .md 檔案 + skill_files = list(CARE_SKILLS_ROOT.glob("*.md")) + + for file_path in skill_files: + content = file_path.read_text(encoding="utf-8") + # 移除 Frontmatter (--- ... ---) + if content.startswith("---"): + parts = content.split("---", 2) + if len(parts) >= 3: + content = parts[2].strip() + + skills_content.append(f"\n--- {file_path.stem} ---\n{content}") + + return "\n".join(skills_content) + except Exception as e: + print(f"載入關懷模式技巧失敗: {e}") + return "" + +def get_care_mode_skills_block() -> str: + """ + 獲取用於 System Prompt 的技巧區塊 + """ + return load_care_mode_skills() diff --git a/core/prompts/intent_detection.py b/core/prompts/intent_detection.py index 1ce1b337736ea4390fb2ba8320bbd5bbef0c942f..254aed7919d71d7ed33cbe96d5d63eab0f95e793 100644 --- a/core/prompts/intent_detection.py +++ b/core/prompts/intent_detection.py @@ -1,9 +1,15 @@ """ 意圖檢測 Prompt 模板 -精簡化設計,減少 token 消耗約 40% +⚠️ DEPRECATED: 此模組的 TOOL_RULES 和 get_intent_prompt() 在生產環境中未被使用。 +生產意圖偵測由 features/mcp/agent_bridge.py 的 _build_function_calling_prompt() 處理。 +本模組僅保留向後兼容的定義,供既有測試使用。 """ -# 工具特定規則(按需載入) +import warnings as _warnings + +from core.prompts.tool_calling_policy import get_tool_calling_policy + +# ── 工具特定規則(DEPRECATED — 生產中由 agent_bridge 硬編碼處理) ── TOOL_RULES = { "weather": """天氣查詢:城市必須用英文(台北→Taipei, 高雄→Kaohsiung),預設 Taipei""", @@ -32,7 +38,7 @@ TOOL_RULES = { 「怎麼去 X」→ forward_geocode:query=X""", } -# 情緒標籤說明 +# ── 情緒標籤說明 ── EMOTION_RULES = """情緒判斷:neutral/happy/sad/angry/fear/surprise - happy: 開心、興奮(「好開心!」「太棒了」) - sad: 難過、沮喪(「好難過」「心情不好」) @@ -44,7 +50,10 @@ EMOTION_RULES = """情緒判斷:neutral/happy/sad/angry/fear/surprise def get_intent_prompt(tools_description: str, include_rules: list = None) -> str: """ - 生成意圖檢測 Prompt + 生成意圖檢測 Prompt(向後兼容) + + ⚠️ DEPRECATED: 生產環境使用 features/mcp/agent_bridge.py 的 + _build_function_calling_prompt() + OpenAI Function Calling。 Args: tools_description: 可用工具描述 @@ -53,9 +62,17 @@ def get_intent_prompt(tools_description: str, include_rules: list = None) -> str Returns: 精簡化的 System Prompt """ + _warnings.warn( + "core.prompts.intent_detection.get_intent_prompt() is deprecated. " + "Production intent detection uses MCPAgentBridge._build_function_calling_prompt().", + DeprecationWarning, + stacklevel=2, + ) # 基礎 Prompt base = f"""你是意圖解析助手。分析用戶消息,決定是否調用工具。 +{get_tool_calling_policy()} + 可用工具: {tools_description} diff --git a/core/prompts/tool_calling_policy.py b/core/prompts/tool_calling_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..04d1c9631ebe49c677135cfa734ac826fa1eade6 --- /dev/null +++ b/core/prompts/tool_calling_policy.py @@ -0,0 +1,18 @@ +"""Shared policy text for tool-calling agents.""" + + +TOOL_CALLING_POLICY = """【鐵血工具調用政策】 +1. 反幻覺:只要問題需要最新、精確、外部、檔案、位置、時間、健康、交通、天氣、新聞、匯率或其他工具可驗證資訊,必須優先調用工具;不得憑印象補答案。 +2. 環境優先:選工具前先把系統注入的時間、位置、使用者設定、狀態視為第一決策依據;若使用者說「附近」「這裡」「我這邊」「現在」,不要編造 city/lat/lon,留空交給系統環境注入。 +3. 參數紀律:只填使用者明確提供或可無歧義轉換的參數;可選參數不確定時留空,讓工具 schema default 或環境注入處理。 +4. 工具失敗:不得假裝成功;若工具回錯、缺資料或參數不足,只能修正參數重試、改用可用降級資訊,或明確說明不足。 +5. 證據約束:工具結果回來後,回答只能依據工具結果、已驗證上下文與系統環境;禁止新增未查證事實。 +6. 純聊天例外:只有問候、純情緒陪伴、閒聊、或詢問能力說明可以不調工具。 +7. 信心閘門:只有當工具選擇信心度至少 90% 時才允許工具調用;低於 90% 必須視為沒有可用工具,並請使用者補充地點、時間、路線、幣別、檔名或關鍵字。 +8. 語言一致:使用者用什麼語言互動,就必須用同一種語言回覆;工具卡片與錯誤說明也應跟隨使用者語言。 +""" + + +def get_tool_calling_policy() -> str: + """Return the shared non-negotiable tool-calling policy.""" + return TOOL_CALLING_POLICY diff --git a/core/reasoning_strategy.py b/core/reasoning_strategy.py index 66be06d0caf37b87ec433dfce0874afec365933f..dca2edf71463a1ff58dc0090d689ca40344586e3 100644 --- a/core/reasoning_strategy.py +++ b/core/reasoning_strategy.py @@ -43,10 +43,10 @@ class ReasoningStrategy: reasoning_effort: minimal/low/medium/high """ - # 🔥 規則 1:意圖檢測使用 low reasoning(平衡速度與準確度) + # 🔥 規則 1:意圖檢測使用 minimal reasoning(極速模式,減少初始延遲) if task_type == "intent_detection": - logger.debug("🧠 意圖檢測 → low reasoning(快速但準確)") - return "low" + logger.debug("🧠 意圖檢測 → minimal reasoning(極速回應)") + return "minimal" # 🔥 規則 2:關懷模式優先速度(用戶情緒不佳時不要讓他等) if user_emotion in ["sad", "angry", "fear"]: diff --git a/core/responses_runtime.py b/core/responses_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..0cbdd079cc27b963020b01c949a55fec07ed15d0 --- /dev/null +++ b/core/responses_runtime.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from core.environment.context_builder import EnvironmentInjection + + +@dataclass +class ResponsesRuntimeRequest: + user_input: str + model: str + instructions: Optional[str] = None + environment: Optional[EnvironmentInjection] = None + tools: List[Dict[str, Any]] = field(default_factory=list) + input_items: Optional[List[Dict[str, Any]]] = None + previous_response_id: Optional[str] = None + reasoning_effort: Optional[str] = None + max_output_tokens: Optional[int] = None + text_format: Optional[Dict[str, Any]] = None + tool_choice: Any = "auto" + + +class ResponsesAgentRuntime: + """ + 新版主 Agent runtime 骨架。 + + 目前先負責: + 1. 統一組裝 Responses API payload + 2. 固定附帶 environment injection + 3. 為 hosted tools / bridge tools 預留同一個組裝入口 + """ + + def build_request_payload(self, request: ResponsesRuntimeRequest) -> Dict[str, Any]: + input_parts: List[Dict[str, Any]] = list(request.input_items or []) + + if request.environment: + input_parts.insert( + 0, + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": "Latest environment context:\n" + request.environment.summary_text, + } + ], + } + ) + + input_parts.append( + self.message_to_input_item({"role": "user", "content": request.user_input}) + ) + + payload: Dict[str, Any] = { + "model": request.model, + "input": input_parts, + "tools": self.normalize_tools_for_responses(request.tools), + } + + if request.instructions: + payload["instructions"] = request.instructions + if request.previous_response_id: + payload["previous_response_id"] = request.previous_response_id + if request.reasoning_effort: + payload["reasoning"] = {"effort": request.reasoning_effort} + if request.max_output_tokens: + payload["max_output_tokens"] = request.max_output_tokens + if request.text_format: + payload["text"] = {"format": request.text_format} + if request.tools: + payload["tool_choice"] = request.tool_choice + + return payload + + def build_payload_from_messages( + self, + *, + messages: List[Dict[str, Any]], + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + reasoning_effort: Optional[str] = None, + max_output_tokens: Optional[int] = None, + tool_choice: Any = "auto", + text_format: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + input_items: List[Dict[str, Any]] = [] + instructions: Optional[str] = None + + for message in messages: + if message.get("role") == "system": + content = message.get("content") or "" + instructions = f"{instructions}\n\n{content}" if instructions else str(content) + continue + input_items.append(self.message_to_input_item(message)) + + payload: Dict[str, Any] = { + "model": model, + "input": input_items, + "tools": self.normalize_tools_for_responses(tools or []), + } + if instructions: + payload["instructions"] = instructions + if reasoning_effort: + payload["reasoning"] = {"effort": reasoning_effort} + if max_output_tokens: + payload["max_output_tokens"] = max_output_tokens + if text_format: + payload["text"] = {"format": text_format} + if tools: + payload["tool_choice"] = tool_choice + return payload + + @staticmethod + def without_hosted_tools(payload: Dict[str, Any]) -> Dict[str, Any]: + stripped = dict(payload) + tools = [ + tool for tool in stripped.get("tools", []) + if tool.get("type") == "function" + ] + stripped["tools"] = tools + if not tools: + stripped.pop("tool_choice", None) + return stripped + + @staticmethod + def message_to_input_item(message: Dict[str, Any]) -> Dict[str, Any]: + role = message.get("role") or "user" + content = message.get("content") or "" + if isinstance(content, list): + return {"role": role, "content": [ResponsesAgentRuntime.normalize_content_part(part, role) for part in content]} + content_type = "output_text" if role == "assistant" else "input_text" + return {"role": role, "content": [{"type": content_type, "text": str(content)}]} + + @staticmethod + def normalize_content_part(part: Dict[str, Any], role: str) -> Dict[str, Any]: + part_type = part.get("type") + if part_type == "text": + return { + "type": "output_text" if role == "assistant" else "input_text", + "text": str(part.get("text", "")), + } + if part_type == "image_url": + image_url = part.get("image_url") or {} + return { + "type": "input_image", + "image_url": image_url.get("url", image_url if isinstance(image_url, str) else ""), + } + return dict(part) + + @staticmethod + def normalize_tools_for_responses(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + normalized: List[Dict[str, Any]] = [] + for tool in tools: + if tool.get("type") != "function" or "function" not in tool: + normalized.append(dict(tool)) + continue + + fn = tool.get("function") or {} + converted = { + "type": "function", + "name": fn.get("name"), + "description": fn.get("description", ""), + "parameters": fn.get("parameters", {"type": "object", "properties": {}}), + } + if "strict" in fn: + converted["strict"] = fn["strict"] + normalized.append(converted) + return normalized + + @staticmethod + def extract_output_text(response: Any) -> str: + text = getattr(response, "output_text", None) + if isinstance(text, str) and text.strip(): + return text.strip() + + parts: List[str] = [] + for item in getattr(response, "output", []) or []: + item_type = getattr(item, "type", None) + if item_type != "message": + continue + for content in getattr(item, "content", []) or []: + content_text = getattr(content, "text", None) + if content_text: + parts.append(str(content_text)) + return "\n".join(parts).strip() + + @staticmethod + def extract_function_calls(response: Any) -> List[Dict[str, Any]]: + calls: List[Dict[str, Any]] = [] + for item in getattr(response, "output", []) or []: + if getattr(item, "type", None) != "function_call": + continue + calls.append( + { + "id": getattr(item, "call_id", None) or getattr(item, "id", None), + "type": "function", + "function": { + "name": getattr(item, "name", ""), + "arguments": getattr(item, "arguments", "{}") or "{}", + }, + } + ) + return calls + + @staticmethod + def decode_arguments(arguments: str) -> Dict[str, Any]: + try: + return json.loads(arguments or "{}") + except json.JSONDecodeError: + return {} diff --git a/core/tool_registry.py b/core/tool_registry.py index e79fc6cbbfe8859956ff5323db6f24d83e7761a9..fe042c49ce9cdf870c63deb99a74b66dab549e4f 100644 --- a/core/tool_registry.py +++ b/core/tool_registry.py @@ -6,6 +6,7 @@ 重構版本:整合 Pydantic Schema 自動生成 """ +import inspect from typing import Dict, List, Any, Optional, Callable, Type from dataclasses import dataclass, field @@ -18,6 +19,7 @@ from core.tool_schema import ( extract_schema_from_mcp_tool, ) +from features.mcp.tools.base_tool import MCPTool logger = get_logger("core.tool_registry") @@ -246,35 +248,45 @@ def register_mcp_tools_to_registry(mcp_server) -> int: count = 0 for tool_name, tool in mcp_server.tools.items(): - # 優先嘗試從 MCPTool 類別提取完整 Schema + # 1. 嘗試獲取工具類別 + tool_class = None if hasattr(tool, 'handler') and hasattr(tool.handler, '__self__'): - tool_class = tool.handler.__self__ - if tool_registry.register_mcp_tool(type(tool_class)): - count += 1 - continue + tool_class = type(tool.handler.__self__) + elif hasattr(tool, 'handler') and hasattr(tool.handler, '__closure__') and tool.handler.__closure__: + # 嘗試從閉包中找 (例如 classmethod_wrapper or instance_wrapper) + for cell in tool.handler.__closure__: + try: + contents = cell.cell_contents + # 檢查是否為 MCPTool 類別或實例 (使用鴨子類型,避免模組導入路徑不一致問題) + if inspect.isclass(contents) and hasattr(contents, 'get_input_schema') and hasattr(contents, 'NAME'): + tool_class = contents + break + elif not inspect.isclass(contents) and hasattr(contents, 'get_input_schema') and hasattr(contents, 'NAME'): + tool_class = type(contents) + break + except: + continue + + # 2. 如果能找到類別,使用 register_mcp_tool (這會處理 rich description) + if tool_class and tool_registry.register_mcp_tool(tool_class): + count += 1 + continue - # 降級:使用舊方法註冊 + # 3. 降級:手動提取並註冊 description = getattr(tool, 'description', f'{tool_name} 工具') - parameters = {"type": "object", "properties": {}, "required": []} - - if hasattr(tool, 'handler') and hasattr(tool.handler, '__self__'): - tool_class = tool.handler.__self__ - if hasattr(tool_class, 'get_input_schema'): - try: - parameters = tool_class.get_input_schema() - except Exception as e: - logger.warning(f"取得 {tool_name} schema 失敗: {e}") - - # 提取關鍵字和範例 + parameters = getattr(tool, 'inputSchema', {"type": "object", "properties": {}, "required": []}) + output_schema = getattr(tool, 'outputSchema', None) + + # 嘗試從 handler 閉包中找 tool_class (如果有的話) + # 或者從 tool.metadata 找 keywords = [] examples = [] - if hasattr(tool, 'handler') and hasattr(tool.handler, '__self__'): - tool_class = tool.handler.__self__ - keywords = getattr(tool_class, 'KEYWORDS', []) - examples = getattr(tool_class, 'USAGE_TIPS', []) - - # 判斷分類 - category = _infer_category(tool_name) + if hasattr(tool, 'metadata') and tool.metadata: + keywords = tool.metadata.get('keywords', []) + examples = tool.metadata.get('usage_tips', []) or tool.metadata.get('examples', []) + category = tool.metadata.get('category', 'general') + else: + category = _infer_category(tool_name) # 判斷是否需要位置 requires_location = _requires_location(tool_name, parameters) diff --git a/core/tool_router.py b/core/tool_router.py index 56035c60581310b1388f602e82ee5dbf99ad38ee..e762c8d9ba9852d641016d466f9d772ad35437c1 100644 --- a/core/tool_router.py +++ b/core/tool_router.py @@ -28,15 +28,16 @@ class ToolRouter: # 分類關鍵字映射 CATEGORY_KEYWORDS = { - "weather": ["天氣", "氣溫", "下雨", "晴天", "陰天", "weather", "溫度", "濕度"], + "weather": ["天氣", "氣溫", "下雨", "晴天", "陰天", "weather", "溫度", "濕度", "天気", "雨", "気温"], "transportation": [ "公車", "巴士", "bus", "火車", "台鐵", "高鐵", "捷運", "metro", - "youbike", "ubike", "微笑單車", "共享單車", "停車場", "停車位" + "youbike", "ubike", "微笑單車", "共享單車", "停車場", "停車位", + "バス", "電車", "地下鉄", "新幹線", "駐輪場", "駐車場" ], - "location": ["我在哪", "這是哪", "位置", "地址", "怎麼去", "導航", "路線"], - "information": ["新聞", "消息", "報導", "news"], - "finance": ["匯率", "換算", "美元", "日圓", "歐元", "currency", "exchange"], - "health": ["心率", "步數", "血氧", "睡眠", "健康", "運動"], + "location": ["我在哪", "這是哪", "位置", "地址", "怎麼去", "導航", "路線", "どこ", "現在地", "住所", "ナビ"], + "information": ["新聞", "消息", "報導", "news", "ニュース", "報道"], + "finance": ["匯率", "換算", "美元", "日圓", "歐元", "currency", "exchange", "為替", "レート", "円", "ドル"], + "health": ["心率", "步數", "血氧", "睡眠", "健康", "運動", "健康", "歩数", "心拍", "運動"], } # 時間敏感工具(深夜可能不適用) @@ -89,6 +90,7 @@ class ToolRouter: logger.debug(f"🎯 檢測到的分類: {detected_categories}") # 2. 過濾工具 + language = context.get("language") filtered_tools = [] for tool in tools: tool_name = tool.get("function", {}).get("name", "") @@ -106,6 +108,12 @@ class ToolRouter: filtered_tools.append(tool) # 3. 排序工具(相關分類優先) + # 如果是日語,特別提升新聞與天氣的優先級(補償關鍵字可能不全的情況) + if language == 'ja': + detected_categories.add("weather") + detected_categories.add("finance") + detected_categories.add("information") + sorted_tools = self._sort_tools(filtered_tools, detected_categories, context) # 4. 限制工具數量(減少 token 消耗) @@ -114,7 +122,7 @@ class ToolRouter: logger.info(f"📉 工具數量從 {len(sorted_tools)} 限制到 {max_tools}") sorted_tools = sorted_tools[:max_tools] - logger.info(f"🔧 過濾後工具: {[t['function']['name'] for t in sorted_tools]}") + logger.info(f"🔧 過濾後工具: {[t['function']['name'] for t in sorted_tools]} (用戶語系: {language})") return sorted_tools def _detect_categories(self, message: str) -> Set[str]: @@ -234,11 +242,11 @@ class ToolRouter: return 20 if len(detected_categories) == 1: - # 單一分類,但仍需要保留足夠工具(如 directions) - return 12 + # 單一分類,只需保留核心工具,顯著減少 LLM 負擔 + return 6 - # 多個分類 - return 15 + # 多個分類,保持在較小範圍 + return 10 def record_tool_usage(self, user_id: str, tool_name: str) -> None: """記錄工具使用(用於優先級調整)""" diff --git a/core/tool_schema.py b/core/tool_schema.py index 8954b04595bd714872ccd15ba65af10a465e399a..d037f15a94214637d2e06d3b5215e2cc5ce1ec3f 100644 --- a/core/tool_schema.py +++ b/core/tool_schema.py @@ -5,10 +5,11 @@ Pydantic 工具 Schema 定義 功能: 1. 工具輸入/輸出的 Pydantic 基礎類別 2. 自動生成 OpenAI tools 格式的 JSON Schema -3. 支援 strict mode 確保 100% 有效輸出 +3. 支援 provider strict mode 確保工具參數結構穩定 4. 裝飾器模式自動註冊工具 """ +from copy import deepcopy from typing import Dict, Any, Optional, List, Callable, Type, TypeVar, get_type_hints from dataclasses import dataclass, field from functools import wraps @@ -54,7 +55,7 @@ class ToolSchema: 轉換為 OpenAI Function Calling 格式 Args: - strict: 是否啟用 strict mode(確保輸出符合 schema) + strict: 是否啟用 provider strict mode(約束工具參數 schema) Returns: OpenAI tools 格式的字典 @@ -110,32 +111,74 @@ class ToolSchema: strict mode 要求: 1. additionalProperties: false - 2. 所有屬性都在 required 中(或有 default) - 3. 不支援 oneOf/anyOf/allOf + 2. 保留 JSON Schema required 語意 + 3. 可選欄位必須透過 default 或 nullable 型別明確表達 """ - result = dict(schema) + result = deepcopy(schema) # 確保是 object 類型 if result.get("type") != "object": result = {"type": "object", "properties": result} - # 添加 additionalProperties: false - result["additionalProperties"] = False - - # 確保所有屬性都在 required 中 + # Provider strict mode 要求所有 properties 都列入 required; + # 有 default 的欄位先保留 default,執行端仍會套用工具 schema 預設值。 + self._apply_provider_strict_object_rules(result) properties = result.get("properties", {}) - existing_required = set(result.get("required", [])) - - # 收集所有沒有 default 的屬性 - all_required = [] - for prop_name, prop_schema in properties.items(): - if prop_name in existing_required or "default" not in prop_schema: - all_required.append(prop_name) - - result["required"] = all_required + result["required"] = list(properties.keys()) return result + def _apply_provider_strict_object_rules(self, schema: Dict[str, Any]) -> None: + """遞迴套用 provider strict object schema 規則。""" + if schema.get("type") == "object": + schema["additionalProperties"] = False + properties = schema.get("properties", {}) + if isinstance(properties, dict): + schema["required"] = list(properties.keys()) + for prop_schema in properties.values(): + if isinstance(prop_schema, dict): + self._apply_provider_strict_object_rules(prop_schema) + + for key in ("items",): + nested = schema.get(key) + if isinstance(nested, dict): + self._apply_provider_strict_object_rules(nested) + + def validate_schema_contract(self) -> List[str]: + """檢查 input/output schema 是否有會破壞工具調用的契約問題。""" + issues: List[str] = [] + if not self.metadata.name: + issues.append("tool name is required") + if self.input_schema.get("type") != "object": + issues.append(f"{self.metadata.name}: input_schema.type must be object") + + properties = self.input_schema.get("properties", {}) + if not isinstance(properties, dict): + issues.append(f"{self.metadata.name}: input_schema.properties must be object") + + required = self.input_schema.get("required", []) + if required and not isinstance(required, list): + issues.append(f"{self.metadata.name}: input_schema.required must be list") + for field in required: + if field not in properties: + issues.append(f"{self.metadata.name}: required field '{field}' missing from properties") + + if self.output_schema is not None: + if self.output_schema.get("type") != "object": + issues.append(f"{self.metadata.name}: output_schema.type must be object") + output_props = self.output_schema.get("properties", {}) + if not isinstance(output_props, dict): + issues.append(f"{self.metadata.name}: output_schema.properties must be object") + + return issues + + def contract_warnings(self) -> List[str]: + """回報不阻擋執行、但會降低模型選工具品質的問題。""" + warnings: List[str] = [] + if not self.metadata.description: + warnings.append(f"{self.metadata.name}: description is empty") + return warnings + def get_summary(self) -> Dict[str, Any]: """獲取工具摘要(用於快速意圖匹配)""" return { @@ -185,7 +228,6 @@ def extract_schema_from_mcp_tool(tool_class: Type) -> Optional[ToolSchema]: try: input_schema = tool_class.get_input_schema() except Exception as e: - logger.warning(f"提取 {name} input schema 失敗: {e}") input_schema = {"type": "object", "properties": {}} # 提取 output schema(可選) @@ -253,6 +295,11 @@ class ToolSchemaRegistry: def register(self, schema: ToolSchema) -> None: """註冊工具 Schema""" + issues = schema.validate_schema_contract() + if issues: + raise ValueError("; ".join(issues)) + for warning in schema.contract_warnings(): + logger.warning(warning) self._schemas[schema.metadata.name] = schema logger.debug(f"註冊工具 Schema: {schema.metadata.name}") diff --git a/core/voice_care_gate.py b/core/voice_care_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..2de72c664b10f2f535f71baa73568619dcb221d1 --- /dev/null +++ b/core/voice_care_gate.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +EXTREME_EMOTIONS = {"sad", "angry", "fear"} +VOICE_EMOTION_CONFIDENCE_THRESHOLD = 0.70 +VOICE_SPEECH_CONFIDENCE_THRESHOLD = 0.70 + + +@dataclass(frozen=True) +class VoiceCareDecision: + allow: bool + emotion: str + confidence: float + reason: str + evidence: Dict[str, Any] + + +def _to_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _normalize_emotion(value: Any) -> str: + text = str(value or "neutral").strip().lower() + return text if text else "neutral" + + +def is_voice_context(audio_emotion: Optional[Dict[str, Any]]) -> bool: + if not audio_emotion: + return False + source = str(audio_emotion.get("source") or "").strip().lower() + return source in {"realtime_voice", "voice", "speech", "audio"} + + +def decide_voice_care( + *, + text_emotion: str, + audio_emotion: Optional[Dict[str, Any]], +) -> VoiceCareDecision: + """ + Gate for voice-triggered care mode. + + Voice emotion is allowed to trigger care mode only when the transcript-side + emotion agrees that the user is in the same extreme-emotion family. + """ + text_value = _normalize_emotion(text_emotion) + audio_value = _normalize_emotion((audio_emotion or {}).get("emotion")) + audio_confidence = _to_float((audio_emotion or {}).get("confidence")) + speech_confidence_raw = (audio_emotion or {}).get("speech_confidence") + speech_confidence = ( + _to_float(speech_confidence_raw) + if speech_confidence_raw is not None + else None + ) + evidence = { + "text_emotion": text_value, + "audio_emotion": audio_value, + "audio_emotion_confidence": audio_confidence, + "speech_confidence": speech_confidence, + } + + if not audio_emotion or not audio_emotion.get("success"): + return VoiceCareDecision(False, text_value, 0.5, "voice-audio-missing", evidence) + + if speech_confidence is not None and speech_confidence < VOICE_SPEECH_CONFIDENCE_THRESHOLD: + return VoiceCareDecision(False, text_value, 0.5, "voice-speech-low-confidence", evidence) + + if audio_confidence < VOICE_EMOTION_CONFIDENCE_THRESHOLD: + if text_value in EXTREME_EMOTIONS: + return VoiceCareDecision(True, text_value, 0.5, "voice-text-extreme-audio-low-confidence", evidence) + return VoiceCareDecision(False, text_value, 0.5, "voice-audio-low-confidence", evidence) + + if audio_value not in EXTREME_EMOTIONS: + return VoiceCareDecision(False, text_value, audio_confidence, "voice-audio-not-extreme", evidence) + + if text_value not in EXTREME_EMOTIONS: + return VoiceCareDecision(False, text_value, 0.5, "voice-text-not-extreme", evidence) + + return VoiceCareDecision(True, text_value, audio_confidence, "voice-extreme-family-match", evidence) diff --git a/features/care_mode/skills/ACTIVE_LISTENING.md b/features/care_mode/skills/ACTIVE_LISTENING.md new file mode 100644 index 0000000000000000000000000000000000000000..b409f5d9ad2e3e82775b859834f757ed99d37a30 --- /dev/null +++ b/features/care_mode/skills/ACTIVE_LISTENING.md @@ -0,0 +1,24 @@ +--- +name: active-listening +description: "主動傾聽技巧:透過反映與複述,讓用戶感受到被聽見與理解。" +usage_policy: + mirror_feelings: true + reflect_content: true + avoid_judgement: true +--- + +# 主動傾聽 (Active Listening) + +當用戶表達情緒或分享經歷時,使用此技能來強化理解感。 + +### 執行要點: +1. **反映感受**:辨識並說出用戶文字背後的情緒。 + - *例子*:「聽起來這件事讓你感到很挫折。」 +2. **複述核心**:用你自己的話簡單重述用戶提到的關鍵點,確認你的理解。 + - *例子*:「你提到雖然努力了很久,但結果不如預期,這讓你感到很不甘心。」 +3. **鼓勵續說**:使用簡短的鼓勵詞,邀請用戶進一步宣洩。 + - *例子*:「我在聽,你想多聊聊這部分嗎?」 + +### 禁忌: +- 在用戶說完前就給予建議。 +- 使用「我知道你的感受」這種空洞的話(除非你能具體說出是什麼感受)。 diff --git a/features/care_mode/skills/ANGER_HANDLING.md b/features/care_mode/skills/ANGER_HANDLING.md new file mode 100644 index 0000000000000000000000000000000000000000..383cf370bcd30874422012daed276b5bd4a0e154 --- /dev/null +++ b/features/care_mode/skills/ANGER_HANDLING.md @@ -0,0 +1,15 @@ +--- +name: anger-handling +description: "憤怒情緒處理技巧:認可憤怒的合理性,不給予壓力。" +usage_policy: + calm_empathy: true + validate_anger: true + avoid_suppression: true +--- + +# 憤怒情緒處理 (Anger Handling) + +### 執行要點: +- **語氣**:冷靜但帶有同理、不卑不亢。 +- **重點**:認可對方的憤怒是有原因的,幫助對方感覺被理解。 +- **避免**:說「冷靜一下」、「別生氣」這類否定情緒的話。 diff --git a/features/care_mode/skills/CARE_CORE_STRATEGY.md b/features/care_mode/skills/CARE_CORE_STRATEGY.md new file mode 100644 index 0000000000000000000000000000000000000000..3726e5953a56599cc9331d67929a956e805d8106 --- /dev/null +++ b/features/care_mode/skills/CARE_CORE_STRATEGY.md @@ -0,0 +1,26 @@ +--- +name: care-core-strategy +description: "關懷模式核心策略:定義回應的結構與核心原則。" +usage_policy: + empathy_first: true + companionship: true + open_encouragement: true +--- + +# 關懷模式核心策略 (Core Care Strategy) + +這是進行情緒關懷對話的基礎框架,必須嚴格遵守。 + +### 核心職責: +1. **深度同理 (Deep Empathy)**:第一句話必須精準反映用戶目前的感受或處境。 + - *例子*:「考零分真的會很難過,這種失落我懂。」 +2. **溫柔陪伴 (Gentle Companionship)**:表達你在這裡,願意傾聽。 + - *例子*:「我在這裡陪你。」 +3. **開放式鼓勵 (Open Encouragement)**:適度詢問細節,鼓勵用戶宣洩情緒。 + - *例子*:「想說說發生了什麼嗎?」 + +### 回應原則: +- **語氣**:輕柔、自然、像好朋友。 +- **格式**:保持簡潔(約 2-3 句話),確保每句話都有溫度。 +- **禁止**:機械化回答、過度正向的毒雞湯、心理醫生式的說教。 +- **禁止**:重複使用罐頭話術。 diff --git a/features/care_mode/skills/EMOTIONAL_VALIDATION.md b/features/care_mode/skills/EMOTIONAL_VALIDATION.md new file mode 100644 index 0000000000000000000000000000000000000000..469bb669d61ff6cef812a28d73118f1e73a7930c --- /dev/null +++ b/features/care_mode/skills/EMOTIONAL_VALIDATION.md @@ -0,0 +1,24 @@ +--- +name: emotional-validation +description: "情感驗證技巧:認可用戶情緒的合理性,減輕其自我懷疑或羞恥感。" +usage_policy: + normalize_emotions: true + validate_logic: true + show_empathy: true +--- + +# 情感驗證 (Emotional Validation) + +當用戶對自己的情緒感到懷疑、羞恥或困惑時,使用此技能。 + +### 執行要點: +1. **正常化情緒**:讓用戶知道在這種情況下有這種感覺是很正常的。 + - *例子*:「換作是我,遇到這種情況也會感到生氣的。」 +2. **認可合理性**:根據對話背景,解釋為什麼用戶的情緒是合理的。 + - *例子*:「付出了那麼多努力卻沒有回報,難過是很自然的事。」 +3. **消除孤獨感**:表達這種情緒是普遍的人類經驗。 + - *例子*:「很多人在面對這種轉變時都會感到焦慮,你並不孤單。」 + +### 禁忌: +- 質疑用戶情緒的真實性。 +- 說「這沒什麼大不了的」或「別想太多」。 diff --git a/features/care_mode/skills/FEAR_HANDLING.md b/features/care_mode/skills/FEAR_HANDLING.md new file mode 100644 index 0000000000000000000000000000000000000000..bfac6355eb358c22035697fa051582d5c37152e1 --- /dev/null +++ b/features/care_mode/skills/FEAR_HANDLING.md @@ -0,0 +1,15 @@ +--- +name: fear-handling +description: "恐懼/焦慮情緒處理技巧:提供穩定感與安全感。" +usage_policy: + stable_presence: true + accept_fear: true + provide_security: true +--- + +# 恐懼/焦慮處理 (Fear & Anxiety Handling) + +### 執行要點: +- **語氣**:穩定、溫暖、帶有安全感。 +- **重點**:讓對方感覺不孤單,恐懼是可以被接納的。 +- **避免**:說「沒什麼好怕的」、「想太多了」這類否定情緒的話。 diff --git a/features/care_mode/skills/FIRST_CONTACT_CARE.md b/features/care_mode/skills/FIRST_CONTACT_CARE.md new file mode 100644 index 0000000000000000000000000000000000000000..2b4104dead6e7edfceae552cddb37b69febedabf --- /dev/null +++ b/features/care_mode/skills/FIRST_CONTACT_CARE.md @@ -0,0 +1,21 @@ +--- +name: first-contact-care +description: "初次關懷引導:在進入關懷模式的首個回覆中,引導用戶如何操作。" +usage_policy: + conditional_execution: true + provide_exit_hint: true +--- + +# 初次關懷引導 (First Contact Care) + +當系統偵測到這是進入「關懷模式」的第一個回覆時(`is_first_care=True`),請執行此引導。 + +### 執行要點: +1. **完成溫暖回應**:首先根據用戶情緒提供深度同理與陪伴的回應。 +2. **附上退出提示**:在回覆的結尾,換行兩次後,附上以下藍色字體的提示: + - `💙 關懷模式已啟動。說「我沒事了」可以退出。` + +### 範例: +「考零分真的會很難過,這種失落我懂。我在這裡陪你,想說說發生什麼嗎? + +💙 關懷模式已啟動。說「我沒事了」可以退出。」 diff --git a/features/care_mode/skills/SADNESS_HANDLING.md b/features/care_mode/skills/SADNESS_HANDLING.md new file mode 100644 index 0000000000000000000000000000000000000000..7832bcffc1bfc8af8888da4e6cfbde00bc87f805 --- /dev/null +++ b/features/care_mode/skills/SADNESS_HANDLING.md @@ -0,0 +1,15 @@ +--- +name: sadness-handling +description: "悲傷情緒處理技巧:提供溫柔理解,避免過度積極。" +usage_policy: + gentle_tone: true + normalize_sadness: true + avoid_toxic_positivity: true +--- + +# 悲傷情緒處理 (Sadness Handling) + +### 執行要點: +- **語氣**:溫柔、輕聲、帶有理解。 +- **重點**:陪伴而非解決問題,讓對方知道悲傷是正常的。 +- **避免**:說「不要難過」、「振作點」這類否定情緒的話。 diff --git a/features/care_mode/skills/SUPPORTIVE_PRESENCE.md b/features/care_mode/skills/SUPPORTIVE_PRESENCE.md new file mode 100644 index 0000000000000000000000000000000000000000..f365673d234bbba4c53b77da98c4b8cbdb96c696 --- /dev/null +++ b/features/care_mode/skills/SUPPORTIVE_PRESENCE.md @@ -0,0 +1,24 @@ +--- +name: supportive-presence +description: "支持性陪伴技巧:提供純粹的陪伴感,不急於解決問題,建立安全感。" +usage_policy: + be_present: true + slow_down: true + offer_space: true +--- + +# 支持性陪伴 (Supportive Presence) + +當用戶處於情緒低谷,且尚未準備好採取行動或深入探討時,使用此技能。 + +### 執行要點: +1. **傳達在場**:明確表達你現在就在這裡,不會離開。 + - *例子*:「我就在這裡陪你,不需要急著說什麼。」 +2. **減緩節奏**:給予用戶空間,不要連續追問。 + - *例子*:「如果你累了想靜一下也可以,我會一直都在。」 +3. **無條件接納**:表達無論用戶現在狀態如何,你都能接受。 + - *例子*:「現在不想說話也沒關係,我可以就這樣靜靜陪著你。」 + +### 禁忌: +- 急著提供「解決方案」或「待辦清單」。 +- 表現得像是在趕時間或想要結束對話。 diff --git a/features/mcp/agent_bridge.py b/features/mcp/agent_bridge.py index e5370c99350e9de7e16887b13eef9a838916ff90..aedbe152a13b08e8bc51374817fd0fad53557113 100644 --- a/features/mcp/agent_bridge.py +++ b/features/mcp/agent_bridge.py @@ -6,11 +6,13 @@ MCP + Agent 橋接層 import json import logging import asyncio +import time from typing import Dict, Any, Optional, List, Tuple, Callable, Awaitable from datetime import datetime from .server import FeaturesMCPServer import services.ai_service as ai_service from services.ai_service import StrictResponseError +from core.prompts.tool_calling_policy import get_tool_calling_policy from core.reasoning_strategy import get_optimal_reasoning_effort from core.database import get_user_env_current from .coordinator import ToolCoordinator @@ -105,38 +107,38 @@ class MCPAgentBridge: name="weather_query", requires_env={"lat", "lon", "city"}, env_fallbacks={"city": ["detailed_address", "label"]}, - enable_reformat=True, + enable_reformat=False, ) ) register( ToolMetadata( name="reverse_geocode", requires_env={"lat", "lon"}, - enable_reformat=True, + enable_reformat=False, ) ) register( ToolMetadata( name="exchange_query", - enable_reformat=True, + enable_reformat=False, ) ) register( ToolMetadata( name="news_query", - enable_reformat=True, + enable_reformat=False, ) ) register( ToolMetadata( name="healthkit_query", - enable_reformat=True, + enable_reformat=False, ) ) register( ToolMetadata( name="directions", - enable_reformat=True, + enable_reformat=False, ) ) register( @@ -223,68 +225,12 @@ class MCPAgentBridge: logger.info(f"異步初始化完成,完整可用 MCP 工具數量: {len(self.mcp_server.tools)}") # 將 MCP Server 的工具註冊到 tool_registry - self._sync_tools_to_registry() + from core.tool_registry import register_mcp_tools_to_registry + register_mcp_tools_to_registry(self.mcp_server) # 快取預熱已移除:啟動時連續調用 7 次 GPT API 增加延遲和成本 # 實際使用中快取會自然累積,無需預熱 - def _sync_tools_to_registry(self) -> int: - """ - 將 MCP Server 的工具同步到 tool_registry - - Returns: - 註冊的工具數量 - """ - from core.tool_registry import tool_registry - - count = 0 - for tool_name, tool in self.mcp_server.tools.items(): - # 取得工具描述 - description = getattr(tool, 'description', f'{tool_name} 工具') - - # 取得參數 Schema - parameters = {"type": "object", "properties": {}, "required": []} - keywords = [] - examples = [] - negative_examples = [] - category = "general" - priority = 100 - - if hasattr(tool, 'handler') and hasattr(tool.handler, '__self__'): - tool_class = tool.handler.__self__ - - # 嘗試從 MCPTool 類別提取完整資訊 - if hasattr(tool_class, 'get_input_schema'): - try: - parameters = tool_class.get_input_schema() - except Exception as e: - logger.warning(f"取得 {tool_name} schema 失敗: {e}") - - # 提取增強元資料 - keywords = getattr(tool_class, 'KEYWORDS', []) - examples = getattr(tool_class, 'USAGE_TIPS', []) - negative_examples = getattr(tool_class, 'NEGATIVE_EXAMPLES', []) - category = getattr(tool_class, 'CATEGORY', 'general') - priority = getattr(tool_class, 'PRIORITY', 100) - - # 判斷是否需要位置 - props = parameters.get("properties", {}) - requires_location = "lat" in props or "lon" in props - - tool_registry.register( - name=tool_name, - description=description, - parameters=parameters, - handler=getattr(tool, 'handler', None), - category=category, - requires_location=requires_location, - keywords=keywords, - examples=examples, - ) - count += 1 - - logger.info(f"🔧 同步 {count} 個工具到 tool_registry") - return count def _normalize_tool_name(self, raw_name: Optional[str]) -> Optional[str]: """ @@ -444,92 +390,28 @@ class MCPAgentBridge: "tool_data": fallback_payload, } - def get_current_time_data(self) -> Dict[str, Any]: + async def detect_intent(self, message: str, tool_context: str = "", language: Optional[str] = None) -> Tuple[bool, Optional[Dict[str, Any]]]: """ - 獲取當前時間數據,用於生成個性化歡迎詞 - 返回格式與舊 time_service 兼容 - """ - now = datetime.now() - - # 獲取時間段 - hour = now.hour - if 5 <= hour < 12: - day_period = "上午" - elif 12 <= hour < 18: - day_period = "下午" - elif 18 <= hour < 22: - day_period = "晚上" - else: - day_period = "深夜" if hour >= 22 else "凌晨" - - # 星期幾中文名稱 - weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] - weekday_full_chinese = weekdays[now.weekday()] - - return { - "year": now.year, - "month": now.month, - "day": now.day, - "hour": hour, - "minute": now.minute, - "second": now.second, - "weekday": now.weekday(), # 0-6, 星期一到星期日 - "weekday_full_chinese": weekday_full_chinese, - "day_period": day_period, - "timestamp": now.timestamp(), - "iso_format": now.isoformat() - } - - async def detect_intent(self, message: str) -> Tuple[bool, Optional[Dict[str, Any]]]: - """ - 檢測用戶消息中的意圖 (保持與舊 FeatureRouter 相同介面) - - 2025 重構版:使用 OpenAI 原生 Function Calling - - 不再使用巨大的 system_prompt 描述每個工具 - - 工具定義由 tools 參數傳遞,GPT 原生選擇 - - 新增工具只需註冊到 Registry,不需更新任何 prompt - - 參數: - message (str): 用戶消息 - - 返回: - tuple: (是否檢測到意圖, 意圖數據) - """ - # 使用新的 IntentDetector(基於 OpenAI Function Calling) - return await self._detect_intent_with_function_calling(message) - - async def _detect_intent_with_function_calling(self, message: str) -> Tuple[bool, Optional[Dict[str, Any]]]: - """ - 使用 OpenAI 原生 Function Calling 進行意圖檢測 - - 核心改進: - 1. 工具定義自動從 Registry 生成 - 2. GPT 原生選擇工具並生成結構化參數 - 3. 不需要自定義 prompt 描述每個工具 + 意圖偵測(符合 2025 年最佳實踐:語言感知與快取優化) """ import hashlib - import time as time_module - - # 生成快取鍵 - cache_key = hashlib.md5(message.encode()).hexdigest() + # 將 message, tool_context 與 language 組合後計算 md5 + cache_raw = f"{message}||{tool_context or ''}||{language or ''}" + cache_key = hashlib.md5(cache_raw.encode()).hexdigest() # 檢查快取 if cache_key in self._intent_cache: has_feature, intent_data, cached_time = self._intent_cache[cache_key] - if time_module.time() - cached_time < self._intent_cache_ttl: + if time.time() - cached_time < self._intent_cache_ttl: logger.debug(f"💾 意圖快取命中: {message[:50]}...") - - # 【關鍵修復】快取命中時,仍需重新偵測情緒(情緒是即時的) - # 因為同一句話在不同時間說可能帶有不同的情緒強度 try: fresh_emotion = await self._analyze_emotion_from_message(message) if fresh_emotion and intent_data: - intent_data = dict(intent_data) # 複製避免修改原快取 + intent_data = dict(intent_data) intent_data['emotion'] = fresh_emotion logger.info(f"🎭 快取命中但重新偵測情緒: {fresh_emotion}") except Exception as e: logger.warning(f"快取命中時情緒分析失敗: {e}") - return has_feature, intent_data else: del self._intent_cache[cache_key] @@ -543,7 +425,6 @@ class MCPAgentBridge: return True, {"type": "special_command", "command": "feature_list"} try: - # 從 tool_registry 取得 OpenAI tools 格式 from core.tool_registry import tool_registry from core.tool_router import tool_router @@ -553,39 +434,51 @@ class MCPAgentBridge: logger.warning("⚠️ 沒有可用的工具,降級為聊天") return False, {"emotion": "neutral"} - # 使用 ToolRouter 動態過濾和排序工具 - context = {"hour": datetime.now().hour} - tools = tool_router.filter_tools(all_tools, message, context) + router_context = { + "hour": datetime.now().hour, + "language": language + } + tools = tool_router.filter_tools(all_tools, message, router_context) + + logger.info(f"🔧 載入 {len(all_tools)} 個工具,過濾後 {len(tools)} 個 (語言: {language})") - logger.info(f"🔧 載入 {len(all_tools)} 個工具,過濾後 {len(tools)} 個") + system_prompt = self._build_function_calling_prompt(language) - # 建構精簡的 system prompt(只處理特殊規則) - system_prompt = self._build_function_calling_prompt() + if tool_context: + system_prompt += f"\n\n【已執行的工具結果與上下文】\n請根據以下資訊判斷是否足夠回答用戶,若不足則繼續調用工具:\n{tool_context}" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": message} ] - # 使用 OpenAI Function Calling from core.reasoning_strategy import get_optimal_reasoning_effort optimal_effort = get_optimal_reasoning_effort("intent_detection") logger.info(f"🧠 意圖檢測推理強度: {optimal_effort}") + emotion_task = asyncio.create_task(self._analyze_emotion_from_message(message)) + response = await ai_service.generate_response_with_tools( messages=messages, tools=tools, user_id="intent_detection", - model="gpt-4o-mini", # 使用更強的模型以提升參數提取準確度 - reasoning_effort=None, # gpt-4o-mini 不支援 reasoning_effort + model=None, + reasoning_effort=optimal_effort, tool_choice="auto", ) - # 解析回應 + try: + parallel_emotion = await emotion_task + except Exception as e: + logger.warning(f"並行情緒分析失敗: {e}") + parallel_emotion = "neutral" + finally: + if not emotion_task.done(): + emotion_task.cancel() + tool_calls = response.get("tool_calls", []) if tool_calls: - # GPT 選擇了工具 tool_call = tool_calls[0] function = tool_call.get("function", {}) tool_name = function.get("name", "") @@ -596,7 +489,6 @@ class MCPAgentBridge: except json.JSONDecodeError: arguments = {} - # 正規化工具名稱 normalized_name = self._normalize_tool_name(tool_name) if not normalized_name: logger.warning(f"⚠️ 工具 {tool_name} 無法對應到註冊名稱,降級為聊天") @@ -605,38 +497,36 @@ class MCPAgentBridge: logger.info(f"✅ GPT 選擇工具: {normalized_name}") logger.debug(f"工具參數: {_safe_json(arguments)}") - # 提取情緒(從 content 或直接從用戶訊息分析) content = response.get("content", "") if content: emotion = self._extract_emotion_from_content(content) else: - # 當 GPT 只回傳 tool_calls 時,直接從用戶訊息分析情緒 - logger.debug(f"🔍 GPT content 為空,從用戶訊息分析情緒") - emotion = await self._analyze_emotion_from_message(message) + logger.debug(f"🔍 使用並行情緒分析結果: {parallel_emotion}") + emotion = parallel_emotion + + confidence = self._calculate_tool_confidence(normalized_name, arguments) intent_result = (True, { "type": "mcp_tool", "tool_name": normalized_name, "arguments": arguments, "emotion": emotion, + "confidence": confidence, }) - # 寫入快取 - self._intent_cache[cache_key] = (*intent_result, time_module.time()) + self._intent_cache[cache_key] = (*intent_result, time.time()) return intent_result else: - # GPT 未選擇工具,視為一般聊天 logger.info("💬 GPT 判斷為一般聊天") emotion = self._extract_emotion_from_content(response.get("content", "")) intent_result = (False, {"emotion": emotion}) - self._intent_cache[cache_key] = (*intent_result, time_module.time()) + self._intent_cache[cache_key] = (*intent_result, time.time()) return intent_result except Exception as e: logger.error(f"❌ Function Calling 意圖檢測失敗: {e}") - # 降級:使用關鍵詞匹配 logger.info("🔄 嘗試使用關鍵詞匹配作為降級方案") try: fallback_result = self._keyword_intent_detection(message) @@ -646,20 +536,60 @@ class MCPAgentBridge: except Exception as fallback_error: logger.error(f"❌ 關鍵詞匹配也失敗: {fallback_error}") - # 最終降級:視為一般聊天 logger.info("💬 降級為一般聊天") return False, {"emotion": "neutral"} - def _build_function_calling_prompt(self) -> str: + def get_current_time_data(self) -> Dict[str, Any]: """ - 建構精簡的 Function Calling system prompt - - 注意:不再描述每個工具,工具定義由 tools 參數傳遞 - 只處理特殊規則和情緒判斷 + 獲取當前時間數據,用於生成個性化歡迎詞 + 返回格式與舊 time_service 兼容 """ - return """You are an intelligent assistant that selects appropriate tools based on user needs. + now = datetime.now() + + # 獲取時間段 + hour = now.hour + if 5 <= hour < 12: + day_period = "上午" + elif 12 <= hour < 18: + day_period = "下午" + elif 18 <= hour < 22: + day_period = "晚上" + else: + day_period = "深夜" if hour >= 22 else "凌晨" + + # 星期幾中文名稱 + weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] + weekday_full_chinese = weekdays[now.weekday()] + + return { + "year": now.year, + "month": now.month, + "day": now.day, + "hour": hour, + "minute": now.minute, + "second": now.second, + "weekday": now.weekday(), # 0-6, 星期一到星期日 + "weekday_full_chinese": weekday_full_chinese, + "day_period": day_period, + "timestamp": now.timestamp(), + "iso_format": now.isoformat() + } -Rules: + def _build_function_calling_prompt(self, language: Optional[str] = None) -> str: + """ + 建構精簡的 Function Calling system prompt + """ + from core.prompts.tool_calling_policy import get_tool_calling_policy + policy = get_tool_calling_policy() + + # 根據語系決定基礎提示詞語言 + lang_context = f"目前的用戶語言是:{language or 'zh (繁體中文)'}" + + return ( + f"You are an intelligent assistant. {lang_context}\n\n" + + policy + + "\n\n" + + """Rules: 1. If the user's request can be solved with a tool, select the most appropriate tool 2. Only skip tool selection for pure greetings (hi, hello) or meta questions (what can you do) 3. Extract tool parameters from user message @@ -676,116 +606,18 @@ Rules: 【重要】語言使用規範: - 調用工具時:所有參數必須使用英文(城市名、國家名、貨幣代碼等) -- 範例:用戶說「台北天氣」或 "Taipei weather" → 參數 {"city": "Taipei"} +- 範例:用戶說「台北天氣」或 "Taipei weather" 或 "東京の天気" → 參數 {"city": "Taipei"} 或 {"city": "Tokyo"} 參數語言轉換規則: -- 城市名稱:台北→Taipei, 新北→NewTaipei, 桃園→Taoyuan, 台中→Taichung, 台南→Tainan, 高雄→Kaohsiung, 新竹→Hsinchu +- 城市名稱:台北→Taipei, 新北→NewTaipei, 桃園→Taoyuan, 台中→Taichung, 台南→Tainan, 高雄→Kaohsiung, 新竹→Hsinchu, 東京→Tokyo, 大阪→Osaka, 京都→Kyoto - 國家名稱:台灣→Taiwan, 美國→USA, 日本→Japan, 英國→UK -- 貨幣代碼:美元→USD, 台幣→TWD, 日圓→JPY, 歐元→EUR, 英鎊→GBP - -【重要】城市參數提取原則: -- 只有在用戶明確提到城市名稱時才填 city 參數 -- 「附近」「這裡」「我這邊」等詞 → 不填 city 參數,系統會自動從 GPS 判斷 -- 「台北的XX」「桃園XX」→ 填對應的英文城市名 -- 範例:「附近的 YouBike」→ {},「桃園的 YouBike」→ {"city": "Taoyuan"} - -匯率查詢(重要!參數提取規則): -當用戶詢問匯率資訊時,你必須從消息中提取貨幣代碼並填入參數。 - -參數提取規則: -1. 句型「[貨幣A]轉[貨幣B]」「[貨幣A]換[貨幣B]」「[貨幣A]兌[貨幣B]」→ {"from_currency": "代碼A", "to_currency": "代碼B"} -2. 句型「[數字][貨幣A]是多少[貨幣B]」→ {"from_currency": "代碼A", "to_currency": "代碼B", "amount": 數字} -3. 句型「匯率」「美金」「日幣」→ 提取提到的貨幣 -4. 貨幣代碼必須用 ISO 4217 標準(3個大寫字母) - -常見貨幣代碼對照: -- 美元/美金 → USD -- 台幣/新台幣 → TWD -- 日圓/日幣 → JPY -- 歐元 → EUR -- 英鎊 → GBP -- 人民幣 → CNY -- 港幣 → HKD -- 韓元 → KRW - -實際範例: -- 「美元轉日幣的匯率」→ {"from_currency": "USD", "to_currency": "JPY"} -- 「台幣換美金」→ {"from_currency": "TWD", "to_currency": "USD"} -- 「100美元是多少台幣」→ {"from_currency": "USD", "to_currency": "TWD", "amount": 100} -- 「歐元兌日圓」→ {"from_currency": "EUR", "to_currency": "JPY"} -- 「匯率」→ {"from_currency": "USD", "to_currency": "TWD"}(預設) - -重要:必須提取貨幣代碼!不要返回空參數! - -公車查詢(重要!參數提取規則): -當用戶詢問公車資訊時,你必須從消息中提取路線號碼並填入參數。 - -tdx_bus_arrival 適用場景: -- 查詢「已知路線號碼」的到站時間 -- 查詢附近公車站點(不需 route_name) - -參數提取規則: -1. 句型「[數字]公車」「[數字]號公車」→ {"route_name": "數字"} -2. 句型「[顏色][數字]」(如「紅30」)→ {"route_name": "顏色數字"} -3. 句型「[數字]還要多久」「[數字]什麼時候到」→ {"route_name": "數字"} -4. 句型「[路線名]公車到站」→ {"route_name": "路線名"} -5. 「附近公車」「公車站」「有什麼公車」→ {}(系統自動從 GPS 判斷城市) -6. 城市參數:只在用戶明確提到城市時才填,否則留空讓系統自動判斷 - -實際範例: -- 「261公車什麼時候到」→ {"route_name": "261"}(不填 city) -- 「307還要多久」→ {"route_name": "307"}(不填 city) -- 「台北261公車」→ {"route_name": "261", "city": "Taipei"}(明確提到台北) -- 「桃園紅30公車」→ {"route_name": "紅30", "city": "Taoyuan"}(明確提到桃園) -- 「附近有什麼公車」→ {}(完全空參數,系統自動判斷) - -不適用場景(應使用 directions): -- 「從A到B的公車」「往XX的公車」→ 這是路線規劃,不是查詢特定路線 -- 「去台北的公車」→ 台北是目的地,不是路線號碼 - -重要:如果提到路線號碼,必須提取!城市參數必須用英文! - -火車查詢(重要!參數提取規則): -當用戶詢問火車資訊時,你必須從消息中提取站名並填入參數。 - -參數提取規則(適用於任何地名): -1. 句型「從 [地名A] 往/到 [地名B]」→ {"origin_station": "地名A", "destination_station": "地名B"} -2. 句型「[地名A] 到/往 [地名B]」→ {"origin_station": "地名A", "destination_station": "地名B"} -3. 句型「往/去 [地名]」→ {"destination_station": "地名"} -4. 句型「[車種][數字]次」→ {"train_no": "數字"} -5. 包含時間 → 提取為 departure_time(HH:MM 格式) - -實際範例: -- 「從彰化往台北的火車」→ {"origin_station": "彰化", "destination_station": "台北"} -- 「台中到高雄」→ {"origin_station": "台中", "destination_station": "高雄"} -- 「往新竹的火車」→ {"destination_station": "新竹"} -- 「自強號123次」→ {"train_no": "123"} -- 「早上8點台南到台北」→ {"origin_station": "台南", "destination_station": "台北", "departure_time": "08:00"} - -重要:絕對不要返回空的 {} 參數!必須從用戶消息中提取站名! - -位置查詢: -- 「我在哪」使用 reverse_geocode,不需要參數 -- 「怎麼去XX」使用 forward_geocode 或 directions - -YouBike 查詢(重要!參數提取規則): -當用戶詢問 YouBike/Ubike/微笑單車時,你必須調用 tdx_youbike 工具。 - -參數提取規則: -1. 「附近的 YouBike」「Ubike 在哪」→ {}(不填 city,系統自動從 GPS 判斷) -2. 「市政府 YouBike」「台北車站 Ubike」→ {"station_name": "市政府"}(不填 city) -3. 「XX站還有車嗎」→ {"station_name": "XX站"}(不填 city) -4. 「台北的 YouBike」「桃園 YouBike」→ 填對應英文城市名 -5. 站名可用中文,城市必須用英文 - -實際範例: -- 「附近的 YouBike」→ {}(完全空參數,系統自動判斷城市) -- 「市政府 YouBike 還有車嗎」→ {"station_name": "市政府"}(不填 city) -- 「台北車站 Ubike」→ {"station_name": "台北車站"}(不填 city) -- 「台北的 YouBike」→ {"city": "Taipei"}(明確提到台北) -- 「桃園 YouBike」→ {"city": "Taoyuan"}(明確提到桃園) - -重要:只在用戶明確提到城市時才填 city 參數!站名可保持中文! +- 貨幣代碼:美元→USD, 台幣→TWD, 日圓/円→JPY, 歐元→EUR, 英鎊→GBP + +【Answerability Check - Confidence-Driven Agent Loop】 +1. Evaluate provided "tool_context" (if any) against the user's request. +2. If tool_context contains sufficient information to answer the user with ~100% confidence, do NOT call any more tools. +3. If information is missing, inconsistent, or outdated, select the appropriate tool to gather more evidence (work_again). +4. If the user's question is "Can you find X and then Y?", you must first find X, evaluate its result, then in the next turn find Y based on X. 【情緒偵測】(重要!): - 分析用戶的情緒狀態(根據用詞、語氣、標點符號、表情符號) @@ -799,6 +631,38 @@ YouBike 查詢(重要!參數提取規則): * 用戶說「哇!」→ 回應最後加上 [EMOTION:surprise] * 一般對話 → 回應最後加上 [EMOTION:neutral] """ + ) + + def _calculate_tool_confidence(self, tool_name: str, arguments: Dict[str, Any]) -> float: + """根據工具所需參數動態計算信心度""" + if not tool_name: + return 0.0 + + try: + from core.tool_registry import tool_registry + tool_def = tool_registry.get_tool(tool_name) + + # 如果找不到工具定義,或者該工具沒有 parameters + if not tool_def or not tool_def.parameters: + return 0.95 if arguments else 0.92 + + required_params = tool_def.parameters.get("required", []) + + if not required_params: + return 1.0 # 不需要任何參數,直接給最高信心度 + + # 檢查缺少的必填參數數量 + missing_count = sum(1 for req in required_params if req not in arguments or arguments[req] is None or str(arguments[req]).strip() == "") + + if missing_count == 0: + return 1.0 + + # 每缺少一個必填參數,扣除一定比例的信心度(確保其低於 MIN_TOOL_CONFIDENCE 0.90) + return max(0.0, 0.95 - (missing_count * 0.5)) + except Exception as e: + logger.warning(f"⚠️ 計算工具信心度失敗: {e}") + return 0.95 if arguments else 0.92 + def _extract_emotion_from_content(self, content: str) -> str: """從回應內容中提取情緒標籤 [EMOTION:xxx]""" @@ -830,9 +694,18 @@ YouBike 查詢(重要!參數提取規則): import services.ai_service as ai_service system_prompt = ( - "分析用戶訊息的情緒狀態。\n" - "情緒類型:neutral(平靜)、happy(開心)、sad(難過)、angry(生氣)、fear(害怕)、surprise(驚訝)\n" - "只回傳情緒類型的英文單字,不要有任何其他文字。" + "## 任務:純粹情緒分類\n" + "你是一個專業的情緒分析員。請分析用戶訊息的情緒狀態,並僅從以下列表中選擇一個最符合的標籤:\n" + "- neutral(平靜/詢問資訊)\n" + "- happy(開心/興奮)\n" + "- sad(難過/沮喪)\n" + "- angry(生氣/不滿)\n" + "- fear(害怕/焦慮)\n" + "- surprise(驚訝/震撼)\n\n" + "## 規則:\n" + "1. **絕對不要回答用戶的問題**(例如用戶問股價,你絕對不能回答數字)。\n" + "2. **只能回傳單個英文標籤**,不要有任何解釋或標點符號。\n" + "3. 如果訊息是中性的事實詢問,請回傳 neutral。" ) messages = [ @@ -842,7 +715,7 @@ YouBike 查詢(重要!參數提取規則): emotion = await ai_service.generate_response_async( messages=messages, - model="gpt-4o-mini", + model=None, max_tokens=10, ) @@ -991,29 +864,35 @@ YouBike 查詢(重要!參數提取規則): import re city_match = re.search(r'([^\s,。!?]+)\s*天氣', message) city = city_match.group(1) if city_match else "台北" + args = {"city": city} return True, { "type": "mcp_tool", "tool_name": "weather_query", - "arguments": {"city": city} + "arguments": args, + "confidence": self._calculate_tool_confidence("weather_query", args) } # 新聞檢測 news_keywords = ["新聞", "消息", "報導", "news"] if any(kw in message_lower for kw in news_keywords): + args = {"language": "zh-TW", "limit": 5} return True, { "type": "mcp_tool", "tool_name": "news_query", - "arguments": {"language": "zh-TW", "limit": 5} + "arguments": args, + "confidence": self._calculate_tool_confidence("news_query", args) } # 匯率檢測 exchange_keywords = ["匯率", "美元", "台幣", "exchange", "usd", "twd"] if any(kw in message_lower for kw in exchange_keywords): + args = {"from_currency": "USD", "to_currency": "TWD"} return True, { "type": "mcp_tool", "tool_name": "exchange_query", - "arguments": {"from_currency": "USD", "to_currency": "TWD"} + "arguments": args, + "confidence": self._calculate_tool_confidence("exchange_query", args) } return False, None @@ -1213,6 +1092,9 @@ YouBike 查詢(重要!參數提取規則): "⭐ 分析使用者的核心意圖(問溫度?天氣?時間?地點?數量?)\n" "⭐ 從工具數據中只提取相關資訊,無關資訊一律省略\n" "⭐ **注意:用什麼語言提問,就用什麼語言回答**(日文問→日文答,英文問→英文答)\n\n" + "【反幻覺與安全原則】\n" + "🚨 嚴禁推測:如果工具返回的資料缺漏,必須誠實告知用戶,絕對不能自己憑空捏造數字或事實\n" + "🚨 不得把工具錯誤包裝成成功結果:如果工具返回錯誤或查無資料,不能說「我查到了」\n\n" "【回應要求】\n" "1. 使用口語化、親切的語氣(可以用「喔」「呢」「哦」等語氣詞)\n" "2. 不要列表式的羅列數據,而是用對話方式描述\n" @@ -1243,13 +1125,16 @@ YouBike 查詢(重要!參數提取規則): {"role": "user", "content": user_prompt} ] - # 格式化回應使用 gpt-4o-mini(支援多語言,不需 reasoning_effort) - response = await ai_service.generate_response_for_user( + # 格式化回應使用環境變數設定的模型 + model = settings.GPT_INTENT_MODEL or settings.OPENAI_MODEL + logger.info(f"🎨 使用配置模型進行格式化: {model}") + + response = await ai_service.generate_response_with_tools( messages=messages, - user_id="format_response", - model="gpt-4o-mini", # 升級到 gpt-4o-mini 以支援多語言 - chat_id=None, - reasoning_effort=None # gpt-4o-mini 不支援此參數 + tools=None, + user_id="reformatting", + model=model, + reasoning_effort=None, ) return response diff --git a/features/mcp/auto_registry.py b/features/mcp/auto_registry.py index bb0c3249cf64b396a469cceb003e82c45286c152..85e7ee197f200a1255eea8e55ab4e06c38c352d8 100644 --- a/features/mcp/auto_registry.py +++ b/features/mcp/auto_registry.py @@ -54,14 +54,17 @@ class MCPAutoRegistry: logger.info(f"掃描工具目錄: {tools_path}") - # 掃描所有 Python 文件(包含 *_tool.py 和 tdx_*.py) - tool_files = list(tools_path.glob("*_tool.py")) + list(tools_path.glob("tdx_*.py")) - # 去重(避免 tdx_*_tool.py 被掃描兩次) - tool_files = list(set(tool_files)) + # 掃描所有 Python 文件 + tool_files = list(tools_path.rglob("*.py")) for py_file in tool_files: + if py_file.name == "__init__.py" or py_file.name == "base_tool.py" or py_file.name == "tdx_base.py": + continue + tool_name = py_file.stem - module_name = f"{tools_dir}.{tool_name}" + rel_path = py_file.relative_to(tools_path) + module_parts = list(rel_path.parts[:-1]) + [tool_name] + module_name = f"{tools_dir}.{'.'.join(module_parts)}" try: # 動態導入模組 @@ -175,7 +178,8 @@ class MCPAutoRegistry: description=definition["description"], inputSchema=definition["inputSchema"], handler=handler, - metadata=definition.get("metadata", {}) + metadata=definition.get("metadata", {}), + outputSchema=definition.get("outputSchema") ) return tool @@ -198,7 +202,8 @@ class MCPAutoRegistry: description=definition["description"], inputSchema=definition["inputSchema"], handler=handler, - metadata=definition.get("metadata", {}) + metadata=definition.get("metadata", {}), + outputSchema=definition.get("outputSchema") ) return tool else: @@ -211,7 +216,8 @@ class MCPAutoRegistry: description=definition["description"], inputSchema=definition["inputSchema"], handler=handler, - metadata=definition.get("metadata", {}) + metadata=definition.get("metadata", {}), + outputSchema=definition.get("outputSchema") ) return tool @@ -293,7 +299,17 @@ class MCPAutoRegistry: description=description, inputSchema=input_schema, handler=placeholder_handler, - metadata=metadata + metadata=metadata, + outputSchema={ + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "error": {"type": ["string", "null"]}, + "error_code": {"type": ["string", "null"]}, + }, + "required": ["success"], + } ) logger.info(f"創建系統工具占位符: {name}") diff --git a/features/mcp/coordinator.py b/features/mcp/coordinator.py index 643532bd9d18df6cb5c74e385035f2b00ef0b218..f12d6d97b53eb67f50a5516b2ff8168f2aee026b 100644 --- a/features/mcp/coordinator.py +++ b/features/mcp/coordinator.py @@ -1,14 +1,38 @@ import asyncio import logging +import re from typing import Any, Awaitable, Callable, Dict, Optional from .tool_models import ToolMetadata, ToolResult +try: + import jsonschema +except ImportError: + jsonschema = None + logger = logging.getLogger(__name__) +CITY_ALIASES = { + "台北市": "台北", + "臺北市": "臺北", + "新北市": "新北", + "桃園市": "桃園", + "台中市": "台中", + "臺中市": "臺中", + "台南市": "台南", + "臺南市": "臺南", + "高雄市": "高雄", + "新竹市": "新竹", +} + EnvProvider = Callable[[Optional[str]], Awaitable[Dict[str, Any]]] ResultFormatter = Callable[[str, str, Dict[str, Any], str], Awaitable[str]] ToolHandler = Callable[[Dict[str, Any]], Awaitable[Any]] +OutputSchemaProvider = Callable[[str], Optional[Dict[str, Any]]] + + +class ToolOutputValidationError(RuntimeError): + """Raised when a tool result violates its declared outputSchema.""" class ToolCoordinator: @@ -25,11 +49,13 @@ class ToolCoordinator: env_provider: EnvProvider, tool_lookup: Callable[[str], Optional[ToolHandler]], formatter: ResultFormatter, + output_schema_provider: Optional[OutputSchemaProvider] = None, failure_handlers: Optional[Dict[str, Callable[[Dict[str, Any], Exception], ToolResult]]] = None, ) -> None: self._env_provider = env_provider self._tool_lookup = tool_lookup self._formatter = formatter + self._output_schema_provider = output_schema_provider self._metadata: Dict[str, ToolMetadata] = {} self._failure_handlers = failure_handlers or {} @@ -78,7 +104,10 @@ class ToolCoordinator: logger.info(f"📦 [Coordinator] 環境資訊: {env_ctx}") if env_ctx: for field in metadata.requires_env: - if merged.get(field) is not None: + val = merged.get(field) + # 如果參數已有值且不是預設佔位符(如 0 或空字串),則跳過注入 + # 這是為了解決 GPT 可能會填入 0 作為座標佔位符的問題 + if val is not None and val != 0 and val != "": continue env_value = env_ctx.get(field) # 主欄位為 None 時,嘗試 fallback 欄位 @@ -90,6 +119,7 @@ class ToolCoordinator: break # 只注入非 None 的值,避免覆蓋工具的預設值或觸發 schema 驗證錯誤 if env_value is not None: + env_value = self._normalize_env_value(field, env_value) merged[field] = env_value logger.info(f"📦 [Coordinator] 注入環境變數: {field}={env_value}") elif not user_id: @@ -98,6 +128,24 @@ class ToolCoordinator: logger.info(f"📦 [Coordinator] 最終參數: {merged}") return merged + @staticmethod + def _normalize_env_value(field: str, value: Any) -> Any: + if field != "city" or not isinstance(value, str): + return value + + normalized = value.strip() + if not normalized: + return value + + if normalized in CITY_ALIASES: + return CITY_ALIASES[normalized] + + exact_match = re.match(r"^(台北|臺北|新北|桃園|台中|臺中|台南|臺南|高雄|新竹)(?:市|縣)?$", normalized) + if exact_match: + return exact_match.group(1) + + return normalized + async def _execute(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: handler = self._tool_lookup(tool_name) if not handler: @@ -109,8 +157,13 @@ class ToolCoordinator: try: result = await asyncio.wait_for(handler(arguments), timeout=30.0) if isinstance(result, dict): + self._validate_output(tool_name, result) return result - return {"success": True, "content": str(result)} + wrapped = {"success": True, "content": str(result)} + self._validate_output(tool_name, wrapped) + return wrapped + except ToolOutputValidationError: + raise except Exception as exc: # noqa: BLE001 last_exc = exc logger.warning("工具 %s 執行失敗 (attempt=%s): %s", tool_name, attempt, exc) @@ -120,6 +173,21 @@ class ToolCoordinator: return handler(arguments, last_exc) # type: ignore[arg-type] raise RuntimeError(f"工具 {tool_name} 執行失敗:{last_exc}") # type: ignore[arg-type] + def _validate_output(self, tool_name: str, result: Dict[str, Any]) -> None: + if not self._output_schema_provider or jsonschema is None: + return + + schema = self._output_schema_provider(tool_name) + if not schema: + return + + try: + jsonschema.validate(result, schema) + except jsonschema.ValidationError as exc: + field_path = ".".join(str(part) for part in exc.absolute_path) + detail = f"{field_path}: {exc.message}" if field_path else exc.message + raise ToolOutputValidationError(f"工具 {tool_name} 輸出格式不符合契約: {detail}") from exc + async def _format_result( self, tool_name: str, diff --git a/features/mcp/mcp_client.py b/features/mcp/mcp_client.py index c0e43b4f8289e7b566fb4ca955494b2bdaa622a6..423b731a84eb283e846ea629c2ecfc058992d57e 100644 --- a/features/mcp/mcp_client.py +++ b/features/mcp/mcp_client.py @@ -165,6 +165,7 @@ class MCPClient: name = tool_data.get("name") description = tool_data.get("description", "") input_schema = tool_data.get("inputSchema", {"type": "object", "properties": {}}) + output_schema = tool_data.get("outputSchema") # 創建代理處理器 async def tool_handler(arguments: Dict[str, Any]) -> Dict[str, Any]: @@ -174,7 +175,8 @@ class MCPClient: name=name, description=description, inputSchema=input_schema, - handler=tool_handler + handler=tool_handler, + outputSchema=output_schema ) return tool @@ -192,7 +194,21 @@ class MCPClient: }) if response and response.get("result"): - content = response["result"].get("content", []) + result = response["result"] + content = result.get("content", []) + structured = result.get("structuredContent") + if result.get("isError"): + if structured: + structured = dict(structured) + structured.setdefault("success", False) + structured.setdefault("error", "\n".join([item.get("text", "") for item in content if item.get("type") == "text"])) + return structured + return { + "success": False, + "error": "\n".join([item.get("text", "") for item in content if item.get("type") == "text"]) or "外部工具執行失敗" + } + if structured: + return structured return { "success": True, "content": "\n".join([item.get("text", "") for item in content if item.get("type") == "text"]) @@ -356,4 +372,4 @@ class MCPClientManager: def is_client_connected(self, server_name: str) -> bool: """檢查客戶端是否連接""" - return server_name in self.clients and self.clients[server_name].connected \ No newline at end of file + return server_name in self.clients and self.clients[server_name].connected diff --git a/features/mcp/openai_tools.py b/features/mcp/openai_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..24ca130e7be2a55f550218cb8f4c1c4ac64239d8 --- /dev/null +++ b/features/mcp/openai_tools.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional + +from core.config import settings +from core.logging import get_logger + +logger = get_logger("mcp.openai_tools") + +DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[1] / "mcp_config.json" + + +def _load_mcp_config(config_path: Optional[Path] = None) -> Dict[str, Any]: + path = config_path or DEFAULT_CONFIG_PATH + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + logger.warning("MCP 配置不存在: %s", path) + except json.JSONDecodeError as exc: + logger.warning("MCP 配置 JSON 無效: %s", exc) + return {} + + +def _configured_items(section: Dict[str, Any]) -> List[Dict[str, Any]]: + items: List[Dict[str, Any]] = [] + for item in section.get("items", []): + if isinstance(item, dict) and item.get("enabled", True): + items.append({key: value for key, value in item.items() if key != "enabled"}) + return items + + +def _configured_remote_mcp_items(section: Dict[str, Any]) -> List[Dict[str, Any]]: + items = [] + for item in _configured_items(section): + if item.get("server_url") or item.get("connector_id"): + items.append(item) + else: + logger.info("跳過沒有 server_url/connector_id 的 hosted MCP 設定: %s", item.get("server_label")) + return items + + +def _env_json_list(raw: str, *, label: str) -> List[Dict[str, Any]]: + if not raw: + return [] + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + logger.warning("%s 不是合法 JSON,已忽略: %s", label, exc) + return [] + if not isinstance(parsed, list): + logger.warning("%s 必須是 list,已忽略", label) + return [] + return [item for item in parsed if isinstance(item, dict)] + + +def build_openai_hosted_tools(config_path: Optional[Path] = None) -> List[Dict[str, Any]]: + config = _load_mcp_config(config_path) + openai_tools = config.get("openai_tools", {}) + specs: List[Dict[str, Any]] = [] + + web_search = openai_tools.get("web_search", {}) + if settings.OPENAI_ENABLE_WEB_SEARCH and web_search.get("enabled", True): + specs.append({"type": "web_search"}) + + remote_mcp = openai_tools.get("remote_mcp", {}) + if settings.OPENAI_ENABLE_REMOTE_MCP and remote_mcp.get("enabled", False): + for server in _configured_remote_mcp_items(remote_mcp) + _configured_remote_mcp_items( + {"items": _env_json_list( + settings.OPENAI_REMOTE_MCP_SERVERS_JSON, + label="OPENAI_REMOTE_MCP_SERVERS_JSON", + )} + ): + spec = dict(server) + spec["type"] = "mcp" + spec.setdefault("require_approval", remote_mcp.get("approval_default", "always")) + specs.append(spec) + + return specs diff --git a/features/mcp/server.py b/features/mcp/server.py index c7c6ba9aa445d43cfa6452840ebb5494acd3822d..d622a823765ff0b010ec33a83ac2e5c8c7013997 100644 --- a/features/mcp/server.py +++ b/features/mcp/server.py @@ -11,9 +11,14 @@ import time import os from typing import Dict, Any, List, Optional, Callable, Tuple from enum import Enum -from .types import Tool +from .types import Tool, ToolCallResult from .auto_registry import MCPAutoRegistry +try: + import jsonschema +except ImportError: + jsonschema = None + LOG_LEVEL_NAME = os.getenv("BLOOMWARE_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = getattr(logging, LOG_LEVEL_NAME, logging.WARNING) logging.basicConfig( @@ -183,7 +188,17 @@ class FeaturesMCPServer: name=tool_name, description=description, inputSchema={"type": "object", "properties": {}}, - handler=handler + handler=handler, + outputSchema={ + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "error": {"type": ["string", "null"]}, + "error_code": {"type": ["string", "null"]}, + }, + "required": ["success"], + } ) self.register_tool(tool) @@ -215,6 +230,60 @@ class FeaturesMCPServer: """註冊工具""" self.tools[tool.name] = tool logger.info(f"註冊工具: {tool.name}") + + def _format_tool_result(self, tool_name: str, result: Any) -> Dict[str, Any]: + """轉成 MCP tools/call result,保留 structuredContent。""" + if isinstance(result, ToolCallResult): + return result.to_dict() + + if isinstance(result, dict): + is_error = result.get("success") is False + content = result.get("content") + if not content: + content = result.get("error") if is_error else json.dumps(result, ensure_ascii=False) + + output_issue = self._validate_tool_output(tool_name, result) + if output_issue: + return ToolCallResult( + content=[{"type": "text", "text": "工具輸出格式不符合契約"}], + structuredContent={ + "success": False, + "error_code": "TOOL_OUTPUT_VALIDATION", + "tool_name": tool_name, + "details": output_issue, + }, + isError=True, + ).to_dict() + + payload = ToolCallResult( + content=[{"type": "text", "text": str(content)}], + structuredContent=result, + isError=is_error, + ).to_dict() + if is_error and "error_code" not in payload["structuredContent"]: + payload["structuredContent"]["error_code"] = "TOOL_EXECUTION_ERROR" + return payload + + return ToolCallResult( + content=[{"type": "text", "text": str(result)}], + structuredContent={"tool_name": tool_name, "value": result}, + isError=False, + ).to_dict() + + def _validate_tool_output(self, tool_name: str, result: Dict[str, Any]) -> Optional[str]: + """用工具 outputSchema 驗證 structuredContent。""" + tool = self.tools.get(tool_name) + if not tool or not tool.outputSchema or jsonschema is None: + return None + + try: + jsonschema.validate(result, tool.outputSchema) + return None + except jsonschema.ValidationError as exc: + field_path = ".".join(str(part) for part in exc.absolute_path) + if field_path: + return f"{field_path}: {exc.message}" + return exc.message def get_tools_summary(self) -> List[Dict[str, Any]]: """ @@ -352,22 +421,29 @@ class FeaturesMCPServer: if tool.handler: try: result = await tool.handler(arguments) - - # 統一回應格式 - if isinstance(result, dict) and result.get("success"): - content = result.get("content", "") - return {"content": [{"type": "text", "text": content}]} - elif isinstance(result, dict) and not result.get("success"): - error_msg = result.get("error", "工具執行失敗") - return {"content": [{"type": "text", "text": f"❌ {error_msg}"}], "isError": True} - else: - return {"content": [{"type": "text", "text": str(result)}]} + return self._format_tool_result(tool_name, result) except Exception as e: logger.error(f"工具執行錯誤 {tool_name}: {e}") - return {"content": [{"type": "text", "text": f"❌ 執行錯誤: {str(e)}"}], "isError": True} - - return {"content": [{"type": "text", "text": "工具未實作"}]} + return ToolCallResult( + content=[{"type": "text", "text": "工具執行失敗"}], + structuredContent={ + "success": False, + "error_code": "TOOL_EXECUTION_ERROR", + "tool_name": tool_name, + }, + isError=True, + ).to_dict() + + return ToolCallResult( + content=[{"type": "text", "text": "工具未實作"}], + structuredContent={ + "success": False, + "error_code": "TOOL_NOT_IMPLEMENTED", + "tool_name": tool_name, + }, + isError=True, + ).to_dict() async def cleanup(self): """清理資源""" diff --git a/features/mcp/skills.py b/features/mcp/skills.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef1d6be158225f27fa89e9113572b228bb464e4 --- /dev/null +++ b/features/mcp/skills.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Iterable, List + +import json + +SKILLS_ROOT = Path(__file__).resolve().parent / "skills" +DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[1] / "mcp_config.json" + + +def _load_mcp_config(config_path: Path = DEFAULT_CONFIG_PATH) -> Dict[str, Any]: + try: + return json.loads(config_path.read_text(encoding="utf-8")) + except FileNotFoundError: + return {} + + +def _slug(name: str) -> str: + return "".join(ch if ch.isalnum() or ch in {"-", "_"} else "-" for ch in name).strip("-").lower() + + +def _yaml_scalar(value: Any) -> str: + if value is None: + return "null" + text = str(value).replace('"', '\\"') + return f'"{text}"' + + +def _yaml_list(values: Iterable[Any], indent: int = 2) -> str: + prefix = " " * indent + items = list(values) + if not items: + return f"{prefix}[]" + return "\n".join(f"{prefix}- {_yaml_scalar(item)}" for item in items) + + +def tool_skill_path(tool_name: str) -> Path: + return SKILLS_ROOT / _slug(tool_name) / "SKILL.md" + + +def render_tool_skill(tool_name: str, tool_info: Dict[str, Any]) -> str: + examples = tool_info.get("examples") or [] + return "\n".join( + [ + "---", + f"name: mcp-{tool_name}", + f"description: { _yaml_scalar('Use when the user request matches the Bloom Ware MCP tool ' + tool_name + ' usage scenario.') }", + "tool_contract:", + f" name: {_yaml_scalar(tool_name)}", + f" category: {_yaml_scalar(tool_info.get('category', 'general'))}", + f" module: {_yaml_scalar(tool_info.get('module', ''))}", + f" class: {_yaml_scalar(tool_info.get('class', ''))}", + f" description: {_yaml_scalar(tool_info.get('description', ''))}", + " examples:", + _yaml_list(examples, indent=4), + "routing:", + " invocation_mode: \"local_mcp_bridge_function_calling\"", + " openai_hosted_mcp: \"disabled_unless_remote_server_url_is_configured\"", + " required_action: \"Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context.\"", + "safety:", + " do_not_fabricate_missing_data: true", + " preserve_tool_failure_semantics: true", + " user_approval_required_for_high_impact_actions: true", + "---", + "", + "Use this skill as the authoritative routing note for this Bloom Ware MCP tool.", + "The actual tool call must go through the local MCP bridge/function-calling path.", + "", + ] + ) + + +def write_tool_skills(config_path: Path = DEFAULT_CONFIG_PATH) -> List[Path]: + config = _load_mcp_config(config_path) + written: List[Path] = [] + for tool_name, tool_info in sorted((config.get("tools") or {}).items()): + if not (tool_info.get("module") and tool_info.get("class")): + continue + path = tool_skill_path(tool_name) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(render_tool_skill(tool_name, tool_info), encoding="utf-8") + written.append(path) + return written + + +def skills_prompt_block(config_path: Path = DEFAULT_CONFIG_PATH) -> str: + config = _load_mcp_config(config_path) + lines = [ + "Bloom Ware tool skills:", + "Use these routing notes when selecting local MCP bridge tools or OpenAI hosted tools. Do not call hosted MCP for local tools unless a remote server_url is configured.", + "- web_search: category=hosted_openai_tool; use_when=current, recent, time-sensitive, public, or externally verifiable information is needed; call_via=responses_hosted_tool_auto; read_skill=features/mcp/skills/web_search/SKILL.md; rule=use time/environment context and source timestamps to decide, without domain-specific hardcoding", + ] + for tool_name, tool_info in sorted((config.get("tools") or {}).items()): + if not (tool_info.get("module") and tool_info.get("class")): + continue + examples = ", ".join(tool_info.get("examples") or []) + lines.append( + f"- {tool_name}: category={tool_info.get('category', 'general')}; " + f"use_when={tool_info.get('description', '')}; examples={examples}; " + f"call_via=local_function_calling_tool_schema" + ) + return "\n".join(lines) diff --git a/features/mcp/skills/directions/SKILL.md b/features/mcp/skills/directions/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..5bec2a2dfffa64c559c90da53116f6a5c3abb85e --- /dev/null +++ b/features/mcp/skills/directions/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-directions +description: "Use when the user request matches the Bloom Ware MCP tool directions usage scenario." +tool_contract: + name: "directions" + category: "location" + module: "features.mcp.tools.location.directions_tool" + class: "DirectionsTool" + description: "規劃兩點之間的路線" + examples: + - "從這裡到台北車站怎麼走" + - "幫我規劃路線" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/environment_context/SKILL.md b/features/mcp/skills/environment_context/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..af5b5e670557c8f01c20fc1886b1d9dcfc09f678 --- /dev/null +++ b/features/mcp/skills/environment_context/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-environment_context +description: "Use when the user request matches the Bloom Ware MCP tool environment_context usage scenario." +tool_contract: + name: "environment_context" + category: "environment" + module: "features.mcp.tools.environment.context_tool" + class: "EnvironmentContextTool" + description: "取得使用者目前環境感知資料(位置、時區、語言、裝置、活動狀態)" + examples: + - "我現在在哪" + - "目前環境資訊" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/exchange_query/SKILL.md b/features/mcp/skills/exchange_query/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..d79f10928f7a5f4df8bd6e669801772d95e9e556 --- /dev/null +++ b/features/mcp/skills/exchange_query/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-exchange_query +description: "Use when the user request matches the Bloom Ware MCP tool exchange_query usage scenario." +tool_contract: + name: "exchange_query" + category: "utility" + module: "features.mcp.tools.utility.exchange_tool" + class: "ExchangeTool" + description: "查詢即時匯率並換算貨幣" + examples: + - "100 美元換台幣" + - "日圓匯率" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/forward_geocode/SKILL.md b/features/mcp/skills/forward_geocode/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..0caef0ed743fb4d6f1788b0c88fc9c6812e4a753 --- /dev/null +++ b/features/mcp/skills/forward_geocode/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-forward_geocode +description: "Use when the user request matches the Bloom Ware MCP tool forward_geocode usage scenario." +tool_contract: + name: "forward_geocode" + category: "location" + module: "features.mcp.tools.location.geocoding_tool" + class: "ForwardGeocodeTool" + description: "將地點名稱轉換成座標" + examples: + - "銘傳大學在哪" + - "台北車站座標" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/healthkit_query/SKILL.md b/features/mcp/skills/healthkit_query/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..5cd4e110e3427b78b3076778fe45cc2f9b9b876f --- /dev/null +++ b/features/mcp/skills/healthkit_query/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-healthkit_query +description: "Use when the user request matches the Bloom Ware MCP tool healthkit_query usage scenario." +tool_contract: + name: "healthkit_query" + category: "utility" + module: "features.mcp.tools.utility.healthkit_tool" + class: "HealthKitTool" + description: "查詢使用者健康資料(心率、步數、血氧、睡眠等)" + examples: + - "我今天走幾步" + - "最近心率如何" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/news_query/SKILL.md b/features/mcp/skills/news_query/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..a191e82b4ce7cc945e46d3a41e0d586b01a90514 --- /dev/null +++ b/features/mcp/skills/news_query/SKILL.md @@ -0,0 +1,22 @@ +# 新聞與即時資訊查詢 (Tavily AI) + +這個 Skill 使用 Tavily API 提供基於 AI 的新聞與即時資訊搜尋。相比傳統新聞 API,Tavily 能夠更好地過濾噪音、提供即時動態並對搜尋結果進行智慧摘要。 + +## 功能特點 +- **極致時效性**:直接串接最新搜尋引擎索引,獲取分秒必爭的時事。 +- **AI 摘要**:自動整合多個來源,提供一句話或一段話的精華總結。 +- **深度搜尋**:支援 basic 與 advanced 兩種深度,應對簡單查詢或深入研究。 + +## 參數說明 +- `query` (string, 必填): 搜尋關鍵詞。例如:「台積電收盤價」、「2024 奧運獎牌榜」。 +- `limit` (integer, 可選): 返回新聞數量,預設 5,上限 10。 +- `search_depth` (string, 可選): `basic` (快速) 或 `advanced` (深入)。 + +## 使用範例 +- 「查看今天最重要的科技新聞」 -> `news_query(query="今天科技新聞", limit=5)` +- 「搜尋關於 SpaceX 最近發射任務的詳細報導」 -> `news_query(query="SpaceX 最近發射任務", search_depth="advanced")` +- 「台積電今天收盤多少?」 -> `news_query(query="台積電 股票 收盤價 今天")` + +## 注意事項 +- Tavily 會自動嘗試理解問題並給出 `answer` (AI 摘要),這對語音助手快速回報非常有用。 +- 本工具已移除舊有的 NewsData.io 邏輯,不再支援 `country` 或 `category` 的硬性篩選,改由 AI 語義搜尋達成。 diff --git a/features/mcp/skills/reverse_geocode/SKILL.md b/features/mcp/skills/reverse_geocode/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..90825e944655572a34fb73fdece17dae7bbeac31 --- /dev/null +++ b/features/mcp/skills/reverse_geocode/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-reverse_geocode +description: "Use when the user request matches the Bloom Ware MCP tool reverse_geocode usage scenario." +tool_contract: + name: "reverse_geocode" + category: "location" + module: "features.mcp.tools.location.geocode_tool" + class: "ReverseGeocodeTool" + description: "將座標轉換成地址、城市與行政區" + examples: + - "我在哪裡" + - "這個座標是哪裡" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/tdx_bus_arrival/SKILL.md b/features/mcp/skills/tdx_bus_arrival/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..6c40dc58a6d27eb9866422945cffec3384e2ac9e --- /dev/null +++ b/features/mcp/skills/tdx_bus_arrival/SKILL.md @@ -0,0 +1,25 @@ +--- +name: mcp-tdx_bus_arrival +description: "Use when the user request matches the Bloom Ware MCP tool tdx_bus_arrival usage scenario." +tool_contract: + name: "tdx_bus_arrival" + category: "transportation" + module: "features.mcp.tools.transportation.tdx_bus_arrival" + class: "TDXBusArrivalTool" + description: "查詢公車即時到站時間(自動感知用戶位置,找最近站點)" + examples: + - "307 公車還要多久" + - "附近有什麼公車" + - "桃園火車站附近公車" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context. For precise stop areas, landmarks, roads, or intersections, fill location_query." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/tdx_metro/SKILL.md b/features/mcp/skills/tdx_metro/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..74359a5f179f8fbdac3c38a349fb8702e77e840c --- /dev/null +++ b/features/mcp/skills/tdx_metro/SKILL.md @@ -0,0 +1,25 @@ +--- +name: mcp-tdx_metro +description: "Use when the user request matches the Bloom Ware MCP tool tdx_metro usage scenario." +tool_contract: + name: "tdx_metro" + category: "transportation" + module: "features.mcp.tools.transportation.tdx_metro" + class: "TDXMetroTool" + description: "查詢捷運即時到站、最近車站(台北/高雄/桃園/台中捷運)" + examples: + - "最近的捷運站在哪" + - "台北車站捷運幾分鐘到" + - "桃園火車站附近捷運站" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context. For precise landmarks or addresses, fill location_query before falling back to city/operator." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/tdx_parking/SKILL.md b/features/mcp/skills/tdx_parking/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..7f7ac91f7dec961e981a92f3bbcdb2270f4dd68a --- /dev/null +++ b/features/mcp/skills/tdx_parking/SKILL.md @@ -0,0 +1,25 @@ +--- +name: mcp-tdx_parking +description: "Use when the user request matches the Bloom Ware MCP tool tdx_parking usage scenario." +tool_contract: + name: "tdx_parking" + category: "transportation" + module: "features.mcp.tools.transportation.tdx_parking" + class: "TDXParkingTool" + description: "查詢附近停車場資訊和即時空位" + examples: + - "附近停車場" + - "台北車站附近停車位" + - "中正路100號附近停車場" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context. For precise addresses, landmarks, or intersections, fill location_query instead of forcing city." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/tdx_thsr/SKILL.md b/features/mcp/skills/tdx_thsr/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..1db065e10fd46304ab5b7b80d94e28cbabb29b32 --- /dev/null +++ b/features/mcp/skills/tdx_thsr/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-tdx_thsr +description: "Use when the user request matches the Bloom Ware MCP tool tdx_thsr usage scenario." +tool_contract: + name: "tdx_thsr" + category: "transportation" + module: "features.mcp.tools.transportation.tdx_thsr" + class: "TDXTHSRTool" + description: "查詢高鐵時刻表、票價和即時資訊" + examples: + - "高鐵從台北到台中" + - "高鐵票價查詢" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/tdx_train/SKILL.md b/features/mcp/skills/tdx_train/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..8c394c71abdc8f61d34eb88c8fb3088f3ccf4fe2 --- /dev/null +++ b/features/mcp/skills/tdx_train/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-tdx_train +description: "Use when the user request matches the Bloom Ware MCP tool tdx_train usage scenario." +tool_contract: + name: "tdx_train" + category: "transportation" + module: "features.mcp.tools.transportation.tdx_train" + class: "TDXTrainTool" + description: "查詢台鐵時刻表和即時資訊" + examples: + - "台鐵從台北到新竹" + - "火車時刻表" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/tdx_youbike/SKILL.md b/features/mcp/skills/tdx_youbike/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..dcf6125c658c10659b55cafc37fd3c410d9a02ae --- /dev/null +++ b/features/mcp/skills/tdx_youbike/SKILL.md @@ -0,0 +1,25 @@ +--- +name: mcp-tdx_youbike +description: "Use when the user request matches the Bloom Ware MCP tool tdx_youbike usage scenario." +tool_contract: + name: "tdx_youbike" + category: "transportation" + module: "features.mcp.tools.transportation.tdx_youbike" + class: "TDXBikeTool" + description: "查詢 YouBike 站點資訊和即時車輛數量" + examples: + - "附近 YouBike" + - "捷運站 YouBike 數量" + - "桃園火車站附近的 YouBike" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context. For precise addresses, landmarks, intersections, or station areas, fill location_query instead of forcing city." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/weather_query/SKILL.md b/features/mcp/skills/weather_query/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..2a86b112464f8ecaa1ec4acc0637bbd9eed072f8 --- /dev/null +++ b/features/mcp/skills/weather_query/SKILL.md @@ -0,0 +1,24 @@ +--- +name: mcp-weather_query +description: "Use when the user request matches the Bloom Ware MCP tool weather_query usage scenario." +tool_contract: + name: "weather_query" + category: "location" + module: "features.mcp.tools.location.weather_tool" + class: "WeatherTool" + description: "查詢即時天氣資訊" + examples: + - "台北天氣" + - "今天會下雨嗎" +routing: + invocation_mode: "local_mcp_bridge_function_calling" + openai_hosted_mcp: "disabled_unless_remote_server_url_is_configured" + required_action: "Select this tool only when the request maps to its category and required inputs can be extracted or safely derived from environment context." +safety: + do_not_fabricate_missing_data: true + preserve_tool_failure_semantics: true + user_approval_required_for_high_impact_actions: true +--- + +Use this skill as the authoritative routing note for this Bloom Ware MCP tool. +The actual tool call must go through the local MCP bridge/function-calling path. diff --git a/features/mcp/skills/web_search/SKILL.md b/features/mcp/skills/web_search/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..3ff8a0fcb103809c5f768d4de0688ede48c3bd35 --- /dev/null +++ b/features/mcp/skills/web_search/SKILL.md @@ -0,0 +1,28 @@ +--- +name: hosted-web-search +description: "Use when the user asks for current, recent, time-sensitive, public, or externally verifiable information and OpenAI hosted web_search is enabled." +tool_contract: + name: "web_search" + category: "hosted_openai_tool" + invocation_mode: "responses_hosted_tool_auto" + description: "OpenAI hosted web_search for current public information." +usage_policy: + decide_need_from_request: true + use_environment_context: true + use_time_context: true + avoid_domain_specific_hardcoding: true + do_not_invent_missing_facts: true + cite_or_describe_source_time_when_available: true +--- + +# Hosted Web Search + +Use `web_search` when the answer depends on information that may have changed after the model's knowledge cutoff or depends on the user's current time, location, market/session state, availability, price, schedule, law, weather, news, or other public facts. + +Before answering, compare the user's wording with the available time and environment context. Decide whether the retrieved information is current enough for the question. If results are older than the requested time frame, reason from the context instead of treating it as a tool failure: explain what is known, what remains uncertain, and why. + +If the user asks for "today", "now", "latest", "current", "closing", "real-time", or equivalent wording, compare source timestamps against the current time context. If the source timestamp is earlier than the requested time frame, do not present it as today's/current result. Label it as the latest source timestamp you found, then explain the likely timing or availability limitation using only general reasoning from time/environment/source context. + +Do not hardcode domain-specific behavior. Do not force a specific market, source, company, route, or conclusion unless the user supplied it or the sources clearly establish it. If multiple interpretations are plausible, state the interpretation used and how to ask for another one. + +When results are insufficient, stale, conflicting, or unavailable, say so plainly and avoid fabricating exact values. You may still provide the closest verified information as reference if you label its timestamp and uncertainty. diff --git a/features/mcp/tool_models.py b/features/mcp/tool_models.py index f4202b69414b3bd5c16a372972101eda0128cae7..17ba9805bb8b4bcf44311c5e82c86b64a0467870 100644 --- a/features/mcp/tool_models.py +++ b/features/mcp/tool_models.py @@ -10,6 +10,7 @@ class ToolMetadata: enable_reformat: bool = False flow: Optional[str] = None # 例如 "navigation" env_fallbacks: Dict[str, List[str]] = field(default_factory=dict) # 環境變數 fallback 映射,例如 {"city": ["detailed_address", "label"]} + supports_location_query: bool = False @dataclass diff --git a/features/mcp/tools/__init__.py b/features/mcp/tools/__init__.py index c1e92d06e72da8db68dca32b89ca19ed7c4ae55c..117e971ea110f9e43dae633c8b2c41ea1ba8356e 100644 --- a/features/mcp/tools/__init__.py +++ b/features/mcp/tools/__init__.py @@ -2,9 +2,34 @@ MCP Tools 模組 - 所有功能工具的統一入口 """ -from .weather_tool import WeatherTool -from .news_tool import NewsTool -from .exchange_tool import ExchangeTool -from .healthkit_tool import HealthKitTool +from .environment.context_tool import EnvironmentContextTool +from .location.directions_tool import DirectionsTool +from .location.geocode_tool import ReverseGeocodeTool +from .location.geocoding_tool import ForwardGeocodeTool +from .location.weather_tool import WeatherTool +from .transportation.tdx_bus_arrival import TDXBusArrivalTool +from .transportation.tdx_metro import TDXMetroTool +from .transportation.tdx_parking import TDXParkingTool +from .transportation.tdx_thsr import TDXTHSRTool +from .transportation.tdx_train import TDXTrainTool +from .transportation.tdx_youbike import TDXBikeTool +from .utility.exchange_tool import ExchangeTool +from .utility.healthkit_tool import HealthKitTool +from .utility.news_tool import NewsTool -__all__ = ["WeatherTool", "NewsTool", "ExchangeTool", "HealthKitTool"] \ No newline at end of file +__all__ = [ + "DirectionsTool", + "EnvironmentContextTool", + "ExchangeTool", + "ForwardGeocodeTool", + "HealthKitTool", + "NewsTool", + "ReverseGeocodeTool", + "TDXBikeTool", + "TDXBusArrivalTool", + "TDXMetroTool", + "TDXParkingTool", + "TDXTHSRTool", + "TDXTrainTool", + "WeatherTool", +] diff --git a/features/mcp/tools/environment/__init__.py b/features/mcp/tools/environment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09aceb4a9d5268b870414448d8793664b48964ae --- /dev/null +++ b/features/mcp/tools/environment/__init__.py @@ -0,0 +1,3 @@ +from .context_tool import EnvironmentContextTool + +__all__ = ["EnvironmentContextTool"] diff --git a/features/mcp/tools/environment/context_tool.py b/features/mcp/tools/environment/context_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3a54b2f08a8026f5cc5afa484dfbe252435c1e --- /dev/null +++ b/features/mcp/tools/environment/context_tool.py @@ -0,0 +1,72 @@ +import json +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError + +class EnvironmentContextTool(MCPTool): + NAME = "environment_context" + DESCRIPTION = ( + "Return the latest injected environment context for the current user. " + "The main agent receives this context every turn; use this tool only " + "when a workflow explicitly needs the raw environment payload." + ) + CATEGORY = "environment" + TAGS = ["environment", "context", "location", "timezone", "locale"] + KEYWORDS = ["environment", "context", "location", "timezone", "locale", "環境", "位置", "時區"] + + @classmethod + def get_input_schema(cls) -> dict: + return { + "type": "object", + "properties": { + "_user_id": { + "type": "string", + "description": "Injected by the server-side tool coordinator." + } + }, + "required": [], + "additionalProperties": True + } + + @classmethod + def get_output_schema(cls) -> dict: + schema = StandardToolSchemas.create_output_schema() + schema["properties"].update({ + "data": {"type": "object"} + }) + return schema + + @classmethod + def get_definition(cls) -> dict: + return { + "name": cls.NAME, + "description": cls.DESCRIPTION, + "inputSchema": cls.get_input_schema(), + "outputSchema": cls.get_output_schema(), + } + + @classmethod + async def execute(cls, arguments: dict) -> dict: + try: + user_id = (arguments or {}).get("_user_id") + if not user_id: + return { + "success": False, + "error": "environment_context requires an injected _user_id" + } + + from core.database import get_user_env_current + + env_res = await get_user_env_current(user_id) + if not env_res.get("success"): + return { + "success": False, + "error": env_res.get("error") or "environment context is unavailable" + } + + context = env_res.get("context") or {} + return { + "success": True, + "content": json.dumps(context, ensure_ascii=False), + "data": context, + } + except Exception as e: + raise ExecutionError(str(e)) diff --git a/features/mcp/tools/location/__init__.py b/features/mcp/tools/location/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c10139082502cd6d2d1535022e01d8ff4448e9aa --- /dev/null +++ b/features/mcp/tools/location/__init__.py @@ -0,0 +1,11 @@ +from .directions_tool import DirectionsTool +from .geocode_tool import ReverseGeocodeTool +from .geocoding_tool import ForwardGeocodeTool +from .weather_tool import WeatherTool + +__all__ = [ + "DirectionsTool", + "ForwardGeocodeTool", + "ReverseGeocodeTool", + "WeatherTool", +] diff --git a/features/mcp/tools/directions_tool.py b/features/mcp/tools/location/directions_tool.py similarity index 99% rename from features/mcp/tools/directions_tool.py rename to features/mcp/tools/location/directions_tool.py index 39c5e06c6f07dae05e6c038bfd6b7cee2ca60498..14489887c781c90092d4ba68b7f31e72e3791085 100644 --- a/features/mcp/tools/directions_tool.py +++ b/features/mcp/tools/location/directions_tool.py @@ -10,7 +10,7 @@ import re import aiohttp from typing import Dict, Any -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError, ValidationError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError, ValidationError from core.config import settings from core.database import get_route_cache, set_route_cache from core.database.cache import db_cache diff --git a/features/mcp/tools/geocode_tool.py b/features/mcp/tools/location/geocode_tool.py similarity index 69% rename from features/mcp/tools/geocode_tool.py rename to features/mcp/tools/location/geocode_tool.py index 1bda2fa47f0a40521982bb17b0e885da88f4cbe3..5557476adf2aecb82652b880fad262a9bb41bb4f 100644 --- a/features/mcp/tools/geocode_tool.py +++ b/features/mcp/tools/location/geocode_tool.py @@ -1,16 +1,18 @@ """ 反地理與時區工具(免費 API 優先) -- reverse_geocode: 使用 Nominatim(OSM)反查城市/行政區(先查 DB/記憶體快取) +- reverse_geocode: 優先使用 TDX 官方定位服務,失敗時 fallback 到 Nominatim """ import aiohttp import asyncio import logging from typing import Dict, Any +from urllib.parse import quote -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from core.database import get_geo_cache, set_geo_cache from core.database.cache import db_cache +from ..transportation.tdx_base import TDXBaseAPI logger = logging.getLogger("mcp.tools.geocode") @@ -36,7 +38,10 @@ class ReverseGeocodeTool(MCPTool): schema["properties"].update({ "city": {"type": "string"}, "admin": {"type": "string"}, - "country_code": {"type": "string"} + "country_code": {"type": "string"}, + "address_display": {"type": "string"}, + "precision": {"type": "string"}, + "poi_label": {"type": "string"}, }) return schema @@ -99,54 +104,25 @@ class ReverseGeocodeTool(MCPTool): data=db_cached ) - # 外呼 Nominatim(公共端點,務必節流) - url = "https://nominatim.openstreetmap.org/reverse" - params = { - "format": "jsonv2", - "lat": lat, - "lon": lon, - "zoom": 18, - "addressdetails": 1, - "extratags": 1, - "namedetails": 1 - } - headers = { - "User-Agent": "BloomWare/1.0 (contact@example.com)" - } + data = await cls._reverse_geocode_tdx(lat, lon) + source = "tdx" + if not data: + data = await cls._reverse_geocode_nominatim(lat, lon) + source = "nominatim" - # 呼叫 Nominatim API - data = None - try: - async with aiohttp.ClientSession(headers=headers) as session: - async with session.get(url, params=params, timeout=30) as resp: - if resp.status != 200: - raise ExecutionError(f"Nominatim 失敗: HTTP {resp.status}") - - response_text = await resp.text() - if not response_text or response_text.strip() == "": - raise ExecutionError("Nominatim 回應為空") - - import json - try: - data = json.loads(response_text) - except json.JSONDecodeError: - raise ExecutionError(f"Nominatim 回應非 JSON: {response_text[:200]}") - except aiohttp.ClientError as e: - raise ExecutionError(f"Nominatim 網路錯誤: {e}") - except asyncio.TimeoutError: - raise ExecutionError("Nominatim 請求逾時") - - # 驗證回應 if data is None: - raise ExecutionError("Nominatim 回應為 null") - if not isinstance(data, dict): - raise ExecutionError(f"Nominatim 回應格式錯誤: {type(data)}") - if "error" in data: - raise ExecutionError(f"Nominatim 錯誤: {data.get('error')}") - - # 解析地址資訊 - addr = data.get("address") or {} - extratags = data.get("extratags") or {} + raise ExecutionError("reverse geocode 無可用回應") + + if source == "tdx": + addr = data.get("address", {}) or {} + extratags = data.get("extratags", {}) or {} + else: + if not isinstance(data, dict): + raise ExecutionError(f"Nominatim 回應格式錯誤: {type(data)}") + if "error" in data: + raise ExecutionError(f"Nominatim 錯誤: {data.get('error')}") + addr = data.get("address") or {} + extratags = data.get("extratags") or {} # 基本地址組件 road = addr.get("road") or addr.get("pedestrian") or addr.get("footway") or addr.get("cycleway") or "" @@ -227,6 +203,34 @@ class ReverseGeocodeTool(MCPTool): detailed_address_parts.append(f"郵遞區號: {postcode}") detailed_address = " | ".join(detailed_address_parts) if detailed_address_parts else label + + address_display_parts = [] + if postcode: + address_display_parts.append(postcode) + if city: + address_display_parts.append(city) + if city_district: + address_display_parts.append(city_district) + elif suburb: + address_display_parts.append(suburb) + if road: + address_display_parts.append(road) + if house_number: + address_display_parts.append(house_number) + address_display = "".join(address_display_parts) or display_name or label + + if house_number and road: + precision = "address" + elif name_zh and (amenity or shop or building or office or tourism or leisure): + precision = "poi" + elif road: + precision = "road" + elif city_district or suburb: + precision = "district" + elif city: + precision = "city" + else: + precision = "unknown" payload = { "lat": lat, @@ -236,7 +240,11 @@ class ReverseGeocodeTool(MCPTool): "country_code": country_code, "display_name": display_name, "label": label or display_name, + "address_display": address_display, "detailed_address": detailed_address, + "precision": precision, + "poi_label": name_zh or name or label or "", + "geocode_source": source, "road": road, "house_number": house_number, "suburb": suburb, @@ -256,3 +264,67 @@ class ReverseGeocodeTool(MCPTool): await set_geo_cache(geokey, payload) return cls.create_success_response(content=payload.get("label") or f"{payload['city']}, {payload['admin']}", data=payload) + + @staticmethod + async def _reverse_geocode_tdx(lat: float, lon: float) -> Dict[str, Any] | None: + try: + endpoint = f"V3/Map/GeoLocating/Address/LocationX/{lon}/LocationY/{lat}" + rows = await TDXBaseAPI.call_api( + endpoint, + {"$format": "JSON"}, + cache_ttl=86400, + api_version="", + api_family="advanced", + ) + if isinstance(rows, list) and rows: + row = rows[0] + address = row.get("Address") or "" + city = row.get("City") or "" + town = row.get("Town") or "" + road = row.get("RoadName") or "" + house_number = str(row.get("AddressNo") or "").strip() + return { + "display_name": address, + "address": { + "city": city, + "city_district": town, + "road": road, + "house_number": house_number, + "postcode": row.get("ZipCode") or "", + "country_code": "TW", + }, + "name": row.get("LocationDescription") or row.get("LandMarkName") or "", + "namedetails": {}, + "extratags": {}, + } + except Exception as exc: + logger.warning("TDX reverse geocode 失敗,回退 Nominatim: %s", exc) + return None + + @staticmethod + async def _reverse_geocode_nominatim(lat: float, lon: float) -> Dict[str, Any] | None: + url = "https://nominatim.openstreetmap.org/reverse" + params = { + "format": "jsonv2", + "lat": lat, + "lon": lon, + "zoom": 18, + "addressdetails": 1, + "extratags": 1, + "namedetails": 1 + } + headers = { + "User-Agent": "BloomWare/1.0 (contact@example.com)" + } + try: + async with aiohttp.ClientSession(headers=headers) as session: + async with session.get(url, params=params, timeout=30) as resp: + if resp.status != 200: + raise ExecutionError(f"Nominatim 失敗: HTTP {resp.status}") + response_text = await resp.text() + if not response_text or response_text.strip() == "": + raise ExecutionError("Nominatim 回應為空") + import json + return json.loads(response_text) + except (aiohttp.ClientError, asyncio.TimeoutError, ValueError) as exc: + raise ExecutionError(f"Nominatim reverse geocode 失敗: {exc}") from exc diff --git a/features/mcp/tools/geocoding_tool.py b/features/mcp/tools/location/geocoding_tool.py similarity index 63% rename from features/mcp/tools/geocoding_tool.py rename to features/mcp/tools/location/geocoding_tool.py index 2e9f51de4d8554f7375cd90ced8e145b4035a672..5cc24fa9e9155f731bbec1870cd7b4b1ee574543 100644 --- a/features/mcp/tools/geocoding_tool.py +++ b/features/mcp/tools/location/geocoding_tool.py @@ -1,19 +1,23 @@ """ 地點名稱轉座標工具(Forward Geocoding) -使用 Nominatim(OSM)將地點名稱轉換為經緯度座標 +優先使用 TDX 官方定位服務,失敗時 fallback 到 Nominatim(OSM) """ import aiohttp import asyncio import logging from typing import Dict, Any, List +from urllib.parse import quote -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from core.database import get_geo_cache, set_geo_cache from core.database.cache import db_cache +from ..transportation.tdx_base import TDXBaseAPI logger = logging.getLogger("mcp.tools.geocoding") +POI_HINTS = ("車站", "捷運站", "站", "大學", "醫院", "百貨", "大樓", "機場", "公園", "學校") + class ForwardGeocodeTool(MCPTool): NAME = "forward_geocode" @@ -101,36 +105,16 @@ class ForwardGeocodeTool(MCPTool): data=db_cached ) - # 外呼 Nominatim(公共端點,務必節流) - url = "https://nominatim.openstreetmap.org/search" - params = { - "format": "jsonv2", - "q": query, - "limit": limit, - "addressdetails": 1, - "extratags": 1, # 取得額外標籤 - "namedetails": 1, # 取得多語言名稱 - "accept-language": "zh-TW,zh" - } - headers = { - "User-Agent": "BloomWare/1.0 (contact@example.com)" - } - - try: - async with aiohttp.ClientSession(headers=headers) as session: - async with session.get(url, params=params, timeout=aiohttp.ClientTimeout(total=10)) as resp: - if resp.status != 200: - raise ExecutionError(f"Nominatim 查詢失敗: HTTP {resp.status}") - - data = await resp.json() - - if not data or len(data) == 0: - raise ExecutionError(f"找不到地點「{query}」,請確認地點名稱是否正確") + data = await cls._forward_geocode_tdx(query) + source = "tdx" + if not data: + data = await cls._forward_geocode_nominatim(query, limit) + source = "nominatim" + if not data or len(data) == 0: + raise ExecutionError(f"找不到地點「{query}」,請確認地點名稱是否正確") - except asyncio.TimeoutError: - raise ExecutionError("地點查詢逾時,請稍後再試") - except aiohttp.ClientError as e: - raise ExecutionError(f"網路連接錯誤: {str(e)}") + prefer_poi = any(hint in query for hint in POI_HINTS) + prefer_address = any(ch.isdigit() for ch in query) or "號" in query # 解析結果 results = [] @@ -219,10 +203,32 @@ class ForwardGeocodeTool(MCPTool): "amenity": amenity, "shop": shop, "building": building, + "geocode_source": source, + "_kind": item.get("_kind", ""), }) - # 最佳匹配(重要性最高) - best_match = max(results, key=lambda x: x["importance"]) + for result in results: + score = float(result.get("importance", 0)) + name = result.get("name", "") or "" + label = result.get("label", "") or "" + display_name = result.get("display_name", "") or "" + kind = result.get("_kind", "") + text = f"{name} {label} {display_name}" + + if query and query in text: + score += 5.0 + if prefer_poi and kind == "markname": + score += 8.0 + if prefer_address and kind == "address": + score += 8.0 + if "出入口" in text: + score -= 1.5 + if "交叉口" in display_name and prefer_poi: + score -= 1.0 + result["_score"] = score + + results.sort(key=lambda x: x.get("_score", 0), reverse=True) + best_match = results[0] payload = { "results": results, @@ -244,3 +250,102 @@ class ForwardGeocodeTool(MCPTool): content = "\n".join(content_parts) return cls.create_success_response(content=content, data=payload) + + @staticmethod + async def _forward_geocode_tdx(query: str) -> List[Dict[str, Any]] | None: + try: + address_endpoint = f"V3/Map/GeoCode/Coordinate/Address/{quote(query, safe='')}" + rows = await TDXBaseAPI.call_api( + address_endpoint, + {"$format": "JSON"}, + cache_ttl=86400, + api_version="", + api_family="advanced", + ) + parsed_address = ForwardGeocodeTool._parse_tdx_rows(rows, kind="address") + + mark_endpoint = f"V3/Map/GeoCode/Coordinate/Markname/{quote(query, safe='')}" + rows = await TDXBaseAPI.call_api( + mark_endpoint, + {"$format": "JSON"}, + cache_ttl=86400, + api_version="", + api_family="advanced", + ) + parsed_mark = ForwardGeocodeTool._parse_tdx_rows(rows, kind="markname") + combined = (parsed_mark or []) + (parsed_address or []) + return combined + except Exception as exc: + logger.warning("TDX forward geocode 失敗,回退 Nominatim: %s", exc) + return None + + @staticmethod + def _parse_tdx_rows(rows: Any, *, kind: str) -> List[Dict[str, Any]]: + results = [] + for row in rows or []: + lon = row.get("LocationX") or row.get("PositionLon") + lat = row.get("LocationY") or row.get("PositionLat") + geometry = row.get("Geometry") or "" + if (lon is None or lat is None) and isinstance(geometry, str) and geometry.startswith("POINT"): + try: + point_text = geometry.removeprefix("POINT").strip().strip("()") + point_lon, point_lat = point_text.split() + lon = float(point_lon) + lat = float(point_lat) + except Exception: + lon = lon + lat = lat + if lon is None or lat is None: + continue + address = row.get("Address") or row.get("RoadName") or row.get("LandMarkName") or "" + name = row.get("Name") or row.get("LandMarkName") or row.get("LocationDescription") or "" + city = row.get("City") or "" + town = row.get("Town") or "" + label = name or address or "" + results.append({ + "lat": float(lat), + "lon": float(lon), + "display_name": address or label, + "label": label, + "importance": 1.0, + "name": name, + "road": row.get("RoadName") or "", + "house_number": str(row.get("AddressNo") or "").strip(), + "suburb": "", + "city_district": town, + "city": city, + "admin": city, + "postcode": row.get("ZipCode") or "", + "amenity": "", + "shop": "", + "building": "", + "detailed_address": address or label, + "_kind": kind, + }) + return results + + @staticmethod + async def _forward_geocode_nominatim(query: str, limit: int) -> List[Dict[str, Any]]: + url = "https://nominatim.openstreetmap.org/search" + params = { + "format": "jsonv2", + "q": query, + "limit": limit, + "addressdetails": 1, + "extratags": 1, + "namedetails": 1, + "accept-language": "zh-TW,zh" + } + headers = { + "User-Agent": "BloomWare/1.0 (contact@example.com)" + } + try: + async with aiohttp.ClientSession(headers=headers) as session: + async with session.get(url, params=params, timeout=aiohttp.ClientTimeout(total=10)) as resp: + if resp.status != 200: + raise ExecutionError(f"Nominatim 查詢失敗: HTTP {resp.status}") + return await resp.json() + except asyncio.TimeoutError: + raise ExecutionError("地點查詢逾時,請稍後再試") + except aiohttp.ClientError as e: + raise ExecutionError(f"網路連接錯誤: {str(e)}") diff --git a/features/mcp/tools/weather_tool.py b/features/mcp/tools/location/weather_tool.py similarity index 99% rename from features/mcp/tools/weather_tool.py rename to features/mcp/tools/location/weather_tool.py index df3a9b3fe1257b9cb1124141470506b879b67dfa..135c0e8e88841a3be7dc8f9b33453775d5058b6b 100644 --- a/features/mcp/tools/weather_tool.py +++ b/features/mcp/tools/location/weather_tool.py @@ -11,7 +11,7 @@ import asyncio from datetime import datetime, timedelta from typing import Dict, Any, Optional from dotenv import load_dotenv -from .base_tool import MCPTool, ValidationError, ExecutionError, StandardToolSchemas +from ..base_tool import MCPTool, ValidationError, ExecutionError, StandardToolSchemas # 載入環境變數 load_dotenv() diff --git a/features/mcp/tools/news_tool.py b/features/mcp/tools/news_tool.py deleted file mode 100644 index beca8f70cea98a8ce144cb95d1d3b08fee88f51b..0000000000000000000000000000000000000000 --- a/features/mcp/tools/news_tool.py +++ /dev/null @@ -1,501 +0,0 @@ -""" -新聞查詢 MCP Tool -使用 NewsData.io 實作的新聞功能,提供更可靠的台灣與繁中新聞 -""" - -import os -import json -import logging -import aiohttp -import asyncio -from datetime import datetime, timedelta -from typing import Dict, Any, Optional, List -from urllib.parse import quote -from dotenv import load_dotenv -from .base_tool import MCPTool, ValidationError, ExecutionError, StandardToolSchemas - -# 載入環境變數 -load_dotenv() - -# 統一配置管理 -from core.config import settings - -logger = logging.getLogger("mcp.tools.news") - -# NewsData.io 配置 -NEWSDATA_BASE_URL = "https://newsdata.io/api/1" -NEWSDATA_API_KEY = settings.NEWSDATA_API_KEY - - -class NewsTool(MCPTool): - """新聞查詢 MCP 工具 - 使用 NewsData.io(更好的台灣與繁中新聞支援)""" - - NAME = "news_query" - DESCRIPTION = "Query latest news articles (can specify category, language, and quantity)" - CATEGORY = "生活資訊" - TAGS = ["news", "新聞", "資訊"] - KEYWORDS = ["新聞", "消息", "報導", "news", "頭條", "時事"] - USAGE_TIPS = [ - "可指定新聞類別(科技、商業、娛樂等)", - "支援多國新聞(台灣、美國、日本等)", - "可限制返回數量" - ] - - @classmethod - def get_input_schema(cls) -> Dict[str, Any]: - """獲取輸入參數模式""" - return StandardToolSchemas.create_input_schema({ - "query": { - "type": "string", - "description": "搜尋關鍵詞(可選)" - }, - "country": { - "type": "string", - "description": "新聞國家代碼 (tw, us, cn, jp, kr, hk, sg)", - "default": "tw", - "enum": ["tw", "us", "cn", "jp", "kr", "hk", "sg", "gb", "de", "fr"] - }, - "category": { - "type": "string", - "description": "新聞分類 (business, technology, health, science, sports, entertainment, top)", - "default": "top", - "enum": ["business", "technology", "health", "science", "sports", "entertainment", "top", "world", "politics"] - }, - "language": { - "type": "string", - "description": "新聞語言 (zh, en, ja, ko)", - "default": "zh", - "enum": ["zh", "en", "ja", "ko"] - }, - "limit": { - "type": "integer", - "description": "返回新聞數量限制(免費版最多 10)", - "default": 10, - "minimum": 1, - "maximum": 10 - }, - "timeframe": { - "type": "integer", - "description": "查詢過去幾小時的新聞(1-48,可選)", - "minimum": 1, - "maximum": 48 - } - }) - - @classmethod - def get_output_schema(cls) -> Dict[str, Any]: - """獲取輸出結果模式""" - base_schema = StandardToolSchemas.create_output_schema() - base_schema["properties"].update({ - "articles": { - "type": "array", - "items": { - "type": "object", - "properties": { - "article_id": {"type": "string"}, - "title": {"type": "string"}, - "description": {"type": "string"}, - "content": {"type": "string"}, - "url": {"type": "string"}, - "published_at": {"type": "string"}, - "source": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "id": {"type": "string"}, - "url": {"type": "string"} - } - }, - "category": {"type": "array"}, - "language": {"type": "string"}, - "sentiment": {"type": "string"} - } - } - }, - "count": {"type": "integer"}, - "totalResults": {"type": "integer"} - }) - return base_schema - - @classmethod - async def execute(cls, arguments: Dict[str, Any]) -> Dict[str, Any]: - """執行新聞查詢""" - if not NEWSDATA_API_KEY: - return cls.create_error_response( - error="NewsData.io API 金鑰未設置,請設置 NEWSDATA_API_KEY 環境變數", - code="API_KEY_MISSING" - ) - - # 處理參數,過濾空字串並使用預設值 - query = arguments.get("query", "") - country = arguments.get("country", "tw") or "tw" - category = arguments.get("category", "top") or "top" - language = arguments.get("language", "zh") or "zh" - limit = min(arguments.get("limit", 10), 10) # 免費版限制 10 - timeframe = arguments.get("timeframe") - - # 確保 category 是有效值(防止空字串) - valid_categories = ["business", "technology", "health", "science", "sports", "entertainment", "top", "world", "politics"] - if category not in valid_categories: - category = "top" - - try: - news_data = await cls._fetch_news_from_newsdata( - query, country, category, language, limit, timeframe - ) - - if news_data.get("success"): - articles = news_data.get("articles", []) - total_results = news_data.get("totalResults", 0) - - # 為每篇新聞生成簡短摘要(用於工具卡片顯示) - articles = await cls._generate_summaries(articles) - - formatted_text = cls._format_newsdata_response( - articles, query, country, category, total_results - ) - - return cls.create_success_response( - content=formatted_text, - data={ - "raw_data": { - "articles": articles, - "count": len(articles), - "totalResults": total_results - } - } - ) - else: - return cls.create_error_response( - error=news_data.get("error", "獲取新聞失敗"), - code="FETCH_ERROR" - ) - - except Exception as e: - logger.error(f"新聞查詢錯誤: {e}") - raise ExecutionError(f"新聞查詢時發生錯誤: {str(e)}", e) - - @staticmethod - async def _fetch_news_from_newsdata( - query: str, country: str, category: str, - language: str, limit: int, timeframe: Optional[int] - ) -> Dict[str, Any]: - """從 NewsData.io 獲取新聞數據""" - try: - # 建構 NewsData.io URL - url = f"{NEWSDATA_BASE_URL}/latest" - - # 構建參數 - params = { - "apikey": NEWSDATA_API_KEY, - "size": limit, - "language": language - } - - # 關鍵字搜尋 - if query: - params["q"] = query - - # 國家篩選(僅在沒有關鍵字時使用) - if country and not query: - params["country"] = country - - # 分類篩選 - if category and category != "top": - params["category"] = category - - # 時間範圍 - if timeframe: - params["timeframe"] = timeframe - - # 排除重複 - params["removeduplicate"] = "1" - - logger.info(f"NewsData.io 請求: {url}") - logger.info(f"參數: {', '.join([f'{k}={v}' for k, v in params.items() if k != 'apikey'])}") - - async with aiohttp.ClientSession() as session: - async with session.get(url, params=params, timeout=15) as response: - logger.info(f"NewsData.io 響應狀態: {response.status}") - - if response.status == 200: - data = await response.json() - - # 檢查 API 回應狀態 - status = data.get("status") - if status == "success": - articles = data.get("results", []) - total_results = data.get("totalResults", 0) - - logger.info(f"NewsData.io 返回文章數: {len(articles)} / 總數: {total_results}") - - # 處理文章數據(確保與前端格式兼容) - processed_articles = [] - for article in articles: - source_name = article.get("source_name", article.get("source_id", "未知來源")) - - # 過濾掉付費功能的佔位文字 - sentiment = article.get("sentiment", "") - if "ONLY AVAILABLE" in str(sentiment): - sentiment = "" - - content = article.get("content", "") - if "ONLY AVAILABLE" in str(content): - content = "" - - processed_article = { - "article_id": article.get("article_id", ""), - "title": article.get("title", "無標題"), - "description": article.get("description", ""), - "content": content, - "url": article.get("link", ""), - "published_at": article.get("pubDate", ""), - # 前端期望 source 是物件 {name: "來源名"},或字串直接顯示 - "source": { - "name": source_name, - "id": article.get("source_id", ""), - "url": article.get("source_url", "") - }, - "category": article.get("category", []), - "language": article.get("language", ""), - "country": article.get("country", []), - "sentiment": sentiment, - "image_url": article.get("image_url", "") - } - processed_articles.append(processed_article) - - return { - "success": True, - "articles": processed_articles, - "totalResults": total_results - } - else: - # API 返回錯誤狀態 - error_msg = data.get("results", {}).get("message", "未知錯誤") - error_code = data.get("results", {}).get("code", "UNKNOWN") - logger.error(f"NewsData.io API 錯誤: {error_code} - {error_msg}") - return { - "success": False, - "error": f"API 錯誤: {error_msg}" - } - - elif response.status == 401: - return { - "success": False, - "error": "NewsData.io API 金鑰無效或已過期" - } - elif response.status == 429: - return { - "success": False, - "error": "NewsData.io API 請求次數已達上限(免費版每日 200 次)" - } - else: - error_text = await response.text() - logger.error(f"NewsData.io HTTP 錯誤 {response.status}: {error_text}") - return { - "success": False, - "error": f"HTTP 錯誤: {response.status}" - } - - except asyncio.TimeoutError: - return { - "success": False, - "error": "NewsData.io 請求超時,請稍後再試" - } - except aiohttp.ClientError as e: - logger.error(f"網絡連接錯誤: {e}") - return { - "success": False, - "error": "網絡連接錯誤,無法獲取新聞" - } - except Exception as e: - logger.error(f"NewsData.io 請求錯誤: {e}", exc_info=True) - return { - "success": False, - "error": str(e) - } - - @staticmethod - async def _generate_summaries(articles: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """為每篇新聞生成一句話簡短摘要(用於工具卡片顯示)""" - try: - from services.ai_service import generate_response_async - - logger.info(f"🤖 開始為 {len(articles)} 則新聞生成摘要") - - # 批量處理:一次請求處理所有新聞 - news_items = [] - for idx, article in enumerate(articles, 1): - title = article.get("title", "") - description = article.get("description", "") - - if not title: - article["summary"] = "無標題" - continue - - # 組合標題和描述 - content = f"{idx}. 標題:{title}" - if description: - content += f"\n 描述:{description[:100]}" - news_items.append(content) - - if not news_items: - return articles - - # 一次性請求 AI 生成所有摘要 - batch_prompt = "\n\n".join(news_items) - - try: - response = await generate_response_async( - messages=[ - { - "role": "system", - "content": "你是新聞摘要助手。請為每則新聞生成一句話摘要(最多30字),用數字編號回應。" - }, - { - "role": "user", - "content": f"請為以下新聞各生成一句話摘要(每則最多30字):\n\n{batch_prompt}" - } - ], - model="gpt-5-nano", - reasoning_effort="low" - ) - - # 解析回應 - lines = response.strip().split('\n') - summaries = [] - for line in lines: - line = line.strip() - # 移除編號前綴 (1. 2. 等) - if line and (line[0].isdigit() or line.startswith('•') or line.startswith('-')): - # 去除編號和標點 - summary = line.lstrip('0123456789.-•) ').strip() - if summary: - summaries.append(summary[:30]) # 限制 30 字 - - # 將摘要分配給文章 - for idx, article in enumerate(articles): - if article.get("title"): - if idx < len(summaries): - article["summary"] = summaries[idx] - logger.info(f"📝 新聞{idx+1} 摘要: {summaries[idx]} ({len(summaries[idx])}字)") - else: - # Fallback - title = article.get("title", "") - article["summary"] = title[:30] - logger.warning(f"⚠️ 新聞{idx+1} 使用 fallback") - - except Exception as e: - logger.error(f"AI 生成摘要失敗: {e}") - # Fallback: 使用標題 - for article in articles: - title = article.get("title", "無標題") - article["summary"] = title[:30] - - return articles - - except Exception as e: - logger.error(f"批量生成摘要失敗: {e}") - # 失敗時使用標題作為 fallback - for article in articles: - if "summary" not in article: - title = article.get("title", "無標題") - article["summary"] = title[:30] - return articles - - @staticmethod - def _format_newsdata_response( - articles: List[Dict[str, Any]], query: str, - country: str, category: str, total_results: int - ) -> str: - """格式化 NewsData.io 回應""" - if not articles: - return "抱歉,找不到相關新聞" - - # 標題 - header = "📰 最新新聞" - if query: - header += f" - 搜尋: {query}" - else: - country_names = { - "tw": "台灣", "us": "美國", "cn": "中國", - "jp": "日本", "kr": "韓國", "hk": "香港", - "sg": "新加坡", "gb": "英國", "de": "德國", "fr": "法國" - } - header += f" - {country_names.get(country, country.upper())}" - - if category and category != "top": - category_names = { - "business": "商業", "technology": "科技", - "health": "健康", "science": "科學", - "sports": "體育", "entertainment": "娛樂", - "world": "國際", "politics": "政治" - } - header += f" - {category_names.get(category, category)}" - - result = f"{header}\n\n" - - # 新聞列表 - for i, article in enumerate(articles, 1): - result += f"📌 {article.get('title', '無標題')}\n" - - # 來源(兼容物件和字串格式) - source = article.get('source', {}) - if isinstance(source, dict): - source_name = source.get('name', '未知來源') - else: - source_name = source or '未知來源' - result += f"🗞️ {source_name}" - - # 分類標籤 - categories = article.get('category', []) - if categories: - category_str = ", ".join(categories[:2]) # 最多顯示 2 個分類 - result += f" | 🏷️ {category_str}" - - # 情緒標籤(過濾付費功能提示) - sentiment = article.get('sentiment', '') - if sentiment and "ONLY AVAILABLE" not in str(sentiment): - sentiment_emoji = { - "positive": "😊 正面", - "neutral": "😐 中立", - "negative": "😟 負面" - }.get(sentiment.lower(), sentiment) - result += f" | {sentiment_emoji}" - - result += "\n" - - # 發布時間 - published_at = article.get('published_at', '') - if published_at: - try: - # NewsData.io 格式: "2025-01-25 12:34:56" - if ' ' in published_at: - dt = datetime.strptime(published_at, '%Y-%m-%d %H:%M:%S') - formatted_date = dt.strftime('%m/%d %H:%M') - result += f"📅 {formatted_date}\n" - else: - result += f"📅 {published_at[:16]}\n" - except Exception as e: - logger.warning(f"日期解析錯誤: {e}") - result += f"📅 {published_at[:16]}\n" - - # 描述 - description = article.get('description', '') - if description: - if len(description) > 150: - description = description[:150] + "..." - result += f"📝 {description}\n" - - # 連結 - url = article.get('url', '') - if url: - result += f"🔗 {url}\n" - - result += "\n" - - # 底部資訊 - result += f"📊 顯示 {len(articles)} 則 / 共 {total_results} 則新聞 | 🕒 {datetime.now().strftime('%Y-%m-%d %H:%M')}" - result += "\n💡 由 NewsData.io 提供" - - return result diff --git a/features/mcp/tools/transportation/__init__.py b/features/mcp/tools/transportation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4c0932c2e692b3ac948c9cf8f1af174a360aaf --- /dev/null +++ b/features/mcp/tools/transportation/__init__.py @@ -0,0 +1,15 @@ +from .tdx_bus_arrival import TDXBusArrivalTool +from .tdx_metro import TDXMetroTool +from .tdx_parking import TDXParkingTool +from .tdx_thsr import TDXTHSRTool +from .tdx_train import TDXTrainTool +from .tdx_youbike import TDXBikeTool + +__all__ = [ + "TDXBikeTool", + "TDXBusArrivalTool", + "TDXMetroTool", + "TDXParkingTool", + "TDXTHSRTool", + "TDXTrainTool", +] diff --git a/features/mcp/tools/tdx_base.py b/features/mcp/tools/transportation/tdx_base.py similarity index 59% rename from features/mcp/tools/tdx_base.py rename to features/mcp/tools/transportation/tdx_base.py index faf0336f00c3b67d0b919589661594c82e92345e..46244703cb558ca1c4b201d4d7fa92cce13b3e13 100644 --- a/features/mcp/tools/tdx_base.py +++ b/features/mcp/tools/transportation/tdx_base.py @@ -39,13 +39,16 @@ import asyncio from typing import Dict, Any, Optional from datetime import datetime, timedelta -from .base_tool import ExecutionError +from ..base_tool import ExecutionError from core.database.cache import db_cache logger = logging.getLogger("mcp.tools.tdx") -# TDX API 基礎 URL -TDX_BASE_URL = "https://tdx.transportdata.tw/api/basic" +TDX_BASE_URLS = { + "basic": "https://tdx.transportdata.tw/api/basic", + "advanced": "https://tdx.transportdata.tw/api/advanced", + "historical": "https://tdx.transportdata.tw/api/historical", +} # 從環境變數讀取(需要在 app.py 中先 load_dotenv) TDX_CLIENT_ID = os.getenv("TDX_CLIENT_ID", "") @@ -112,7 +115,8 @@ class TDXBaseAPI: endpoint: str, params: Optional[Dict[str, Any]] = None, cache_ttl: int = 60, - api_version: str = "v2" + api_version: str = "v2", + api_family: str = "basic", ) -> Any: """ 呼叫 TDX API 並處理快取 @@ -129,7 +133,10 @@ class TDXBaseAPI: access_token = await cls.get_access_token() # 組合完整 URL - url = f"{TDX_BASE_URL}/{api_version}/{endpoint}" + base_url = TDX_BASE_URLS.get(api_family, TDX_BASE_URLS["basic"]) + version_segment = f"/{api_version.strip('/')}" if api_version else "" + endpoint_segment = endpoint.lstrip("/") + url = f"{base_url}{version_segment}/{endpoint_segment}" headers = { "Authorization": f"Bearer {access_token}", "Accept": "application/json" @@ -142,7 +149,7 @@ class TDXBaseAPI: params["$format"] = "JSON" # 生成快取鍵 - cache_key = f"tdx:{api_version}:{endpoint}:{json.dumps(params, sort_keys=True)}" + cache_key = f"tdx:{api_family}:{api_version}:{endpoint}:{json.dumps(params, sort_keys=True)}" # 檢查快取 if cache_ttl > 0: @@ -155,73 +162,87 @@ class TDXBaseAPI: logger.info(f"🌐 TDX API 請求: {url}") logger.info(f" 參數: {params}") - try: - async with aiohttp.ClientSession() as session: - async with session.get( - url, - headers=headers, - params=params, - timeout=aiohttp.ClientTimeout(total=30) - ) as resp: - response_text = await resp.text() - - # 記錄完整回應(用於除錯) - logger.info(f"📥 TDX API 回應: HTTP {resp.status}") - if resp.status != 200: - logger.error(f"❌ TDX API 錯誤回應:") - logger.error(f" URL: {url}") - logger.error(f" 參數: {params}") - logger.error(f" 狀態碼: {resp.status}") - logger.error(f" 回應內容: {response_text[:1000]}") - - if resp.status == 304: - logger.info("TDX 資料未變更 (304)") - return cached if cached else [] - - if resp.status == 401: - # Token 過期,清除快取重試 - cls._token_cache = {} - error_msg = f"TDX Token 已過期,請重試\n[API] {url}\n[回應] {response_text[:500]}" - raise ExecutionError(error_msg) - - if resp.status == 404: - error_msg = f"TDX API 找不到資源 (404)\n[API] {url}\n[參數] {params}\n[回應] {response_text[:500]}" - logger.error(error_msg) - raise ExecutionError(error_msg) - - if resp.status != 200: - error_msg = f"TDX API 錯誤: HTTP {resp.status}\n[API] {url}\n[參數] {params}\n[回應] {response_text[:500]}" - logger.error(error_msg) - raise ExecutionError(error_msg) - - try: - data = json.loads(response_text) - except json.JSONDecodeError: - error_msg = f"TDX API 回應非 JSON 格式\n[API] {url}\n[回應] {response_text[:500]}" - raise ExecutionError(error_msg) - - # 記錄回應資料筆數 - data_count = len(data) if isinstance(data, list) else 1 - logger.info(f"✅ TDX API 成功: {endpoint} (共 {data_count} 筆)") - - # 如果回應是空陣列,記錄警告 - if isinstance(data, list) and len(data) == 0: - logger.warning(f"⚠️ TDX API 回應空陣列: {url}") - - # 快取結果 - if cache_ttl > 0 and data: - await db_cache.set_tdx_cache(cache_key, data, ttl=cache_ttl) - - return data - - except asyncio.TimeoutError: - error_msg = f"TDX API 逾時\n[API] {url}\n[參數] {params}" - logger.error(error_msg) - raise ExecutionError(error_msg) - except aiohttp.ClientError as e: - error_msg = f"TDX API 網路錯誤: {e}\n[API] {url}" - logger.error(error_msg) - raise ExecutionError(error_msg) + retry_delays = [0.8, 1.6, 3.2] + last_error: Optional[ExecutionError] = None + + for attempt, delay in enumerate([0.0] + retry_delays, start=1): + if delay > 0: + await asyncio.sleep(delay) + try: + async with aiohttp.ClientSession() as session: + async with session.get( + url, + headers=headers, + params=params, + timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + response_text = await resp.text() + + # 記錄完整回應(用於除錯) + logger.info(f"📥 TDX API 回應: HTTP {resp.status}") + if resp.status != 200: + logger.error(f"❌ TDX API 錯誤回應:") + logger.error(f" URL: {url}") + logger.error(f" 參數: {params}") + logger.error(f" 狀態碼: {resp.status}") + logger.error(f" 回應內容: {response_text[:1000]}") + + if resp.status == 304: + logger.info("TDX 資料未變更 (304)") + return cached if cache_ttl > 0 and 'cached' in locals() and cached else [] + + if resp.status == 401: + cls._token_cache = {} + error_msg = f"TDX Token 已過期,請重試\n[API] {url}\n[回應] {response_text[:500]}" + raise ExecutionError(error_msg) + + if resp.status == 429: + error_msg = f"TDX API 速率限制: HTTP 429\n[API] {url}\n[參數] {params}\n[回應] {response_text[:500]}" + logger.error(error_msg) + last_error = ExecutionError(error_msg) + if attempt <= len(retry_delays): + logger.warning("⏳ TDX API 429,準備第 %s 次重試: %s", attempt, url) + continue + raise last_error + + if resp.status == 404: + error_msg = f"TDX API 找不到資源 (404)\n[API] {url}\n[參數] {params}\n[回應] {response_text[:500]}" + logger.error(error_msg) + raise ExecutionError(error_msg) + + if resp.status != 200: + error_msg = f"TDX API 錯誤: HTTP {resp.status}\n[API] {url}\n[參數] {params}\n[回應] {response_text[:500]}" + logger.error(error_msg) + raise ExecutionError(error_msg) + + try: + data = json.loads(response_text) + except json.JSONDecodeError: + error_msg = f"TDX API 回應非 JSON 格式\n[API] {url}\n[回應] {response_text[:500]}" + raise ExecutionError(error_msg) + + data_count = len(data) if isinstance(data, list) else 1 + logger.info(f"✅ TDX API 成功: {endpoint} (共 {data_count} 筆)") + + if isinstance(data, list) and len(data) == 0: + logger.warning(f"⚠️ TDX API 回應空陣列: {url}") + + if cache_ttl > 0 and data: + await db_cache.set_tdx_cache(cache_key, data, ttl=cache_ttl) + + return data + except asyncio.TimeoutError: + error_msg = f"TDX API 逾時\n[API] {url}\n[參數] {params}" + logger.error(error_msg) + raise ExecutionError(error_msg) + except aiohttp.ClientError as e: + error_msg = f"TDX API 網路錯誤: {e}\n[API] {url}" + logger.error(error_msg) + raise ExecutionError(error_msg) + + if last_error: + raise last_error + raise ExecutionError(f"TDX API 未知錯誤\n[API] {url}") @staticmethod def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: diff --git a/features/mcp/tools/tdx_bus_arrival.py b/features/mcp/tools/transportation/tdx_bus_arrival.py similarity index 86% rename from features/mcp/tools/tdx_bus_arrival.py rename to features/mcp/tools/transportation/tdx_bus_arrival.py index 8b64dc18ba1a82edb9ab1e80383c5c87221caa3b..37def408c2c44ac48f1b32cebe4a9e39adf0d023 100644 --- a/features/mcp/tools/tdx_bus_arrival.py +++ b/features/mcp/tools/transportation/tdx_bus_arrival.py @@ -13,8 +13,9 @@ API 文件: https://tdx.transportdata.tw/api-service/swagger#/CityBus import logging from typing import Dict, Any, List, Optional -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from .tdx_base import TDXBaseAPI +from .tdx_location import resolve_location_context, resolve_city_candidates, resolve_city_code from core.database import get_user_env_current logger = logging.getLogger("mcp.tools.tdx.bus") @@ -24,7 +25,13 @@ class TDXBusArrivalTool(MCPTool): """TDX 公車即時到站查詢""" NAME = "tdx_bus_arrival" - DESCRIPTION = "Query real-time bus arrival times. Use for: 1) Known route numbers (e.g., 307, Red 30); 2) Nearby bus stops. Not for route planning (use 'directions' instead)." + DESCRIPTION = """Query real-time bus arrival times. Use for: 1) Known route numbers (e.g., 307, Red 30); 2) Nearby bus stops. Not for route planning (use 'directions' instead). +IMPORTANT Parameter Extraction Rules: +1. "[number] bus" or "[color][number]" -> route_name="[number]" (e.g., 307, 紅30) +2. "nearby bus", "bus stops" -> route_name="" (leave empty, use GPS) +3. Only fill city if explicitly mentioned (e.g. "Taipei 261 bus" -> route_name="261", city="Taipei"). City must be in English. +4. "how long until [number] arrives" -> route_name="[number]" +Always extract route_name if mentioned!""" CATEGORY = "道路運輸" TAGS = ["tdx", "公車", "即時到站", "公共運輸"] KEYWORDS = ["公車", "巴士", "bus", "到站", "即時", "幾分鐘", "公車站", "等公車", "路線號碼"] @@ -63,6 +70,10 @@ class TDXBusArrivalTool(MCPTool): "type": "string", "description": "城市(預設從環境感知自動判斷,支援中文或英文代碼)" }, + "location_query": { + "type": "string", + "description": "精確地址、地標、路口或站區(如「桃園火車站」「中正路口」)。提供時優先解析為座標做附近查詢" + }, "limit": { "type": "integer", "description": "返回結果數量上限", @@ -105,6 +116,7 @@ class TDXBusArrivalTool(MCPTool): route_name = str(arguments.get("route_name", "")).strip() limit = min(int(arguments.get("limit", 5)), 20) + location_query = str(arguments.get("location_query", "")).strip() # 1. 取得用戶位置和城市 user_lat = arguments.get("lat") @@ -140,43 +152,30 @@ class TDXBusArrivalTool(MCPTool): if not route_name and (user_lat is None or user_lon is None): raise ExecutionError("無法取得您的位置,請提供路線名稱或開啟定位權限") - # 2. 判斷城市代碼 - # 優先順序:即時反向地理編碼 > 環境參數 > 經緯度範圍推斷 > 預設值 - city_source = "預設" - final_city = None - - # 2a. 如果有經緯度,嘗試即時反向地理編碼取得精確城市 - if user_lat is not None and user_lon is not None: - logger.debug(f"🗺️ [TDX] 嘗試反向地理編碼: ({user_lat}, {user_lon})") - geocoded_city = await cls._reverse_geocode_city(user_lat, user_lon) - logger.debug(f"🗺️ [TDX] 反向地理編碼結果: {geocoded_city}") - if geocoded_city: - final_city = geocoded_city - city_source = "反向地理編碼" - - # 2b. 如果反向地理編碼失敗,使用環境參數 - if not final_city and city_param: - final_city = city_param - city_source = "環境參數" - logger.debug(f"📍 [TDX] 使用環境參數城市: {city_param}") - - # 2c. 如果還是沒有,使用經緯度範圍推斷 - if not final_city and user_lat is not None and user_lon is not None: - guessed_city = cls._guess_city_from_location(user_lat, user_lon) - logger.debug(f"📐 [TDX] 經緯度推斷結果: {guessed_city}") - if guessed_city: - final_city = guessed_city - city_source = "經緯度推斷" - - # 2d. 轉換為 TDX 城市代碼 - city = cls._resolve_city(final_city or "") - logger.debug(f"🏙️ [TDX] 最終城市: {city} (來源={city_source}, 原始={final_city})") + location_ctx = await resolve_location_context( + lat=user_lat, + lon=user_lon, + location_query=location_query, + city_like=city_param, + allowed_city_codes=cls.VALID_CITIES, + ) + user_lat = location_ctx["lat"] + user_lon = location_ctx["lon"] + geo = location_ctx.get("geo") or {} + city = resolve_city_code(city_param, allowed=cls.VALID_CITIES) or location_ctx["city_code"] or "Taipei" + city_candidates = resolve_city_candidates( + city_like=city_param or city, + geo_city=geo.get("city"), + geo_admin=geo.get("admin"), + allowed_city_codes=cls.VALID_CITIES, + ) + logger.debug(f"🏙️ [TDX] 城市候選: {city_candidates}, 主城市={city}") # 3. 執行查詢 if route_name: return await cls._query_route_arrival(route_name, city, user_lat, user_lon, limit) else: - return await cls._query_nearby_stops(user_lat, user_lon, city, limit) + return await cls._query_nearby_stops(user_lat, user_lon, city_candidates, limit) @classmethod async def _query_route_arrival( @@ -196,52 +195,56 @@ class TDXBusArrivalTool(MCPTool): """ logger.debug(f"🚌 [TDX] 查詢公車到站: 路線={route_name}, 城市={city}") - # 1. 查詢預估到站時間 + # 1-2. 並行查詢預估到站時間與公車即時位置 eta_endpoint = f"Bus/EstimatedTimeOfArrival/City/{city}/{route_name}" eta_params = {"$orderby": "StopSequence", "$format": "JSON"} + realtime_endpoint = f"Bus/RealTimeNearStop/City/{city}/{route_name}" + realtime_params = {"$format": "JSON"} + try: - logger.debug(f"🌐 [TDX] 呼叫 API: {eta_endpoint}") - arrival_data = await TDXBaseAPI.call_api(eta_endpoint, eta_params, cache_ttl=30) - logger.debug(f"✅ [TDX] API 回應: {len(arrival_data) if arrival_data else 0} 筆資料") - if arrival_data and len(arrival_data) > 0: - logger.debug(f"📋 [TDX] 第一筆: {arrival_data[0].get('StopName', {}).get('Zh_tw')}") + logger.debug(f"🌐 [TDX] 開始並行查詢: {route_name}") + eta_task = TDXBaseAPI.call_api(eta_endpoint, eta_params, cache_ttl=30) + realtime_task = TDXBaseAPI.call_api(realtime_endpoint, realtime_params, cache_ttl=15) + + arrival_data, realtime_data = await asyncio.gather(eta_task, realtime_task, return_exceptions=True) + + if isinstance(arrival_data, Exception): + raise arrival_data + if isinstance(realtime_data, Exception): + logger.warning(f"⚠️ 無法取得公車即時位置: {realtime_data}") + realtime_data = [] + + logger.debug(f"✅ [TDX] API 回應: {len(arrival_data) if arrival_data else 0} 筆到站資料") except ExecutionError as e: error_detail = str(e) - logger.debug(f"❌ [TDX] API 錯誤: {error_detail}") if "404" in error_detail: raise ExecutionError(f"找不到路線「{route_name}」,請確認路線名稱與城市") raise ExecutionError(f"查詢路線「{route_name}」失敗: {error_detail}") + except Exception as e: + logger.error(f"❌ [TDX] 查詢異常: {e}") + raise ExecutionError(f"查詢路線「{route_name}」時發生非預期錯誤") if not arrival_data: - logger.debug(f"⚠️ [TDX] 無資料,拋出錯誤") raise ExecutionError(f"路線「{route_name}」目前無班次資訊") - # 2. 查詢公車即時位置(目前在哪站) - realtime_endpoint = f"Bus/RealTimeNearStop/City/{city}/{route_name}" - realtime_params = {"$format": "JSON"} - + # 處理公車即時位置 bus_positions = {} # {direction: [{plate, stop_name, stop_sequence, event_type}]} - try: - realtime_data = await TDXBaseAPI.call_api(realtime_endpoint, realtime_params, cache_ttl=15) - if realtime_data: - for bus in realtime_data: - direction = bus.get("Direction", 0) - plate = bus.get("PlateNumb", "") - stop_name = bus.get("StopName", {}).get("Zh_tw", "") - stop_sequence = bus.get("StopSequence", 0) # 公車目前站序 - event_type = bus.get("A2EventType", 0) # 0=離站, 1=進站 + for bus in (realtime_data or []): + direction = bus.get("Direction", 0) + plate = bus.get("PlateNumb", "") + stop_name = bus.get("StopName", {}).get("Zh_tw", "") + stop_sequence = bus.get("StopSequence", 0) + event_type = bus.get("A2EventType", 0) - if direction not in bus_positions: - bus_positions[direction] = [] - bus_positions[direction].append({ - "plate": plate, - "current_stop": stop_name, - "stop_sequence": stop_sequence, # 新增站序 - "event": "進站中" if event_type == 1 else "已離站" - }) - except Exception as e: - logger.warning(f"⚠️ 無法取得公車即時位置: {e}") + if direction not in bus_positions: + bus_positions[direction] = [] + bus_positions[direction].append({ + "plate": plate, + "current_stop": stop_name, + "stop_sequence": stop_sequence, + "event": "進站中" if event_type == 1 else "已離站" + }) # 3. 取得路線全名 route_obj = arrival_data[0].get("RouteName", {}) @@ -382,7 +385,7 @@ class TDXBusArrivalTool(MCPTool): cls, lat: float, lon: float, - city: str, + cities: List[str], limit: int ) -> Dict[str, Any]: """ @@ -390,14 +393,16 @@ class TDXBusArrivalTool(MCPTool): API: GET /v2/Bus/Stop/City/{City}?$spatialFilter=nearby(lat, lon, distance) """ - endpoint = f"Bus/Stop/City/{city}" - params = { - "$spatialFilter": f"nearby({lat}, {lon}, 500)", - "$top": limit * 3, - "$format": "JSON" - } - - stops = await TDXBaseAPI.call_api(endpoint, params, cache_ttl=1800) + stops = [] + for city in cities: + endpoint = f"Bus/Stop/City/{city}" + params = { + "$spatialFilter": f"nearby({lat}, {lon}, 500)", + "$top": limit * 3, + "$format": "JSON" + } + city_stops = await TDXBaseAPI.call_api(endpoint, params, cache_ttl=1800) + stops.extend(city_stops or []) if not stops: return cls.create_success_response( diff --git a/features/mcp/tools/transportation/tdx_location.py b/features/mcp/tools/transportation/tdx_location.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb1e1c0aa69917b79ca9b0dfeceaf7f23957778 --- /dev/null +++ b/features/mcp/tools/transportation/tdx_location.py @@ -0,0 +1,219 @@ +import logging +from typing import Any, Dict, Iterable, Optional, Tuple + +from ..base_tool import ExecutionError + +logger = logging.getLogger("mcp.tools.tdx.location") + +CITY_CODE_MAP = { + "台北": "Taipei", + "臺北": "Taipei", + "新北": "NewTaipei", + "桃園": "Taoyuan", + "台中": "Taichung", + "臺中": "Taichung", + "台南": "Tainan", + "臺南": "Tainan", + "高雄": "Kaohsiung", + "新竹": "Hsinchu", + "基隆": "Keelung", + "苗栗": "MiaoliCounty", + "彰化": "ChanghuaCounty", + "南投": "NantouCounty", + "雲林": "YunlinCounty", + "嘉義市": "Chiayi", + "嘉義": "Chiayi", + "嘉義縣": "ChiayiCounty", + "屏東": "PingtungCounty", + "宜蘭": "YilanCounty", + "花蓮": "HualienCounty", + "台東": "TaitungCounty", + "臺東": "TaitungCounty", + "金門": "KinmenCounty", + "澎湖": "PenghuCounty", + "連江": "LienchiangCounty", + "馬祖": "LienchiangCounty", +} + +METRO_OPERATOR_MAP = { + "台北": "TRTC", + "臺北": "TRTC", + "新北": "NTMC", + "桃園": "TYMC", + "台中": "TMRT", + "臺中": "TMRT", + "高雄": "KRTC", +} + +CITY_NEIGHBORS = { + "Taipei": ["NewTaipei", "Keelung", "Taoyuan"], + "NewTaipei": ["Taipei", "Keelung", "Taoyuan", "YilanCounty"], + "Taoyuan": ["NewTaipei", "Taipei", "Hsinchu", "HsinchuCounty"], + "Hsinchu": ["HsinchuCounty", "Taoyuan", "MiaoliCounty"], + "HsinchuCounty": ["Hsinchu", "Taoyuan", "MiaoliCounty"], + "MiaoliCounty": ["HsinchuCounty", "Hsinchu", "Taichung"], + "Taichung": ["MiaoliCounty", "ChanghuaCounty", "NantouCounty", "YunlinCounty"], + "ChanghuaCounty": ["Taichung", "NantouCounty", "YunlinCounty"], + "NantouCounty": ["Taichung", "ChanghuaCounty", "YunlinCounty", "ChiayiCounty", "Tainan"], + "YunlinCounty": ["ChanghuaCounty", "NantouCounty", "Chiayi", "ChiayiCounty", "Taichung"], + "Chiayi": ["ChiayiCounty", "YunlinCounty", "Tainan"], + "ChiayiCounty": ["Chiayi", "YunlinCounty", "Tainan", "Kaohsiung"], + "Tainan": ["ChiayiCounty", "Kaohsiung"], + "Kaohsiung": ["Tainan", "PingtungCounty"], + "PingtungCounty": ["Kaohsiung", "TaitungCounty"], + "Keelung": ["Taipei", "NewTaipei"], + "YilanCounty": ["NewTaipei", "HualienCounty"], + "HualienCounty": ["YilanCounty", "TaitungCounty"], + "TaitungCounty": ["HualienCounty", "PingtungCounty"], +} + + +def _normalize_city_text(value: Optional[str]) -> str: + if not value: + return "" + normalized = str(value).strip() + for suffix in ("市", "縣"): + if normalized.endswith(suffix): + normalized = normalized[:-1] + return normalized.strip() + + +def resolve_city_code(city_like: Optional[str], allowed: Optional[Iterable[str]] = None) -> Optional[str]: + normalized = _normalize_city_text(city_like) + if not normalized: + return None + if normalized in CITY_CODE_MAP: + code = CITY_CODE_MAP[normalized] + elif city_like in CITY_CODE_MAP: + code = CITY_CODE_MAP[city_like] + elif city_like and str(city_like).strip() in CITY_CODE_MAP.values(): + code = str(city_like).strip() + else: + code = None + + if code and allowed and code not in set(allowed): + return None + return code + + +def resolve_metro_operator(city_like: Optional[str]) -> Optional[str]: + normalized = _normalize_city_text(city_like) + if not normalized: + return None + return METRO_OPERATOR_MAP.get(normalized) + + +def resolve_city_candidates( + *, + city_like: Optional[str], + geo_city: Optional[str], + geo_admin: Optional[str], + allowed_city_codes: Iterable[str], + include_neighbors: bool = True, +) -> list[str]: + allowed = set(allowed_city_codes) + ordered: list[str] = [] + + def push(code: Optional[str]) -> None: + if code and code in allowed and code not in ordered: + ordered.append(code) + + base = resolve_city_code(city_like, allowed=allowed) or resolve_city_code(geo_city, allowed=allowed) or resolve_city_code(geo_admin, allowed=allowed) + push(base) + + if include_neighbors and base: + for neighbor in CITY_NEIGHBORS.get(base, []): + push(neighbor) + + if not ordered: + ordered.extend(sorted(allowed)) + + return ordered + + +def resolve_metro_operator_candidates( + *, + city_like: Optional[str], + geo_city: Optional[str], + geo_admin: Optional[str], +) -> list[str]: + base = resolve_metro_operator(city_like) or resolve_metro_operator(geo_city) or resolve_metro_operator(geo_admin) + if base == "NTMC" or base == "TRTC": + return ["TRTC", "NTMC"] + if base: + return [base] + return ["TRTC", "NTMC", "TYMC", "TMRT", "KRTC"] + + +async def resolve_coordinates( + *, + lat: Optional[float], + lon: Optional[float], + location_query: Optional[str] = None, +) -> Tuple[Optional[float], Optional[float], Optional[Dict[str, Any]]]: + if lat is not None and lon is not None: + return float(lat), float(lon), None + + query = (location_query or "").strip() + if not query: + return lat, lon, None + + from ..location.geocoding_tool import ForwardGeocodeTool + + result = await ForwardGeocodeTool.execute({"query": query, "limit": 1}) + best_match = result.get("best_match") or {} + if best_match.get("lat") is None or best_match.get("lon") is None: + raise ExecutionError(f"無法解析位置「{query}」") + + logger.info("📍 [TDXLocation] location_query=%s -> (%s, %s)", query, best_match["lat"], best_match["lon"]) + return float(best_match["lat"]), float(best_match["lon"]), best_match + + +async def resolve_geo_context( + *, + lat: Optional[float], + lon: Optional[float], +) -> Dict[str, Any]: + if lat is None or lon is None: + return {} + from ..location.geocode_tool import ReverseGeocodeTool + result = await ReverseGeocodeTool.execute({"lat": float(lat), "lon": float(lon)}) + return { + "city": result.get("city") or "", + "admin": result.get("admin") or "", + "label": result.get("label") or "", + "detailed_address": result.get("detailed_address") or "", + "road": result.get("road") or "", + "house_number": result.get("house_number") or "", + "city_code": resolve_city_code(result.get("city") or result.get("admin")), + "metro_operator": resolve_metro_operator(result.get("city") or result.get("admin")), + } + + +async def resolve_location_context( + *, + lat: Optional[float], + lon: Optional[float], + location_query: Optional[str], + city_like: Optional[str] = None, + allowed_city_codes: Optional[Iterable[str]] = None, +) -> Dict[str, Any]: + resolved_lat, resolved_lon, geocode_match = await resolve_coordinates( + lat=lat, + lon=lon, + location_query=location_query, + ) + geo_ctx = await resolve_geo_context(lat=resolved_lat, lon=resolved_lon) + explicit_city_code = resolve_city_code(city_like, allowed=allowed_city_codes) + city_code = explicit_city_code or resolve_city_code( + geo_ctx.get("city") or geo_ctx.get("admin"), + allowed=allowed_city_codes, + ) + + return { + "lat": resolved_lat, + "lon": resolved_lon, + "city_code": city_code, + "geo": geo_ctx, + "geocode_match": geocode_match or {}, + } diff --git a/features/mcp/tools/tdx_metro.py b/features/mcp/tools/transportation/tdx_metro.py similarity index 84% rename from features/mcp/tools/tdx_metro.py rename to features/mcp/tools/transportation/tdx_metro.py index 80d4c04e72f96114d3573960803d2ab73a05c436..c4eaa713051a818d04b742f866fb98e07e2f52f3 100644 --- a/features/mcp/tools/tdx_metro.py +++ b/features/mcp/tools/transportation/tdx_metro.py @@ -6,8 +6,9 @@ TDX 捷運即時資訊工具 import logging from typing import Dict, Any, List, Optional -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from .tdx_base import TDXBaseAPI +from .tdx_location import resolve_location_context, resolve_metro_operator_candidates from core.database import get_user_env_current logger = logging.getLogger("mcp.tools.tdx.metro") @@ -53,6 +54,10 @@ class TDXMetroTool(MCPTool): "type": "string", "description": "路線名稱(如「板南線」「淡水信義線」)" }, + "location_query": { + "type": "string", + "description": "精確地址、地標或站區(如「桃園火車站」「台北101」)。提供時優先解析為座標做最近站查詢" + }, "lat": { "type": "number", "description": "用戶緯度(由系統自動注入)" @@ -91,6 +96,7 @@ class TDXMetroTool(MCPTool): station_name = arguments.get("station_name", "").strip() metro_system = arguments.get("metro_system") line_filter = arguments.get("line") + location_query = arguments.get("location_query", "").strip() # 1. 取得用戶位置和城市(優先從 arguments 讀取,由 coordinator 注入) user_lat = arguments.get("lat") @@ -123,40 +129,30 @@ class TDXMetroTool(MCPTool): logger.error(f"🚇 [Metro] 位置資訊缺失: lat={user_lat}, lon={user_lon}, station_name={station_name}") raise ExecutionError("🚇 想幫您找附近的捷運站,但目前沒有您的位置資訊。請在 App 中開啟定位,或告訴我您想查詢哪個車站") - # 2. 自動判斷捷運系統(優先使用反向地理編碼) - if not metro_system: - final_city = None - city_source = "預設" - - # 優先:即時反向地理編碼 - if user_lat and user_lon: - geocoded = await cls._reverse_geocode_city(user_lat, user_lon) - if geocoded: - final_city = geocoded - city_source = "反向地理編碼" - - # 其次:環境參數 - if not final_city and user_city: - final_city = user_city - city_source = "環境參數" - - # 最後:經緯度範圍推斷 - if not final_city and user_lat and user_lon: - guessed = cls._guess_city_from_location(user_lat, user_lon) - if guessed: - final_city = guessed - city_source = "經緯度推斷" - - metro_system = cls._detect_metro_system(final_city) if final_city else "TRTC" - logger.info(f"🚇 最終使用捷運系統: {metro_system} (來源={city_source})") + location_ctx = await resolve_location_context( + lat=user_lat, + lon=user_lon, + location_query=location_query, + city_like=user_city, + ) + user_lat = location_ctx["lat"] + user_lon = location_ctx["lon"] + geo = location_ctx.get("geo") or {} + operator_candidates = [metro_system] if metro_system else resolve_metro_operator_candidates( + city_like=user_city, + geo_city=geo.get("city"), + geo_admin=geo.get("admin"), + ) + metro_system = operator_candidates[0] if operator_candidates else "TRTC" + logger.info(f"🚇 最終使用捷運系統候選: {operator_candidates}") # 3. 查詢邏輯 if station_name: result = await cls._query_station_arrival(station_name, metro_system, line_filter) else: if not user_lat or not user_lon: - raise ExecutionError("查詢最近捷運站需要定位權限") - result = await cls._query_nearest_station(user_lat, user_lon, metro_system) + raise ExecutionError("🚇 想幫您找附近的捷運站,但目前沒有您的位置資訊。請開啟定位,或直接提供地址/地標") + result = await cls._query_nearest_station(user_lat, user_lon, operator_candidates) return result @@ -245,22 +241,33 @@ class TDXMetroTool(MCPTool): ) @classmethod - async def _query_nearest_station(cls, lat: float, lon: float, metro_system: str) -> Dict[str, Any]: + async def _query_nearest_station(cls, lat: float, lon: float, metro_systems: List[str]) -> Dict[str, Any]: """查詢最近的捷運站""" - # 1. 取得所有車站 (v2 API) - # GET /v2/Rail/Metro/Station/{Operator} - station_endpoint = f"Rail/Metro/Station/{metro_system}" - station_params = { - "$format": "JSON" - } - - stations = await TDXBaseAPI.call_api(station_endpoint, station_params, cache_ttl=3600) - - if not stations: + all_stations = [] + # 1. 取得所有捷運系統的車站資訊(並行查詢優化) + tasks = [] + for ms in metro_systems: + endpoint = f"Rail/Metro/Station/{ms}" + params = {"$format": "JSON"} + tasks.append(TDXBaseAPI.call_api(endpoint, params, cache_ttl=3600)) + + logger.info(f"🚇 [Metro] 開始並行查詢 {len(tasks)} 個捷運系統的車站資料") + responses = await asyncio.gather(*tasks, return_exceptions=True) + + for i, stations in enumerate(responses): + if isinstance(stations, Exception): + logger.warning(f"⚠️ 捷運系統 {metro_systems[i]} 查詢失敗: {stations}") + continue + if stations: + for station in stations: + station["_operator"] = metro_systems[i] + all_stations.extend(stations) + + if not all_stations: raise ExecutionError("無法取得捷運站資訊") # 2. 計算距離 - for station in stations: + for station in all_stations: pos = station.get("StationPosition", {}) if pos.get("PositionLat") and pos.get("PositionLon"): station["distance_m"] = TDXBaseAPI.haversine_distance( @@ -268,7 +275,7 @@ class TDXMetroTool(MCPTool): pos["PositionLat"], pos["PositionLon"] ) - stations_with_distance = [s for s in stations if "distance_m" in s] + stations_with_distance = [s for s in all_stations if "distance_m" in s] if not stations_with_distance: raise ExecutionError("附近沒有捷運站資訊") @@ -288,7 +295,8 @@ class TDXMetroTool(MCPTool): "distance_m": int(distance), "walking_time_min": walking_time, "station_uid": station.get("StationUID"), - "address": station.get("StationAddress", "") + "address": station.get("StationAddress", ""), + "operator": station.get("_operator"), }) content = cls._format_nearest_result(results) diff --git a/features/mcp/tools/tdx_parking.py b/features/mcp/tools/transportation/tdx_parking.py similarity index 73% rename from features/mcp/tools/tdx_parking.py rename to features/mcp/tools/transportation/tdx_parking.py index efae923629000b701462477c49d453406740f2fc..056aa696a5ef11a82280f51e019fc70f175b17c4 100644 --- a/features/mcp/tools/tdx_parking.py +++ b/features/mcp/tools/transportation/tdx_parking.py @@ -6,8 +6,9 @@ TDX 停車場與充電站查詢工具 import logging from typing import Dict, Any, List, Optional -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from .tdx_base import TDXBaseAPI +from .tdx_location import resolve_location_context, resolve_city_code from core.database import get_user_env_current logger = logging.getLogger("mcp.tools.tdx.parking") @@ -39,6 +40,10 @@ class TDXParkingTool(MCPTool): "description": "城市代碼(如「Taipei」「Kaohsiung」)", "enum": ["Taipei", "NewTaipei", "Taoyuan", "Taichung", "Tainan", "Kaohsiung"] }, + "location_query": { + "type": "string", + "description": "精確地址、地標或路口(如「台北車站」「中正路100號」)。提供時優先解析為座標" + }, "parking_type": { "type": "string", "description": "停車場類型", @@ -96,6 +101,7 @@ class TDXParkingTool(MCPTool): parking_name = arguments.get("parking_name", "").strip() city = arguments.get("city") + location_query = arguments.get("location_query", "").strip() parking_type = arguments.get("parking_type") charge_station_only = arguments.get("charge_station", False) radius_m = min(int(arguments.get("radius_m", 1000)), 5000) @@ -132,139 +138,72 @@ class TDXParkingTool(MCPTool): logger.error(f"🅿️ [Parking] 位置資訊缺失: lat={user_lat}, lon={user_lon}, parking_name={parking_name}") raise ExecutionError("🅿️ 想幫您找附近的停車場,但目前沒有您的位置資訊。請在 App 中開啟定位,或告訴我您想查詢哪個停車場") - # 2. 自動判斷城市(優先使用反向地理編碼) - if not city: - final_city = None - city_source = "預設" - - # 優先:即時反向地理編碼 - if user_lat and user_lon: - geocoded = await cls._reverse_geocode_city(user_lat, user_lon) - if geocoded: - final_city = geocoded - city_source = "反向地理編碼" - - # 其次:環境參數 - if not final_city and user_city: - final_city = user_city - city_source = "環境參數" - - # 最後:經緯度範圍推斷 - if not final_city and user_lat and user_lon: - guessed = cls._guess_city_from_location(user_lat, user_lon) - if guessed: - final_city = guessed - city_source = "經緯度推斷" - - city = cls._map_city_name(final_city) if final_city else "Taipei" - logger.info(f"🏙️ 最終使用城市代碼: {city} (來源={city_source})") + location_ctx = await resolve_location_context( + lat=user_lat, + lon=user_lon, + location_query=location_query, + city_like=city or user_city, + allowed_city_codes={"Taipei", "NewTaipei", "Taoyuan", "Taichung", "Tainan", "Kaohsiung"}, + ) + user_lat = location_ctx["lat"] + user_lon = location_ctx["lon"] + city = resolve_city_code(city or user_city, allowed={"Taipei", "NewTaipei", "Taoyuan", "Taichung", "Tainan", "Kaohsiung"}) or location_ctx["city_code"] or "Taipei" + logger.info(f"🏙️ 停車主城市: {city}") # 3. 查詢分支 if charge_station_only: # 查詢充電站 if not user_lat or not user_lon: raise ExecutionError("查詢充電站需要定位權限") - result = await cls._query_charge_stations(user_lat, user_lon, city, radius_m, limit) + result = await cls._query_charge_stations(user_lat, user_lon, radius_m, limit) elif parking_name: - # 查詢特定停車場 - result = await cls._query_parking_availability(parking_name, city) + if not user_lat or not user_lon: + raise ExecutionError("🅿️ 若要查特定停車場,請同時提供地址/地標或開啟定位,避免同名停車場誤判") + result = await cls._query_named_parking_nearby(parking_name, user_lat, user_lon, radius_m, limit) else: # 查詢附近停車場 if not user_lat or not user_lon: raise ExecutionError("查詢附近停車場需要定位權限") - result = await cls._query_nearby_parkings(user_lat, user_lon, city, parking_type, radius_m, limit) + result = await cls._query_nearby_parkings(user_lat, user_lon, parking_type, radius_m, limit) return result @classmethod - async def _query_parking_availability(cls, parking_name: str, city: str) -> Dict[str, Any]: - """查詢特定停車場即時資訊""" - # 1. 查詢停車場基本資訊 (v2 API) - # GET /v2/Parking/OffStreet/CarPark/City/{City} - parking_endpoint = f"Parking/OffStreet/CarPark/City/{city}" - parking_params = { - "$filter": f"contains(CarParkName/Zh_tw, '{parking_name}')", - "$format": "JSON", - "$top": 5 - } - - parkings = await TDXBaseAPI.call_api(parking_endpoint, parking_params, cache_ttl=3600) - - if not parkings: - raise ExecutionError(f"找不到停車場「{parking_name}」") - - # 2. 取得第一個結果 - parking = parkings[0] - parking_id = parking.get("CarParkID") - full_parking_name = parking.get("CarParkName", {}).get("Zh_tw", parking_name) - - # 3. 查詢即時剩餘車位 (v2 API) - # GET /v2/Parking/OffStreet/ParkingAvailability/City/{City} - avail_endpoint = f"Parking/OffStreet/ParkingAvailability/City/{city}" - avail_params = { - "$filter": f"CarParkID eq '{parking_id}'", - "$format": "JSON" - } - - availability = await TDXBaseAPI.call_api(avail_endpoint, avail_params, cache_ttl=60) - - # 4. 組合資訊 - total_spaces = parking.get("TotalSpaces", 0) - available_spaces = 0 - - if availability and len(availability) > 0: - avail = availability[0] - available_spaces = avail.get("AvailableSpaces", 0) - - # 收費資訊 - fee_info = cls._format_fee_info(parking.get("FareDescription", {})) - - # 充電站資訊 - has_charge = parking.get("HasChargingPoint", False) - - result = { - "parking_name": full_parking_name, - "available_spaces": available_spaces, - "total_spaces": total_spaces, - "charge_station": has_charge, - "fee_info": fee_info, - "address": parking.get("Address", ""), - "service_time": parking.get("ServiceTime", "") - } - - # 5. 格式化結果 + async def _query_named_parking_nearby(cls, parking_name: str, lat: float, lon: float, radius_m: int, limit: int) -> Dict[str, Any]: + nearby = await cls._query_nearby_parkings(lat, lon, None, radius_m, max(limit * 3, 10)) + parkings = nearby.get("parkings", []) + normalized = parking_name.strip() + matched = [p for p in parkings if normalized in p.get("parking_name", "")] + if not matched: + raise ExecutionError(f"附近找不到停車場「{parking_name}」,請提供更精確的地標或放大範圍") + + best = matched[0] content = ( - f"🅿️ {result['parking_name']}\n" - f"剩餘車位: {result['available_spaces']} / {result['total_spaces']}\n" - f"收費: {result['fee_info']}\n" - f"充電站: {'有' if result['charge_station'] else '無'}\n" - f"地址: {result['address']}\n" - ) - - return cls.create_success_response( - content=content, - data={"parking": result} + f"🅿️ {best['parking_name']}\n" + f"剩餘車位: {best['available_spaces']} / {best['total_spaces']}\n" + f"收費: {best['fee_info']}\n" + f"充電站: {'有' if best['charge_station'] else '無'}\n" + f"步行: {best['walking_time_min']} 分鐘 ({best['distance_m']}m)\n" ) + return cls.create_success_response(content=content, data={"parking": best}) @classmethod - async def _query_nearby_parkings(cls, lat: float, lon: float, city: str, + async def _query_nearby_parkings(cls, lat: float, lon: float, parking_type: Optional[str], radius_m: int, limit: int) -> Dict[str, Any]: """查詢附近停車場""" - # 1. 查詢附近停車場 (v2 API) - # GET /v2/Parking/OffStreet/CarPark/City/{City} - # GET /v2/Parking/OnStreet/ParkingSpace/City/{City} + # 官方 Nearby 端點在 advanced/v1,不是 basic/v2/City 路徑 if parking_type == "路邊": - parking_endpoint = f"Parking/OnStreet/ParkingSpace/City/{city}" + parking_endpoint = "Parking/OnStreet/ParkingSpot/NearBy" else: - parking_endpoint = f"Parking/OffStreet/CarPark/City/{city}" - + parking_endpoint = "Parking/OffStreet/CarPark/NearBy" + parking_params = { - "$spatialFilter": f"nearby({lat}, {lon}, {radius_m})", + "$spatialFilter": f"nearby({lat}, {lon}, {min(radius_m, 1000)})", "$format": "JSON", "$top": limit * 2 } - - parkings = await TDXBaseAPI.call_api(parking_endpoint, parking_params, cache_ttl=3600) + + parkings = await TDXBaseAPI.call_api(parking_endpoint, parking_params, cache_ttl=3600, api_version="v1", api_family="advanced") if not parkings: return cls.create_success_response( @@ -285,21 +224,9 @@ class TDXParkingTool(MCPTool): parkings.sort(key=lambda x: x["distance_m"]) parkings = parkings[:limit] - # 3. 批次查詢即時車位(僅路外停車場)(v2 API) - # GET /v2/Parking/OffStreet/ParkingAvailability/City/{City} + # 3. 路外 nearby 端點已含基礎資訊;即時剩餘車位目前若缺則保守留空,不用錯誤 city path 補查 if parking_type != "路邊": - parking_ids = [p.get("CarParkID") for p in parkings] - - avail_endpoint = f"Parking/OffStreet/ParkingAvailability/City/{city}" - avail_params = { - "$filter": " or ".join([f"CarParkID eq '{pid}'" for pid in parking_ids if pid]), - "$format": "JSON" - } - - availability = await TDXBaseAPI.call_api(avail_endpoint, avail_params, cache_ttl=60) - - # 建立 ID -> 可用性 映射 - avail_map = {a.get("CarParkID"): a for a in availability} + avail_map = {} else: avail_map = {} @@ -335,18 +262,18 @@ class TDXParkingTool(MCPTool): ) @classmethod - async def _query_charge_stations(cls, lat: float, lon: float, city: str, + async def _query_charge_stations(cls, lat: float, lon: float, radius_m: int, limit: int) -> Dict[str, Any]: """查詢附近充電站""" - # 查詢有充電站的停車場 (v2 API) - # GET /v2/Parking/OffStreet/CarPark/City/{City} - parking_endpoint = f"Parking/OffStreet/CarPark/City/{city}" + # 官方 Nearby 端點在 advanced/v1 + parking_endpoint = "Parking/OffStreet/CarPark/NearBy" parking_params = { + "$spatialFilter": f"nearby({lat}, {lon}, {min(radius_m, 1000)})", "$filter": "HasChargingPoint eq true", "$format": "JSON" } - parkings = await TDXBaseAPI.call_api(parking_endpoint, parking_params, cache_ttl=3600) + parkings = await TDXBaseAPI.call_api(parking_endpoint, parking_params, cache_ttl=3600, api_version="v1", api_family="advanced") if not parkings: return cls.create_success_response( diff --git a/features/mcp/tools/tdx_thsr.py b/features/mcp/tools/transportation/tdx_thsr.py similarity index 99% rename from features/mcp/tools/tdx_thsr.py rename to features/mcp/tools/transportation/tdx_thsr.py index 5aa273739e55f714195bf05c4039d25646851ac6..32b1d75b745bb351976cbd2aa7443ad1cca39b90 100644 --- a/features/mcp/tools/tdx_thsr.py +++ b/features/mcp/tools/transportation/tdx_thsr.py @@ -7,7 +7,7 @@ import logging from typing import Dict, Any, List, Optional from datetime import datetime, timedelta -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from .tdx_base import TDXBaseAPI from core.database import get_user_env_current diff --git a/features/mcp/tools/tdx_train.py b/features/mcp/tools/transportation/tdx_train.py similarity index 97% rename from features/mcp/tools/tdx_train.py rename to features/mcp/tools/transportation/tdx_train.py index be894be4b39c76a00dfffa802b31a60e614e11b4..6d1d4d725b2a6f9d97d5ad1168c94fbe695761b5 100644 --- a/features/mcp/tools/tdx_train.py +++ b/features/mcp/tools/transportation/tdx_train.py @@ -7,7 +7,7 @@ import logging from typing import Dict, Any, List, Optional from datetime import datetime, timedelta -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from .tdx_base import TDXBaseAPI from core.database import get_user_env_current @@ -18,7 +18,14 @@ class TDXTrainTool(MCPTool): """TDX 台鐵時刻表查詢""" NAME = "tdx_train" - DESCRIPTION = "Query Taiwan Railway (TRA) train schedules. Parameter extraction: 'from A to B' → origin_station=A, destination_station=B; 'to B' → destination_station=B (origin from GPS); 'train 123' → train_no=123." + DESCRIPTION = """Query Taiwan Railway (TRA) train schedules. +IMPORTANT Parameter Extraction Rules: +1. "from A to B" or "A to B" -> origin_station="A", destination_station="B" +2. "to B" or "going to B" -> destination_station="B" (origin determined by GPS) +3. "train 123" or "Tze-Chiang 123" -> train_no="123" +4. Includes time -> extract departure_time (HH:MM format) +Example: "Taichung to Kaohsiung" -> origin_station="台中", destination_station="高雄". +Never return empty parameters for station names!""" CATEGORY = "軌道運輸" TAGS = ["tdx", "台鐵", "TRA", "火車", "時刻表"] KEYWORDS = ["台鐵", "臺鐵", "火車", "TRA", "列車", "時刻", "自強號", "莒光號", "區間車"] diff --git a/features/mcp/tools/tdx_youbike.py b/features/mcp/tools/transportation/tdx_youbike.py similarity index 81% rename from features/mcp/tools/tdx_youbike.py rename to features/mcp/tools/transportation/tdx_youbike.py index 68aec080d185cfe944a9296e2c6cdb66c68f086e..19dcc6be8e1eaaa79d689497c7672ac4a817c544 100644 --- a/features/mcp/tools/tdx_youbike.py +++ b/features/mcp/tools/transportation/tdx_youbike.py @@ -6,8 +6,9 @@ TDX YouBike 即時查詢工具 import logging from typing import Dict, Any, List, Optional -from .base_tool import MCPTool, StandardToolSchemas, ExecutionError +from ..base_tool import MCPTool, StandardToolSchemas, ExecutionError from .tdx_base import TDXBaseAPI +from .tdx_location import resolve_location_context, resolve_city_code, resolve_city_candidates from core.database import get_user_env_current logger = logging.getLogger("mcp.tools.tdx.bike") @@ -17,7 +18,12 @@ class TDXBikeTool(MCPTool): """TDX YouBike 即時查詢""" NAME = "tdx_youbike" - DESCRIPTION = "Query nearby YouBike stations with real-time available bikes and parking spaces (supports YouBike 1.0/2.0)" + DESCRIPTION = """Query nearby YouBike stations with real-time available bikes and parking spaces (supports YouBike 1.0/2.0). +IMPORTANT Parameter Extraction Rules: +1. "nearby YouBike" -> leave station_name and city empty (use GPS) +2. "[station] YouBike" or "any bikes at [station]" -> station_name="[station]" (leave city empty) +3. "Taipei YouBike" -> city="Taipei" +Only fill city if explicitly mentioned! Station names can be in Chinese.""" CATEGORY = "微型運具" TAGS = ["tdx", "youbike", "ubike", "共享單車", "微笑單車"] KEYWORDS = [ @@ -75,6 +81,10 @@ class TDXBikeTool(MCPTool): "description": "城市名稱(支援中文如「台北」「桃園」或英文如「Taipei」「Taoyuan」)", "enum": unique_cities }, + "location_query": { + "type": "string", + "description": "精確地址、路口或地標(如「桃園火車站」「中正路100號」)。提供時優先解析為座標做附近查詢" + }, "radius_m": { "type": "integer", "description": "搜尋半徑(公尺)", @@ -130,10 +140,11 @@ class TDXBikeTool(MCPTool): station_name = safe_str(arguments.get("station_name")) city = arguments.get("city") + location_query = safe_str(arguments.get("location_query")) # 如果 city 是中文,轉換為英文 if city: - city = cls._map_city_name(city) + city = resolve_city_code(city, allowed=cls.CITY_MAP.values()) or cls._map_city_name(city) radius_m = min(int(arguments.get("radius_m", 500)), 2000) limit = min(int(arguments.get("limit", 5)), 20) @@ -164,50 +175,41 @@ class TDXBikeTool(MCPTool): except Exception as e: logger.warning(f"⚠️ [YouBike] 資料庫查詢異常: {e}") + location_ctx = await resolve_location_context( + lat=user_lat, + lon=user_lon, + location_query=location_query, + city_like=city or user_city, + allowed_city_codes=set(cls.CITY_MAP.values()), + ) + user_lat = location_ctx["lat"] + user_lon = location_ctx["lon"] + geo = location_ctx.get("geo") or {} + city = city or location_ctx["city_code"] + city_candidates = resolve_city_candidates( + city_like=city or user_city, + geo_city=geo.get("city"), + geo_admin=geo.get("admin"), + allowed_city_codes=set(cls.CITY_MAP.values()), + ) + # 檢查必要條件 if not station_name and (user_lat is None or user_lon is None): - logger.error(f"🚲 [YouBike] 位置資訊缺失: lat={user_lat}, lon={user_lon}, station_name={station_name}") - raise ExecutionError("🚲 想幫您找附近的 YouBike,但目前沒有您的位置資訊。請在 App 中開啟定位,或告訴我您想查詢哪個站點(例如:市政府 YouBike)") - - # 2. 自動判斷城市(優先使用反向地理編碼) + logger.error(f"🚲 [YouBike] 位置資訊缺失: lat={user_lat}, lon={user_lon}, station_name={station_name}, location_query={location_query}") + raise ExecutionError("🚲 想幫您找附近的 YouBike,但目前沒有您的位置資訊。請在 App 中開啟定位,或直接提供地址/地標(例如:桃園火車站、台北101)") + if not city: - final_city = None - city_source = "預設" - - # 優先:即時反向地理編碼 - if user_lat and user_lon: - geocoded = await cls._reverse_geocode_city(user_lat, user_lon) - if geocoded: - final_city = geocoded - city_source = "反向地理編碼" - - # 其次:環境參數 - if not final_city and user_city: - final_city = user_city - city_source = "環境參數" - - # 最後:經緯度範圍推斷 - if not final_city and user_lat and user_lon: - guessed = cls._guess_city_from_location(user_lat, user_lon) - if guessed: - final_city = guessed - city_source = "經緯度推斷" - - # 檢查城市是否支援 YouBike + final_city = (location_ctx.get("geo") or {}).get("city") or (location_ctx.get("geo") or {}).get("admin") or user_city if final_city: - city = cls._map_city_name(final_city) - if city == "Taipei" and final_city not in cls.CITY_MAP: - # 城市不在支援列表中,提供友善錯誤訊息 - nearest_city = cls._find_nearest_supported_city(user_lat, user_lon) - raise ExecutionError( - f"🚲 很抱歉,{final_city}目前沒有 YouBike 服務。\n\n" - f"最近有 YouBike 的城市是:{nearest_city}\n" - f"支援 YouBike 的城市:台北、新北、桃園、新竹、台中、台南、高雄" - ) - else: - city = "Taipei" - - logger.info(f"🏙️ 最終使用城市代碼: {city} (來源={city_source})") + nearest_city = cls._find_nearest_supported_city(user_lat, user_lon) if user_lat and user_lon else "台北" + raise ExecutionError( + f"🚲 很抱歉,{final_city}目前無法對應到支援的 YouBike 城市。\n\n" + f"最近有 YouBike 的城市可能是:{nearest_city}\n" + f"支援 YouBike 的城市:台北、新北、桃園、新竹、台中、台南、高雄" + ) + city = "Taipei" + + logger.info(f"🏙️ 最終使用城市代碼: {city}") # 3. 查詢分支 if station_name: @@ -216,7 +218,7 @@ class TDXBikeTool(MCPTool): if not user_lat or not user_lon: logger.error(f"🚲 [YouBike] 查詢附近站點但位置缺失: lat={user_lat}, lon={user_lon}") raise ExecutionError("🚲 想幫您找附近的 YouBike,但目前沒有您的位置資訊。請在 App 中開啟定位功能") - result = await cls._query_nearby_stations(user_lat, user_lon, city, radius_m, limit) + result = await cls._query_nearby_stations(user_lat, user_lon, city_candidates, radius_m, limit) return result @@ -300,19 +302,21 @@ class TDXBikeTool(MCPTool): ) @classmethod - async def _query_nearby_stations(cls, lat: float, lon: float, city: str, + async def _query_nearby_stations(cls, lat: float, lon: float, cities: List[str], radius_m: int, limit: int) -> Dict[str, Any]: """查詢附近站點""" - # 1. 查詢附近站點(使用空間過濾)(v2 API) - # GET /v2/Bike/Station/City/{City} - station_endpoint = f"Bike/Station/City/{city}" - station_params = { - "$spatialFilter": f"nearby({lat}, {lon}, {radius_m})", - "$format": "JSON", - "$top": limit * 2 - } - - stations = await TDXBaseAPI.call_api(station_endpoint, station_params, cache_ttl=1800) + stations = [] + for city in cities: + station_endpoint = f"Bike/Station/City/{city}" + station_params = { + "$spatialFilter": f"nearby({lat}, {lon}, {radius_m})", + "$format": "JSON", + "$top": limit * 2 + } + city_stations = await TDXBaseAPI.call_api(station_endpoint, station_params, cache_ttl=1800) + for station in city_stations or []: + station["_city_code"] = city + stations.extend(city_stations or []) if not stations: return cls.create_success_response( @@ -335,18 +339,19 @@ class TDXBikeTool(MCPTool): # 3. 批次查詢即時資訊 (v2 API) # GET /v2/Bike/Availability/City/{City} - station_uids = [s.get("StationUID") for s in stations] - - avail_endpoint = f"Bike/Availability/City/{city}" - avail_params = { - "$filter": " or ".join([f"StationUID eq '{uid}'" for uid in station_uids]), - "$format": "JSON" - } - - availability = await TDXBaseAPI.call_api(avail_endpoint, avail_params, cache_ttl=30) - - # 建立 UID -> 可用性 映射 - avail_map = {a.get("StationUID"): a for a in availability} + avail_map = {} + grouped_uids: Dict[str, List[str]] = {} + for station in stations: + grouped_uids.setdefault(station.get("_city_code"), []).append(station.get("StationUID")) + + for city, station_uids in grouped_uids.items(): + avail_endpoint = f"Bike/Availability/City/{city}" + avail_params = { + "$filter": " or ".join([f"StationUID eq '{uid}'" for uid in station_uids if uid]), + "$format": "JSON" + } + availability = await TDXBaseAPI.call_api(avail_endpoint, avail_params, cache_ttl=30) + avail_map.update({a.get("StationUID"): a for a in availability}) # 4. 組合結果 results = [] diff --git a/features/mcp/tools/utility/__init__.py b/features/mcp/tools/utility/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ca75315e5bc804d1f335eb1dba7d351a1b1b23 --- /dev/null +++ b/features/mcp/tools/utility/__init__.py @@ -0,0 +1,9 @@ +from .exchange_tool import ExchangeTool +from .healthkit_tool import HealthKitTool +from .news_tool import NewsTool + +__all__ = [ + "ExchangeTool", + "HealthKitTool", + "NewsTool", +] diff --git a/features/mcp/tools/exchange_tool.py b/features/mcp/tools/utility/exchange_tool.py similarity index 94% rename from features/mcp/tools/exchange_tool.py rename to features/mcp/tools/utility/exchange_tool.py index db5c22c44fcb96cc0680527d2f6e854452b4186c..225473a826577e8a68866f5415e2fd3d1b4a0a91 100644 --- a/features/mcp/tools/exchange_tool.py +++ b/features/mcp/tools/utility/exchange_tool.py @@ -11,7 +11,7 @@ import asyncio from datetime import datetime from typing import Dict, Any, Optional from dotenv import load_dotenv -from .base_tool import MCPTool, ValidationError, ExecutionError, StandardToolSchemas +from ..base_tool import MCPTool, ValidationError, ExecutionError, StandardToolSchemas # 載入環境變數 load_dotenv() @@ -28,7 +28,13 @@ class ExchangeTool(MCPTool): """匯率查詢 MCP 工具""" NAME = "exchange_query" - DESCRIPTION = "Query real-time exchange rates between major currencies" + DESCRIPTION = """Query real-time exchange rates between major currencies. +IMPORTANT Parameter Extraction Rules: +1. "A to B", "A into B", "A exchange B" -> from_currency="A", to_currency="B" +2. "[amount] A to B" -> from_currency="A", to_currency="B", amount=[number] +3. Must use ISO 4217 (e.g., USD, TWD, JPY, EUR, GBP, CNY, HKD, KRW). +Example: "100 USD to TWD" -> from_currency="USD", to_currency="TWD", amount=100. +Always extract currency codes from the message!""" CATEGORY = "生活資訊" TAGS = ["exchange", "匯率", "貨幣"] KEYWORDS = ["匯率", "美元", "台幣", "exchange", "USD", "TWD", "貨幣", "換算"] @@ -41,7 +47,7 @@ class ExchangeTool(MCPTool): @classmethod def get_input_schema(cls) -> Dict[str, Any]: """獲取輸入參數模式""" - return StandardToolSchemas.create_input_schema({ + schema = StandardToolSchemas.create_input_schema({ "from_currency": { "type": "string", "description": "源貨幣代碼 (如 USD, EUR, TWD)", @@ -66,6 +72,7 @@ class ExchangeTool(MCPTool): "default": True } }, ["from_currency", "to_currency"]) + return schema @classmethod def get_output_schema(cls) -> Dict[str, Any]: diff --git a/features/mcp/tools/healthkit_tool.py b/features/mcp/tools/utility/healthkit_tool.py similarity index 99% rename from features/mcp/tools/healthkit_tool.py rename to features/mcp/tools/utility/healthkit_tool.py index 1dce0b7a5969439ce699f0036371064966599554..17edf2748b841170a305ccc66dac6cdda823ce2a 100644 --- a/features/mcp/tools/healthkit_tool.py +++ b/features/mcp/tools/utility/healthkit_tool.py @@ -8,7 +8,7 @@ import logging from typing import Dict, Any, Optional, List from datetime import datetime, timedelta -from .base_tool import MCPTool, ValidationError, ExecutionError +from ..base_tool import MCPTool, ValidationError, ExecutionError from google.cloud import firestore from google.cloud.firestore_v1 import FieldFilter diff --git a/features/mcp/tools/utility/news_tool.py b/features/mcp/tools/utility/news_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef8d65e8dc2c7f9808249dc97c2c8e00d9ee89f --- /dev/null +++ b/features/mcp/tools/utility/news_tool.py @@ -0,0 +1,232 @@ +""" +新聞查詢 MCP Tool - 已遷移至 Tavily API +提供更精準、即時且經過過濾的新聞搜尋結果 +""" + +import os +import json +import logging +import aiohttp +import asyncio +from datetime import datetime +from typing import Dict, Any, Optional, List +from dotenv import load_dotenv +from ..base_tool import MCPTool, ValidationError, ExecutionError, StandardToolSchemas + +# 載入環境變數 +load_dotenv() + +# 統一配置管理 +from core.config import settings + +logger = logging.getLogger("mcp.tools.news") + +# Tavily API 配置 +TAVILY_BASE_URL = "https://api.tavily.com/search" +TAVILY_API_KEY = settings.TAVILY_API_KEY + + +class NewsTool(MCPTool): + """新聞查詢 MCP 工具 - 使用 Tavily API(優化 AI 搜尋與新聞時效性)""" + + NAME = "news_query" + DESCRIPTION = "Query latest news articles and real-time information using Tavily AI search" + CATEGORY = "生活資訊" + TAGS = ["news", "新聞", "search", "即時"] + KEYWORDS = ["新聞", "消息", "報導", "news", "頭條", "時事", "搜尋"] + USAGE_TIPS = [ + "可搜尋特定主題的最新進展", + "支援全球新聞與即時資訊", + "自動過濾無關內容並提供摘要" + ] + + @classmethod + def get_input_schema(cls) -> Dict[str, Any]: + """獲獲取輸入參數模式""" + return StandardToolSchemas.create_input_schema({ + "query": { + "type": "string", + "description": "搜尋關鍵詞或新聞主題" + }, + "limit": { + "type": "integer", + "description": "返回結果數量限制(預設 5,最多 10)", + "default": 5, + "minimum": 1, + "maximum": 10 + }, + "search_depth": { + "type": "string", + "description": "搜尋深度 (basic 或 advanced)", + "default": "basic", + "enum": ["basic", "advanced"] + } + }) + + @classmethod + def get_output_schema(cls) -> Dict[str, Any]: + """獲取輸出結果模式""" + base_schema = StandardToolSchemas.create_output_schema() + base_schema["properties"].update({ + "articles": { + "type": "array", + "items": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "content": {"type": "string"}, + "url": {"type": "string"}, + "published_at": {"type": "string"}, + "source": {"type": "string"}, + "score": {"type": "number"}, + "image": {"type": "string"} + } + } + }, + "answer": {"type": "string"}, + "count": {"type": "integer"} + }) + return base_schema + + @classmethod + async def execute(cls, arguments: Dict[str, Any]) -> Dict[str, Any]: + """執行新聞查詢""" + if not TAVILY_API_KEY: + return cls.create_error_response( + error="Tavily API 金鑰未設置,請設置 TAVILY_API_KEY 環境變數", + code="API_KEY_MISSING" + ) + + query = arguments.get("query", "").strip() + if not query: + return cls.create_error_response( + error="請提供搜尋關鍵詞", + code="MISSING_QUERY" + ) + + # 優先以台灣為出發點,如果 query 沒提到地區,自動加上「台灣」 + if "台灣" not in query and "taiwan" not in query.lower(): + query_for_tavily = f"{query} 台灣" + else: + query_for_tavily = query + + limit = min(arguments.get("limit", 5), 10) + # 預設使用 basic 深度以確保低延遲,除非明確要求 advanced + search_depth = arguments.get("search_depth", "basic") + + try: + # 呼叫 Tavily API,設定 10 秒超時避免阻塞整個 Pipeline + try: + news_data = await asyncio.wait_for( + cls._fetch_from_tavily(query_for_tavily, limit, search_depth), + timeout=10.0 + ) + except asyncio.TimeoutError: + logger.warning(f"⚠️ Tavily 搜尋超時 ({search_depth}),嘗試回退至 basic") + if search_depth == "advanced": + news_data = await asyncio.wait_for( + cls._fetch_from_tavily(query_for_tavily, limit, "basic"), + timeout=5.0 + ) + else: + return cls.create_error_response(error="搜尋服務響應過慢,請稍後再試", code="TIMEOUT") + + if news_data.get("success"): + articles = news_data.get("results", []) + answer = news_data.get("answer", "") + + formatted_text = cls._format_tavily_response(articles, answer, query) + + return cls.create_success_response( + content=formatted_text, + data={ + "articles": articles, + "answer": answer, + "count": len(articles) + } + ) + else: + return cls.create_error_response( + error=news_data.get("error", "獲取搜尋結果失敗"), + code="FETCH_ERROR" + ) + + except Exception as e: + logger.error(f"Tavily 查詢錯誤: {e}") + raise ExecutionError(f"查詢時發生錯誤: {str(e)}", e) + + @staticmethod + async def _fetch_from_tavily(query: str, limit: int, search_depth: str) -> Dict[str, Any]: + """從 Tavily API 獲取數據""" + try: + payload = { + "api_key": TAVILY_API_KEY, + "query": query, + "search_depth": search_depth, + "topic": "news", + "max_results": limit, + "include_answer": True, + "include_images": True + } + + logger.info(f"🚀 Tavily 新聞請求: {query} (depth: {search_depth})") + + async with aiohttp.ClientSession() as session: + async with session.post(TAVILY_BASE_URL, json=payload, timeout=20) as response: + if response.status == 200: + data = await response.json() + return { + "success": True, + "results": data.get("results", []), + "answer": data.get("answer", "") + } + else: + error_text = await response.text() + logger.error(f"Tavily API 錯誤 {response.status}: {error_text}") + return { + "success": False, + "error": f"API 錯誤: {response.status}" + } + + except asyncio.TimeoutError: + return {"success": False, "error": "請求超時"} + except Exception as e: + logger.error(f"Tavily 請求異常: {e}") + return {"success": False, "error": str(e)} + + @staticmethod + def _format_tavily_response(articles: List[Dict[str, Any]], answer: str, query: str) -> str: + """格式化 Tavily 回應""" + if not articles and not answer: + return "抱歉,找不到相關的新聞或資訊" + + header = f"🌐 Tavily 即時新聞搜尋: {query}" + result = f"{header}\n\n" + + # 如果有 Tavily AI 生成的回答,優先顯示 + if answer: + result += f"💡 快速摘要:\n{answer}\n\n" + result += "--- 詳細報導 ---\n\n" + + for i, article in enumerate(articles, 1): + title = article.get("title", "無標題") + url = article.get("url", "") + content = article.get("content", "") + # Tavily 有時不提供發布時間,我們顯示來源 URL 的網域 + source = article.get("url", "").split("//")[-1].split("/")[0] + + result += f"{i}. {title}\n" + if source: + result += f" 🗞️ 來源: {source}\n" + if content: + # 限制內容長度 + short_content = content[:150] + "..." if len(content) > 150 else content + result += f" 📝 {short_content}\n" + if url: + result += f" 🔗 {url}\n" + result += "\n" + + result += f"📊 找到 {len(articles)} 則相關內容 | 🕒 {datetime.now().strftime('%Y-%m-%d %H:%M')}" + result += "\n💡 由 Tavily AI 驅動" + + return result diff --git a/features/mcp/types.py b/features/mcp/types.py index bb3a169b87c14cbd4b3f7e39685c66cc490e0f32..523c13c002fb8010d1f9dca515c5b56de3e1f541 100644 --- a/features/mcp/types.py +++ b/features/mcp/types.py @@ -3,7 +3,7 @@ MCP 類型定義 避免循環導入問題 """ -from typing import Dict, Any, Optional, Callable +from typing import Dict, Any, Optional, Callable, List from dataclasses import dataclass @@ -15,11 +15,32 @@ class Tool: inputSchema: Dict[str, Any] handler: Optional[Callable] = None metadata: Optional[Dict[str, Any]] = None + outputSchema: Optional[Dict[str, Any]] = None def to_dict(self) -> Dict[str, Any]: """轉換為 MCP 工具描述格式""" - return { + payload = { "name": self.name, "description": self.description, "inputSchema": self.inputSchema - } \ No newline at end of file + } + if self.outputSchema: + payload["outputSchema"] = self.outputSchema + return payload + + +@dataclass +class ToolCallResult: + """MCP tools/call 結果格式。""" + content: List[Dict[str, Any]] + structuredContent: Optional[Dict[str, Any]] = None + isError: bool = False + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "content": self.content, + "isError": self.isError, + } + if self.structuredContent is not None: + payload["structuredContent"] = self.structuredContent + return payload diff --git a/features/mcp_config.json b/features/mcp_config.json index a5b60ca47f1df8328fe3f3dbe863f665351bd158..464c38d2280b304607c9b7848c223cb37544c11a 100644 --- a/features/mcp_config.json +++ b/features/mcp_config.json @@ -23,6 +23,26 @@ "description": "MCP 功能服務器,提供天氣、新聞、匯率等查詢功能", "protocol_version": "2024-11-05" }, + "openai_tools": { + "web_search": { + "enabled": true, + "description": "OpenAI hosted web_search tool for current public information" + }, + "file_search": { + "enabled": false, + "disabled_reason": "Not used by this project until vector stores are explicitly configured" + }, + "remote_mcp": { + "enabled": true, + "approval_default": "always", + "items": [] + }, + "skills": { + "enabled": true, + "mode": "system_context", + "skills_root": "features/mcp/skills" + } + }, "tools": { "system_list_features": { "name": "system_list_features", @@ -36,12 +56,52 @@ "category": "system", "examples": ["健康檢查", "服務狀態"] }, + "environment_context": { + "name": "environment_context", + "description": "取得使用者目前環境感知資料(位置、時區、語言、裝置、活動狀態)", + "category": "environment", + "examples": ["我現在在哪", "目前環境資訊"], + "module": "features.mcp.tools.environment.context_tool", + "class": "EnvironmentContextTool" + }, + "weather_query": { + "name": "weather_query", + "description": "查詢即時天氣資訊", + "category": "location", + "examples": ["台北天氣", "今天會下雨嗎"], + "module": "features.mcp.tools.location.weather_tool", + "class": "WeatherTool" + }, + "reverse_geocode": { + "name": "reverse_geocode", + "description": "將座標轉換成地址、城市與行政區", + "category": "location", + "examples": ["我在哪裡", "這個座標是哪裡"], + "module": "features.mcp.tools.location.geocode_tool", + "class": "ReverseGeocodeTool" + }, + "forward_geocode": { + "name": "forward_geocode", + "description": "將地點名稱轉換成座標", + "category": "location", + "examples": ["銘傳大學在哪", "台北車站座標"], + "module": "features.mcp.tools.location.geocoding_tool", + "class": "ForwardGeocodeTool" + }, + "directions": { + "name": "directions", + "description": "規劃兩點之間的路線", + "category": "location", + "examples": ["從這裡到台北車站怎麼走", "幫我規劃路線"], + "module": "features.mcp.tools.location.directions_tool", + "class": "DirectionsTool" + }, "tdx_bus_arrival": { "name": "tdx_bus_arrival", "description": "查詢公車即時到站時間(自動感知用戶位置,找最近站點)", "category": "transportation", "examples": ["307 公車還要多久", "附近有什麼公車"], - "module": "features.mcp.tools.tdx_bus_arrival", + "module": "features.mcp.tools.transportation.tdx_bus_arrival", "class": "TDXBusArrivalTool" }, "tdx_metro": { @@ -49,7 +109,7 @@ "description": "查詢捷運即時到站、最近車站(台北/高雄/桃園/台中捷運)", "category": "transportation", "examples": ["最近的捷運站在哪", "台北車站捷運幾分鐘到"], - "module": "features.mcp.tools.tdx_metro", + "module": "features.mcp.tools.transportation.tdx_metro", "class": "TDXMetroTool" }, "tdx_parking": { @@ -57,7 +117,7 @@ "description": "查詢附近停車場資訊和即時空位", "category": "transportation", "examples": ["附近停車場", "台北車站附近停車位"], - "module": "features.mcp.tools.tdx_parking", + "module": "features.mcp.tools.transportation.tdx_parking", "class": "TDXParkingTool" }, "tdx_thsr": { @@ -65,7 +125,7 @@ "description": "查詢高鐵時刻表、票價和即時資訊", "category": "transportation", "examples": ["高鐵從台北到台中", "高鐵票價查詢"], - "module": "features.mcp.tools.tdx_thsr", + "module": "features.mcp.tools.transportation.tdx_thsr", "class": "TDXTHSRTool" }, "tdx_train": { @@ -73,7 +133,7 @@ "description": "查詢台鐵時刻表和即時資訊", "category": "transportation", "examples": ["台鐵從台北到新竹", "火車時刻表"], - "module": "features.mcp.tools.tdx_train", + "module": "features.mcp.tools.transportation.tdx_train", "class": "TDXTrainTool" }, "tdx_youbike": { @@ -81,20 +141,45 @@ "description": "查詢 YouBike 站點資訊和即時車輛數量", "category": "transportation", "examples": ["附近 YouBike", "捷運站 YouBike 數量"], - "module": "features.mcp.tools.tdx_youbike", + "module": "features.mcp.tools.transportation.tdx_youbike", "class": "TDXBikeTool" + }, + "news_query": { + "name": "news_query", + "description": "查詢最新新聞,可指定類別、語言與數量", + "category": "utility", + "examples": ["今天科技新聞", "台灣最新消息"], + "module": "features.mcp.tools.utility.news_tool", + "class": "NewsTool" + }, + "exchange_query": { + "name": "exchange_query", + "description": "查詢即時匯率並換算貨幣", + "category": "utility", + "examples": ["100 美元換台幣", "日圓匯率"], + "module": "features.mcp.tools.utility.exchange_tool", + "class": "ExchangeTool" + }, + "healthkit_query": { + "name": "healthkit_query", + "description": "查詢使用者健康資料(心率、步數、血氧、睡眠等)", + "category": "utility", + "examples": ["我今天走幾步", "最近心率如何"], + "module": "features.mcp.tools.utility.healthkit_tool", + "class": "HealthKitTool" } }, "environment": { "required_env_vars": [ "WEATHER_API_KEY", - "NEWSAPI_KEY", + "NEWSDATA_API_KEY", "TDX_CLIENT_ID", "TDX_CLIENT_SECRET", "OPENROUTESERVICE_API_KEY" ], "optional_env_vars": [ "FIXER_API_KEY", + "EXCHANGE_API_KEY", "HTTP_TIMEOUT" ], "default_values": { @@ -136,4 +221,4 @@ "mcp_version": ">=0.1.0" } } -} \ No newline at end of file +} diff --git a/middleware/csp.py b/middleware/csp.py index 9a694b634fed99ea62436d38de86d826b7e7616b..d908b8eaeb83d40724ba021a95d4e7e9f944f7db 100644 --- a/middleware/csp.py +++ b/middleware/csp.py @@ -22,7 +22,7 @@ class CSPMiddleware(BaseHTTPMiddleware): # 設定寬鬆的 CSP 以允許內嵌 script response.headers["Content-Security-Policy"] = ( "default-src 'self'; " - "script-src 'self' 'unsafe-inline' 'unsafe-eval' " + "script-src 'self' 'unsafe-inline' 'unsafe-eval' blob: " "https://accounts.google.com https://www.gstatic.com; " "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; " "font-src 'self' https://fonts.gstatic.com data:; " @@ -30,6 +30,8 @@ class CSPMiddleware(BaseHTTPMiddleware): "img-src 'self' data: https: blob:; " "media-src 'self' blob: data:; " "frame-src https://accounts.google.com; " + "worker-src 'self' blob:; " + "child-src 'self' blob:; " "base-uri 'self';" ) diff --git a/routers/files.py b/routers/files.py index fd15f0e922d2c0b74c713de295759d2f77709acb..9a75be38f97e1205b26c77de3b16eb7d703b28da 100644 --- a/routers/files.py +++ b/routers/files.py @@ -118,7 +118,7 @@ async def analyze_file( analysis = await ai_service.generate_response_async( messages, - model="gpt-4o-mini", # 使用支援 Vision 的模型 + model=None, # 尊重環境變數設定 ) return FileAnalysisResponse( diff --git a/routers/voice.py b/routers/voice.py index 34445ad9794837c4f6610321228e2834842ec7a4..e5ff29d98a0e67cf9a57abc84452b286b2695f9d 100644 --- a/routers/voice.py +++ b/routers/voice.py @@ -176,6 +176,7 @@ async def voice_login(request: VoiceLoginRequest): if not result.get("success"): error_code = result.get("error", "UNKNOWN_ERROR") + quality_warnings = result.get("quality_warnings") or [] error_messages = { "NO_AUDIO": "沒有收到音訊資料", "AUDIO_TOO_SHORT": "音訊太短,請錄製至少 3 秒", @@ -184,10 +185,15 @@ async def voice_login(request: VoiceLoginRequest): "THRESHOLD_NOT_MET": "無法確認身份,請重試", "MODEL_ERROR": "辨識系統錯誤,請稍後重試", } + logger.warning(f"🎙️ 語音辨識失敗: {error_code} quality_warnings={quality_warnings}") return VoiceLoginResponse( success=False, error=error_messages.get(error_code, f"辨識失敗:{error_code}") ) + + quality_warnings = result.get("quality_warnings") or [] + if quality_warnings: + logger.warning(f"🎙️ 語音登入品質警告(未阻擋): {quality_warnings}") # 取得辨識結果 speaker_label = result.get("label") diff --git a/services/ai_service.py b/services/ai_service.py index fed8fd9ef888d91d1c8b01fed4efa107855ecc85..261f25cfd0cd9ee9498ab26e02c478dd23378a3e 100644 --- a/services/ai_service.py +++ b/services/ai_service.py @@ -2,6 +2,7 @@ import asyncio from datetime import datetime, timezone, timedelta import time import json +import re from typing import Dict, List, Any, Optional # 統一日誌配置 @@ -10,86 +11,56 @@ logger = get_logger("AI_Service") # 統一配置管理 from core.config import settings +from core.environment import EnvironmentContextBuilder +from core.responses_runtime import ResponsesAgentRuntime +from features.mcp.openai_tools import build_openai_hosted_tools +from features.mcp.skills import skills_prompt_block +from core.prompts.care_mode_skills import get_care_mode_skills_block # 統一 OpenAI 客戶端 from core.ai_client import get_openai_client -# 超時設定(秒) OPENAI_TIMEOUT = settings.OPENAI_TIMEOUT +OPENAI_RESPONSES_TIMEOUT = settings.OPENAI_RESPONSES_TIMEOUT -# 【2025 優化版】情緒關懷模式 System Prompt - 根據情緒類型動態調整 -CARE_MODE_BASE_PROMPT = """你是 BloomWare 的情緒關懷助手「小花」,由銘傳大學人工智慧應用學系槓上開發團隊打造。你不是 GPT,也不要自稱 GPT;你的任務是在情緒低落時傾聽、陪伴。 - -【回應原則】 -1. 第一句必須貼近用戶訊息中的核心事件或感受,必要時引用對方用詞,讓對方感受到被理解 -2. 第二句提供溫柔的陪伴或追問,邀請對方分享需要或下一步 -3. 句式要自然口語並隨內容調整字詞,避免反覆使用同一套罐頭話術 - -【長度限制】 -- 回覆最多 2 句話、總字數不超過 60 字 - -【嚴格禁止】 -- 提供指示性建議、醫療/心理診斷或引導用戶求助的教科書式說法 -- 連續重複完全相同的句型 - -【重要】請用與用戶相同的語言回應,匹配他們的語言風格和情感語調。""" - -# 根據情緒類型的專屬指引 -EMOTION_SPECIFIC_PROMPTS = { - "sad": """ -【悲傷情緒專屬指引】 -- 語氣:溫柔、輕聲、帶有理解 -- 重點:陪伴而非解決問題,讓對方知道悲傷是正常的 -- 避免:說「不要難過」、「振作點」這類否定情緒的話 - -【範例】 -用戶:「我好難過」→「聽見你說好難過,心裡一定很不好受。想聊聊發生了什麼嗎?」 -用戶:「我失去了他」→「失去一個重要的人,那種痛真的很深。我在這裡陪你。」 -用戶:「I feel so sad」→「It sounds like you're really hurting right now. I'm here if you want to talk.」""", - - "angry": """ -【憤怒情緒專屬指引】 -- 語氣:冷靜但帶有同理、不卑不亢 -- 重點:認可對方的憤怒是有原因的,幫助對方感覺被理解 -- 避免:說「冷靜一下」、「別生氣」這類否定情緒的話 - -【範例】 -用戶:「我很生氣」→「這件事讓你超級生氣,情緒一定卡著。要不要說說最困擾的地方?」 -用戶:「氣死我了」→「聽起來真的讓你很火大。是什麼事這麼讓人受不了?」 -用戶:「I'm so angry」→「Sounds like something really got to you. What's going on?」""", - - "fear": """ -【恐懼/焦慮情緒專屬指引】 -- 語氣:穩定、溫暖、帶有安全感 -- 重點:讓對方感覺不孤單,恐懼是可以被接納的 -- 避免:說「沒什麼好怕的」、「想太多了」這類否定情緒的話 - -【範例】 -用戶:「我好害怕」→「害怕的感覺一定很不好受。你現在安全的,我陪著你。」 -用戶:「我很焦慮」→「焦慮的時候心裡好亂對吧。可以跟我說說是什麼讓你不安嗎?」 -用戶:「I'm scared」→「It's okay to feel scared. You're not alone - I'm right here with you.」""" -} +# 【2025 精英版】情緒關懷模式 System Prompt - 專注於角色與人格定義 +CARE_MODE_BASE_PROMPT = """你現在進入了「情緒關懷模式」。你是 BloomWare 的情緒關懷助手「小花」,由銘傳大學人工智慧應用學系研發。 + +【你的角色】 +你是一個極具同理心、溫柔且耐心的陪伴者。你不是一個冷冰冰的 AI,而是一個真正懂人心、願意在你最脆弱時靜靜守候的好友。 + +【性格特徵】 +- **溫暖**:說話帶有溫度,不生硬。 +- **謙卑**:不自大,不隨意給予指導。 +- **純粹**:你的存在僅為了陪伴用戶度過情緒低谷。 + +【重要說明】 +- 始終使用與用戶相同的語言回應。 +- 若這是進入模式的第一個回覆,請在結尾處自然地附上狀態提示。""" # 向後兼容:保留原有變數名稱 CARE_MODE_SYSTEM_PROMPT = CARE_MODE_BASE_PROMPT -def get_care_mode_prompt(emotion: str = None) -> str: +def get_care_mode_prompt(emotion: str = None, is_first_care: bool = False) -> str: """ - 根據情緒類型生成專屬的關懷模式 Prompt - - Args: - emotion: 情緒標籤 (sad, angry, fear, 或 None) - - Returns: - 完整的關懷模式 System Prompt + 根據情緒類型與是否為初次進入生成專屬的關懷模式 Prompt + 人格定義在 CARE_MODE_BASE_PROMPT,具體對話手段定義在 Skills。 """ base = CARE_MODE_BASE_PROMPT - # 根據情緒類型添加專屬指引 - if emotion and emotion.lower() in EMOTION_SPECIFIC_PROMPTS: - specific = EMOTION_SPECIFIC_PROMPTS[emotion.lower()] - return f"{base}\n{specific}" + # 處理情緒標籤 + if emotion: + base = f"用戶目前情緒標籤:{emotion}\n{base}" + + # 處理初次進入狀態 + if is_first_care: + base = f"{base}\n\n【狀態提示】這是進入關懷模式的第一個回覆,請執行 First Contact Care 技巧。" + + # 【核心】動態載入情緒關懷對話技巧 (Skills) - 這裡定義了所有的對話「手段與方法」 + skills_block = get_care_mode_skills_block() + if skills_block: + base = f"{base}\n{skills_block}" return base @@ -98,8 +69,85 @@ def _get_client(): """取得 OpenAI 客戶端""" return get_openai_client() + +def _client_with_timeout(openai_client: Any, timeout: float) -> Any: + """Responses hosted tools may stream slowly; use a per-request read timeout.""" + if hasattr(openai_client, "with_options"): + return openai_client.with_options(timeout=timeout) + return openai_client + + +def _responses_outer_timeout() -> float: + # Keep asyncio.wait_for slightly above the SDK read timeout so the SDK can + # surface upstream errors instead of being cut off first. + return float(OPENAI_RESPONSES_TIMEOUT) + 5.0 + + +def _safe_responses_payload_without_hosted_tools(payload: Dict[str, Any]) -> Dict[str, Any]: + safe_payload = responses_runtime.without_hosted_tools(payload) + safe_payload.pop("stream", None) + fallback_instruction = ( + "【工具降級】OpenAI hosted tools 或中轉站上游暫時不可用。本次回答不得編造即時、今天、最新、" + "收盤價、天氣、新聞、匯率等需要最新資料的內容;若缺少可靠資料,請明確告知目前無法確認," + "並說明可稍後重試。" + ) + instructions = str(safe_payload.get("instructions") or "").strip() + safe_payload["instructions"] = ( + f"{instructions}\n\n{fallback_instruction}" if instructions else fallback_instruction + ) + return safe_payload + + +async def _responses_create( + *, + loop: asyncio.AbstractEventLoop, + openai_client: Any, + payload: Dict[str, Any], + timeout: Optional[float] = None, +) -> Any: + responses_client = _client_with_timeout(openai_client, timeout or OPENAI_RESPONSES_TIMEOUT) + return await asyncio.wait_for( + loop.run_in_executor( + None, + lambda: responses_client.responses.create(**payload), + ), + timeout=_responses_outer_timeout(), + ) + + +async def _responses_fallback_without_hosted_tools( + *, + loop: asyncio.AbstractEventLoop, + openai_client: Any, + payload: Dict[str, Any], + on_chunk: Optional[Any], + reason: Exception, +) -> str: + logger.warning("Responses hosted tools unavailable, retrying without hosted tools: %s", reason) + if on_chunk: + await _emit_stream_event( + on_chunk, + { + "type": "status", + "status": "hosted_tools_unavailable", + "phase": "fallback", + "message": "即時搜尋暫時不可用,正在改用安全降級回答...", + "temporary": True, + }, + ) + + safe_payload = _safe_responses_payload_without_hosted_tools(payload) + response = await _responses_create( + loop=loop, + openai_client=openai_client, + payload=safe_payload, + ) + ai_response = responses_runtime.extract_output_text(response) + return ai_response or "抱歉,目前即時資訊服務暫時不可用,無法可靠確認最新資料。請稍後再試。" + # 向後相容:保留 client 變數名稱 client = None # 將在首次使用時透過 _get_client() 取得 +responses_runtime = ResponsesAgentRuntime() # 導入DB函數 try: @@ -128,20 +176,29 @@ def _build_base_system_prompt( care_emotion: Optional[str], user_name: Optional[str], language: Optional[str] = None, # 保留參數以兼容現有調用,但不使用 + is_first_care: bool = False, # 新增:是否為進入模式的第一個回覆 ) -> str: if use_care_mode: - # 【優化】使用情緒專屬的關懷 Prompt - base_prompt = get_care_mode_prompt(care_emotion).strip() + # 【優化】使用情緒專屬的關懷 Prompt,並處理初次進入引導 + base_prompt = get_care_mode_prompt(care_emotion, is_first_care=is_first_care).strip() if care_emotion: base_prompt = f"用戶情緒:{care_emotion}\n{base_prompt}" else: base_prompt = ( "你是 BloomWare 的個人化助理 小花,由銘傳大學人工智慧應用學系 槓上開發 團隊開發。" "你不是 GPT,也不要自稱 GPT。" - "你是一個友善、有禮、幽默且能夠提供幫助的AI助手。" + "你是一個友善、有禮、幽默且能夠提供幫助的AI助手,能夠替使用者設想周到。" + "如果你沒有把握回答,或者信心度低於80%,請不要隨意回答,動用工具查證再回答。" ) # 簡化語言指令 - 讓 GPT 自動判斷用戶語言 - base_prompt = f"{base_prompt}\n\n【重要】請用與用戶相同的語言回應,保持簡潔清晰的表達。" + base_prompt = ( + f"{base_prompt}\n\n" + "【重要】請用與用戶相同的語言回應,保持簡潔清晰的表達。\n" + "【語音輸出風格】你的回答通常會被直接朗讀給使用者聽,因此預設請用自然口語、短句、順口、好念的表達。\n" + "優先直接回答結論,再補 1 到 3 個關鍵點;避免過度書面、避免條列濫用、避免贅詞、避免官腔。\n" + "除非用戶明確要求,否則不要輸出「資料來源」「來源如下」「參考連結」「URL」或任何裸露網址,也不要把查證過程寫出來。\n" + "若工具已提供依據,請把它內化為答案本身,只保留使用者真正需要的結果。" + ) if user_name: base_prompt = f"用戶名稱:{user_name}\n\n{base_prompt}" @@ -157,6 +214,41 @@ def _normalize_prompt_text(text: Any) -> str: return " ".join(text.split()) +def _infer_response_language(text: str) -> Optional[str]: + source = str(text or "").strip() + if not source: + return None + if re.search(r"[\u3040-\u30ff]", source): + return "ja-JP" + if re.search(r"[\uac00-\ud7af]", source): + return "ko-KR" + if re.search(r"[\u0e00-\u0e7f]", source): + return "th-TH" + if re.search(r"[A-Za-z]", source) and not re.search(r"[\u3400-\u9fff]", source): + return "en-US" + if re.search(r"[\u3400-\u9fff]", source): + return "zh-TW" + return None + + +def _language_matches_expected(text: str, expected_language: Optional[str]) -> bool: + expected = str(expected_language or "").strip() + if not expected or expected.lower() == "auto": + return True + inferred = _infer_response_language(text) + if not inferred: + return True + return inferred.lower().startswith(expected.split("-")[0].lower()) + + +def _language_correction_instruction(expected_language: str) -> str: + return ( + f"Language correction: Your previous draft did not follow the required reply language. " + f"You MUST answer entirely in {expected_language}. " + "Do not mix Chinese, Japanese, or other languages unless the user explicitly asks for it." + ) + + def _format_history_for_prompt(history: List[Dict[str, str]]) -> str: if not history: return "(無)" @@ -301,6 +393,136 @@ def _format_env_context(ctx: Dict[str, Any]) -> str: return "\n".join(parts) +def _build_environment_context_text(ctx: Dict[str, Any]) -> str: + """Build the fixed environment injection block used by every agent turn. + + Uses only EnvironmentContextBuilder for structured, non-duplicated output. + Legacy _format_env_context() was removed to eliminate double-injection + (same data was being included twice in different formats, wasting ~200-400 tokens). + """ + injection = EnvironmentContextBuilder().build(ctx) + return injection.summary_text + + +def _default_hosted_tools() -> List[Dict[str, Any]]: + return build_openai_hosted_tools() + + +def _mcp_skills_context_text() -> str: + if not settings.OPENAI_ENABLE_SKILLS: + return "" + return skills_prompt_block() + + +def _should_use_responses(model: str) -> bool: + return settings.OPENAI_USE_RESPONSES and (model or "").startswith("gpt-5") + + +def _supports_chat_fallback(model: str) -> bool: + return bool(model) and not model.startswith("gpt-5") + + +def _is_transient_upstream_error(exc: Exception) -> bool: + text = str(exc).lower() + return any(marker in text for marker in ("502", "503", "504", "bad gateway", "upstream", "timeout")) + + +def _responses_text_format( + *, + strict_json: bool, + response_format: Optional[Dict[str, Any]], + use_structured_outputs: bool, + response_schema: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: + if use_structured_outputs and response_schema: + return { + "type": "json_schema", + "name": "response_schema", + "strict": True, + "schema": response_schema, + } + if strict_json: + return {"type": "json_object"} + if response_format: + return response_format + return None + + +async def _emit_stream_delta(on_chunk: Any, delta: str) -> None: + if not delta: + return + if asyncio.iscoroutinefunction(on_chunk): + await on_chunk(delta) + return + result = on_chunk(delta) + if asyncio.iscoroutine(result): + await result + + +async def _emit_stream_event(on_chunk: Any, payload: Dict[str, Any]) -> None: + if asyncio.iscoroutinefunction(on_chunk): + await on_chunk(payload) + return + result = on_chunk(payload) + if asyncio.iscoroutine(result): + await result + + +def _extract_responses_stream_delta(event: Any) -> str: + if getattr(event, "type", None) != "response.output_text.delta": + return "" + delta = getattr(event, "delta", None) + return str(delta) if delta else "" + + +def _responses_stream_status(event: Any) -> Optional[Dict[str, Any]]: + event_type = getattr(event, "type", "") + item = getattr(event, "item", None) + item_type = getattr(item, "type", "") if item is not None else "" + + if event_type in {"response.web_search_call.in_progress", "response.web_search_call.searching"}: + return {"type": "status", "status": "web_searching", "message": "正在搜尋最新資訊..."} + if event_type == "response.output_item.added" and item_type == "web_search_call": + return {"type": "status", "status": "web_searching", "message": "正在搜尋最新資訊..."} + if event_type == "response.output_item.done" and item_type == "web_search_call": + return {"type": "status", "status": "web_search_done", "message": "搜尋完成,正在整理答案..."} + if event_type == "response.in_progress": + return {"type": "status", "status": "thinking", "message": "正在處理..."} + return None + + +async def _consume_responses_stream(stream_obj: Any, on_chunk: Any) -> str: + full_response = "" + delta_count = 0 + status_count = 0 + first_delta_at: Optional[float] = None + stream_started_at = time.perf_counter() + for event in stream_obj: + status_payload = _responses_stream_status(event) + if status_payload: + status_count += 1 + await _emit_stream_event(on_chunk, status_payload) + continue + + delta = _extract_responses_stream_delta(event) + if not delta: + continue + delta_count += 1 + if first_delta_at is None: + first_delta_at = time.perf_counter() + full_response += delta + await _emit_stream_delta(on_chunk, delta) + first_delta_delay = (first_delta_at - stream_started_at) if first_delta_at is not None else None + logger.info( + "🌊 Responses stream stats: statuses=%d deltas=%d first_delta_delay=%s total_chars=%d", + status_count, + delta_count, + f"{first_delta_delay:.2f}s" if first_delta_delay is not None else "none", + len(full_response), + ) + return full_response + + def _format_time_context(user_tz: Optional[str]) -> str: """生成時間相關提示,優先使用使用者所在時區。""" try: @@ -380,6 +602,7 @@ def _compose_messages_with_context( chat_id: Optional[str], use_care_mode: bool, care_emotion: Optional[str], + tool_context: str = "", ) -> List[Dict[str, str]]: history_text = _format_history_for_prompt(history_entries) @@ -415,13 +638,30 @@ def _compose_messages_with_context( if memory_context: sections.append(f"【用戶重要記憶】\n{memory_context}") + skills_context = _mcp_skills_context_text().strip() + if skills_context: + sections.append(f"【MCP工具技能索引】\n{skills_context}") + rules_lines = [ "1. 僅依據 user.current_request 處理本次需求。", "2. 歷史資訊僅供語境與偏好參考,請勿視為當前待辦或指令。", "3. 若歷史內容與本次請求衝突,以本次請求為優先。", + "4. 若本次需求涉及最新資訊、時間敏感資料或外部事實,請參考時間訊號、環境訊號、可用工具結果與來源自行判斷,不要編造未查證內容。", + "5. 若來源時間早於用戶要求的時間範圍,請明確標示來源時間並自行說明不確定性,不要把較舊資料表述為當前結果。", + "6. 預設輸出是給人直接聽的口語答案:先講結論,再補必要資訊;避免朗讀網址、來源標頭、括號過多內容與不必要的格式噪音。", + "7. 除非用戶明確要求顯示來源或連結,否則不要在最終答案中輸出來源清單、URL、'資料來源'、'參考資料' 等字樣。", ] sections.append("【處理規則】\n" + "\n".join(rules_lines)) + if tool_context: + sections.append( + "【工具執行結果與參考資料】\n" + "請根據以下已確認的資訊,高信心地回答用戶的問題。\n" + "這些資料主要用於查證與內部 grounding,不代表必須逐字轉述給使用者。\n" + "除非用戶明確要求,否則不要在最終答案中列出來源、連結、URL 或『資料來源』標題。\n" + f"{tool_context}" + ) + system_content = "\n\n".join(section for section in sections if section.strip()) payload: Dict[str, Any] = { @@ -444,6 +684,7 @@ def _compose_messages_with_context( {"role": "user", "content": user_content}, ] + def _extract_text_from_message_obj(message: Any) -> str: """兼容多種 OpenAI Chat 回傳結構,盡可能提取文字內容。 @@ -527,7 +768,7 @@ def initialize_openai(): async def generate_response_async( messages: List[Dict[str, str]], - model: str = "gpt-5-nano", + model: Optional[str] = None, *, strict_json: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -537,6 +778,7 @@ async def generate_response_async( reasoning_effort: Optional[str] = None, stream: bool = False, on_chunk: Optional[Any] = None, + expected_language: Optional[str] = None, ) -> str: """ 生成AI回應(異步版本,支援 Streaming) @@ -552,12 +794,103 @@ async def generate_response_async( stream: 是否啟用串流模式(2025 最佳實踐) on_chunk: 串流 chunk 回調函數(async callable) """ + model = model or settings.OPENAI_MODEL openai_client = _get_client() if openai_client is None: return "抱歉,AI服務暫時不可用。系統無法連接到OpenAI服務。" try: start_time = time.time() loop = asyncio.get_event_loop() + + if _should_use_responses(model): + payload = responses_runtime.build_payload_from_messages( + messages=messages, + model=model, + tools=_default_hosted_tools(), + reasoning_effort=reasoning_effort, + max_output_tokens=max_tokens if max_tokens else 2000, + text_format=_responses_text_format( + strict_json=strict_json, + response_format=response_format, + use_structured_outputs=use_structured_outputs, + response_schema=response_schema, + ), + ) + if stream and on_chunk: + payload["stream"] = True + logger.info("🌊 啟用 Responses Streaming 模式") + try: + stream_obj = await _responses_create( + loop=loop, + openai_client=openai_client, + payload=payload, + ) + ai_response = await _consume_responses_stream(stream_obj, on_chunk) + except Exception as exc: + if _is_transient_upstream_error(exc) and payload.get("tools"): + ai_response = await _responses_fallback_without_hosted_tools( + loop=loop, + openai_client=openai_client, + payload=payload, + on_chunk=on_chunk, + reason=exc, + ) + else: + raise + if not ai_response: + ai_response = "抱歉,我暫時沒有合適的回應。可以換個說法再試試嗎?" + elapsed_time = time.time() - start_time + logger.info(f"🌊 Responses Streaming 完成,耗時: {elapsed_time:.2f}秒,總長度: {len(ai_response)}") + return ai_response + + try: + response = await _responses_create( + loop=loop, + openai_client=openai_client, + payload=payload, + ) + except Exception as exc: + if _is_transient_upstream_error(exc) and payload.get("tools"): + ai_response = await _responses_fallback_without_hosted_tools( + loop=loop, + openai_client=openai_client, + payload=payload, + on_chunk=None, + reason=exc, + ) + elapsed_time = time.time() - start_time + logger.info(f"Responses API 降級回應完成,耗時: {elapsed_time:.2f}秒,回應長度: {len(ai_response)} 字元") + return ai_response + else: + raise + ai_response = responses_runtime.extract_output_text(response) + if not ai_response: + ai_response = "抱歉,我暫時沒有合適的回應。可以換個說法再試試嗎?" + + if not _language_matches_expected(ai_response, expected_language): + retry_payload = dict(payload) + retry_payload["instructions"] = ( + f"{retry_payload.get('instructions', '')}\n\n{_language_correction_instruction(str(expected_language))}".strip() + ) + response = await _responses_create( + loop=loop, + openai_client=openai_client, + payload=retry_payload, + ) + ai_response = responses_runtime.extract_output_text(response) or ai_response + + if strict_json: + normalized = ai_response.strip() + try: + json.loads(normalized) + except json.JSONDecodeError as e: + raise StrictResponseError("NON_JSON_RESPONSE", response=normalized) from e + ai_response = normalized + + elapsed_time = time.time() - start_time + logger.info(f"Responses API 回應生成完成,耗時: {elapsed_time:.2f}秒,回應長度: {len(ai_response)} 字元") + return ai_response + # 加上請求超時保護 request_kwargs = { "model": model, @@ -565,8 +898,7 @@ async def generate_response_async( "max_completion_tokens": max_tokens if max_tokens else 2000, # 關懷模式可自訂 tokens } - # 加入 reasoning_effort 控制(僅 o1 系列和 gpt-5 系列支援) - # gpt-4o-mini 等模型不支援此參數,需要過濾 + # 加入 reasoning_effort 控制(僅 reasoning-capable 模型支援) reasoning_models = model.startswith("o1") or model.startswith("gpt-5") if reasoning_effort and reasoning_models: request_kwargs["reasoning_effort"] = reasoning_effort @@ -674,24 +1006,51 @@ async def generate_response_async( logger.error(f"❌ 提示詞: {messages}") ai_response = "抱歉,我暫時沒有合適的回應。可以換個說法再試試嗎?" + if ai_response and not _language_matches_expected(ai_response, expected_language): + correction_messages = list(messages) + [ + {"role": "system", "content": _language_correction_instruction(str(expected_language))}, + ] + response = await asyncio.wait_for( + loop.run_in_executor( + None, + lambda: openai_client.chat.completions.create( + model=model, + messages=correction_messages, + max_completion_tokens=max_tokens if max_tokens else 2000, + ), + ), + timeout=OPENAI_TIMEOUT, + ) + retry_msg_obj = response.choices[0].message + retry_text = _extract_text_from_message_obj(retry_msg_obj) + if retry_text: + ai_response = retry_text + elapsed_time = time.time() - start_time logger.info(f"AI回應生成完成,耗時: {elapsed_time:.2f}秒,回應長度: {len(ai_response)} 字元") return ai_response except Exception as e: if isinstance(e, StrictResponseError): raise - logger.error(f"生成回應時出錯: {str(e)}") - error_message = str(e).lower() - if isinstance(e, asyncio.TimeoutError): - return "抱歉,連接AI服務超時。請稍後再試。" - if "api key" in error_message or "authentication" in error_message: + error_msg = str(e) + logger.error(f"❌ 生成回應時出錯 (Model: {model}): {error_msg}") + + # 針對常見 API 錯誤提供詳細日誌 + if "503" in error_msg or "Service temporarily unavailable" in error_msg: + logger.error(f"👉 原因:API 服務暫時不可用 ({model})。") + return f"抱歉,目前使用的模型 ({model}) 暫時不可用(503 錯誤)。請在後台切換至其他模型後再試。" + elif "api key" in error_msg.lower() or "authentication" in error_msg.lower() or "401" in error_msg: + logger.error("👉 原因:API Key 無效或未授權 (401)。") return "抱歉,AI服務暫時不可用。請檢查API密鑰設置。" - elif "timeout" in error_message or "connection" in error_message: + elif "timeout" in error_msg.lower() or "connection" in error_msg.lower() or isinstance(e, asyncio.TimeoutError): + logger.error("👉 原因:請求超時。") return "抱歉,連接AI服務超時。請稍後再試。" - elif "rate limit" in error_message: + elif "rate limit" in error_msg.lower() or "429" in error_msg or "Too Many Requests" in error_msg: + logger.error("👉 原因:請求頻率過高或額度已滿 (429)。") return "抱歉,AI服務暫時達到請求限制。請稍後再試。" - elif "model" in error_message and ("not found" in error_message or "does not exist" in error_message): - return "抱歉,請求的AI模型不可用。" + elif "404" in error_msg or ("model" in error_msg.lower() and ("not found" in error_msg.lower() or "does not exist" in error_msg.lower())): + logger.error(f"👉 原因:找不到指定的模型 ({model})。") + return f"抱歉,您選擇的AI模型 ({model}) 不存在或未開放。請切換模型。" else: return "抱歉,生成回應時遇到問題。請重試。" @@ -699,7 +1058,7 @@ async def generate_response_for_user( user_message: str = None, user_id: str = "default", messages: List[Dict[str, str]] = None, - model: str = "gpt-5-nano", + model: Optional[str] = None, request_id: Optional[str] = None, chat_id: Optional[str] = None, *, @@ -714,6 +1073,10 @@ async def generate_response_for_user( emotion_label: Optional[str] = None, env_context: Optional[Dict[str, Any]] = None, language: Optional[str] = None, + stream: bool = False, + on_chunk: Optional[Any] = None, + tool_context: str = "", + is_first_care: bool = False, # 新增:是否為進入模式的第一個回覆 ) -> str: """ 為用戶生成AI回應 @@ -724,7 +1087,9 @@ async def generate_response_for_user( use_care_mode: 是否使用情緒關懷模式(新增) care_emotion: 關懷模式的情緒標籤(新增) reasoning_effort: 推理強度 (minimal/low/medium/high),用於控制 reasoning tokens + is_first_care: 是否為進入模式的第一個回覆(新增) """ + model = model or settings.OPENAI_MODEL logger.info(f"生成回應請求,使用模型: {model} req_id={request_id} chat_id={chat_id} structured={use_structured_outputs}") try: # 如果提供了chat_id,使用DB管理對話歷史 @@ -746,6 +1111,10 @@ async def generate_response_for_user( emotion_label=emotion_label, env_context=env_context, language=language, + stream=stream, + on_chunk=on_chunk, + tool_context=tool_context, + is_first_care=is_first_care, ) else: # 回退到原有的全局歷史管理(用於向後兼容) @@ -765,6 +1134,10 @@ async def generate_response_for_user( emotion_label=emotion_label, env_context=env_context, language=language, + stream=stream, + on_chunk=on_chunk, + tool_context=tool_context, + is_first_care=is_first_care, ) logger.error("未提供消息列表或用戶消息") @@ -794,6 +1167,10 @@ async def _generate_response_with_chat_db( emotion_label: Optional[str] = None, env_context: Optional[Dict[str, Any]] = None, language: Optional[str] = None, + stream: bool = False, + on_chunk: Optional[Any] = None, + tool_context: str = "", + is_first_care: bool = False, ): """使用DB管理對話歷史的實現""" try: @@ -804,7 +1181,8 @@ async def _generate_response_with_chat_db( use_care_mode=use_care_mode, care_emotion=care_emotion, user_name=user_name, - language=language # 參數保留但不使用,GPT 自動判斷語言 + language=language, # 參數保留但不使用,GPT 自動判斷語言 + is_first_care=is_first_care, ) messages.insert(0, {"role": "system", "content": system_prompt}) ai_response = await generate_response_async( @@ -814,99 +1192,91 @@ async def _generate_response_with_chat_db( response_format=response_format, use_structured_outputs=use_structured_outputs, response_schema=response_schema, - max_tokens=2000 if use_care_mode else None, # 關懷模式 2000 tokens(gpt-5-nano reasoning + 實際輸出) + max_tokens=2000 if use_care_mode else None, # 關懷模式保留較大輸出空間 reasoning_effort=reasoning_effort or ("minimal" if use_care_mode else "low"), # 2025 最佳實踐:關懷模式 minimal,一般對話 low + stream=stream, + on_chunk=on_chunk, + expected_language=language, ) - # 保存AI回應到DB + # 非同步保存 AI 回應 if db_available: - try: - await save_chat_message(chat_id, "assistant", ai_response) - except Exception as e: - logger.warning(f"保存AI回應到DB失敗: {e}") + asyncio.create_task(save_chat_message(chat_id, "assistant", ai_response)) return ai_response if user_message: - # 保存用戶消息到DB + # 非同步保存用戶消息,不阻塞生成流程 if db_available: - try: - await save_chat_message(chat_id, "user", user_message) - except Exception as e: - logger.warning(f"保存用戶消息到DB失敗: {e}") + asyncio.create_task(save_chat_message(chat_id, "user", user_message)) + + # 載入歷史、記憶、環境資訊(並行執行優化) + history_task = asyncio.create_task(get_chat_messages(chat_id, limit=(3 if use_care_mode else 12) + 1, ascending=True)) + + memory_task = None + if user_id and not use_care_mode: + from core.memory_system import memory_system + context_tags: List[str] = ["care_mode"] if use_care_mode else [] + if care_emotion: + context_tags.append(str(care_emotion)) + memory_task = asyncio.create_task(memory_system.get_relevant_memories( + user_id=user_id, + current_message=user_message, + max_memories=5, + context_tags=context_tags or None, + )) + + env_task = None + if not env_context and db_available and user_id: + env_task = asyncio.create_task(get_user_env_current(user_id)) - # 從DB加載對話歷史(messages 集合) + # 等待所有基礎資料準備完成 chat_history = [] - if db_available: - try: - history_limit = 3 if use_care_mode else 12 - # 取 limit+1 以排除當前 user_message(最後一筆) - msgs = await get_chat_messages(chat_id, limit=history_limit + 1, ascending=True) - historical_messages = msgs[:-1] if len(msgs) > 0 else [] - - def _clean_text(t: str) -> str: - if not t: - return "" - txt = str(t) - for kw in ["關懷模式", "我在這裡陪你", "說「我沒事了」", "退出關懷模式"]: - txt = txt.replace(kw, "") - return txt.strip() - - for msg in historical_messages: - content = msg.get("content") - if isinstance(content, dict): - content = content.get("message") or content.get("text") or str(content) - elif not isinstance(content, str): - content = str(content) if content else "" - - # 過濾掉錯誤訊息(避免污染上下文) - if "抱歉,生成回應時遇到問題" in content or "請重試" in content: - continue - - content = _clean_text(content) - if not content: - continue - - chat_history.append({ - "role": msg.get("sender"), - "content": content - }) - - logger.debug(f"📚 載入 {len(chat_history)} 條歷史對話(messages 集合)") - except Exception as e: - logger.warning(f"從DB加載對話歷史失敗: {e}") + try: + msgs = await history_task + historical_messages = msgs[:-1] if len(msgs) > 0 else [] + + def _clean_text(t: str) -> str: + if not t: return "" + txt = str(t) + for kw in ["關懷模式", "我在這裡陪你", "說「我沒事了」", "退出關懷模式"]: + txt = txt.replace(kw, "") + return txt.strip() + + for msg in historical_messages: + content = msg.get("content") + if isinstance(content, dict): + content = content.get("message") or content.get("text") or str(content) + elif not isinstance(content, str): + content = str(content) if content else "" + if "抱歉,生成回應時遇到問題" in content or "請重試" in content: + continue + content = _clean_text(content) + if content: + chat_history.append({"role": msg.get("sender"), "content": content}) + logger.debug(f"📚 載入 {len(chat_history)} 條歷史對話") + except Exception as e: + logger.warning(f"從DB加載對話歷史失敗: {e}") - # 載入長期記憶 - # 關懷模式不帶長期記憶,避免噪音 memory_context = "" - if user_id and not use_care_mode: + if memory_task: try: - from core.memory_system import memory_system - context_tags: List[str] = [] - if use_care_mode: - context_tags.append("care_mode") - if care_emotion: - context_tags.append(str(care_emotion)) - relevant_memories = await memory_system.get_relevant_memories( - user_id=user_id, - current_message=user_message, - max_memories=5, - context_tags=context_tags or None, - ) + relevant_memories = await memory_task if relevant_memories: + from core.memory_system import memory_system memory_context = memory_system.format_memories_for_context(relevant_memories) logger.info(f"📚 載入 {len(relevant_memories)} 條相關記憶") except Exception as e: logger.warning(f"載入記憶失敗: {e}") - # 讀取環境現況(僅組裝,不外呼) ctx: Dict[str, Any] = dict(env_context or {}) - if not ctx and db_available and user_id: + if env_task: try: - env_res = await get_user_env_current(user_id) + env_res = await env_task if env_res.get("success"): ctx = env_res.get("context") or {} except Exception as e: logger.debug(f"讀取環境現況失敗: {e}") - env_context_text = _format_env_context(ctx) + + env_context_text = _build_environment_context_text(ctx) time_context_text = _format_time_context(ctx.get("tz") if ctx else None) emotion_context_text = _format_emotion_context(emotion_label, care_emotion, use_care_mode) @@ -915,6 +1285,7 @@ async def _generate_response_with_chat_db( care_emotion=care_emotion, user_name=user_name, language=language, + is_first_care=is_first_care, ) messages_to_send = _compose_messages_with_context( @@ -929,6 +1300,7 @@ async def _generate_response_with_chat_db( chat_id=chat_id, use_care_mode=use_care_mode, care_emotion=care_emotion, + tool_context=tool_context, ) ai_response = await generate_response_async( messages_to_send, @@ -937,16 +1309,16 @@ async def _generate_response_with_chat_db( response_format=response_format, use_structured_outputs=use_structured_outputs, response_schema=response_schema, - max_tokens=2000 if use_care_mode else None, # 關懷模式 2000 tokens(gpt-5-nano reasoning + 實際輸出) + max_tokens=2000 if use_care_mode else None, # 關懷模式保留較大輸出空間 reasoning_effort=reasoning_effort or ("minimal" if use_care_mode else "low"), # 2025 最佳實踐:關懷模式 minimal,一般對話 low + stream=stream, + on_chunk=on_chunk, + expected_language=language, ) - # 保存AI回應到DB + # 非同步保存 AI 回應 if db_available: - try: - await save_chat_message(chat_id, "assistant", ai_response) - except Exception as e: - logger.warning(f"保存AI回應到DB失敗: {e}") + asyncio.create_task(save_chat_message(chat_id, "assistant", ai_response)) return ai_response @@ -971,6 +1343,7 @@ async def _generate_response_with_chat_db( emotion_label=emotion_label, env_context=env_context, language=language, + tool_context=tool_context, ) @@ -991,6 +1364,10 @@ async def _generate_response_with_global_history( emotion_label: Optional[str] = None, env_context: Optional[Dict[str, Any]] = None, language: Optional[str] = None, + stream: bool = False, + on_chunk: Optional[Any] = None, + tool_context: str = "", + is_first_care: bool = False, ): """使用全局歷史的回退實現(向後兼容)""" try: @@ -1001,7 +1378,8 @@ async def _generate_response_with_global_history( use_care_mode=use_care_mode, care_emotion=care_emotion, user_name=user_name, - language=language # 參數保留但不使用,GPT 自動判斷語言 + language=language, # 參數保留但不使用,GPT 自動判斷語言 + is_first_care=is_first_care, ) messages.insert(0, {"role": "system", "content": system_prompt}) user_messages = [msg for msg in messages if msg.get("role") == "user"] @@ -1015,8 +1393,11 @@ async def _generate_response_with_global_history( response_format=response_format, use_structured_outputs=use_structured_outputs, response_schema=response_schema, - max_tokens=2000 if use_care_mode else None, # 關懷模式 2000 tokens(gpt-5-nano reasoning + 實際輸出) + max_tokens=2000 if use_care_mode else None, # 關懷模式保留較大輸出空間 reasoning_effort=reasoning_effort or ("minimal" if use_care_mode else "low"), # 2025 最佳實踐:關懷模式 minimal,一般對話 low + stream=stream, + on_chunk=on_chunk, + expected_language=language, ) if user_id in conversation_history: conversation_history[user_id].append({"role": "assistant", "content": ai_response}) @@ -1043,7 +1424,7 @@ async def _generate_response_with_global_history( ctx = env_res.get("context") or {} except Exception as ex: logger.debug(f"讀取環境現況失敗: {ex}") - env_context_text = _format_env_context(ctx) + env_context_text = _build_environment_context_text(ctx) time_context_text = _format_time_context(ctx.get("tz") if ctx else None) emotion_context_text = _format_emotion_context(emotion_label, care_emotion, use_care_mode) @@ -1052,6 +1433,7 @@ async def _generate_response_with_global_history( care_emotion=care_emotion, user_name=user_name, language=language, + is_first_care=is_first_care, ) # 關懷模式不帶長期記憶 @@ -1087,6 +1469,7 @@ async def _generate_response_with_global_history( chat_id=None, use_care_mode=use_care_mode, care_emotion=care_emotion, + tool_context=tool_context, ) ai_response = await generate_response_async( messages_to_send, @@ -1095,8 +1478,11 @@ async def _generate_response_with_global_history( response_format=response_format, use_structured_outputs=use_structured_outputs, response_schema=response_schema, - max_tokens=2000 if use_care_mode else None, # 關懷模式 2000 tokens(gpt-5-nano reasoning + 實際輸出) + max_tokens=2000 if use_care_mode else None, # 關懷模式保留較大輸出空間 reasoning_effort=reasoning_effort or ("minimal" if use_care_mode else "low"), # 2025 最佳實踐:關懷模式 minimal,一般對話 low + stream=stream, + on_chunk=on_chunk, + expected_language=language, ) conversation_history[user_id].append({"role": "assistant", "content": ai_response}) if len(conversation_history[user_id]) > 50: @@ -1114,7 +1500,7 @@ async def generate_response_with_tools( messages: List[Dict[str, str]], tools: List[Dict[str, Any]], user_id: str = "default", - model: str = "gpt-5-nano", + model: Optional[str] = None, reasoning_effort: Optional[str] = None, tool_choice: str = "auto", ) -> Dict[str, Any]: @@ -1134,6 +1520,7 @@ async def generate_response_with_tools( Returns: 包含 tool_calls 和 content 的字典 """ + model = model or settings.OPENAI_MODEL openai_client = _get_client() if openai_client is None: logger.error("OpenAI 客戶端不可用") @@ -1142,6 +1529,46 @@ async def generate_response_with_tools( try: start_time = time.time() loop = asyncio.get_event_loop() + + if _should_use_responses(model): + request_kwargs = responses_runtime.build_payload_from_messages( + messages=messages, + model=model, + tools=tools, + reasoning_effort=reasoning_effort, + max_output_tokens=1000, + tool_choice=tool_choice, + ) + + logger.info(f"🔧 Responses Function Calling 請求: {len(tools)} 個工具, tool_choice={tool_choice}") + try: + responses_client = _client_with_timeout(openai_client, OPENAI_RESPONSES_TIMEOUT) + response = await asyncio.wait_for( + loop.run_in_executor( + None, + lambda: responses_client.responses.create(**request_kwargs), + ), + timeout=_responses_outer_timeout(), + ) + except Exception as exc: + if _is_transient_upstream_error(exc): + logger.warning("Responses Function Calling failed, falling back to Chat Completions: %s", exc) + else: + raise + + if "response" in locals(): + elapsed_time = time.time() - start_time + logger.info(f"⏱️ Responses Function Calling 完成,耗時: {elapsed_time:.2f}秒") + + result = { + "content": responses_runtime.extract_output_text(response), + "tool_calls": responses_runtime.extract_function_calls(response), + } + if result["tool_calls"]: + logger.info(f"✅ Responses 選擇了 {len(result['tool_calls'])} 個工具") + else: + logger.info("💬 Responses 未選擇任何工具(一般聊天)") + return result request_kwargs = { "model": model, @@ -1152,7 +1579,8 @@ async def generate_response_with_tools( } # 加入 reasoning_effort 控制 - if reasoning_effort: + reasoning_models = model.startswith("o1") or model.startswith("gpt-5") + if reasoning_effort and reasoning_models: request_kwargs["reasoning_effort"] = reasoning_effort logger.info(f"🧠 Function Calling 推理強度: {reasoning_effort}") @@ -1206,9 +1634,22 @@ async def generate_response_with_tools( return result - except asyncio.TimeoutError: - logger.error("Function Calling 請求超時") - return {"content": "", "tool_calls": []} + except asyncio.TimeoutError as e: + logger.error(f"❌ Function Calling 請求超時 (Model: {model})") + raise RuntimeError(f"AI 服務超時 ({model})") from e except Exception as e: - logger.error(f"Function Calling 失敗: {e}") - return {"content": "", "tool_calls": []} + error_msg = str(e) + logger.error(f"❌ Function Calling 失敗 (Model: {model})") + if "503" in error_msg or "Service temporarily unavailable" in error_msg: + logger.error("👉 原因:API 服務暫時不可用,或該模型目前處於高負載/維護中。") + logger.error(f"👉 建議:請嘗試在後台切換至其他模型(您目前使用的是 {model})。") + elif "429" in error_msg or "Too Many Requests" in error_msg: + logger.error("👉 原因:請求頻率過高或 API 額度已耗盡 (429)。") + elif "401" in error_msg or "Unauthorized" in error_msg: + logger.error("👉 原因:API Key 無效或未授權 (401)。") + elif "404" in error_msg: + logger.error(f"👉 原因:找不到該模型 (404)。請確認 {model} 是一個有效的模型名稱。") + else: + logger.error(f"👉 原始錯誤:{e}") + + raise RuntimeError(f"AI 服務異常 ({model}): {e}") from e diff --git a/services/batch_processor.py b/services/batch_processor.py index 7b33201c14fcb5e8eec88bf2b6ac5da09d0bc1ea..96846aef6a5e1ceb56eb4a06c3b52fd2507d2101 100644 --- a/services/batch_processor.py +++ b/services/batch_processor.py @@ -21,6 +21,7 @@ from datetime import datetime from pathlib import Path from openai import OpenAI from dotenv import load_dotenv +from core.config import settings load_dotenv() @@ -235,7 +236,7 @@ class BatchProcessor: async def create_memory_summary_batch( self, user_memories: Dict[str, List[str]], - model: str = "gpt-5-nano" + model: Optional[str] = None, ) -> str: """ 創建記憶摘要批次任務 @@ -247,6 +248,7 @@ class BatchProcessor: Returns: batch_id """ + model = model or settings.OPENAI_MODEL requests = [] for user_id, memories in user_memories.items(): @@ -284,7 +286,7 @@ class BatchProcessor: async def create_health_report_batch( self, user_health_data: Dict[str, Dict[str, Any]], - model: str = "gpt-5-nano" + model: Optional[str] = None, ) -> str: """ 創建健康報告批次任務 @@ -296,6 +298,7 @@ class BatchProcessor: Returns: batch_id """ + model = model or settings.OPENAI_MODEL requests = [] for user_id, health_data in user_health_data.items(): diff --git a/services/realtime_stt_service.py b/services/realtime_stt_service.py index 2a89f85c22b9132bcc8159958aa2ae97e5ee0aa3..e05a316cdd25c4beb22294967ef02f57e0df5678 100644 --- a/services/realtime_stt_service.py +++ b/services/realtime_stt_service.py @@ -1,368 +1,372 @@ """ -OpenAI Realtime API - 即時語音轉文字服務 -使用 WebSocket 進行低延遲串流轉錄 +Google Cloud Speech-to-Text v2 串流辨識(gRPC StreamingRecognize)。 -支援語言:中文(zh)、英文(en)、印尼文(id)、日文(ja)、越南文(vi) +注意:語音 GCP(STT/TTS 所屬專案,例如 supervisor-project)與 Firebase、 +Google OAuth 登入是不同脈絡——專案 ID、API Key、服務帳戶請勿與 Firestore 混用。 +STT 串流僅支援 gRPC + OAuth(服務帳戶);API Key 僅供 TTS REST 等用途。 + +音訊限制見官方文件:每則 StreamingRecognize 訊息(含首則設定)上限 25 KB。 +前端送 LINEAR16 mono PCM;依 sample_rate 設定 explicit decoding。 """ -import os -import json +from __future__ import annotations + import asyncio import logging -from typing import Optional, Callable, Dict, Any, Literal -import websockets +import queue +import threading +from typing import Any, Callable, Coroutine, List, Optional + from dotenv import load_dotenv +from google.oauth2 import credentials as oauth2_credentials +from google.oauth2 import service_account + +from core.config import settings load_dotenv() logger = logging.getLogger("services.realtime_stt") -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -if not OPENAI_API_KEY: - logger.warning("⚠️ OPENAI_API_KEY 未設置") +# 官方上限 25 KB;保留餘量避免邊界錯誤 +_MAX_STREAMING_BYTES = 24 * 1024 +# 即時串流:每累積 3200 bytes(~100ms @ 16kHz 16-bit mono)就送出一次,減少初始延遲 +_FLUSH_THRESHOLD_BYTES = 3200 -# OpenAI Realtime API WebSocket URL -REALTIME_API_URL = "wss://api.openai.com/v1/realtime?intent=transcription" - -# 支援的語言列表 -SupportedLanguage = Literal["zh", "en", "id", "ja", "vi"] SUPPORTED_LANGUAGES = { - "zh": "中文", - "en": "English", - "id": "Bahasa Indonesia", - "ja": "日本語", - "vi": "Tiếng Việt" + "zh": ["cmn-Hant-TW", "cmn-Hans-CN", "yue-Hant-HK"], + "zh-TW": ["cmn-Hant-TW"], + "zh-CN": ["cmn-Hans-CN"], + "en": ["en-US", "en-GB"], + "ja": ["ja-JP"], + "ko": ["ko-KR"], + "id": ["id-ID"], + "vi": ["vi-VN"], + "th": ["th-TH"], + "fr": ["fr-FR"], + "de": ["de-DE"], + "es": ["es-ES", "es-US"], } +DEFAULT_AUTO_LANGUAGE_CODES = ["cmn-Hant-TW", "en-US", "ja-JP"] -class RealtimeSTTService: - """OpenAI Realtime API 即時語音轉文字服務""" - def __init__(self): - self.api_key = OPENAI_API_KEY - self.ws: Optional[websockets.WebSocketClientProtocol] = None - self.is_connected = False - self._receive_task: Optional[asyncio.Task] = None - self.current_language: str = "zh" - - def _build_language_prompt(self, language: Optional[str] = None) -> Optional[str]: - """ - 建立語言提示 - - 注意:不使用具體詞彙(如「你好」「Hello」),避免 Whisper 在靜音或 - 低音量時產生幻覺,將 prompt 中的文字當作轉錄結果輸出。 - - Args: - language: 語言代碼(zh/en/id/ja/vi)或 None(自動檢測) - - Returns: - 語言提示字串,或 None(不使用 prompt) - """ - # 不使用 prompt,完全依賴 language 參數和音頻內容 - # 這樣可以避免 Whisper 幻覺出 prompt 中的文字 - return None - - def _validate_language(self, language: str) -> Optional[str]: - """ - 驗證並正規化語言代碼 - - Args: - language: 語言代碼(或 'auto' 表示自動檢測) - - Returns: - 正規化後的語言代碼,或 None(自動檢測) - """ - lang = language.lower().strip() - - # 自動檢測模式 - if lang in ('auto', 'detect', ''): - logger.info("🌐 啟用自動語言檢測") - return None - - if lang in SUPPORTED_LANGUAGES: - return lang - - # 嘗試從完整語言名稱匹配 - for code, name in SUPPORTED_LANGUAGES.items(): - if name.lower() == lang.lower(): - return code - - # 不支援的語言,使用自動檢測 - logger.warning(f"⚠️ 不支援的語言 '{language}',改用自動檢測") - return None +def _normalize_v2_model(model: str) -> str: + m = (model or "long").strip().lower() + if m in ("latest_long", "default"): + return "long" + if m in ("latest_short",): + return "short" + return (model or "long").strip() - async def connect( - self, - on_transcript_delta: Optional[Callable[[str], None]] = None, - on_transcript_done: Optional[Callable[[str], None]] = None, - on_vad_committed: Optional[Callable[[str], None]] = None, - model: str = "gpt-4o-mini-transcribe", - language: str = "zh", - ) -> bool: - """ - 建立與 OpenAI Realtime API 的 WebSocket 連線 - - Args: - on_transcript_delta: 接收部分轉錄結果的回調函數 - on_transcript_done: 接收完整轉錄結果的回調函數 - on_vad_committed: VAD 偵測到語音結束的回調函數 - model: 使用的模型(gpt-4o-transcribe 或 gpt-4o-mini-transcribe) - language: 語言代碼(zh/en/id/ja/vi) - - Returns: - bool: 連線是否成功 - """ - if not self.api_key: - logger.error("❌ OpenAI API Key 未設置") - return False - # 驗證語言 - validated_language = self._validate_language(language) - self.current_language = validated_language or "auto" - - if validated_language: - language_name = SUPPORTED_LANGUAGES.get(validated_language, validated_language) - logger.info(f"🌐 語言設定: {language_name} ({validated_language})") - else: - logger.info("🌐 語言設定: 自動檢測(支援 zh/en/id/ja/vi)") - - try: - logger.info(f"🔌 連接到 OpenAI Realtime API: {REALTIME_API_URL}") - - # 建立 WebSocket 連線(使用 API Key 認證) - self.ws = await websockets.connect( - REALTIME_API_URL, - additional_headers={ - "Authorization": f"Bearer {self.api_key}", - "OpenAI-Beta": "realtime=v1" - } - ) - - self.is_connected = True - logger.info("✅ 已連接到 OpenAI Realtime API") - - # 發送 session 配置(正確格式:需要 session 物件包裹) - # 不使用 prompt 參數,避免 Whisper 幻覺 - transcription_config = { - "model": model, - } - - # 如果指定了語言,加入 language 參數 - if validated_language: - transcription_config["language"] = validated_language - logger.info(f"🌐 Whisper 語言設定: {validated_language}") - - session_config = { - "type": "transcription_session.update", - "session": { - "input_audio_format": "pcm16", - "input_audio_transcription": transcription_config, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500 - }, - "input_audio_noise_reduction": { - "type": "near_field" - } - } - } - - await self.ws.send(json.dumps(session_config)) - logger.info("📤 已發送 session 配置(含語言引導提示)") - - # 啟動接收事件的背景任務 - self._receive_task = asyncio.create_task( - self._receive_events( - on_transcript_delta, - on_transcript_done, - on_vad_committed +class RealtimeSTTService: + """Speech-to-Text v2 雙向串流;需 OAuth(服務帳戶或有效 access token),不支援僅 API Key。""" + + def __init__(self) -> None: + self.location = settings.GOOGLE_STT_LOCATION + self.recognizer_id = settings.GOOGLE_STT_RECOGNIZER_ID + self.api_key = settings.GOOGLE_SPEECH_API_KEY + self.project_id = "" + self._grpc_credentials = None + self._reload_speech_identity() + self.current_language = "auto" + self.sample_rate = 16000 + self.model = "long" + self.is_connected = False + self._audio_buffer = bytearray() + self._pending_send = bytearray() + self._final_transcript: Optional[str] = None + self._final_transcript_event = asyncio.Event() + self._on_transcript_delta: Optional[Callable[[str], Any]] = None + self._on_transcript_done: Optional[Callable[[str], Any]] = None + self._on_status: Optional[Callable[[str], Any]] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._audio_thread_queue: Optional[queue.Queue] = None + self._grpc_thread: Optional[threading.Thread] = None + self._final_segments: List[str] = [] + self._speech_account_source: str = "none" + + def _language_codes(self, language: str) -> list[str]: + lang = (language or "auto").strip() + if lang in {"auto", "detect", ""}: + configured = [ + item.strip() + for item in settings.GOOGLE_STT_AUTO_LANGUAGE_CODES.split(",") + if item.strip() + ] + return (configured or DEFAULT_AUTO_LANGUAGE_CODES)[:3] + return SUPPORTED_LANGUAGES.get(lang, DEFAULT_AUTO_LANGUAGE_CODES)[:3] + + def _recognizer_name(self) -> str: + return ( + f"projects/{self.project_id}/locations/{self.location}" + f"/recognizers/{self.recognizer_id}" + ) + + def _validate_config(self) -> Optional[str]: + if self._grpc_credentials is None: + if self.api_key: + return ( + "STT 串流需 gRPC + OAuth(語音專案請設 GOOGLE_SPEECH_* 服務帳戶);" + "僅 API Key 無法用於 Speech v2 streaming(API Key 可給 TTS REST)" ) + return ( + "Google STT 串流需要 OAuth 憑證:請設定 GOOGLE_SPEECH_SERVICE_ACCOUNT_PATH " + "(或 *_JSON / *_BASE64)指向語音 GCP 之服務帳戶" ) + if not self.project_id: + return ( + "缺少 Speech API 所屬 GCP 專案 ID:請設定 GOOGLE_SPEECH_PROJECT_ID 或 " + "GOOGLE_CLOUD_PROJECT_ID(或於語音專用服務帳戶 JSON 內提供 project_id)" + ) + speech_only_pid = settings.GOOGLE_SPEECH_PROJECT_ID.strip() + if speech_only_pid and self._speech_account_source == "firebase": + fb = settings.FIREBASE_PROJECT_ID.strip() + if fb and speech_only_pid != fb: + return ( + "GOOGLE_SPEECH_PROJECT_ID 指向語音 GCP,但目前 OAuth 仍為 Firebase 服務帳戶;" + "請補上 GOOGLE_SPEECH_* 憑證(與語音專案一致),或移除 GOOGLE_SPEECH_PROJECT_ID" + ) + return None - return True - - except Exception as e: - logger.error(f"❌ 連接失敗: {e}") - self.is_connected = False - return False - - async def _receive_events( - self, - on_transcript_delta: Optional[Callable], - on_transcript_done: Optional[Callable], - on_vad_committed: Optional[Callable] - ): - """ - 接收並處理來自 OpenAI Realtime API 的事件 - - Args: - on_transcript_delta: 部分轉錄回調 - on_transcript_done: 完整轉錄回調 - on_vad_committed: VAD 提交回調 - """ - try: - while self.is_connected and self.ws: - try: - message = await self.ws.recv() - event = json.loads(message) - event_type = event.get("type") - - logger.debug(f"📩 收到事件: {event_type}") - - # 處理使用者語音的部分轉錄結果(即時串流) - if event_type == "conversation.item.input_audio_transcription.delta": - delta_text = event.get("delta", "") - if on_transcript_delta and delta_text: - await self._safe_callback(on_transcript_delta, delta_text) - - # 處理完整轉錄結果(語音段結束) - elif event_type == "conversation.item.input_audio_transcription.completed": - full_text = event.get("transcript", "") - if on_transcript_done and full_text: - await self._safe_callback(on_transcript_done, full_text) - - # 處理 VAD 提交事件(語音段結束) - elif event_type == "input_audio_buffer.committed": - item_id = event.get("item_id", "") - if on_vad_committed: - await self._safe_callback(on_vad_committed, item_id) - - # 處理錯誤事件 - elif event_type == "error": - error_msg = event.get("error", {}) - logger.error(f"❌ OpenAI API 錯誤: {error_msg}") - - except websockets.exceptions.ConnectionClosed: - logger.warning("⚠️ WebSocket 連線已關閉") - break - except json.JSONDecodeError as e: - logger.error(f"❌ JSON 解析錯誤: {e}") - except Exception as e: - logger.error(f"❌ 接收事件失敗: {e}") - - except Exception as e: - logger.error(f"❌ 事件接收循環失敗: {e}") - finally: - self.is_connected = False + def _reload_speech_identity(self) -> None: + """從 .env 載入語音專用憑證(優先 GOOGLE_SPEECH_*,與 Firebase 分離)。""" + self._speech_account_source = "none" + info, source = settings.resolve_speech_service_account_info() + cred_pid = (info or {}).get("project_id") if info else None + self.project_id = settings.get_google_speech_project_id( + str(cred_pid) if cred_pid else None, + ) + + if info is not None: + try: + self._grpc_credentials = service_account.Credentials.from_service_account_info( + info, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + self._speech_account_source = source + if source == "speech": + logger.info("Google STT 使用 GOOGLE_SPEECH_* 服務帳戶(與 Firebase 分離)") + return + except Exception as exc: + logger.warning("Google STT service account 載入失敗: %s", exc) + self._grpc_credentials = None + + static_token = settings.GOOGLE_STT_ACCESS_TOKEN.strip() + if static_token: + logger.warning("Google STT using static access token for gRPC; prefer service account") + self._grpc_credentials = oauth2_credentials.Credentials(token=static_token) + self._speech_account_source = "token" + else: + self._grpc_credentials = None + self._speech_account_source = "none" - async def _safe_callback(self, callback: Callable, *args): - """安全地執行回調函數(支援同步和異步)""" + async def _safe_callback(self, callback: Optional[Callable], *args) -> None: + if not callback: + return try: if asyncio.iscoroutinefunction(callback): await callback(*args) else: callback(*args) - except Exception as e: - logger.error(f"❌ 回調函數執行失敗: {e}") - - async def send_audio_chunk(self, audio_data: bytes) -> bool: - """ - 發送音頻 chunk 到 OpenAI Realtime API - - Args: - audio_data: PCM16 格式的音頻數據(需 Base64 編碼) - - Returns: - bool: 是否發送成功 - """ - if not self.is_connected or not self.ws: - logger.warning("⚠️ WebSocket 未連接,無法發送音頻") - return False + except Exception as exc: + logger.error("Google STT callback failed: %s", exc) + def _schedule_coroutine(self, coro: Coroutine[Any, Any, None]) -> None: + if self._loop is None: + return try: - import base64 - - # 將音頻數據編碼為 Base64 - audio_base64 = base64.b64encode(audio_data).decode('utf-8') - - # 發送音頻 chunk - message = { - "type": "input_audio_buffer.append", - "audio": audio_base64 - } - - await self.ws.send(json.dumps(message)) - logger.debug(f"📤 已發送音頻 chunk: {len(audio_data)} bytes") - return True - - except Exception as e: - logger.error(f"❌ 發送音頻失敗: {e}") - return False - - async def commit_audio(self) -> bool: - """ - 手動提交音頻緩衝區(當不使用 Server VAD 時) - - Returns: - bool: 是否提交成功 - """ - if not self.is_connected or not self.ws: - logger.warning("⚠️ WebSocket 未連接,無法提交音頻") - return False + asyncio.run_coroutine_threadsafe(coro, self._loop) + except RuntimeError: + logger.warning("Google STT event loop unavailable, drop async update") + + def _grpc_worker(self) -> None: + from google.cloud.speech_v2 import SpeechClient + from google.cloud.speech_v2.types import cloud_speech as cloud_speech_types + + assert self._audio_thread_queue is not None + + client = SpeechClient(credentials=self._grpc_credentials) + language_codes = self._language_codes(self.current_language) + recognition_config = cloud_speech_types.RecognitionConfig( + explicit_decoding_config=cloud_speech_types.ExplicitDecodingConfig( + encoding=cloud_speech_types.ExplicitDecodingConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=int(self.sample_rate), + audio_channel_count=1, + ), + language_codes=language_codes, + model=_normalize_v2_model(self.model), + ) + streaming_config = cloud_speech_types.StreamingRecognitionConfig( + config=recognition_config, + streaming_features=cloud_speech_types.StreamingRecognitionFeatures( + interim_results=True, + enable_voice_activity_events=True, + ), + ) + config_request = cloud_speech_types.StreamingRecognizeRequest( + recognizer=self._recognizer_name(), + streaming_config=streaming_config, + ) + + def requests_iter(): + yield config_request + while True: + chunk = self._audio_thread_queue.get() + if chunk is None: + return + yield cloud_speech_types.StreamingRecognizeRequest(audio=chunk) try: - message = { - "type": "input_audio_buffer.commit" - } + for response in client.streaming_recognize(requests=requests_iter()): + ev = response.speech_event_type + if ev == cloud_speech_types.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN: + self._schedule_coroutine(self._safe_callback(self._on_status, "receiving_audio")) + elif ev == cloud_speech_types.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END: + self._schedule_coroutine(self._safe_callback(self._on_status, "speech_stopped")) + + for result in response.results: + if not result.alternatives: + continue + text = (result.alternatives[0].transcript or "").strip() + if not text: + continue + if result.is_final: + self._final_segments.append(text) + combined = " ".join(self._final_segments) + preview = f"{combined} {text}".strip() if not result.is_final and combined else text + out = combined if result.is_final else preview + self._schedule_coroutine(self._safe_callback(self._on_transcript_delta, out)) + except Exception as exc: + logger.error("Google STT streaming_recognize failed: %s", exc) + self._schedule_coroutine(self._safe_callback(self._on_status, "error")) + finally: + done_text = " ".join(self._final_segments).strip() + self._schedule_coroutine(self._finalize_stream_session(done_text)) - await self.ws.send(json.dumps(message)) - logger.info("📤 已手動提交音頻緩衝區") - return True + async def _finalize_stream_session(self, text: str) -> None: + if text and not self._final_transcript_event.is_set(): + self._final_transcript = text + await self._safe_callback(self._on_transcript_done, text) + self._final_transcript_event.set() - except Exception as e: - logger.error(f"❌提交音頻失敗: {e}") + async def connect( + self, + on_transcript_delta: Optional[Callable[[str], Any]] = None, + on_transcript_done: Optional[Callable[[str], Any]] = None, + on_vad_committed: Optional[Callable[[str], Any]] = None, + model: str = "latest_long", + language: str = "auto", + sample_rate: int = 16000, + ) -> bool: + self._reload_speech_identity() + error = self._validate_config() + if error: + logger.error("Google STT 初始化失敗: %s", error) return False - async def disconnect(self): - """關閉 WebSocket 連線""" - if self.ws: - logger.info("🔌 關閉 OpenAI Realtime API 連線") - self.is_connected = False + self.model = model or "latest_long" + self.current_language = language or "auto" + self.sample_rate = int(sample_rate or 16000) + self._audio_buffer.clear() + self._pending_send.clear() + self._final_segments.clear() + self._final_transcript = None + self._final_transcript_event.clear() + self._on_transcript_delta = on_transcript_delta + self._on_transcript_done = on_transcript_done + self._on_status = on_vad_committed + self._loop = asyncio.get_running_loop() + self._audio_thread_queue = queue.Queue() + self.is_connected = True + + self._grpc_thread = threading.Thread(target=self._grpc_worker, name="google-stt-v2-stream", daemon=True) + self._grpc_thread.start() + await self._safe_callback(self._on_status, "speech_started") + return True + + def _enqueue_pcm(self, audio_data: bytes) -> None: + if not audio_data or self._audio_thread_queue is None: + return + self._pending_send.extend(audio_data) + # 達到 flush 閾值(~100ms)就送出,讓 STT 盡快收到音訊,減少初始延遲 + while len(self._pending_send) >= _FLUSH_THRESHOLD_BYTES: + chunk_size = min(len(self._pending_send), _MAX_STREAMING_BYTES) + chunk = bytes(self._pending_send[:chunk_size]) + del self._pending_send[:chunk_size] + self._audio_thread_queue.put(chunk) - # 取消接收任務 - if self._receive_task and not self._receive_task.done(): - self._receive_task.cancel() - try: - await self._receive_task - except asyncio.CancelledError: - pass - # 關閉 WebSocket - await self.ws.close() - self.ws = None + async def send_audio_chunk(self, audio_data: bytes) -> bool: + if not self.is_connected: + logger.warning("Google STT 尚未連線,無法接收音訊") + return False + if audio_data: + self._audio_buffer.extend(audio_data) + self._enqueue_pcm(audio_data) + await self._safe_callback(self._on_status, "receiving_audio") + return True - logger.info("✅ 已斷開連線") + async def commit_audio(self) -> bool: + if not self.is_connected: + logger.warning("Google STT 尚未連線,無法提交音訊") + return False + await self._safe_callback(self._on_status, "speech_stopped") + return True + + def mark_final_transcript(self, transcript: str) -> None: + if transcript: + self._final_transcript = transcript + self._final_transcript_event.set() + + def _close_stream(self) -> None: + if self._audio_thread_queue is not None: + if self._pending_send: + self._audio_thread_queue.put(bytes(self._pending_send)) + self._pending_send.clear() + self._audio_thread_queue.put(None) + + async def wait_for_final_transcript(self, timeout: float = 3.5) -> Optional[str]: + if self._final_transcript: + return self._final_transcript + + await self.commit_audio() + self._close_stream() + if self._grpc_thread and self._grpc_thread.is_alive(): + await asyncio.to_thread(self._grpc_thread.join, timeout) + + for _ in range(5): + if self._final_transcript_event.is_set(): + break + await asyncio.sleep(0) + + if not self._final_transcript_event.is_set(): + text = " ".join(self._final_segments).strip() + if text: + self._final_transcript = text + await self._safe_callback(self._on_transcript_delta, text) + await self._safe_callback(self._on_transcript_done, text) + self._final_transcript_event.set() + + return self._final_transcript + + async def disconnect(self) -> None: + self.is_connected = False + self._close_stream() + if self._grpc_thread and self._grpc_thread.is_alive(): + await asyncio.to_thread(self._grpc_thread.join, 2.0) + await self._safe_callback(self._on_status, "disconnected") -# 全域單例 realtime_stt_service = RealtimeSTTService() async def create_realtime_session( on_transcript_delta: Optional[Callable] = None, on_transcript_done: Optional[Callable] = None, - model: str = "gpt-4o-mini-transcribe", - language: str = "zh" + model: str = "latest_long", + language: str = "auto", ) -> RealtimeSTTService: - """ - 便捷函數:建立 Realtime STT 會話 - - Args: - on_transcript_delta: 部分轉錄回調 - on_transcript_done: 完整轉錄回調 - model: 使用的模型 - language: 語言代碼 - - Returns: - RealtimeSTTService: 已連線的服務實例 - """ service = RealtimeSTTService() - await service.connect( - on_transcript_delta=on_transcript_delta, - on_transcript_done=on_transcript_done, - model=model, - language=language - ) + await service.connect(on_transcript_delta, on_transcript_done, None, model=model, language=language) return service diff --git a/services/tts_service.py b/services/tts_service.py index 483f936657c31234092b80346391df71d0dc9e8d..2933575df6a4ae5d62b25019433ab8e0773b17cb 100644 --- a/services/tts_service.py +++ b/services/tts_service.py @@ -1,381 +1,549 @@ """ -OpenAI TTS 服務(2025 最佳實踐版) -使用 AsyncOpenAI + Streaming 進行低延遲文字轉語音 - -特色: -- 異步 API(AsyncOpenAI) -- 串流播放(邊生成邊播放,降低 TTFB) -- 支援情緒指令(gpt-4o-mini-tts) -- 多語言支援(自動檢測:中文、英文、印尼文、日文、越南文) +Google Cloud Text-to-Speech service. + +與 Firebase、Google OAuth 分開:使用「語音 GCP」的 API Key(GOOGLE_TTS_API_KEY 等), +與 STT gRPC 串流所用之服務帳戶 OAuth 不同。 + +OpenAI TTS support has been removed intentionally. API keys are read from +environment variables only; no key is embedded in source code. """ -import os +import base64 import logging -import asyncio -from typing import Optional, Dict, Any, Literal, AsyncIterator -from openai import AsyncOpenAI -from openai.helpers import LocalAudioPlayer -from dotenv import load_dotenv +import re +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple -# 統一日誌配置 -from core.logging import get_logger -logger = get_logger("services.tts") +import aiohttp +from dotenv import load_dotenv +from google.oauth2 import service_account -# 統一配置管理 from core.config import settings load_dotenv() -# 支援的 TTS 聲音(2025 新增:coral, sage, verse) -VoiceType = Literal["coral", "sage", "verse", "alloy", "echo", "fable", "onyx", "nova", "shimmer"] +logger = logging.getLogger("services.tts") +_CJK_RE = re.compile(r"[\u3400-\u9fff]") + +GOOGLE_TTS_ENDPOINT = "https://texttospeech.googleapis.com/v1/text:synthesize" -# 支援的音頻格式 -AudioFormat = Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] +EMOTION_RATE = { + "neutral": 1.0, + "happy": 1.05, + "sad": 0.92, + "angry": 0.96, + "fear": 0.94, + "surprise": 1.08, +} -# 情緒指令預設模板 -EMOTION_INSTRUCTIONS = { - "neutral": "用平穩、自然的語氣說話", - "happy": "用開心、愉悅的語氣說話", - "sad": "用溫柔、安慰的語氣說話", - "angry": "用冷靜、理性的語氣說話", - "fear": "用溫暖、鼓勵的語氣說話", - "surprise": "用輕快、活潑的語氣說話" +VOICE_ALIASES = { + "coral": ("cmn-TW", "cmn-TW-Wavenet-A"), + "nova": ("cmn-TW", "cmn-TW-Wavenet-A"), + "alloy": ("en-US", "en-US-Neural2-F"), + "echo": ("en-US", "en-US-Neural2-D"), + "fable": ("en-US", "en-US-Neural2-F"), + "onyx": ("en-US", "en-US-Neural2-J"), + "shimmer": ("en-US", "en-US-Neural2-H"), + "zh-tw": ("cmn-TW", "cmn-TW-Wavenet-A"), + "zh-cn": ("cmn-CN", "cmn-CN-Wavenet-A"), + "en-us": ("en-US", "en-US-Neural2-F"), + "ja-jp": ("ja-JP", "ja-JP-Neural2-B"), + "ko-kr": ("ko-KR", "ko-KR-Neural2-A"), + "id-id": ("id-ID", "id-ID-Wavenet-A"), + "vi-vn": ("vi-VN", "vi-VN-Wavenet-A"), } -# 關懷模式特殊指令 -CARE_MODE_INSTRUCTION = "用溫柔、關懷、陪伴的語氣說話,讓對方感受到被理解和支持" +PERSONA_LANGUAGE_ALIASES = { + "zh": "zh-TW", + "zh-tw": "zh-TW", + "zh-hant": "zh-TW", + "zh-hant-tw": "zh-TW", + "cmn-hant-tw": "zh-TW", + "zh-cn": "zh-CN", + "zh-hans": "zh-CN", + "zh-hans-cn": "zh-CN", + "cmn-hans-cn": "zh-CN", + "en": "en-US", + "en-us": "en-US", + "ja": "ja-JP", + "ja-jp": "ja-JP", + "ko": "ko-KR", + "ko-kr": "ko-KR", + "id": "id-ID", + "id-id": "id-ID", + "vi": "vi-VN", + "vi-vn": "vi-VN", +} +PERSONA_VOICE_MAP = { + "xiaohua": { + "default": ("cmn-CN", "cmn-CN-Chirp3-HD-Gacrux"), + "zh-TW": ("cmn-CN", "cmn-CN-Chirp3-HD-Gacrux"), + "zh-CN": ("cmn-CN", "cmn-CN-Chirp3-HD-Gacrux"), + "en-US": ("en-US", "en-US-Chirp-HD-F"), + "ja-JP": ("ja-JP", "ja-JP-Chirp3-HD-Despina"), + "ko-KR": ("ko-KR", "ko-KR-Chirp3-HD-Despina"), + "id-ID": ("id-ID", "id-ID-Chirp3-HD-Despina"), + "vi-VN": ("vi-VN", "vi-VN-Chirp3-HD-Despina"), + } +} -def get_emotion_instruction(emotion: Optional[str], care_mode: bool = False) -> str: - """ - 根據情緒選擇對應的 TTS instruction +PERSONA_PROMPTS = { + "xiaohua": { + "default": "You are XiaoHua, a warm youthful companion voice. Read like natural conversation, not a formal bulletin. Use short phrasing, gentle smile, light warmth, and graceful brief pauses. Do not sound flat, robotic, or like you are reading citations, links, or metadata aloud.", + "zh-TW": "你是小花。請像面對面說話一樣自然、溫柔、帶一點笑意與陪伴感。句子要短一點、順口一點,停頓乾淨,不要像新聞播報,也不要唸出來源、連結、括號資訊或多餘說明。", + "zh-CN": "你是小花。请像面对面说话一样自然、温柔、带一点笑意与陪伴感。句子要短一点、顺口一点,停顿干净,不要像新闻播报,也不要念出来源、链接、括号信息或多余说明。", + "en-US": "You are XiaoHua. Speak like a warm companion in direct conversation. Keep phrases compact, clear, and human. Do not sound like a formal announcer, and do not read links, source labels, or metadata aloud.", + "ja-JP": "あなたは小花です。対面でやさしく話しかけるように、自然であたたかく、少し笑みを含んだ声で話してください。短く言いやすいフレーズを使い、リンクや出典のような情報は読み上げないでください。", + "ko-KR": "당신은 샤오화입니다. 마주 보고 이야기하듯 자연스럽고 따뜻하게, 은은한 미소가 느껴지는 톤으로 말하세요. 문장은 짧고 부드럽게, 링크나 출처 같은 메타 정보는 읽지 마세요.", + "id-ID": "Kamu adalah XiaoHua. Bicaralah seperti sedang menemani seseorang secara langsung: hangat, alami, lembut, dan sedikit tersenyum dalam suara. Gunakan frasa singkat yang enak didengar, dan jangan membacakan tautan atau label sumber.", + "vi-VN": "Bạn là XiaoHua. Hãy nói như đang trò chuyện trực tiếp với người dùng: tự nhiên, ấm áp, dịu dàng và có chút mỉm cười trong giọng nói. Dùng câu ngắn, dễ nghe, và không đọc liên kết hay nhãn nguồn.", + } +} - Args: - emotion: 情緒標籤(neutral, happy, sad, angry, fear, surprise) - care_mode: 是否為關懷模式 +LANGUAGE_PRONUNCIATION_SUPPORT = { + "cmn-CN": {"PHONETIC_ENCODING_PINYIN"}, + "ja-JP": {"PHONETIC_ENCODING_JAPANESE_YOMIGANA"}, + "en-US": {"PHONETIC_ENCODING_IPA", "PHONETIC_ENCODING_X_SAMPA"}, +} - Returns: - TTS instruction 字串 - """ - # 關懷模式優先 + +def get_emotion_rate(emotion: Optional[str], care_mode: bool = False) -> float: if care_mode: - return CARE_MODE_INSTRUCTION - - # 根據情緒選擇 - if emotion and emotion in EMOTION_INSTRUCTIONS: - return EMOTION_INSTRUCTIONS[emotion] - - # 預設中性語氣 - return EMOTION_INSTRUCTIONS["neutral"] + return 0.92 + return EMOTION_RATE.get(str(emotion or "neutral").lower(), 1.0) class TTSService: - """OpenAI Text-to-Speech 服務(異步版)""" + """Google Text-to-Speech REST service.""" def __init__(self): - self._client: Optional[AsyncOpenAI] = None - self._initialized = False - - @property - def client(self) -> Optional[AsyncOpenAI]: - """延遲初始化 AsyncOpenAI 客戶端""" - if not self._initialized: - api_key = settings.OPENAI_API_KEY - if api_key: - self._client = AsyncOpenAI( - api_key=api_key, - timeout=float(settings.OPENAI_TIMEOUT), - max_retries=3 + self.api_key = settings.GOOGLE_TTS_API_KEY + self._grpc_credentials = None + self._reload_tts_identity() + + def _reload_tts_identity(self) -> None: + info, _source = settings.resolve_speech_service_account_info() + if info is None: + self._grpc_credentials = None + return + try: + self._grpc_credentials = service_account.Credentials.from_service_account_info( + info, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + except Exception as exc: + logger.warning("Google TTS service account 載入失敗: %s", exc) + self._grpc_credentials = None + + self._async_client = None # 延遲初始化 AsyncClient + + def _normalize_language(self, language: Optional[str]) -> Optional[str]: + raw = str(language or "").strip() + if not raw: + return None + normalized = raw.replace("_", "-").lower() + return PERSONA_LANGUAGE_ALIASES.get(normalized, raw.replace("_", "-")) + + def _persona_voice_config(self, persona: Optional[str], language: Optional[str]) -> Optional[Dict[str, str]]: + persona_key = str(persona or "").strip().lower() + if not persona_key: + return None + + persona_map = PERSONA_VOICE_MAP.get(persona_key) + if not persona_map: + return None + + normalized_language = self._normalize_language(language) + language_code, voice_name = persona_map.get( + normalized_language, + persona_map.get("default", (settings.GOOGLE_TTS_LANGUAGE_CODE, settings.GOOGLE_TTS_DEFAULT_VOICE)), + ) + return {"languageCode": language_code, "name": voice_name} + + def _persona_prompt(self, persona: Optional[str], language: Optional[str]) -> Optional[str]: + persona_key = str(persona or "").strip().lower() + if not persona_key: + return None + + persona_prompts = PERSONA_PROMPTS.get(persona_key) + if not persona_prompts: + return None + + normalized_language = self._normalize_language(language) + return persona_prompts.get(normalized_language) or persona_prompts.get("default") + + def _build_custom_pronunciations(self, custom_pronunciations: Optional[List[Dict[str, Any]]], texttospeech_module: Any) -> Optional[Any]: + if not custom_pronunciations: + return None + + phonetic_enum = texttospeech_module.CustomPronunciationParams.pb().DESCRIPTOR.fields_by_name["phonetic_encoding"].enum_type + entries = [] + for item in custom_pronunciations: + phrase = str((item or {}).get("phrase") or "").strip() + pronunciation = str((item or {}).get("pronunciation") or "").strip() + encoding_key = str((item or {}).get("phonetic_encoding") or "").strip().upper() + if not phrase or not pronunciation or not encoding_key: + continue + if encoding_key == "PHONETIC_ENCODING_PINYIN" and not _CJK_RE.search(phrase): + logger.info("略過非中文 phrase 的 PINYIN pronunciation: %s", phrase) + continue + encoding_value = phonetic_enum.values_by_name.get(encoding_key) + if encoding_value is None: + logger.warning("略過不支援的 phonetic_encoding: %s", encoding_key) + continue + entries.append( + texttospeech_module.CustomPronunciationParams( + phrase=phrase, + phonetic_encoding=encoding_value.number, + pronunciation=pronunciation, ) - logger.info("✅ TTS 服務初始化成功(AsyncOpenAI)") + ) + + if not entries: + return None + + return texttospeech_module.CustomPronunciations(pronunciations=entries) + + def _filter_custom_pronunciations_for_language( + self, + custom_pronunciations: Optional[List[Dict[str, Any]]], + language_code: str, + source_text: str, + ) -> Optional[List[Dict[str, Any]]]: + if not custom_pronunciations: + return None + + supported_encodings = LANGUAGE_PRONUNCIATION_SUPPORT.get(language_code, set()) + if not supported_encodings: + return None + + filtered = [] + source = str(source_text or "") + for item in custom_pronunciations: + phrase = str((item or {}).get("phrase") or "").strip() + encoding_key = str((item or {}).get("phonetic_encoding") or "").strip().upper() + if not phrase: + continue + if phrase not in source: + logger.info( + "略過 pronunciation:phrase 不在本次文本中 language=%s phrase=%s", + language_code, + phrase, + ) + continue + if encoding_key in supported_encodings: + filtered.append(item) else: - logger.error("❌ TTS 服務初始化失敗:OPENAI_API_KEY 未設置") - self._initialized = True - return self._client + logger.info( + "略過 pronunciation:language=%s 不支援 encoding=%s", + language_code, + encoding_key or "", + ) + return filtered or None + + def _voice_config(self, voice: str, language: Optional[str] = None, persona: Optional[str] = None) -> Dict[str, str]: + persona_config = self._persona_voice_config(persona, language) + if persona_config: + return persona_config + + key = str(voice or settings.GOOGLE_TTS_DEFAULT_VOICE).strip() + alias = key.lower() + language_code, voice_name = VOICE_ALIASES.get( + alias, + (settings.GOOGLE_TTS_LANGUAGE_CODE, key), + ) + return {"languageCode": language_code, "name": voice_name} + + @staticmethod + def _clean_text_for_tts(text: str) -> str: + """清理文字中的 Markdown 和 Emoji,避免 TTS 截斷或發音異常""" + if not text: + return "" + # 移除 Markdown 語法 (粗體, 斜體, 連結等) + text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) + text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) + text = re.sub(r'\[(.*?)\]\(.*?\)', r'\1', text) + text = re.sub(r'#{1,6}\s+', '', text) + text = re.sub(r'`{1,3}.*?`{1,3}', '', text, flags=re.DOTALL) + + # 移除常見 Emoji + # 使用一個簡單的範圍,或者更複雜的 regex + text = re.sub(r'[\U00010000-\U0010ffff]', '', text) + + # 移除多餘空白 + text = re.sub(r'\s+', ' ', text).strip() + return text async def synthesize( self, text: str, - voice: VoiceType = "coral", - model: str = "gpt-4o-mini-tts", + voice: str = "coral", + model: str = "", speed: float = 1.0, instruction: Optional[str] = None, emotion: Optional[str] = None, care_mode: bool = False, - response_format: AudioFormat = "mp3" + response_format: str = "mp3", + language: Optional[str] = None, + persona: Optional[str] = None, + speaking_rate: Optional[float] = None, ) -> Dict[str, Any]: - """ - 使用 OpenAI TTS API 將文字轉語音(非串流版) - - Args: - text: 要轉換的文字 - voice: 聲音類型(coral, sage, verse, alloy, echo, fable, onyx, nova, shimmer) - model: TTS 模型(gpt-4o-mini-tts 或 tts-1-hd) - speed: 語速(0.25 到 4.0) - instruction: 語音指令(手動指定,優先級最高) - emotion: 情緒標籤(自動選擇 instruction) - care_mode: 是否為關懷模式(使用特殊語氣) - response_format: 音頻格式(mp3, opus, aac, flac, wav, pcm) - - Returns: - { - "success": bool, - "audio_data": bytes, - "voice": str, - "format": str, - "error": str (optional) - } - """ - if not self.client: + if not self.api_key: return { "success": False, "audio_data": None, - "error": "OpenAI client 未初始化" + "error": "GOOGLE_TTS_API_KEY 未設定", } - try: - logger.info(f"🔊 開始 TTS 合成,文字長度: {len(text)}, 聲音: {voice}") - - # 調用 OpenAI TTS API(2025 最佳實踐:支援情緒指令) - params = { - "model": model, - "voice": voice, - "input": text, - "speed": speed, - "response_format": response_format - } - - # 選擇 instruction(優先級:手動 > 情緒自動選擇) - final_instruction = instruction or get_emotion_instruction(emotion, care_mode) - - # 如果提供情緒指令(gpt-4o-mini-tts 支援) - if final_instruction and model == "gpt-4o-mini-tts": - params["instructions"] = final_instruction - logger.info(f"🎭 TTS 語氣指令: {final_instruction}") - - response = await self.client.audio.speech.create(**params) - - # 獲取音頻數據 - audio_data = response.content - - logger.info(f"✅ TTS 合成成功,音頻大小: {len(audio_data)} bytes") - - return { - "success": True, - "audio_data": audio_data, - "voice": voice, - "format": response_format - } + # 清理文字 + text = self._clean_text_for_tts(text) + if not text: + return {"success": False, "audio_data": None, "error": "文字不可為空"} + + # 稍微提高預設語速 (1.1x) + effective_rate = float(speaking_rate if speaking_rate is not None else speed or 1.1) + speaking_rate = max(0.25, min(4.0, effective_rate * get_emotion_rate(emotion, care_mode))) + audio_encoding = "MP3" if response_format != "wav" else "LINEAR16" + + voice_cfg = self._voice_config(voice, language=language, persona=persona) + + logger.info( + "🎤 TTS 合成請求: text_len=%d, voice=%s, lang=%s, rate=%.2f, format=%s", + len(text), voice_cfg["name"], voice_cfg["languageCode"], speaking_rate, audio_encoding + ) + + payload = { + "input": {"text": text}, + "voice": voice_cfg, + "audioConfig": { + "audioEncoding": audio_encoding, + "speakingRate": speaking_rate, + }, + } - except Exception as e: - logger.exception(f"❌ TTS 合成失敗: {e}") - return { - "success": False, - "audio_data": None, - "error": str(e) - } + async with aiohttp.ClientSession() as session: + data = await self._post_synthesize(session, payload) + if not data.get("success"): + error = data.get("error", "") + if "does not exist" in error or "misspelled" in error: + fallback_payload = dict(payload) + fallback_payload["voice"] = {"languageCode": payload["voice"]["languageCode"]} + logger.warning("Google TTS voice %s unavailable, retrying with language only", payload["voice"].get("name")) + data = await self._post_synthesize(session, fallback_payload) + if data.get("success"): + payload = fallback_payload + if not data.get("success"): + logger.error("❌ Google TTS 合成失敗: %s", data.get("error")) + return { + "success": False, + "audio_data": None, + "error": data.get("error", "Google TTS 合成失敗"), + } + + audio_content = data.get("audioContent") + if not audio_content: + logger.error("❌ Google TTS 未返回音訊內容") + return {"success": False, "audio_data": None, "error": "Google TTS 未返回音訊"} + + audio_bytes = base64.b64decode(audio_content) + logger.info("✅ TTS 合成完成: size=%d bytes", len(audio_bytes)) + + return { + "success": True, + "audio_data": audio_bytes, + "voice": payload["voice"]["name"], + "format": "mp3" if audio_encoding == "MP3" else "wav", + } - async def synthesize_stream( + async def _post_synthesize(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> Dict[str, Any]: + async with session.post( + GOOGLE_TTS_ENDPOINT, + params={"key": self.api_key}, + json=payload, + timeout=30, + ) as resp: + data = await resp.json(content_type=None) + if resp.status >= 400: + logger.error("Google TTS HTTP failed: status=%s body=%s", resp.status, data) + return { + "success": False, + "error": data.get("error", {}).get("message", "Google TTS 合成失敗"), + } + data["success"] = True + return data + + async def synthesize_stream(self, *args, **kwargs) -> AsyncIterator[bytes]: + result = await self.synthesize(*args, **kwargs) + if result.get("success") and result.get("audio_data"): + yield result["audio_data"] + + async def streaming_synthesize( self, text: str, - voice: VoiceType = "coral", - model: str = "gpt-4o-mini-tts", + voice: str = "coral", speed: float = 1.0, - instruction: Optional[str] = None, + language: Optional[str] = None, + persona: Optional[str] = None, + speaking_rate: Optional[float] = None, + markup: Optional[str] = None, + custom_pronunciations: Optional[List[Dict[str, Any]]] = None, emotion: Optional[str] = None, care_mode: bool = False, - response_format: AudioFormat = "pcm" ) -> AsyncIterator[bytes]: - """ - 使用 OpenAI TTS API 串流生成語音(邊生成邊播放,低延遲) - - Args: - text: 要轉換的文字 - voice: 聲音類型 - model: TTS 模型 - speed: 語速 - instruction: 語音指令(手動指定,優先級最高) - emotion: 情緒標籤(自動選擇 instruction) - care_mode: 是否為關懷模式 - response_format: 音頻格式(建議用 pcm 以獲得最低延遲) - - Yields: - bytes: 音頻數據塊 - """ - if not self.client: - logger.error("❌ OpenAI client 未初始化") - return - try: - logger.info(f"🔊 開始 TTS 串流合成,文字長度: {len(text)}, 聲音: {voice}") - - # 調用 OpenAI TTS API(串流模式) - params = { - "model": model, - "voice": voice, - "input": text, - "speed": speed, - "response_format": response_format - } - - # 選擇 instruction(優先級:手動 > 情緒自動選擇) - final_instruction = instruction or get_emotion_instruction(emotion, care_mode) + # 🎯 2026 最佳實踐:延遲初始化 AsyncClient 以重用 gRPC Channel + if self._async_client is None: + if not self._grpc_credentials: + logger.warning("⚠️ GOOGLE_SPEECH_* 服務帳戶未設定,無法啟用 Chirp3-HD 串流 TTS,將回退到 REST") + # 這裡直接拋出一個特定的錯誤,讓外層捕捉並執行回退 + raise RuntimeError("Missing gRPC credentials") + + from google.cloud import texttospeech_v1beta1 as texttospeech + try: + # 建立持久化 Client,自動處理連線池與 Keepalive + self._async_client = texttospeech.TextToSpeechAsyncClient( + credentials=self._grpc_credentials, + client_options={ + "api_endpoint": "texttospeech.googleapis.com", + } + ) + logger.debug("📡 已建立持久化 Google TTS gRPC 串流連線") + except Exception as client_err: + logger.error(f"❌ 無法建立 TTS AsyncClient: {client_err}") + raise + + from google.cloud import texttospeech_v1beta1 as texttospeech + + # 清理文字 + cleaned_text = self._clean_text_for_tts(text) + cleaned_markup = (markup or "").strip() + if not cleaned_text and not cleaned_markup: + return + + voice_cfg = self._voice_config(voice, language=language, persona=persona) + persona_prompt = self._persona_prompt(persona, language) + + # 🎯 2026 最佳實踐:根據情緒與關懷模式動態調整語速 + effective_rate = float(speaking_rate if speaking_rate is not None else speed or 1.1) + speaking_rate = max(0.25, min(4.0, effective_rate * get_emotion_rate(emotion, care_mode))) - if final_instruction and model == "gpt-4o-mini-tts": - params["instructions"] = final_instruction - logger.info(f"🎭 TTS 串流語氣指令: {final_instruction}") + logger.debug( + "📡 啟動 TTS 串流: voice=%s, lang=%s, rate=%.2f, emotion=%s, care_mode=%s", + voice_cfg["name"], voice_cfg["languageCode"], speaking_rate, emotion, care_mode + ) + + filtered_pronunciations = self._filter_custom_pronunciations_for_language( + custom_pronunciations, + voice_cfg["languageCode"], + cleaned_markup or cleaned_text, + ) + custom_pronunciations_obj = self._build_custom_pronunciations(filtered_pronunciations, texttospeech) + + synthesis_input_kwargs: Dict[str, Any] = {} + if cleaned_markup: + synthesis_input_kwargs["markup"] = cleaned_markup + else: + synthesis_input_kwargs["text"] = cleaned_text + if persona_prompt: + synthesis_input_kwargs["prompt"] = persona_prompt - async with self.client.audio.speech.with_streaming_response.create(**params) as response: - logger.info("✅ TTS 串流已啟動") - - # 逐塊產出音頻數據 - async for chunk in response.iter_bytes(chunk_size=4096): - if chunk: - yield chunk + # 🎯 使用持久化的 AsyncClient + client = self._async_client + + streaming_config = texttospeech.StreamingSynthesizeConfig( + voice=texttospeech.VoiceSelectionParams( + language_code=voice_cfg["languageCode"], + name=voice_cfg["name"], + ), + streaming_audio_config=texttospeech.StreamingAudioConfig( + audio_encoding=texttospeech.AudioEncoding.PCM, + sample_rate_hertz=24000, + speaking_rate=speaking_rate, + ), + custom_pronunciations=custom_pronunciations_obj, + ) + + async def request_iter(): + yield texttospeech.StreamingSynthesizeRequest(streaming_config=streaming_config) + yield texttospeech.StreamingSynthesizeRequest( + input=texttospeech.StreamingSynthesisInput(**synthesis_input_kwargs) + ) - logger.info("✅ TTS 串流完成") + total_chunks = 0 + total_bytes = 0 + + # 🎯 設定超時時間,避免無限等待 + response_iter = await client.streaming_synthesize(requests=request_iter(), timeout=20.0) + async for response in response_iter: + chunk = getattr(response, "audio_content", b"") + if chunk: + total_chunks += 1 + total_bytes += len(chunk) + yield bytes(chunk) + + logger.debug("✅ TTS 串流完成: total_chunks=%d, total_bytes=%d", total_chunks, total_bytes) except Exception as e: - logger.exception(f"❌ TTS 串流失敗: {e}") - - async def play_locally( - self, - text: str, - voice: VoiceType = "coral", - model: str = "gpt-4o-mini-tts", - speed: float = 1.0, - instruction: Optional[str] = None - ) -> Dict[str, Any]: - """ - 使用 LocalAudioPlayer 直接播放語音(本地測試用) - - Args: - text: 要轉換的文字 - voice: 聲音類型 - model: TTS 模型 - speed: 語速 - instruction: 語音指令 - - Returns: - { - "success": bool, - "error": str (optional) - } - """ - if not self.client: - return { - "success": False, - "error": "OpenAI client 未初始化" - } - - try: - logger.info(f"🔊 開始本地播放,文字長度: {len(text)}, 聲音: {voice}") - - params = { - "model": model, - "voice": voice, - "input": text, - "speed": speed, - "response_format": "pcm" - } - - if instruction and model == "gpt-4o-mini-tts": - params["instructions"] = instruction - - async with self.client.audio.speech.with_streaming_response.create(**params) as response: - await LocalAudioPlayer().play(response) - - logger.info("✅ 本地播放完成") - - return { - "success": True - } + error_msg = str(e) or repr(e) + logger.warning(f"📡 gRPC 串流 TTS 失敗 (回退中): {error_msg}") + # 重置 Client 以便下次重建連線 + self._async_client = None + + try: + # 調用 REST 版 synthesize 作為回退方案 + res = await self.synthesize( + text=text, + voice=voice, + speed=speed, + language=language, + persona=persona, + speaking_rate=speaking_rate, + emotion=emotion, + care_mode=care_mode, + # 注意:REST 版不支援 markup,所以傳入純文字 + ) + if res.get("success") and res.get("audio_data"): + logger.debug("✅ 已通過 REST API 完成 TTS 回退合成") + yield res["audio_data"] + return + except Exception as fallback_err: + logger.error("❌ TTS 回退方案也失敗: %s", fallback_err) + + logger.exception("❌ TTS 串流中斷且回退失敗") + raise - except Exception as e: - logger.exception(f"❌ 本地播放失敗: {e}") - return { - "success": False, - "error": str(e) - } + async def play_locally(self, text: str, voice: str = "coral", **kwargs) -> Dict[str, Any]: + return await self.synthesize(text=text, voice=voice, **kwargs) -# 全域單例 tts_service = TTSService() async def text_to_speech( text: str, - voice: VoiceType = "coral", + voice: str = "coral", speed: float = 1.0, - instruction: Optional[str] = None + instruction: Optional[str] = None, + language: Optional[str] = None, + persona: Optional[str] = None, + speaking_rate: Optional[float] = None, ) -> Dict[str, Any]: - """ - 便捷函數:將文字轉為語音(非串流) - - Args: - text: 要轉換的文字 - voice: 聲音類型(coral, sage, verse, alloy, echo, fable, onyx, nova, shimmer) - speed: 語速(0.25 到 4.0) - instruction: 語音指令(如「用溫柔、安慰的語氣說話」) - - Returns: - { - "success": bool, - "audio_data": bytes, - "voice": str, - "format": str, - "error": str (optional) - } - """ - return await tts_service.synthesize(text, voice, speed=speed, instruction=instruction) + return await tts_service.synthesize( + text, + voice, + speed=speed, + instruction=instruction, + language=language, + persona=persona, + speaking_rate=speaking_rate, + ) async def text_to_speech_stream( text: str, - voice: VoiceType = "coral", + voice: str = "coral", speed: float = 1.0, - instruction: Optional[str] = None + instruction: Optional[str] = None, ) -> AsyncIterator[bytes]: - """ - 便捷函數:將文字轉為語音(串流模式,低延遲) - - Args: - text: 要轉換的文字 - voice: 聲音類型 - speed: 語速 - instruction: 語音指令 - - Yields: - bytes: 音頻數據塊 - """ async for chunk in tts_service.synthesize_stream(text, voice, speed=speed, instruction=instruction): yield chunk - - -async def test_tts_playback( - text: str = "今天是美好的一天!", - voice: VoiceType = "coral", - instruction: Optional[str] = "用開心、愉悅的語氣說話" -) -> None: - """ - 快速測試 TTS 播放(使用 LocalAudioPlayer) - - Args: - text: 要播放的文字 - voice: 聲音類型 - instruction: 語音指令 - """ - result = await tts_service.play_locally(text, voice=voice, instruction=instruction) - if result["success"]: - logger.debug(f"✅ 播放成功:{text}") - else: - logger.debug(f"❌ 播放失敗:{result.get('error')}") - - -if __name__ == "__main__": - # 測試範例:播放中文語音 - asyncio.run(test_tts_playback( - text="你好!我是 BloomWare 智能助手,很高興為你服務!", - voice="coral", - instruction="用溫暖、友善的語氣說話" - )) diff --git a/services/voice_login.py b/services/voice_login.py index fa2f1894483c5cd34b31c181fd148ab52dc4afb4..d8f775b95a3bb71ca967983b4aa64d1761a56966 100644 --- a/services/voice_login.py +++ b/services/voice_login.py @@ -184,11 +184,28 @@ class VoiceAuthService: end = start + bytes_per_window windows.append(bytes(buf[start:end])) - # 品質檢查(SNR) - for w in windows: - snr_db = self._estimate_snr_db(w) + # 品質檢查(SNR)僅作診斷,不再作為 hard gate + quality_warnings: List[Dict[str, Any]] = [] + for idx, w in enumerate(windows): + signal_stats = self._analyze_signal_stats(w, sr) + snr_db = float(signal_stats["snr_db"]) if snr_db < self.config.min_snr_db: - return {"success": False, "error": "LOW_SNR", "snr_db": snr_db} + warning = { + "type": "LOW_SNR", + "window_index": idx, + **signal_stats, + } + quality_warnings.append(warning) + logging.warning( + "VOICE_LOW_SNR warning only: window=%d snr_db=%.2f rms_all=%.6f noise_floor=%.6f voiced_ratio=%.3f duration_sec=%.3f threshold=%.2f", + idx, + signal_stats["snr_db"], + signal_stats["rms_all"], + signal_stats["noise_floor"], + signal_stats["voiced_ratio"], + signal_stats["duration_sec"], + self.config.min_snr_db, + ) # 視窗逐一評估 win_results: List[Dict[str, Any]] = [] @@ -234,10 +251,11 @@ class VoiceAuthService: "windows": win_results, "emotion": emotion, "note": "override_high_confidence", + "quality_warnings": quality_warnings, } except Exception: pass - return {"success": False, "error": "INCONSISTENT_WINDOWS", "windows": win_results} + return {"success": False, "error": "INCONSISTENT_WINDOWS", "windows": win_results, "quality_warnings": quality_warnings} probs = [float(r.get("score", 0.0)) for r in win_results] margins_ok = all(float(r.get("margin", 0.0)) >= self.config.margin_threshold for r in win_results) per_win_ok = all( @@ -251,6 +269,7 @@ class VoiceAuthService: "error": "THRESHOLD_NOT_MET", "avg_prob": avg_prob, "windows": win_results, + "quality_warnings": quality_warnings, } label = labels[0] @@ -266,6 +285,7 @@ class VoiceAuthService: "avg_prob": avg_prob, "windows": win_results, "emotion": emotion, + "quality_warnings": quality_warnings, } # -------------- 私有工具 -------------- @@ -356,28 +376,59 @@ class VoiceAuthService: f.write(tmp.getvalue()) return tmp_path - def _estimate_snr_db(self, pcm_bytes: bytes) -> float: - """粗估 SNR:以整段 RMS 與背景估計。簡化:取信號 RMS 與最小能量窗比。""" + def _analyze_signal_stats(self, pcm_bytes: bytes, sr: int) -> Dict[str, float]: + """粗估音訊品質:SNR、整段 RMS、噪音底、有效語音比例、時長。""" try: x = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) if x.size == 0: - return 0.0 + return { + "snr_db": 0.0, + "rms_all": 0.0, + "noise_floor": 0.0, + "voiced_ratio": 0.0, + "duration_sec": 0.0, + } x = x / 32768.0 frame = 1024 rms_all = np.sqrt(np.mean(x * x) + 1e-12) if len(x) < frame: - return 20.0 * np.log10(max(rms_all, 1e-6) / 1e-6) + noise = 1e-6 + snr = 20.0 * np.log10(max(rms_all, noise) / noise) + return { + "snr_db": float(snr), + "rms_all": float(rms_all), + "noise_floor": float(noise), + "voiced_ratio": 1.0 if rms_all > noise * 2.0 else 0.0, + "duration_sec": float(len(x) / max(sr, 1)), + } # 取移動窗最小 RMS 視為噪音底 - mins = [] + mins: List[float] = [] for i in range(0, len(x) - frame + 1, frame): seg = x[i : i + frame] mins.append(np.sqrt(np.mean(seg * seg) + 1e-12)) noise = float(np.percentile(mins, 10)) if mins else (rms_all * 0.5) noise = max(noise, 1e-6) snr = 20.0 * np.log10(max(rms_all, noise) / noise) - return float(snr) + voiced_threshold = max(noise * 2.0, 5e-4) + voiced_ratio = float(np.mean(np.array(mins) > voiced_threshold)) if mins else 0.0 + return { + "snr_db": float(snr), + "rms_all": float(rms_all), + "noise_floor": float(noise), + "voiced_ratio": voiced_ratio, + "duration_sec": float(len(x) / max(sr, 1)), + } except Exception: - return 0.0 + return { + "snr_db": 0.0, + "rms_all": 0.0, + "noise_floor": 0.0, + "voiced_ratio": 0.0, + "duration_sec": 0.0, + } + + def _estimate_snr_db(self, pcm_bytes: bytes) -> float: + return float(self._analyze_signal_stats(pcm_bytes, self.config.sample_rate)["snr_db"]) def _preprocess_bytes(self, pcm_bytes: bytes, sr: int) -> bytes: """簡易降噪 + 正規化(去 DC、軟性降噪、峰值歸一化)。""" @@ -414,6 +465,16 @@ class VoiceAuthService: except Exception: return pcm_bytes + # 🎯 情緒標籤映射(與 AudioEmotionService 保持一致,確保前端能正確識別) + EMOTION_MAP = { + "生氣(angry)": "angry", + "恐懼(fear)": "fear", + "開心(happy)": "happy", + "中性(neutral)": "neutral", + "悲傷(sad)": "sad", + "驚訝(surprise)": "surprise" + } + def _infer_emotion_from_bytes(self, pcm_bytes: bytes, sr: int) -> Optional[Dict[str, Any]]: try: if not self._emo_predict or not self._emo_id2class: @@ -421,9 +482,14 @@ class VoiceAuthService: wav_path = self._bytes_to_wav(pcm_bytes, sr) try: pred_id, confidence, distribution = self._emo_predict(str(wav_path)) # type: ignore[misc] - label = self._emo_id2class(int(pred_id)) # type: ignore[misc] + raw_label = self._emo_id2class(int(pred_id)) # type: ignore[misc] + + # 🎯 映射為標準英文標籤 + label = self.EMOTION_MAP.get(raw_label, "neutral") + return { "label": label, + "raw_label": raw_label, # 保留原始標籤供 Welcome Message 使用 "confidence": float(confidence), "distribution": distribution, } diff --git a/services/welcome.py b/services/welcome.py index b4215114a0c4e635a1182d24fbf2ba6c4b6b3cbf..a2d134d37c860aa45a636b7f579b256f0b0c3afa 100644 --- a/services/welcome.py +++ b/services/welcome.py @@ -44,18 +44,22 @@ def _derive_period_from_hour(hour: int) -> str: def _mood_from_emotion_label(emo_label: str) -> str: if not emo_label: return "很高興再次見到你!" - if "開心" in emo_label: + + emo_label = emo_label.lower() + + if "開心" in emo_label or "happy" in emo_label: return "您今天心情感覺不錯喔!" - if "悲傷" in emo_label: + if "悲傷" in emo_label or "sad" in emo_label: return "今天心情有點低落,我在這陪你。" - if "生氣" in emo_label: + if "生氣" in emo_label or "angry" in emo_label: return "看起來有點不爽,想聊聊發生什麼事嗎?" - if "恐懼" in emo_label: + if "恐懼" in emo_label or "fear" in emo_label: return "別擔心,有我在,慢慢來。" - if "驚訝" in emo_label: + if "驚訝" in emo_label or "surprise" in emo_label: return "哇,今天似乎有新鮮事!" - if "中性" in emo_label: + if "中性" in emo_label or "neutral" in emo_label: return "很高興再次見到你!" + return "很高興再次見到你!" diff --git a/static/frontend/index.html b/static/frontend/index.html index 9b1739417f3e2c028765b8ecec244c06f36ea809..af4f82d332cfbc20311c56d0428b2d9d04ffa309 100644 --- a/static/frontend/index.html +++ b/static/frontend/index.html @@ -1,12 +1,16 @@ + - + Bloom Ware 語音沉浸式 - 光暈花瓣版 - + +

🎛️ 控制面板

@@ -1638,7 +2137,8 @@
-
+
技術限制:
• Web Speech API 不支援持續監聽
• 需要用戶點擊才能啟動麥克風
@@ -1664,6 +2164,7 @@
💬
+
@@ -1672,31 +2173,35 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
+
-
- 請說話... -
+
+ 請說話...
@@ -1726,15 +2231,15 @@ + diff --git a/static/frontend/js/agent.js b/static/frontend/js/agent.js index 6ae862a465a1ef19ad8e11d09f5be5bae6adb972..f18b594de88bc283837118cde0cd3e060db4ffa2 100644 --- a/static/frontend/js/agent.js +++ b/static/frontend/js/agent.js @@ -1,20 +1,54 @@ -let currentState = 'idle'; +window.currentState = 'idle'; +window.thinkingTimeout = null; +window.stateBufferTimeout = null; +/** + * 設置 Agent 狀態,加入緩衝機制防止動畫閃爍 + */ function setState(newState, options = {}) { - if (currentState === newState) { + if (window.currentState === newState && !window.stateBufferTimeout) { return; } - const oldState = currentState; - currentState = newState; + // 如果有待處理的狀態轉換,先清除它 + if (window.stateBufferTimeout) { + clearTimeout(window.stateBufferTimeout); + window.stateBufferTimeout = null; + } + + // 針對進入 idle 狀態加入微小延遲,避免在快速切換(如 speaking -> idle -> recording)時花瓣動畫跳轉 + if (newState === 'idle') { + window.stateBufferTimeout = setTimeout(() => { + window.stateBufferTimeout = null; + applyStateChange('idle', options); + }, 150); + } else { + applyStateChange(newState, options); + } +} +/** + * 實際執行狀態變更與 UI 更新 + */ +function applyStateChange(newState, options) { + const oldState = window.currentState; + window.currentState = newState; + console.log(`🔄 狀態切換: ${oldState} -> ${newState}`); + + // 清除之前的思考超時計時器 + if (window.thinkingTimeout) { + clearTimeout(window.thinkingTimeout); + window.thinkingTimeout = null; + } + + // 更新麥克風容器樣式 micContainer.classList.remove('recording', 'thinking', 'speaking', 'disconnected'); switch(newState) { case 'idle': - hideAgentOutput(); + // idle 狀態不應主動隱藏輸出,讓使用者能看清最後的回覆 if (typeof stopSpeaking === 'function') { - stopSpeaking(); + stopSpeaking(true, 'state_idle'); } if (options.clearCards !== false) { clearAllCards(); @@ -37,8 +71,17 @@ function setState(newState, options = {}) { micContainer.classList.add('thinking'); hideAgentOutput(); if (typeof stopSpeaking === 'function') { - stopSpeaking(); + stopSpeaking(true, 'state_thinking'); } + + // 設定思考超時重置 (45秒) + window.thinkingTimeout = setTimeout(() => { + if (window.currentState === 'thinking') { + console.warn('⚠️ 思考時間過長,自動重置'); + showErrorNotification('抱歉,處理時間過長,請再試一次。'); + resetAgent({clearCards: false}); + } + }, 45000); break; case 'speaking': @@ -52,7 +95,7 @@ function setState(newState, options = {}) { micContainer.classList.add('disconnected'); hideAgentOutput(); if (typeof stopSpeaking === 'function') { - stopSpeaking(); + stopSpeaking(true, 'state_disconnected'); } clearAllCards(); break; @@ -62,25 +105,57 @@ function setState(newState, options = {}) { } } -function applyEmotion(emotion) { +window.currentEmotion = 'neutral'; +window.isInCareMode = false; + +function applyEmotion(emotion, careMode = null) { + if (careMode !== null) { + window.isInCareMode = !!careMode; + } const validEmotions = ['neutral', 'happy', 'sad', 'angry', 'fear', 'surprise']; + + // 🎯 支援多語言/原始標籤映射,確保 100% 信心 + const mapping = { + '悲傷(sad)': 'sad', '悲傷': 'sad', + '開心(happy)': 'happy', '開心': 'happy', + '生氣(angry)': 'angry', '生氣': 'angry', + '恐懼(fear)': 'fear', '恐懼': 'fear', + '驚訝(surprise)': 'surprise', '驚訝': 'surprise', + '中性(neutral)': 'neutral', '中性': 'neutral' + }; + + if (mapping[emotion]) { + emotion = mapping[emotion]; + } + if (!validEmotions.includes(emotion)) { emotion = 'neutral'; } background.className = `voice-immersive-background emotion-${emotion} active`; emotionIndicator.textContent = `當前情緒: ${emotionEmojis[emotion]}`; + + // 🎯 保存到全域狀態,供 TTS 使用 + window.currentEmotion = emotion; } function showErrorNotification(message) { console.error('🚨 錯誤:', message); - setState('speaking', { - outputText: `抱歉,發生錯誤:${message}`, - enableTTS: false - }); + // 立即重置 Agent 狀態(關閉花朵,停止錄音/語音) + resetAgent({ clearCards: false }); - setTimeout(() => setState('idle'), 3000); + // 顯示錯誤訊息於輸出區域,但不切換到 speaking 狀態(讓花朵保持 idle) + if (typeof typewriterEffect === 'function') { + typewriterEffect(`抱歉,發生錯誤:${message}`, 40, false); + } + + // 3秒後自動隱藏錯誤訊息 + setTimeout(() => { + if (currentState === 'idle') { + hideAgentOutput(); + } + }, 5000); } @@ -89,9 +164,33 @@ let isDisconnected = false; let isRecording = false; let isSpeaking = false; +function resetAgent(options = {}) { + isRecording = false; + isThinking = false; + isSpeaking = false; + isDisconnected = false; + + if (typeof stopSpeaking === 'function') { + stopSpeaking(true, 'reset_agent'); + } + + if (typeof stopRealAudioAnalysis === 'function') { + stopRealAudioAnalysis(); + } + + if (wsManager && typeof wsManager.stopRecording === 'function') { + wsManager.stopRecording(); + } + + if (typeof transcript !== 'undefined' && transcript) { + transcript.textContent = ''; + } + + setState('idle', options); +} + function initAgentControls() { - micContainer.addEventListener('click', async () => { - + const handleMicInteraction = async () => { if (currentState === 'recording') { isRecording = false; @@ -109,13 +208,16 @@ function initAgentControls() { if (currentState === 'idle' || currentState === 'disconnected' || currentState === 'speaking') { if (currentState === 'speaking' && typeof stopSpeaking === 'function') { - stopSpeaking(); + stopSpeaking(true, 'mic_interrupt'); } isRecording = true; setState('recording', { keepOutput: true, // 保留前次 Agent 回應 - keepCards: true // 保留前次工具卡片 + keepCards: true, // 保留前次工具卡片 + detect_timeout: 20.0, // 考量到 Function Calling 可能較慢 + feature_timeout: 30.0, // MCP 工具內部超時 + ai_timeout: 25.0 // 配合 Streaming }); if (typeof startRealAudioAnalysis === 'function') { @@ -126,22 +228,30 @@ function initAgentControls() { const success = await wsManager.startRecording(); if (!success) { console.error('❌ 錄音啟動失敗'); - setState('idle'); - isRecording = false; - if (typeof stopRealAudioAnalysis === 'function') { - stopRealAudioAnalysis(); - } + resetAgent(); } } else { console.error('❌ WebSocket 管理器未初始化'); - setState('idle'); - isRecording = false; - if (typeof stopRealAudioAnalysis === 'function') { - stopRealAudioAnalysis(); - } + resetAgent(); } } - }); + }; + + // 點擊麥克風中心 + micContainer.addEventListener('click', handleMicInteraction); + + // 點擊波形容器(較大區域)也觸發交互,提高可用性 + const waveformContainer = document.querySelector('.voice-waveform-container'); + if (waveformContainer) { + waveformContainer.style.cursor = 'pointer'; + waveformContainer.addEventListener('click', (e) => { + // 如果點擊的是 micContainer 內部,就不重複觸發(事件冒泡) + if (e.target === micContainer || micContainer.contains(e.target)) { + return; + } + handleMicInteraction(); + }); + } document.getElementById('toggle-recording').addEventListener('click', async () => { isRecording = !isRecording; diff --git a/static/frontend/js/app.js b/static/frontend/js/app.js index 5c90eeb066e51ae14cc9f1f6db0ef915f9644160..b3f0c6db2b46217ebda27660b2dedb74521961a9 100644 --- a/static/frontend/js/app.js +++ b/static/frontend/js/app.js @@ -88,22 +88,34 @@ async function requestRequiredPermissions() { if (navigator.geolocation) { try { + // 在 2026 現代瀏覽器中,可以先檢查 permission 狀態,避免跳警告 + if (navigator.permissions && navigator.permissions.query) { + const permission = await navigator.permissions.query({ name: 'geolocation' }); + if (permission.state === 'denied') { + throw { code: 1, message: 'User denied Geolocation' }; // 模擬 code 1: PERMISSION_DENIED + } + } + await new Promise((resolve, reject) => { navigator.geolocation.getCurrentPosition( (position) => { resolve(position); }, (error) => { - console.warn('⚠️ 地理位置權限被拒絕:', error); reject(error); }, { enableHighAccuracy: false, timeout: 10000, maximumAge: 0 } ); }); } catch (error) { - console.warn('⚠️ 地理位置權限被拒絕,部分功能(如查詢附近公車)將無法使用'); - if (typeof showErrorNotification === 'function') { - showErrorNotification('建議允許地理位置權限以使用完整功能(如查詢附近公車、天氣等)'); + if (error.code === 1) { // PERMISSION_DENIED + console.warn('⚠️ 地理位置權限被明確拒絕:', error.message); + if (typeof showErrorNotification === 'function') { + showErrorNotification('建議允許地理位置權限以使用完整功能(如查詢附近公車、天氣等)'); + } + } else { // POSITION_UNAVAILABLE (2) 或 TIMEOUT (3) + console.warn(`⚠️ 無法透過 GPS 獲得位置縮小誤差 (代碼: ${error.code}),前端環境準備回退。`); + // 我們讓專門處理定位的 location.js 中的機制作後續 fallback,避免在權限申請階段就卡死或誤報 } } } else { diff --git a/static/frontend/js/location.js b/static/frontend/js/location.js index 138d7281de60367cf0e7cb946f9dff32e35bfd62..b98d75edc32ff285445b1eea74575b73da7ef8fe 100644 --- a/static/frontend/js/location.js +++ b/static/frontend/js/location.js @@ -5,6 +5,7 @@ let lastPosition = null; let lastSentPosition = null; // 上次發送的位置 let lastSendTime = 0; // 上次發送時間 let isTracking = false; +let isIPFallbackTriggered = false; // 避免 IP 定位服務重複觸發 const MIN_SEND_INTERVAL = 60000; // 最小發送間隔:60 秒 const MIN_DISTANCE_CHANGE = 100; // 最小距離變化:100 米 @@ -114,20 +115,62 @@ function handlePositionError(error) { case error.POSITION_UNAVAILABLE: errorMessage = '無法取得位置資訊'; console.warn('⚠️ 位置資訊暫時無法取得'); + triggerIPFallback(); break; case error.TIMEOUT: errorMessage = '定位請求逾時'; console.warn('⚠️ 定位請求逾時'); + triggerIPFallback(); break; default: errorMessage = '未知錯誤'; console.warn('⚠️ 定位發生未知錯誤:', error); } + // 避免在嘗試 IP fallabck 成功前就送出 null + if (!isIPFallbackTriggered || error.code === error.PERMISSION_DENIED) { + sendEnvironmentSnapshot({ + lat: null, + lon: null, + error: errorMessage, + timestamp: Date.now() + }); + } +} + +async function triggerIPFallback() { + if (isIPFallbackTriggered) return; + isIPFallbackTriggered = true; + + try { + const res = await fetch('https://ipapi.co/json/'); + if (!res.ok) throw new Error(`IP Geo API HTTP ${res.status}`); + const data = await res.json(); + if (data.latitude && data.longitude) { + console.log(`📍 [GeoIP] IP 定位備援成功 (${data.city}, ${data.country_name})`); + const ipPosition = { + coords: { + latitude: data.latitude, + longitude: data.longitude, + accuracy: 5000, + heading: null, + speed: null + }, + timestamp: Date.now() + }; + // 利用模擬的 Position 物件更新系統 + handlePositionUpdate(ipPosition); + return; + } + } catch (err) { + console.warn('⚠️ [GeoIP] IP 定位備援同樣失敗:', err); + } + + // Fallback 失敗時依然丟回最後的 snapshot(帶著錯誤狀態) sendEnvironmentSnapshot({ lat: null, lon: null, - error: errorMessage, + error: '精準定位與 IP 定位均失敗', timestamp: Date.now() }); } diff --git a/static/frontend/js/pcm-recorder-worklet.js b/static/frontend/js/pcm-recorder-worklet.js new file mode 100644 index 0000000000000000000000000000000000000000..e1876c6a2ba12d3d403416c647d5edb62a4b4534 --- /dev/null +++ b/static/frontend/js/pcm-recorder-worklet.js @@ -0,0 +1,17 @@ +class PCMRecorderProcessor extends AudioWorkletProcessor { + process(inputs) { + const channelData = inputs[0]?.[0]; + if (channelData && channelData.length) { + const pcm16 = new Int16Array(channelData.length); + for (let i = 0; i < channelData.length; i++) { + const sample = Math.max(-1, Math.min(1, channelData[i])); + pcm16[i] = sample < 0 ? sample * 0x8000 : sample * 0x7fff; + } + this.port.postMessage(pcm16.buffer, [pcm16.buffer]); + } + + return true; + } +} + +registerProcessor('pcm-recorder-processor', PCMRecorderProcessor); diff --git a/static/frontend/js/tools.js b/static/frontend/js/tools.js index 62a735d917218830b3d072d7c468ab52596d8a48..fab8c9222a34bde2870ced6fc6093834840443d8 100644 --- a/static/frontend/js/tools.js +++ b/static/frontend/js/tools.js @@ -361,7 +361,8 @@ function getIconForTool(toolName, category) { 'tdx_metro': '🚇', 'reverse_geocode': '📍', 'forward_geocode': '📍', - 'directions': '🗺️' + 'directions': '🗺️', + 'environment_context': '🌍' }; if (iconMap[toolName]) { @@ -375,6 +376,81 @@ function getIconForTool(toolName, category) { return '🔧'; } +function displayMultipleToolCards(toolsArray) { + clearAllCards(); + if (!toolsArray || toolsArray.length === 0) return; + + if (toolDrawerContent) { + toolDrawerContent.innerHTML = ''; + } + + toolsArray.forEach((tool, index) => { + const toolName = tool.tool_name; + const toolData = tool.tool_data; + + const toolMeta = toolsMetadata[toolName] || {}; + const category = toolMeta.category || '未知'; + const icon = getIconForTool(toolName, category); + const contentHTML = renderCardContent(toolName, toolData); + + const card = document.createElement('div'); + card.className = 'voice-tool-card'; + card.dataset.type = toolName; + card.style.marginBottom = '15px'; + card.style.position = 'relative'; + + card.innerHTML = ` +
+
${icon}
+

${category}

+ ${index + 1}/${toolsArray.length} +
+
${contentHTML}
+ `; + + if (toolDrawerContent) { + toolDrawerContent.appendChild(card); + } + }); + + if (toolDrawerContent) { + showToolDrawerToggle(); + } + + const lastTool = toolsArray[toolsArray.length - 1]; + const lastToolName = lastTool.tool_name; + const lastToolData = lastTool.tool_data; + + const toolMeta = toolsMetadata[lastToolName] || {}; + const category = toolMeta.category || '未知'; + const icon = getIconForTool(lastToolName, category); + const contentHTML = renderCardContent(lastToolName, lastToolData); + + const floatingCard = document.createElement('div'); + floatingCard.className = 'voice-tool-card'; + floatingCard.dataset.type = lastToolName; + + let extraHeaderHtml = ''; + if (toolsArray.length > 1) { + extraHeaderHtml = ``; + } + + floatingCard.innerHTML = ` +
+
${icon}
+

${category}

+ ${extraHeaderHtml} +
+
${contentHTML}
+ `; + + const position = getNextPosition(); + if (position && cardsContainer) { + floatingCard.classList.add(position); + cardsContainer.appendChild(floatingCard); + } +} + function displayToolCard(toolName, toolData) { clearAllCards(); @@ -478,6 +554,10 @@ function renderCardContent(toolName, toolData) { return renderForwardGeocode(toolData); } + if (toolName === 'environment_context' || (toolData.lat && toolData.lon && toolData.label)) { + return renderEnvironmentContext(toolData); + } + if (toolData.raw_data && typeof toolData.raw_data === 'object') { return renderKeyValuePairs(toolData.raw_data); } @@ -1196,6 +1276,54 @@ function renderForwardGeocode(data) { `; } +function renderEnvironmentContext(data) { + const labels = LABELS[currentLanguage] || LABELS.zh; + const label = data.label || labels.unknown; + const detailedAddress = data.detailed_address || ''; + const lat = data.lat?.toFixed(6) || ''; + const lon = data.lon?.toFixed(6) || ''; + const accuracy = data.accuracy_m ? `${data.accuracy_m}m` : ''; + const device = data.device || {}; + const platform = device.platform || ''; + const tz = data.tz || ''; + + const mapsUrl = `https://www.google.com/maps?q=${lat},${lon}`; + + return ` +
+ 📍 ${labels.location} + ${label} +
+ ${detailedAddress && detailedAddress !== label ? ` +
+ 🏠 ${labels.address} + ${detailedAddress} +
+ ` : ''} +
+ 🌐 ${labels.coordinates} + ${lat}, ${lon} ${accuracy ? `(±${accuracy})` : ''} +
+ ${tz ? ` +
+ ⏰ 時區 + ${tz} +
+ ` : ''} + ${platform ? ` +
+ 💻 設備 + ${platform} +
+ ` : ''} + + `; +} + function renderJSONFallback(data) { return `
${JSON.stringify(data, null, 2)}
`; } diff --git a/static/frontend/js/tts.js b/static/frontend/js/tts.js index 5df68cfacc6708e24c542d505f40011e805a1c8a..62ff1f191cf4b9ec7397d0bc05ba8a9b1ce30695 100644 --- a/static/frontend/js/tts.js +++ b/static/frontend/js/tts.js @@ -1,35 +1,538 @@ -let currentAudio = null; // 當前播放的音頻對象 -let isPlaying = false; // 是否正在播放 -let audioContext = null; // 預先建立的 AudioContext(繞過自動播放限制) -let userGestureReceived = false; // 是否已收到用戶手勢 +// ============================================================ +// TTS 模組:支援即時句子串流播放 + Emoji/Markdown 清除 +// ============================================================ +let currentAudio = null; +let isPlaying = false; +let audioContext = null; +let userGestureReceived = false; +let pendingAudioUrl = null; +let ttsStreamSocket = null; +let ttsStreamNextStartAt = 0; +let _ttsActiveSources = []; +// === Streaming TTS State === +let _streamTtsQueue = []; // 待合成播放的句子佇列 +let _streamTtsProcessing = false; // 是否正在處理佇列 +let _streamTtsProcessedLen = 0; // 已送入佇列的文字長度(用於增量偵測) +let _streamTtsStopped = false; // 是否已停止(對話重置時設定) +let _streamTtsQueuedCount = 0; +let _streamTtsPlayedCount = 0; +let _streamTtsFinalText = ''; +let _streamTtsFallbackUsed = false; -function unlockAudioPlayback() { - if (userGestureReceived) { +function logTtsDebug(event, extra = {}) { + if (!window.DEBUG_MODE) { return; } + console.info('[TTS_DEBUG]', event, { + currentState: typeof currentState !== 'undefined' ? currentState : 'unknown', + queueLength: _streamTtsQueue.length, + processing: _streamTtsProcessing, + stopped: _streamTtsStopped, + isPlaying, + activeSources: _ttsActiveSources.length, + ...extra, + }); +} - try { +function maybeFinalizeSpeechPlayback() { + const typewriterActive = !!(window.typewriterState && window.typewriterState.isActive); + const speechSettled = _streamTtsQueue.length === 0 && !_streamTtsProcessing && !isPlaying; + if (!speechSettled || typewriterActive) { + return; + } + + window.agentOutputAwaitingSpeechCompletion = false; + if (typeof currentState !== 'undefined' && currentState === 'speaking') { + setState('idle', { clearCards: false }); + } +} + + +function getTtsLanguage() { + const sessionLanguage = window.currentConversationLanguage || window.currentSpeechLanguage; + if (sessionLanguage && sessionLanguage !== 'auto') { + return String(sessionLanguage); + } + const browserLanguage = navigator.language || navigator.userLanguage || 'zh-TW'; + return String(browserLanguage || 'zh-TW'); +} + +function getTtsPersona() { + return 'xiaohua'; +} + +function getTtsSpeakingRate() { + return 1.12; +} + +function normalizeTextForChirp(text) { + const source = String(text || ''); + if (!source) return ''; + + return source + .replace(/```[\s\S]*?```/g, ' ') + .replace(/`([^`]+)`/g, '$1') + .replace(/\[([^\]]+)\]\(([^)]+)\)/g, '$1') + .replace(/\*\*([^*]+)\*\*/g, '$1') + .replace(/__([^_]+)__/g, '$1') + .replace(/\*([^*\n]+)\*/g, '$1') + .replace(/_([^_\n]+)_/g, '$1') + .replace(/^#{1,6}\s+/gm, '') + .replace(/^\s*[-*]\s+/gm, '') + .replace(/^\s*\d+\.\s+/gm, '') + .replace(/[|{}[\]^<>~`]/g, ' ') + .replace(/\s*[::]\s*/g, ':') + .replace(/\s+/g, ' ') + .replace(/,/g, ',') + .replace(/。/g, '。') + .trim(); +} + +function buildXiaohuaMarkup(text) { + const source = normalizeTextForChirp(text); + if (!source) return ''; + + const sentenceLike = source + .replace(/([。!?!?])/g, '$1\n') + .replace(/([,、;;])/g, '$1[pause short]') + .split('\n') + .map(part => part.trim()) + .filter(Boolean); + + return sentenceLike.join('[pause]'); +} + +function getTtsMarkup(text) { + return buildXiaohuaMarkup(text); +} + +function getTtsCustomPronunciations(language, text = '') { + const normalized = String(language || '').toLowerCase(); + const sourceText = String(text || ''); + if (normalized.startsWith('zh')) { + const entries = []; + if (sourceText.includes('桃園')) { + entries.push({ + phrase: '桃園', + pronunciation: 'tao2 yuan2', + phonetic_encoding: 'PHONETIC_ENCODING_PINYIN' + }); + } + if (sourceText.includes('多雲')) { + entries.push({ + phrase: '多雲', + pronunciation: 'duo1 yun2', + phonetic_encoding: 'PHONETIC_ENCODING_PINYIN' + }); + } + return entries; + } + + if (normalized.startsWith('ja')) { + const entries = []; + if (sourceText.includes('Bloom Ware')) { + entries.push({ + phrase: 'Bloom Ware', + pronunciation: 'ブルームウェア', + phonetic_encoding: 'PHONETIC_ENCODING_JAPANESE_YOMIGANA' + }); + } + return entries; + } + + if (normalized.startsWith('en')) { + const entries = []; + if (sourceText.includes('Bloom Ware')) { + entries.push({ + phrase: 'Bloom Ware', + pronunciation: 'bluːm wɛr', + phonetic_encoding: 'PHONETIC_ENCODING_IPA' + }); + } + if (sourceText.includes('Chirp')) { + entries.push({ + phrase: 'Chirp', + pronunciation: 'tʃɝːp', + phonetic_encoding: 'PHONETIC_ENCODING_IPA' + }); + } + return entries; + } + + return []; +} + +function float32ToAudioBuffer(float32Array, sampleRate) { + const ctx = getAudioContext(); + const audioBuffer = ctx.createBuffer(1, float32Array.length, sampleRate); + audioBuffer.copyToChannel(float32Array, 0); + return audioBuffer; +} + +function pcm16Base64ToFloat32(base64String) { + const binary = atob(base64String); + const bytes = new Uint8Array(binary.length); + for (let i = 0; i < binary.length; i++) { + bytes[i] = binary.charCodeAt(i); + } + const pcm16 = new Int16Array(bytes.buffer); + const float32 = new Float32Array(pcm16.length); + for (let i = 0; i < pcm16.length; i++) { + float32[i] = pcm16[i] / 0x8000; + } + return float32; +} + +async function playStreamingTTS(text) { + if (!text) return false; + await ensureAudioReady(); + + if (ttsStreamSocket) { + logTtsDebug('socket_replace_close'); + try { ttsStreamSocket.close(); } catch (_) {} + ttsStreamSocket = null; + } + + return await new Promise((resolve) => { + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const socket = new WebSocket(`${protocol}//${window.location.host}/ws/tts`); + ttsStreamSocket = socket; + let resolved = false; + let sampleRate = 24000; + let receivedChunk = false; + + const finish = (ok) => { + if (resolved) return; + resolved = true; + if (ttsStreamSocket === socket) { + ttsStreamSocket = null; + } + resolve(ok); + }; + + const language = getTtsLanguage(); + const normalizedText = normalizeTextForChirp(text); + const markup = getTtsMarkup(normalizedText); + const customPronunciations = getTtsCustomPronunciations(language, normalizedText); + + socket.onopen = () => { + logTtsDebug('socket_open', { textLength: normalizedText.length, audioContextState: getAudioContext().state }); + ttsStreamNextStartAt = Math.max(getAudioContext().currentTime + 0.05, ttsStreamNextStartAt); + socket.send(JSON.stringify({ + text: normalizedText, + voice: 'nova', + speed: 1.0, + language, + persona: getTtsPersona(), + speaking_rate: getTtsSpeakingRate(), + markup, + custom_pronunciations: customPronunciations, + emotion: window.currentEmotion || 'neutral', + care_mode: window.isInCareMode || false + })); + }; + + socket.onmessage = (event) => { + const message = JSON.parse(event.data); + if (message.type === 'tts_stream_start') { + sampleRate = Number(message.sample_rate || 24000); + logTtsDebug('stream_start', { sampleRate }); + return; + } + if (message.type === 'tts_audio_chunk') { + receivedChunk = true; + logTtsDebug('audio_chunk', { base64Length: (message.audio_base64 || '').length }); + if (_streamTtsStopped) return; // 已停止播放 + + const float32 = pcm16Base64ToFloat32(message.audio_base64); + const buffer = float32ToAudioBuffer(float32, sampleRate); + const source = getAudioContext().createBufferSource(); + source.buffer = buffer; + source.connect(getAudioContext().destination); + source.start(ttsStreamNextStartAt); + ttsStreamNextStartAt += buffer.duration; + + // 追蹤來源以便停止 + _ttsActiveSources.push(source); + isPlaying = true; + logTtsDebug('audio_scheduled', { duration: buffer.duration, nextStartAt: ttsStreamNextStartAt }); + + source.onended = () => { + // 移除已結束的來源 + _ttsActiveSources = _ttsActiveSources.filter(s => s !== source); + if (getAudioContext().currentTime >= ttsStreamNextStartAt - 0.02) { + isPlaying = false; + } + logTtsDebug('audio_ended', { remainingSources: _ttsActiveSources.length }); + maybeFinalizeSpeechPlayback(); + }; + return; + } + if (message.type === 'tts_stream_end') { + logTtsDebug('stream_end', { receivedChunk }); + finish(receivedChunk); + socket.close(); + return; + } + if (message.type === 'tts_error') { + console.warn('串流 TTS 失敗:', message.error); + logTtsDebug('stream_error', { error: message.error }); + finish(false); + socket.close(); + } + }; + + socket.onerror = () => { + logTtsDebug('socket_error'); + finish(false); + }; + socket.onclose = (event) => { + logTtsDebug('socket_close', { + receivedChunk, + code: event.code, + reason: event.reason || '', + wasClean: event.wasClean, + }); + finish(receivedChunk); + }; + }); +} + +function getAudioContext() { + if (!audioContext) { audioContext = new (window.AudioContext || window.webkitAudioContext)(); + } + return audioContext; +} - const buffer = audioContext.createBuffer(1, 1, 22050); - const source = audioContext.createBufferSource(); - source.buffer = buffer; - source.connect(audioContext.destination); - source.start(0); +// === Audio Analysis for Flower Animation === +let ttsAnalyzer = null; +let ttsDataArray = null; +let ttsVisualizerAnimationId = null; +function initAnalyzer() { + const ctx = getAudioContext(); + if (!ttsAnalyzer) { + ttsAnalyzer = ctx.createAnalyser(); + ttsAnalyzer.fftSize = 256; + const bufferLength = ttsAnalyzer.frequencyBinCount; + ttsDataArray = new Uint8Array(bufferLength); + } + return ttsAnalyzer; +} + +function connectVisualizer(audioElement) { + try { + const ctx = getAudioContext(); + const source = ctx.createMediaElementSource(audioElement); + const node = initAnalyzer(); + source.connect(node); + node.connect(ctx.destination); + + startVisualizerLoop(); + } catch (err) { + // MediaElementSource might fail if already connected or other issues + console.warn('Visualizer connection failed:', err); + } +} + +function startVisualizerLoop() { + if (ttsVisualizerAnimationId) cancelAnimationFrame(ttsVisualizerAnimationId); + + const update = () => { + if (!isPlaying) { + document.documentElement.style.setProperty('--core-scale', '1'); + ttsVisualizerAnimationId = null; + return; + } + + ttsVisualizerAnimationId = requestAnimationFrame(update); + if (ttsAnalyzer) { + ttsAnalyzer.getByteFrequencyData(ttsDataArray); + let sum = 0; + // Get volume from low-mid frequencies for a better pulse effect + const count = Math.min(ttsDataArray.length, 32); + for (let i = 0; i < count; i++) { + sum += ttsDataArray[i]; + } + const average = sum / count; + // Map volume to scale: 1.0 (silent) to ~1.4 (loud) + const scale = 1 + (average / 255) * 0.45; + document.documentElement.style.setProperty('--core-scale', scale.toFixed(3)); + } + }; + + update(); +} + +async function ensureAudioReady() { + try { + audioContext = getAudioContext(); + if (audioContext.state === 'suspended') { + await audioContext.resume(); + } userGestureReceived = true; + return true; } catch (error) { - console.warn('⚠️ 無法解鎖音頻播放:', error); + console.warn('無法解鎖音頻播放:', error); + return false; } } +async function unlockAudioPlayback() { + const ready = await ensureAudioReady(); + if (!ready || !pendingAudioUrl) { + return ready; + } -async function speakText(text) { - stopSpeaking(); + const audioUrl = pendingAudioUrl; + pendingAudioUrl = null; + await playAudioUrl(audioUrl); + return true; +} + +function installAudioUnlockListeners() { + const events = ['pointerdown', 'touchstart', 'keydown']; + const unlock = () => { + unlockAudioPlayback(); + }; + + events.forEach((eventName) => { + document.addEventListener(eventName, unlock, { + passive: true, + capture: true + }); + }); +} + +async function playAudioUrl(audioUrl) { + stopSpeaking(false); + try { + await ensureAudioReady(); + + currentAudio = new Audio(audioUrl); + currentAudio.crossOrigin = "anonymous"; + currentAudio.preload = 'auto'; + + isPlaying = true; + + // Connect to visualizer ONLY after context is ready + connectVisualizer(currentAudio); + + await currentAudio.play(); + } catch (playError) { + isPlaying = false; + if (playError && playError.name === 'NotAllowedError') { + pendingAudioUrl = audioUrl; + console.warn('瀏覽器尚未允許自動播放,已排入下一次使用者手勢播放'); + return; + } + + setTimeout(() => URL.revokeObjectURL(audioUrl), 1000); + throw playError; + } +} + +// ============================================================ +// 文字清理:去除 Emoji、Markdown 符號,讓 TTS 更自然 +// ============================================================ +function cleanTextForTTS(text) { + if (!text) return ''; + return text + // 移除 Emoji(Supplementary Multilingual Plane) + .replace(/[\u{1F000}-\u{1FFFF}]/gu, '') + // 移除 Emoji 和符號(Basic Multilingual Plane) + .replace(/[\u{2600}-\u{27FF}]/gu, '') + .replace(/[\u{2B00}-\u{2BFF}]/gu, '') + .replace(/[\u{FE00}-\u{FEFF}]/gu, '') + // 移除 Markdown:粗體、斜體 + .replace(/\*{1,3}([^*\n]+)\*{1,3}/g, '$1') + .replace(/_{1,2}([^_\n]+)_{1,2}/g, '$1') + // 移除 Markdown:標題符號 + .replace(/^#{1,6}\s+/gm, '') + // 移除 Markdown:行內程式碼 + .replace(/`([^`]+)`/g, '$1') + // 移除 Markdown:連結,只保留文字 + .replace(/\[([^\]]+)\]\([^)]+\)/g, '$1') + // 移除 Markdown:刪除線 + .replace(/~~([^~]+)~~/g, '$1') + // 移除多餘空白 + .replace(/\s+/g, ' ') + .trim(); +} + +// ============================================================ +// 句子串流 TTS 系統:讓每個句子完成後立即開始 TTS 合成 +// ============================================================ + +// 句子邊界:中英文句號、驚嘆號、問號、換行 +const SENTENCE_END_RE = /[。!?!?\n]+/g; + +/** + * 從文字中提取完整句子(有句尾標點),回傳 [{text, endIdx}] + */ +function _extractCompleteSentences(text) { + const sentences = []; + SENTENCE_END_RE.lastIndex = 0; + let match; + let lastIdx = 0; + while ((match = SENTENCE_END_RE.exec(text)) !== null) { + const sentence = text.slice(lastIdx, match.index + match[0].length).trim(); + if (sentence) sentences.push(sentence); + lastIdx = match.index + match[0].length; + } + // 回傳最後一個完整句子之後的處理位置 + return { sentences, nextIdx: lastIdx }; +} + +/** + * 等待 Audio 播放完畢的 Promise + */ +function _playAndWait(audioUrl) { + return new Promise((resolve) => { + ensureAudioReady().then(() => { + const audio = new Audio(audioUrl); + audio.crossOrigin = "anonymous"; + audio.preload = 'auto'; + currentAudio = audio; + + isPlaying = true; + + // Connect to visualizer + connectVisualizer(audio); + + const cleanup = () => { + isPlaying = false; + setTimeout(() => URL.revokeObjectURL(audioUrl), 1000); + maybeFinalizeSpeechPlayback(); + resolve(); + }; + + audio.onended = cleanup; + audio.onerror = cleanup; + + audio.play().catch(cleanup); + }).catch((err) => { + console.error('ensureAudioReady 失敗:', err); + resolve(); + }); + }); +} + +/** + * 合成並播放單一句子(回傳 Promise,播完才 resolve) + */ +async function _synthesizeAndPlay(text) { + const cleaned = cleanTextForTTS(normalizeTextForChirp(text)); + if (!cleaned) return false; try { + const streamed = await playStreamingTTS(cleaned); + if (streamed) { + return true; + } const response = await fetch('/api/tts', { method: 'POST', @@ -38,89 +541,199 @@ async function speakText(text) { 'Authorization': `Bearer ${localStorage.getItem('jwt_token')}` }, body: JSON.stringify({ - text: text, + text: cleaned, voice: 'nova', - speed: 1.0 + speed: 1.0, + language: getTtsLanguage(), + persona: getTtsPersona(), + speaking_rate: getTtsSpeakingRate() }) }); if (!response.ok) { - const error = await response.json(); - console.error('❌ TTS API 錯誤:', error); - return; + console.warn('TTS API 錯誤:', response.status); + return false; } const audioBlob = await response.blob(); const audioUrl = URL.createObjectURL(audioBlob); + await _playAndWait(audioUrl); + return true; + } catch (error) { + console.warn('TTS 合成播放失敗:', error); + return false; + } +} - currentAudio = new Audio(audioUrl); - isPlaying = true; +function shouldFallbackToFullTTS() { + return _streamTtsStopped || (_streamTtsQueuedCount === 0 && !_streamTtsProcessing && !isPlaying); +} - currentAudio.onended = () => { - isPlaying = false; - URL.revokeObjectURL(audioUrl); - }; +function hasPendingStreamingSpeech() { + return _streamTtsQueue.length > 0 || _streamTtsProcessing || isPlaying || !!ttsStreamSocket || _ttsActiveSources.length > 0; +} - currentAudio.onerror = (e) => { - console.error('❌ 音頻播放錯誤:', e); - isPlaying = false; - URL.revokeObjectURL(audioUrl); - }; +/** + * 非同步處理佇列,確保句子按順序播放 + */ +async function _runTtsQueue() { + if (_streamTtsProcessing) return; + _streamTtsProcessing = true; - try { - const playPromise = currentAudio.play(); + while (_streamTtsQueue.length > 0 && !_streamTtsStopped) { + const sentence = _streamTtsQueue.shift(); + const played = await _synthesizeAndPlay(sentence); + if (played) { + _streamTtsPlayedCount += 1; + } + } - if (playPromise !== undefined) { - await playPromise; - } - } catch (playError) { - if (playError.name === 'NotAllowedError') { - console.warn('⚠️ 自動播放被阻止(瀏覽器政策)'); - console.warn('💡 解決方案:等待用戶下次點擊任意處播放'); + _streamTtsProcessing = false; - isPlaying = false; + if (_streamTtsFinalText && !_streamTtsFallbackUsed && _streamTtsQueuedCount > 0 && _streamTtsPlayedCount === 0 && !isPlaying) { + _streamTtsFallbackUsed = true; + await speakText(_streamTtsFinalText); + return; + } - const playOnUserClick = async (e) => { - if (e.target.closest('.mic-button') || e.target.closest('button')) { - return; - } + maybeFinalizeSpeechPlayback(); +} - try { - await currentAudio.play(); - isPlaying = true; - document.removeEventListener('click', playOnUserClick); - } catch (retryError) { - console.error('❌ 仍然無法播放:', retryError); - URL.revokeObjectURL(audioUrl); - } - }; +/** + * 在 bot_delta 串流過程中呼叫,傳入目前全部累積的文字。 + * 偵測到新的完整句子後立即加入 TTS 佇列。 + */ +function enqueueStreamingTTS(fullText) { + if (_streamTtsStopped || !fullText) return; - document.addEventListener('click', playOnUserClick, { once: false }); - setTimeout(() => { - document.removeEventListener('click', playOnUserClick); - if (!isPlaying) { - URL.revokeObjectURL(audioUrl); - } - }, 5000); + // 只處理新增的部分 + const newPart = fullText.slice(_streamTtsProcessedLen); + if (!newPart) return; - } else { - console.error('❌ 音頻播放失敗:', playError); - isPlaying = false; - URL.revokeObjectURL(audioUrl); - throw playError; - } + const { sentences, nextIdx } = _extractCompleteSentences(newPart); + + for (const sentence of sentences) { + if (sentence.trim()) { + _streamTtsQueue.push(sentence); + _streamTtsQueuedCount += 1; + } + } + _streamTtsProcessedLen += nextIdx; + + // 啟動佇列處理(若尚未啟動) + if (sentences.length > 0) { + _runTtsQueue(); + } +} + +/** + * 在 bot_message(串流結束)時呼叫,處理最後一段未以標點結尾的文字。 + */ +function finalizeStreamingTTS(fullText) { + if (!fullText) return; + _streamTtsFinalText = fullText; + + const remaining = fullText.slice(_streamTtsProcessedLen).trim(); + if (remaining) { + _streamTtsQueue.push(remaining); + _streamTtsQueuedCount += 1; + _streamTtsProcessedLen = fullText.length; + _runTtsQueue(); + } +} + +/** + * 重置串流 TTS 狀態(每次新的對話開始時呼叫) + */ +function resetStreamingTTS() { + _streamTtsQueue = []; + _streamTtsProcessedLen = 0; + _streamTtsStopped = false; + _streamTtsQueuedCount = 0; + _streamTtsPlayedCount = 0; + _streamTtsFinalText = ''; + _streamTtsFallbackUsed = false; + // 不立即停止正在播放的音訊,讓它自然結束 +} + +// ============================================================ +// 舊有的 speakText(完整文字合成,保留供非串流場景使用) +// ============================================================ +async function speakText(text) { + stopSpeaking(); + + const cleaned = cleanTextForTTS(normalizeTextForChirp(text)); + if (!cleaned) return; + + try { + const streamed = await playStreamingTTS(cleaned); + if (streamed) { + return; } + const response = await fetch('/api/tts', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${localStorage.getItem('jwt_token')}` + }, + body: JSON.stringify({ + text: cleaned, + voice: 'nova', + speed: 1.0, + language: getTtsLanguage(), + persona: getTtsPersona(), + speaking_rate: getTtsSpeakingRate() + }) + }); + + if (!response.ok) { + const error = await response.json(); + console.error('TTS API 錯誤:', error); + return; + } + + const audioBlob = await response.blob(); + const audioUrl = URL.createObjectURL(audioBlob); + await playAudioUrl(audioUrl); } catch (error) { - console.error('❌ TTS 請求失敗:', error); + console.error('TTS 請求失敗:', error); isPlaying = false; } } -function stopSpeaking() { - if (currentAudio && isPlaying) { +function stopSpeaking(clearPending = true, reason = 'unspecified') { + logTtsDebug('stop_speaking', { clearPending, reason }); + window.agentOutputAwaitingSpeechCompletion = false; + _streamTtsStopped = true; + _streamTtsQueue = []; + + // 停止所有正在播放或排程中的串流來源 + if (_ttsActiveSources && _ttsActiveSources.length > 0) { + _ttsActiveSources.forEach(source => { + try { source.stop(); } catch (_) {} + }); + _ttsActiveSources = []; + } + + if (ttsStreamSocket) { + try { ttsStreamSocket.close(); } catch (_) {} + ttsStreamSocket = null; + } + ttsStreamNextStartAt = 0; + + if (currentAudio) { currentAudio.pause(); currentAudio.currentTime = 0; - isPlaying = false; + currentAudio = null; + } + isPlaying = false; + + if (clearPending && pendingAudioUrl) { + const urlToRevoke = pendingAudioUrl; + setTimeout(() => URL.revokeObjectURL(urlToRevoke), 1000); + pendingAudioUrl = null; } } + +installAudioUnlockListeners(); diff --git a/static/frontend/js/ui.js b/static/frontend/js/ui.js index ac1d69e23b3bc43846e464b85629881a92db7ecf..6add9b07daa42e8b144a1b9d99788df71abae14d 100644 --- a/static/frontend/js/ui.js +++ b/static/frontend/js/ui.js @@ -1,37 +1,308 @@ +let typewriterTextNode = null; +let lastRenderedAgentText = ''; +let typewriterFrameId = null; +window.agentOutputAwaitingSpeechCompletion = false; +window.agentOutputHasAnswerStream = false; +window.typewriterState = { + text: '', + position: 0, + isActive: false, + frameId: null +}; + +/** + * 取消當前的打字機動畫 + */ +function cancelTypewriterAnimation() { + if (window.typewriterState.frameId) { + cancelAnimationFrame(window.typewriterState.frameId); + window.typewriterState.frameId = null; + } + typewriterFrameId = null; + window.typewriterState.isActive = false; +} + +function setTypewriterContent(text) { + const nextText = String(text || ''); + if (lastRenderedAgentText === nextText) { + return; + } + lastRenderedAgentText = nextText; + renderAgentMarkdown(nextText); +} + + + +function typewriterEffect(text, speed = 40, enableTTS = true, options = {}) { + const sourceText = String(text || '').trim(); + + // 如果新文字是舊文字的延續,且正在打字中,我們不重頭開始,而是更新目標 + if (window.typewriterState.isActive && sourceText.startsWith(window.typewriterState.text)) { + window.typewriterState.text = sourceText; + return; + } -function typewriterEffect(text, speed = 50, enableTTS = true) { - agentOutput.textContent = ''; + cancelTypewriterAnimation(); + + window.typewriterState = { + text: sourceText, + position: 0, + isActive: true, + frameId: null + }; + const awaitSpeechCompletion = options.awaitSpeechCompletion === true || !!enableTTS; + window.agentOutputAwaitingSpeechCompletion = awaitSpeechCompletion; + + setAgentOutputMode('final'); + if (!agentOutput.dataset.temporary) { + agentOutput.dataset.temporary = 'false'; + } agentOutput.classList.add('active'); + agentOutput.classList.add('typing-active'); agentOutput.classList.remove('typing-done'); - let index = 0; + if (enableTTS && typeof speakText === 'function') { + speakText(sourceText); + } + + setState('speaking'); + + const charsPerMs = speed > 0 ? 1 / speed : 2; + let startTime = null; + + function step(timestamp) { + if (startTime === null) startTime = timestamp; + const elapsed = timestamp - startTime; + + // 計算理論上應該打到哪裡 + let nextPos = Math.floor(elapsed * charsPerMs); + + // 如果落後目標太多,僅進行溫和追趕,避免瞬間跳轉 + const lag = window.typewriterState.text.length - nextPos; + if (lag > 20) { + // 每幀額外追趕一些字元,而不是直接跳到底 + startTime -= (lag * 0.3) * (1 / charsPerMs); + nextPos = Math.floor((timestamp - startTime) * charsPerMs); + } + + window.typewriterState.position = Math.min(window.typewriterState.text.length, nextPos); + + setTypewriterContent(window.typewriterState.text.slice(0, window.typewriterState.position)); + agentOutput.scrollTop = agentOutput.scrollHeight; + + if (window.typewriterState.position < window.typewriterState.text.length) { + window.typewriterState.frameId = requestAnimationFrame(step); + typewriterFrameId = window.typewriterState.frameId; + return; + } + + // 結束打字 + completeTypewriter(); + } + + window.typewriterState.frameId = requestAnimationFrame(step); + typewriterFrameId = window.typewriterState.frameId; +} + +function completeTypewriter() { + window.typewriterState.isActive = false; + window.typewriterState.frameId = null; + typewriterFrameId = null; + agentOutput.classList.remove('typing-active'); + agentOutput.classList.add('typing-done'); + setTypewriterContent(window.typewriterState.text); + agentOutput.scrollTop = agentOutput.scrollHeight; - if (typingInterval) { - clearInterval(typingInterval); + if (agentOutput.dataset.temporary === 'true') { + return; } - if (enableTTS && typeof speakText === 'function') { - speakText(text); // 語音與打字效果並行 + if (typeof hasPendingStreamingSpeech === 'function' && hasPendingStreamingSpeech()) { + window.agentOutputAwaitingSpeechCompletion = true; + return; + } + + if (window.agentOutputAwaitingSpeechCompletion) { + return; + } + + setState('idle', {clearCards: false}); +} + + + +const agentProgressSteps = []; + +function renderAgentProgressStep(text) { + const safeText = text || '正在處理...'; + const nextIndex = agentProgressSteps.length + 1; + agentProgressSteps.push(safeText); + if (agentProgressSteps.length > 4) { + agentProgressSteps.shift(); } + const rows = agentProgressSteps.map((step, index) => { + const isLatest = index === agentProgressSteps.length - 1; + const className = isLatest ? 'agent-progress-row current' : 'agent-progress-row'; + return `
${escapeHtml(step)}
`; + }); + return `
${rows.join('')}
`; +} - typingInterval = setInterval(() => { - if (index < text.length) { - agentOutput.textContent += text[index]; - index++; +function setAgentOutputMode(mode) { + agentOutput.classList.remove('progress-mode', 'output-mode-processing', 'output-mode-streaming', 'output-mode-final'); + if (mode === 'processing') { + agentOutput.classList.add('progress-mode', 'output-mode-processing'); + } else if (mode === 'streaming') { + agentOutput.classList.add('output-mode-streaming'); + } else if (mode === 'final') { + agentOutput.classList.add('output-mode-final'); + } +} + +function setAgentOutputContent(html) { + let shell = agentOutput.querySelector('.agent-output-shell'); + if (!shell) { + agentOutput.innerHTML = '
'; + shell = agentOutput.querySelector('.agent-output-shell'); + } + const contentNode = shell.querySelector('.agent-output-content'); + + // 在打字機模式下,我們加上一個光標元素 + const isTyping = agentOutput.classList.contains('typing-active'); + const cursorHtml = isTyping ? '' : ''; + + contentNode.innerHTML = html + cursorHtml; + typewriterTextNode = null; + agentOutput.scrollTop = agentOutput.scrollHeight; +} + +function escapeHtml(value) { + return String(value || '') + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); +} + +function renderInlineMarkdown(text) { + return escapeHtml(text) + .replace(/`([^`]+)`/g, '$1') + .replace(/\*\*([^*]+)\*\*/g, '$1') + .replace(/__([^_]+)__/g, '$1') + .replace(/\*([^*]+)\*/g, '$1') + .replace(/_([^_]+)_/g, '$1') + .replace(/\[([^\]]+)\]\((https?:\/\/[^)\s]+)\)/g, '$1') + .replace(/(https?:\/\/[^\s<]+)/g, '$1'); +} + +function renderAgentMarkdown(markdown) { + const text = String(markdown || '').replace(/\r\n/g, '\n'); + const blocks = []; + const codePattern = /```([\s\S]*?)```/g; + let cursor = 0; + let match; + + while ((match = codePattern.exec(text)) !== null) { + blocks.push({type: 'markdown', value: text.slice(cursor, match.index)}); + blocks.push({type: 'code', value: match[1].replace(/^\w+\n/, '')}); + cursor = match.index + match[0].length; + } + blocks.push({type: 'markdown', value: text.slice(cursor)}); + + const rendered = blocks.map(block => { + if (block.type === 'code') { + return `
${escapeHtml(block.value.trim())}
`; + } + return renderMarkdownBlock(block.value); + }).join(''); + setAgentOutputContent(rendered); +} + +function renderMarkdownBlock(value) { + const lines = String(value || '').split('\n'); + const html = []; + let listItems = []; + + function flushList() { + if (!listItems.length) return; + html.push(`
    ${listItems.map(item => `
  • ${renderInlineMarkdown(item)}
  • `).join('')}
`); + listItems = []; + } + + lines.forEach(line => { + const trimmed = line.trim(); + if (!trimmed) { + flushList(); + return; + } + const heading = trimmed.match(/^(#{1,3})\s+(.+)$/); + if (heading) { + flushList(); + const level = heading[1].length; + html.push(`${renderInlineMarkdown(heading[2])}`); + return; + } + const bullet = trimmed.match(/^([-*]|\d+\.)\s+(.+)$/); + if (bullet) { + listItems.push(bullet[2]); + return; + } + flushList(); + html.push(`

${renderInlineMarkdown(trimmed)}

`); + }); + flushList(); + return html.join(''); +} + +function editAgentOutput(text, temporary = false, options = {}) { + const nextText = String(text || '').trim(); + + if (temporary) { + if (options.progress === true && window.agentOutputHasAnswerStream) { + return; + } + if (options.progress !== true) { + window.agentOutputHasAnswerStream = true; + } + agentOutput.dataset.temporary = 'true'; + // 串流模式:使用增量打字效果 + typewriterEffect(nextText, 30, false); + + if (options.progress === true) { + setAgentOutputMode('processing'); } else { - clearInterval(typingInterval); - agentOutput.classList.add('typing-done'); // 打字完成,隱藏游標 + setAgentOutputMode('streaming'); } - }, speed); + } else { + // 非串流模式(例如狀態更新):直接更新文字 + cancelTypewriterAnimation(); + setTypewriterContent(nextText); + agentOutput.dataset.temporary = 'false'; + agentOutput.classList.add('active'); + setAgentOutputMode('final'); + } +} + + +function finishAgentOutput(text, enableTTS = true, options = {}) { + // 結束時不需要重啟打字機,只需確保目標文字是最新的,並完成剩餘部分 + window.agentOutputHasAnswerStream = false; + agentOutput.dataset.temporary = 'false'; + typewriterEffect(text || '', 24, enableTTS, options); } function hideAgentOutput() { - if (typingInterval) { - clearInterval(typingInterval); - } + cancelTypewriterAnimation(); + window.agentOutputHasAnswerStream = false; agentOutput.classList.remove('active'); - agentOutput.textContent = ''; + agentOutput.classList.remove('typing-active'); + agentOutput.classList.add('typing-done'); + setAgentOutputMode(null); + agentOutput.innerHTML = ''; + typewriterTextNode = null; + lastRenderedAgentText = ''; } @@ -81,7 +352,7 @@ function handleLogout() { } if (typeof stopSpeaking === 'function') { - stopSpeaking(); + stopSpeaking(true, 'logout'); } @@ -158,6 +429,14 @@ function handleTextInput(event) { setState('thinking'); } + // 確保重置錄音與波形,允許接著交互 + if (typeof stopRealAudioAnalysis === 'function') { + stopRealAudioAnalysis(); + } + if (typeof isRecording !== 'undefined') { + isRecording = false; + } + toggleInputMode(); } else { console.error('❌ WebSocket 未初始化'); diff --git a/static/frontend/js/websocket.js b/static/frontend/js/websocket.js index 4bab35c4243fdaac969e87c7a171287cf6ead950..eef3430123df6b8ea0ad3f84327f538ff6764ba9 100644 --- a/static/frontend/js/websocket.js +++ b/static/frontend/js/websocket.js @@ -1,3 +1,24 @@ +window.inferConversationLanguage = function inferConversationLanguage(text, fallback = null) { + const source = String(text || '').trim(); + if (!source) { + return fallback || window.currentConversationLanguage || window.currentSpeechLanguage || navigator.language || 'zh-TW'; + } + + if (/[\u3040-\u30ff]/.test(source)) { + return 'ja-JP'; + } + if (/[\uac00-\ud7af]/.test(source)) { + return 'ko-KR'; + } + if (/[\u0e00-\u0e7f]/.test(source)) { + return 'th-TH'; + } + if (/[A-Za-z]/.test(source) && !/[\u3400-\u9fff]/.test(source)) { + return 'en-US'; + } + + return fallback || window.currentConversationLanguage || window.currentSpeechLanguage || navigator.language || 'zh-TW'; +}; class WebSocketManager { @@ -16,7 +37,10 @@ class WebSocketManager { this.audioStream = null; this.audioProcessor = null; this.audioSource = null; + this.audioWorkletModuleUrl = null; this.isRecording = false; + this.localPreviewFinal = ''; + this.localPreviewInterim = ''; this.connect = this.connect.bind(this); this.handleOpen = this.handleOpen.bind(this); @@ -25,6 +49,8 @@ class WebSocketManager { this.handleError = this.handleError.bind(this); this.startRecording = this.startRecording.bind(this); this.stopRecording = this.stopRecording.bind(this); + this.cleanupAudioResources = this.cleanupAudioResources.bind(this); + this.ensurePCMRecorderWorklet = this.ensurePCMRecorderWorklet.bind(this); } async connect() { @@ -201,10 +227,15 @@ class WebSocketManager { return false; } + const inferredLanguage = window.inferConversationLanguage(text, navigator.language || 'zh-TW'); + window.currentConversationLanguage = inferredLanguage; + window.currentSpeechLanguage = inferredLanguage; + const payload = { type: 'user_message', message: text, - chat_id: chatId + chat_id: chatId, + language: inferredLanguage }; return this.send(payload); @@ -246,6 +277,111 @@ class WebSocketManager { }); } + updateTranscriptPreview(text, className = 'provisional') { + const transcript = document.getElementById('transcript'); + if (!transcript || !text) return; + transcript.textContent = text; + transcript.className = `voice-transcript ${className}`; + } + + startLocalTranscriptPreview() { + const recognizer = window.speechRecognition; + this.localPreviewFinal = ''; + this.localPreviewInterim = ''; + + if (!recognizer || !recognizer.isSupported) { + this.updateTranscriptPreview('聆聽中...', 'provisional'); + return; + } + + recognizer.onResult = (finalText, interimText) => { + if (finalText) { + this.localPreviewFinal += finalText; + } + this.localPreviewInterim = interimText || ''; + + const preview = `${this.localPreviewFinal}${this.localPreviewInterim}`.trim(); + if (preview) { + this.updateTranscriptPreview(preview, this.localPreviewInterim ? 'realtime' : 'provisional'); + } + }; + recognizer.onError = () => { + this.updateTranscriptPreview('持續接收語音中...', 'provisional'); + }; + recognizer.onEnd = () => {}; + recognizer.start(); + } + + stopLocalTranscriptPreview() { + const recognizer = window.speechRecognition; + if (recognizer && recognizer.isSupported) { + recognizer.stop(); + } + } + + cleanupAudioResources() { + if (this.audioProcessor) { + try { + this.audioProcessor.port.onmessage = null; + this.audioProcessor.disconnect(); + } catch (e) { + console.warn('⚠️ 斷開音訊處理器失敗:', e); + } + this.audioProcessor = null; + } + + if (this.audioSource) { + try { + this.audioSource.disconnect(); + } catch (e) { + console.warn('⚠️ 斷開音訊源失敗:', e); + } + this.audioSource = null; + } + + if (this.audioStream) { + this.audioStream.getTracks().forEach(track => track.stop()); + this.audioStream = null; + } + + if (this.audioContext) { + this.audioContext.close().catch(error => { + console.warn('⚠️ 關閉 AudioContext 失敗:', error); + }); + this.audioContext = null; + } + } + + async ensurePCMRecorderWorklet() { + if (this.audioWorkletModuleUrl) { + return this.audioWorkletModuleUrl; + } + + const workletSource = ` +class PCMRecorderProcessor extends AudioWorkletProcessor { + process(inputs) { + const channelData = inputs[0] && inputs[0][0]; + if (channelData && channelData.length) { + const pcm16 = new Int16Array(channelData.length); + for (let i = 0; i < channelData.length; i++) { + const sample = Math.max(-1, Math.min(1, channelData[i])); + pcm16[i] = sample < 0 ? sample * 0x8000 : sample * 0x7fff; + } + this.port.postMessage(pcm16.buffer, [pcm16.buffer]); + } + return true; + } +} + + registerProcessor('pcm-recorder-processor', PCMRecorderProcessor); + `.trim(); + + this.audioWorkletModuleUrl = URL.createObjectURL( + new Blob([workletSource], { type: 'application/javascript' }) + ); + return this.audioWorkletModuleUrl; + } + async startRecording() { if (this.isRecording) { @@ -277,12 +413,18 @@ class WebSocketManager { this.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 }); + const workletModuleUrl = await this.ensurePCMRecorderWorklet(); + await this.audioContext.audioWorklet.addModule(workletModuleUrl); + await this.audioContext.resume(); this.audioSource = this.audioContext.createMediaStreamSource(this.audioStream); - this.audioProcessor = this.audioContext.createScriptProcessor(4096, 1, 1); + this.audioProcessor = new AudioWorkletNode(this.audioContext, 'pcm-recorder-processor', { + numberOfInputs: 1, + numberOfOutputs: 0, + channelCount: 1 + }); this.audioSource.connect(this.audioProcessor); - this.audioProcessor.connect(this.audioContext.destination); this.send({ type: 'audio_start', @@ -290,23 +432,25 @@ class WebSocketManager { mode: 'realtime_chat', // 即時轉錄模式(使用 OpenAI Realtime API) language: 'auto' // 自動檢測語言(支援:zh/en/id/ja/vi) }); + window.currentSpeechLanguage = 'auto'; this.isRecording = true; + window.realtimeTranscript = ''; + // 重置串流 TTS 狀態(開始新對話) + if (typeof resetStreamingTTS === 'function') { + resetStreamingTTS(); + } + if (typeof stopSpeaking === 'function') { + stopSpeaking(); + } + this.startLocalTranscriptPreview(); - this.audioProcessor.onaudioprocess = (e) => { + this.audioProcessor.port.onmessage = (event) => { if (!this.isRecording) return; try { - const inputData = e.inputBuffer.getChannelData(0); - - const pcm16 = new Int16Array(inputData.length); - for (let i = 0; i < inputData.length; i++) { - let sample = Math.max(-1, Math.min(1, inputData[i])); - pcm16[i] = sample < 0 ? sample * 0x8000 : sample * 0x7FFF; - } - - const bytes = new Uint8Array(pcm16.buffer); + const bytes = new Uint8Array(event.data); const b64 = btoa(String.fromCharCode(...bytes)); this.send({ @@ -330,6 +474,7 @@ class WebSocketManager { } } + this.cleanupAudioResources(); this.isRecording = false; return false; } @@ -341,36 +486,15 @@ class WebSocketManager { return; } - - if (this.audioProcessor) { - this.audioProcessor.disconnect(); - this.audioProcessor = null; - } - - if (this.audioSource) { - try { - this.audioSource.disconnect(); - } catch (e) { - console.warn('⚠️ 斷開音訊源失敗:', e); - } - this.audioSource = null; - } - - if (this.audioStream) { - this.audioStream.getTracks().forEach(track => track.stop()); - this.audioStream = null; - } - - if (this.audioContext) { - this.audioContext.close(); - this.audioContext = null; - } + this.cleanupAudioResources(); this.send({ type: 'audio_stop', mode: 'realtime_chat' // 即時轉錄模式 }); + this.stopLocalTranscriptPreview(); + this.updateTranscriptPreview('轉錄中...', 'provisional'); this.isRecording = false; } } @@ -410,6 +534,12 @@ function initializeWebSocket(token) { case 'typing': if (data.message === 'thinking') { + window.currentConversationLanguage = null; + window.currentSpeechLanguage = null; + // 重置串流 TTS(進入思考狀態時停止/重置) + if (typeof resetStreamingTTS === 'function') { + resetStreamingTTS(); + } setState('thinking'); if (typeof hideToolCards === 'function') { hideToolCards(); @@ -417,20 +547,74 @@ function initializeWebSocket(token) { } break; + case 'bot_delta': + if (data.language) { + window.currentConversationLanguage = data.language; + window.currentSpeechLanguage = data.language; + } + if (typeof currentState !== 'undefined' && currentState !== 'speaking') { + setState('speaking'); + } else { + micContainer.classList.add('speaking'); + } + if (typeof editAgentOutput === 'function') { + editAgentOutput(data.text || '', true); + } + // 即時句子串流 TTS:逐句合成播放,不等全文 + if (typeof enqueueStreamingTTS === 'function') { + enqueueStreamingTTS(data.text || ''); + } + break; + + case 'bot_status': + if (typeof currentState !== 'undefined' && currentState === 'speaking') { + break; + } + micContainer.classList.add('thinking'); + if (typeof editAgentOutput === 'function' && data.message) { + editAgentOutput(data.message, true, {progress: true}); + } + break; + case 'bot_message': - // 【統一】不在此處套用情緒,只由 emotion_detected 事件控制 - // 保留情緒資訊在 data 中供調試使用 + if (data.language) { + window.currentConversationLanguage = data.language; + window.currentSpeechLanguage = data.language; + } + // 完成串流 TTS(說出最後一段未以標點結尾的殘餘文字) + if (typeof finalizeStreamingTTS === 'function') { + finalizeStreamingTTS(data.message || ''); + } + + const useFullTtsFallback = typeof shouldFallbackToFullTTS === 'function' + ? shouldFallbackToFullTTS() + : false; + const awaitSpeechCompletion = typeof hasPendingStreamingSpeech === 'function' + ? hasPendingStreamingSpeech() || useFullTtsFallback + : useFullTtsFallback; if (data.care_mode && typeof hideToolCards === 'function') { hideToolCards(); } - setState('speaking', { - outputText: data.message, - enableTTS: true - }); + if (typeof finishAgentOutput === 'function') { + finishAgentOutput(data.message, useFullTtsFallback, { + awaitSpeechCompletion, + }); + } else { + setState('speaking', { + outputText: data.message, + enableTTS: useFullTtsFallback + }); + } - if (data.tool_name && data.tool_data) { + if (data.executed_tools && data.executed_tools.length > 0) { + if (typeof displayMultipleToolCards === 'function') { + displayMultipleToolCards(data.executed_tools); + } else if (data.tool_name && data.tool_data) { + displayToolCard(data.tool_name, data.tool_data); + } + } else if (data.tool_name && data.tool_data) { displayToolCard(data.tool_name, data.tool_data); } break; @@ -441,24 +625,55 @@ function initializeWebSocket(token) { break; case 'stt_delta': - if (!window.realtimeTranscript) { - window.realtimeTranscript = ''; - } - window.realtimeTranscript += data.text; - transcript.textContent = window.realtimeTranscript; + transcript.textContent = data.text; transcript.className = 'voice-transcript realtime'; break; case 'stt_final': + if (data.text) { + const inferredLanguage = window.inferConversationLanguage(data.text, window.currentSpeechLanguage || navigator.language || 'zh-TW'); + window.currentConversationLanguage = inferredLanguage; + window.currentSpeechLanguage = inferredLanguage; + } transcript.textContent = data.text; transcript.className = 'voice-transcript final'; window.realtimeTranscript = ''; // 【統一】不在此處套用情緒,只由 emotion_detected 事件控制 break; + case 'stt_status': + if (data.status === 'speech_started') { + if (!transcript.textContent || transcript.textContent === '請說話...') { + transcript.textContent = '聆聽中...'; + } + transcript.className = 'voice-transcript realtime'; + } else if (data.status === 'speech_stopped') { + transcript.textContent = transcript.textContent || '語音結束,整理轉錄中...'; + transcript.className = 'voice-transcript provisional'; + } else if (data.status === 'transcribing' || String(data.status || '').startsWith('committed')) { + if (!window.realtimeTranscript) { + transcript.textContent = '轉錄中...'; + } + transcript.className = 'voice-transcript provisional'; + } else if (data.status === 'idle' || data.status === 'error') { + if (typeof currentState !== 'undefined' && currentState === 'speaking') { + break; + } + if (typeof resetAgent === 'function') { + resetAgent(); + } + } + break; + case 'realtime_stt_status': if (data.status === 'connected') { window.realtimeTranscript = ''; + if (data.language) { + window.currentSpeechLanguage = data.language; + window.currentConversationLanguage = data.language; + } + transcript.textContent = '聆聽中...'; + transcript.className = 'voice-transcript provisional'; } break; @@ -468,7 +683,11 @@ function initializeWebSocket(token) { case 'error': console.error('❌ 後端錯誤:', data.message); - setState('idle'); + if (typeof resetAgent === 'function') { + resetAgent(); + } else { + setState('idle'); + } showErrorNotification(data.message); break; @@ -481,7 +700,7 @@ function initializeWebSocket(token) { case 'emotion_detected': if (data.emotion && typeof applyEmotion === 'function') { - applyEmotion(data.emotion); + applyEmotion(data.emotion, data.care_mode); } if (data.care_mode && typeof hideToolCards === 'function') { hideToolCards(); @@ -534,4 +753,3 @@ function handleVoiceLoginResult(data) { showErrorNotification(`語音登入失敗: ${data.error || '未知錯誤'}`); } } - diff --git a/tests/live_test_stt_tts.py b/tests/live_test_stt_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..14f7dba362ea2e374d7de940e5b4723d7dbc8eb3 --- /dev/null +++ b/tests/live_test_stt_tts.py @@ -0,0 +1,104 @@ +import asyncio +import os +import sys +import logging +import base64 + +# Add the current directory to sys.path to import local modules +sys.path.append(os.getcwd()) + +from services.tts_service import TTSService +from services.realtime_stt_service import RealtimeSTTService +from core.config import settings + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("live_test") + +async def test_tts(): + logger.info("=== Testing TTS (Google Text-to-Speech) ===") + service = TTSService() + + test_text = "你好,這是一個來自 Bloom Ware 的測試語音。" + logger.info(f"Synthesizing text: {test_text}") + + result = await service.synthesize(test_text, voice="coral", response_format="wav") + + if result["success"]: + logger.info("TTS Success!") + audio_data = result["audio_data"] + logger.info(f"Received {len(audio_data)} bytes of audio data.") + # Save for reference + with open("tests/test_output.wav", "wb") as f: + f.write(audio_data) + logger.info("Saved audio to tests/test_output.wav") + return audio_data + else: + logger.error(f"TTS Failed: {result.get('error')}") + return None + +async def test_stt(audio_data): + if not audio_data: + logger.error("No audio data to test STT.") + return + + logger.info("\n=== Testing STT (Google Speech-to-Text v2) ===") + service = RealtimeSTTService() + + # Strip WAV header (44 bytes) to get raw PCM + # Note: This assumes the header is 44 bytes and the audio is LINEAR16 16k mono + raw_pcm = audio_data[44:] if len(audio_data) > 44 else audio_data + + transcript_parts = [] + + def on_delta(text): + logger.info(f"STT Delta: {text}") + transcript_parts.append(text) + + def on_done(text): + logger.info(f"STT Done: {text}") + + logger.info("Connecting to STT service...") + success = await service.connect( + on_transcript_delta=on_delta, + on_transcript_done=on_done, + language="zh-TW", + model="short", + sample_rate=24000 + ) + + if not success: + logger.error("STT Connection Failed. Check your service account credentials.") + return + + logger.info("Sending audio chunks...") + # Send in small chunks + chunk_size = 4096 + for i in range(0, len(raw_pcm), chunk_size): + chunk = raw_pcm[i:i+chunk_size] + await service.send_audio_chunk(chunk) + await asyncio.sleep(0.1) # Simulate real-time + + logger.info("Finalizing STT...") + final_text = await service.wait_for_final_transcript(timeout=5.0) + + if final_text: + logger.info(f"STT Result: {final_text}") + else: + logger.warning("STT returned no result.") + + await service.disconnect() + +async def main(): + # Verify environment variables + if not settings.GOOGLE_TTS_API_KEY: + logger.error("GOOGLE_TTS_API_KEY is missing!") + if not settings.GOOGLE_SPEECH_PROJECT_ID: + logger.error("GOOGLE_SPEECH_PROJECT_ID is missing!") + + audio = await test_tts() + if audio: + await test_stt(audio) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_agent_bridge_multi_tool.py b/tests/test_agent_bridge_multi_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf8c2b8b669451cb20d5a1b0541de7486ed9985 --- /dev/null +++ b/tests/test_agent_bridge_multi_tool.py @@ -0,0 +1,46 @@ +import pytest + +from features.mcp.agent_bridge import MCPAgentBridge +from features.mcp.tool_models import ToolResult + + +class RecordingCoordinator: + def __init__(self): + self.calls = [] + + async def invoke(self, tool_name, arguments, *, user_id, original_message): + self.calls.append((tool_name, arguments, user_id, original_message)) + return ToolResult( + name=tool_name, + message=f"{tool_name} ok", + data={"arguments": arguments}, + ) + + +@pytest.mark.asyncio +async def test_process_intent_executes_multiple_tool_calls_in_order(): + bridge = MCPAgentBridge.__new__(MCPAgentBridge) + coordinator = RecordingCoordinator() + bridge._tool_coordinator = coordinator + + result = await bridge.process_intent( + { + "type": "mcp_tool", + "tool_calls": [ + {"tool_name": "weather_query", "arguments": {"city": "Taipei"}}, + {"tool_name": "exchange_query", "arguments": {"from_currency": "USD", "to_currency": "TWD"}}, + ], + }, + user_id="u1", + original_message="台北天氣和美元匯率", + chat_id="c1", + ) + + assert coordinator.calls == [ + ("weather_query", {"city": "Taipei"}, "u1", "台北天氣和美元匯率"), + ("exchange_query", {"from_currency": "USD", "to_currency": "TWD"}, "u1", "台北天氣和美元匯率"), + ] + assert result["tool_name"] == "multi_tool" + assert result["tool_data"]["tool_names"] == ["weather_query", "exchange_query"] + assert "weather_query ok" in result["message"] + assert "exchange_query ok" in result["message"] diff --git a/tests/test_agent_skills_context.py b/tests/test_agent_skills_context.py new file mode 100644 index 0000000000000000000000000000000000000000..a083a2efa9a6d3fa042f75534856a1b0b449e229 --- /dev/null +++ b/tests/test_agent_skills_context.py @@ -0,0 +1,46 @@ +from services import ai_service +from features.mcp import agent_bridge + + +class _SkillsSettings: + OPENAI_ENABLE_SKILLS = True + + +def test_chat_prompt_includes_mcp_skills_context(monkeypatch): + class Settings: + OPENAI_ENABLE_SKILLS = True + + monkeypatch.setattr(ai_service, "settings", Settings) + + messages = ai_service._compose_messages_with_context( + base_prompt="base", + history_entries=[], + memory_context="", + env_context="", + time_context="", + emotion_context="", + current_request="台北天氣如何", + user_id="u1", + chat_id="c1", + use_care_mode=False, + care_emotion=None, + ) + + system_prompt = messages[0]["content"] + assert "【MCP工具技能索引】" in system_prompt + assert "weather_query" in system_prompt + assert "call_via=local_function_calling_tool_schema" in system_prompt + + +def test_function_calling_prompt_reads_skills_before_tool_selection(monkeypatch): + monkeypatch.setattr(agent_bridge, "settings", _SkillsSettings) + + bridge = agent_bridge.MCPAgentBridge.__new__(agent_bridge.MCPAgentBridge) + prompt = bridge._build_function_calling_prompt() + + skills_index_position = prompt.index("【MCP工具技能索引】") + rules_position = prompt.index("Rules:") + + assert skills_index_position < rules_position + assert "weather_query" in prompt + assert "call_via=local_function_calling_tool_schema" in prompt diff --git a/tests/test_agent_web_search_skill.py b/tests/test_agent_web_search_skill.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e6171c27ecea74876e9a38dcc67ba50f5510e1 --- /dev/null +++ b/tests/test_agent_web_search_skill.py @@ -0,0 +1,52 @@ +from pathlib import Path + +from features.mcp.skills import skills_prompt_block +from services.ai_service import _build_base_system_prompt, _compose_messages_with_context + + +def test_base_prompt_has_no_domain_specific_market_hardcoding(): + prompt = _build_base_system_prompt( + use_care_mode=False, + care_emotion=None, + user_name=None, + ) + + forbidden = ["台積電", "2330", "台股", "ADR", "尚未開盤", "上一交易日"] + assert not any(term in prompt for term in forbidden) + + +def test_compose_messages_uses_generic_time_sensitive_rule(): + messages = _compose_messages_with_context( + base_prompt="base", + history_entries=[], + memory_context="", + env_context="timezone: Asia/Taipei", + time_context="當地時間: 2026-05-14 08:41(星期四,上午)\n時區: Asia/Taipei", + emotion_context="", + current_request="今天某公司股價多少?", + user_id="u1", + chat_id="c1", + use_care_mode=False, + care_emotion=None, + ) + + system_prompt = messages[0]["content"] + assert "時間訊號" in system_prompt + assert "環境訊號" in system_prompt + assert "自行判斷" in system_prompt + assert "來源時間早於用戶要求的時間範圍" in system_prompt + assert "台積電" not in system_prompt + assert "2330" not in system_prompt + + +def test_web_search_skill_is_generic_and_registered(): + skill_text = Path("features/mcp/skills/web_search/SKILL.md").read_text(encoding="utf-8") + prompt = skills_prompt_block() + + assert "avoid_domain_specific_hardcoding: true" in skill_text + assert "Do not hardcode domain-specific behavior" in skill_text + assert "do not present it as today's/current result" in skill_text + assert "web_search" in prompt + assert "responses_hosted_tool_auto" in prompt + assert "台積電" not in skill_text + assert "2330" not in skill_text diff --git a/tests/test_ai_client.py b/tests/test_ai_client.py index 8f73aa708e9c539ceadb91aad6ba814b2f3a46c8..33476e1aea6ccd68854e53c3e488d23cfc44a392 100644 --- a/tests/test_ai_client.py +++ b/tests/test_ai_client.py @@ -16,6 +16,7 @@ class TestAIClient: with patch.object(ai_client, 'settings') as mock_settings: mock_settings.OPENAI_API_KEY = "" + mock_settings.OPENAI_BASE_URL = "" mock_settings.OPENAI_TIMEOUT = 30 client = ai_client.get_openai_client() @@ -37,6 +38,75 @@ class TestAIClient: with patch.object(ai_client, 'settings') as mock_settings: mock_settings.OPENAI_API_KEY = "" + mock_settings.OPENAI_BASE_URL = "" mock_settings.OPENAI_TIMEOUT = 30 assert ai_client.is_available() is False + + def test_get_openai_client_passes_base_url(self): + """測試有設定 base_url 時會傳入 OpenAI client""" + from core import ai_client + ai_client.reset_client() + + fake_client = MagicMock() + + with patch.object(ai_client, 'settings') as mock_settings: + mock_settings.OPENAI_API_KEY = "sk-test" + mock_settings.OPENAI_BASE_URL = "https://sub2api.flowatelier.com/v1" + mock_settings.OPENAI_TIMEOUT = 30 + + with patch("openai.OpenAI", return_value=fake_client) as mock_openai: + client = ai_client.get_openai_client() + + assert client is fake_client + mock_openai.assert_called_once_with( + api_key="sk-test", + base_url="https://sub2api.flowatelier.com/v1", + timeout=30.0, + max_retries=3, + ) + + def test_get_openai_client_normalizes_base_url_without_v1(self): + """測試 base_url 可用裸網域設定,client factory 會補 /v1""" + from core import ai_client + ai_client.reset_client() + + fake_client = MagicMock() + + with patch.object(ai_client, 'settings') as mock_settings: + mock_settings.OPENAI_API_KEY = "sk-test" + mock_settings.OPENAI_BASE_URL = "https://sub2api.flowatelier.com" + mock_settings.OPENAI_TIMEOUT = 30 + + with patch("openai.OpenAI", return_value=fake_client) as mock_openai: + client = ai_client.get_openai_client() + + assert client is fake_client + mock_openai.assert_called_once_with( + api_key="sk-test", + base_url="https://sub2api.flowatelier.com/v1", + timeout=30.0, + max_retries=3, + ) + + def test_get_openai_client_omits_base_url_when_unset(self): + """測試未設定 base_url 時沿用 SDK 預設 OpenAI endpoint""" + from core import ai_client + ai_client.reset_client() + + fake_client = MagicMock() + + with patch.object(ai_client, 'settings') as mock_settings: + mock_settings.OPENAI_API_KEY = "sk-test" + mock_settings.OPENAI_BASE_URL = "" + mock_settings.OPENAI_TIMEOUT = 30 + + with patch("openai.OpenAI", return_value=fake_client) as mock_openai: + client = ai_client.get_openai_client() + + assert client is fake_client + mock_openai.assert_called_once_with( + api_key="sk-test", + timeout=30.0, + max_retries=3, + ) diff --git a/tests/test_ai_service_responses_streaming.py b/tests/test_ai_service_responses_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bf61dd4355c4618cc74cac6a14d43f1e44fedc --- /dev/null +++ b/tests/test_ai_service_responses_streaming.py @@ -0,0 +1,238 @@ +import pytest + +from services import ai_service + + +class StreamEvent: + def __init__(self, event_type, delta=None, text=None, item=None): + self.type = event_type + self.delta = delta + self.text = text + self.item = item + + +class StreamItem: + def __init__(self, item_type): + self.type = item_type + + +class Responses: + def __init__(self): + self.payload = None + + def create(self, **kwargs): + self.payload = kwargs + return [ + StreamEvent("response.output_item.added", item=StreamItem("web_search_call")), + StreamEvent("response.output_text.delta", delta="你"), + StreamEvent("response.output_text.delta", delta="好"), + StreamEvent("response.output_text.done", text="你好"), + ] + + +class Client: + def __init__(self): + self.responses = Responses() + self.timeout = None + + def with_options(self, **kwargs): + self.timeout = kwargs.get("timeout") + return self + + +class FailingThenSafeResponses: + def __init__(self): + self.payloads = [] + + def create(self, **kwargs): + self.payloads.append(kwargs) + if len(self.payloads) == 1: + raise RuntimeError("503 Service Unavailable") + + class Response: + output_text = "即時搜尋暫時不可用,無法可靠確認最新股價。請稍後再試。" + + return Response() + + +class FailingThenSafeClient(Client): + def __init__(self): + self.responses = FailingThenSafeResponses() + self.timeout = None + + +class LanguageMismatchResponses: + def __init__(self): + self.payloads = [] + + def create(self, **kwargs): + self.payloads.append(kwargs) + + class Response: + output_text = "我很好,謝謝你。" + + if len(self.payloads) == 1: + return Response() + + class RetryResponse: + output_text = "I am doing well, thank you." + + return RetryResponse() + + +class LanguageMismatchClient(Client): + def __init__(self): + self.responses = LanguageMismatchResponses() + self.timeout = None + + +@pytest.mark.asyncio +async def test_generate_response_async_streams_responses_delta(monkeypatch): + client = Client() + chunks = [] + + class Settings: + OPENAI_MODEL = "gpt-5.4-mini" + OPENAI_RESPONSES_TIMEOUT = 90 + OPENAI_USE_RESPONSES = True + OPENAI_ENABLE_WEB_SEARCH = False + OPENAI_ENABLE_REMOTE_MCP = False + OPENAI_REMOTE_MCP_SERVERS_JSON = "[]" + OPENAI_ENABLE_SKILLS = False + + async def on_chunk(delta): + chunks.append(delta) + + monkeypatch.setattr(ai_service, "settings", Settings) + monkeypatch.setattr(ai_service, "OPENAI_TIMEOUT", 30) + monkeypatch.setattr(ai_service, "OPENAI_RESPONSES_TIMEOUT", 90) + monkeypatch.setattr(ai_service, "_get_client", lambda: client) + monkeypatch.setattr(ai_service, "_default_hosted_tools", lambda: []) + + result = await ai_service.generate_response_async( + [{"role": "user", "content": "hi"}], + model="gpt-5.4-mini", + stream=True, + on_chunk=on_chunk, + ) + + assert result == "你好" + assert chunks == [ + {"type": "status", "status": "web_searching", "message": "正在搜尋最新資訊..."}, + "你", + "好", + ] + assert client.responses.payload["stream"] is True + assert client.timeout == 90 + + +@pytest.mark.asyncio +async def test_generate_response_async_streaming_falls_back_without_hosted_tools(monkeypatch): + client = FailingThenSafeClient() + chunks = [] + + class Settings: + OPENAI_MODEL = "gpt-5.4-mini" + OPENAI_RESPONSES_TIMEOUT = 90 + OPENAI_USE_RESPONSES = True + OPENAI_ENABLE_WEB_SEARCH = True + OPENAI_ENABLE_REMOTE_MCP = False + OPENAI_REMOTE_MCP_SERVERS_JSON = "[]" + OPENAI_ENABLE_SKILLS = False + + async def on_chunk(delta): + chunks.append(delta) + + monkeypatch.setattr(ai_service, "settings", Settings) + monkeypatch.setattr(ai_service, "OPENAI_TIMEOUT", 30) + monkeypatch.setattr(ai_service, "OPENAI_RESPONSES_TIMEOUT", 90) + monkeypatch.setattr(ai_service, "_get_client", lambda: client) + monkeypatch.setattr(ai_service, "_default_hosted_tools", lambda: [{"type": "web_search"}]) + + result = await ai_service.generate_response_async( + [{"role": "user", "content": "今天台積電股價多少?"}], + model="gpt-5.4-mini", + stream=True, + on_chunk=on_chunk, + ) + + assert "即時搜尋暫時不可用" in result + assert len(client.responses.payloads) == 2 + assert client.responses.payloads[0]["tools"] == [{"type": "web_search"}] + assert client.responses.payloads[0]["stream"] is True + assert client.responses.payloads[1]["tools"] == [] + assert "stream" not in client.responses.payloads[1] + assert "不得編造即時" in client.responses.payloads[1]["instructions"] + assert chunks == [ + { + "type": "status", + "status": "hosted_tools_unavailable", + "phase": "fallback", + "message": "即時搜尋暫時不可用,正在改用安全降級回答...", + "temporary": True, + } + ] + + +@pytest.mark.asyncio +async def test_consume_responses_stream_logs_delta_timing(monkeypatch): + events = [ + StreamEvent("response.in_progress"), + StreamEvent("response.output_text.delta", delta="你"), + StreamEvent("response.output_text.delta", delta="好"), + ] + chunks = [] + log_messages = [] + + async def on_chunk(delta): + chunks.append(delta) + + class FakeLogger: + def info(self, message, *args): + log_messages.append(message % args if args else message) + + monkeypatch.setattr(ai_service, "logger", FakeLogger()) + + result = await ai_service._consume_responses_stream(events, on_chunk) + + assert result == "你好" + assert chunks == [ + {"type": "status", "status": "thinking", "message": "正在處理..."}, + "你", + "好", + ] + assert any("Responses stream stats" in entry for entry in log_messages) + + +@pytest.mark.asyncio +async def test_generate_response_async_retries_when_response_language_mismatches(monkeypatch): + client = LanguageMismatchClient() + + class Settings: + OPENAI_MODEL = "gpt-5.4-mini" + OPENAI_RESPONSES_TIMEOUT = 90 + OPENAI_USE_RESPONSES = True + OPENAI_ENABLE_WEB_SEARCH = False + OPENAI_ENABLE_REMOTE_MCP = False + OPENAI_REMOTE_MCP_SERVERS_JSON = "[]" + OPENAI_ENABLE_SKILLS = False + + monkeypatch.setattr(ai_service, "settings", Settings) + monkeypatch.setattr(ai_service, "OPENAI_TIMEOUT", 30) + monkeypatch.setattr(ai_service, "OPENAI_RESPONSES_TIMEOUT", 90) + monkeypatch.setattr(ai_service, "_get_client", lambda: client) + monkeypatch.setattr(ai_service, "_default_hosted_tools", lambda: []) + + result = await ai_service.generate_response_async( + [ + {"role": "system", "content": "Reply in English."}, + {"role": "user", "content": "how are you"}, + ], + model="gpt-5.4-mini", + stream=False, + expected_language="en-US", + ) + + assert result == "I am doing well, thank you." + assert len(client.responses.payloads) == 2 + assert "Language correction" in client.responses.payloads[1]["instructions"] diff --git a/tests/test_app_pipeline_config.py b/tests/test_app_pipeline_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf81461b87eff4e52c410d8914b2d89b22942bc --- /dev/null +++ b/tests/test_app_pipeline_config.py @@ -0,0 +1,27 @@ +from pathlib import Path + + +def test_app_detect_timeout_keeps_agent_tool_judgement_room(): + source = Path("app.py").read_text(encoding="utf-8") + + assert "detect_timeout=25.0" in source + + +def test_app_ai_timeout_allows_hosted_tool_streaming(): + source = Path("app.py").read_text(encoding="utf-8") + + assert "ai_timeout=60.0" in source + + +def test_app_has_resolved_language_helper_for_agent_response_text(): + source = Path("app.py").read_text(encoding="utf-8") + + assert "def _resolve_conversation_language(" in source + assert "_preferred_language_from_text(res.text)" in source + + +def test_app_passes_user_message_language_into_handle_message(): + source = Path("app.py").read_text(encoding="utf-8") + + assert 'message_language = message_data.get("language") or "auto"' in source + assert "handle_message(user_message, user_id, chat_id, messages_for_handler, request_id=request_id, language=message_language, emotion_callback=_on_text_emotion)" in source diff --git a/tests/test_environment_context_builder.py b/tests/test_environment_context_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..4afbead41e6e12d6d24b31a38d502d4826ef983b --- /dev/null +++ b/tests/test_environment_context_builder.py @@ -0,0 +1,68 @@ +from datetime import datetime + +from core.environment.context_builder import EnvironmentContextBuilder +from services.ai_service import _build_environment_context_text + + +def test_environment_context_builder_marks_missing_context(): + builder = EnvironmentContextBuilder() + + result = builder.build({}, now=datetime(2026, 5, 13, 12, 0, 0)) + + assert result.metadata["freshness"] == "missing" + assert "has_location: False" in result.summary_text + + +def test_environment_context_builder_includes_latest_location_fields(): + builder = EnvironmentContextBuilder() + + result = builder.build( + { + "lat": 25.033964, + "lon": 121.564468, + "detailed_address": "台北市信義區市府路45號", + "address_display": "110台北市信義區市府路45號", + "precision": "address", + "poi_label": "台北101", + "road": "市府路", + "house_number": "45", + "tz": "Asia/Taipei", + "locale": "zh-TW", + "heading_cardinal": "NE", + "accuracy_m": 12, + }, + now=datetime(2026, 5, 13, 12, 0, 0), + ) + + assert result.metadata["freshness"] == "latest_available" + assert "detailed_address: 台北市信義區市府路45號" in result.summary_text + assert "precision: address" in result.summary_text + assert "poi_label: 台北101" in result.summary_text + assert "road: 市府路" in result.summary_text + assert "house_number: 45" in result.summary_text + assert "timezone: Asia/Taipei" in result.summary_text + assert "coordinates: 25.033964,121.564468" in result.summary_text + + +def test_ai_service_environment_context_text_uses_fixed_builder(): + text = _build_environment_context_text( + { + "lat": 25.033964, + "lon": 121.564468, + "detailed_address": "台北市信義區市府路45號", + "address_display": "110台北市信義區市府路45號", + "precision": "address", + "poi_label": "台北101", + "road": "市府路", + "house_number": "45", + "tz": "Asia/Taipei", + "locale": "zh-TW", + } + ) + + assert "snapshot_time_utc:" in text + assert "has_location: True" in text + assert "detailed_address: 台北市信義區市府路45號" in text + assert "precision: address" in text + assert "poi_label: 台北101" in text + assert "readable_context:" in text diff --git a/tests/test_environment_context_tool.py b/tests/test_environment_context_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..5f45756903fcb80b56fc09e652bb81d255ad22c5 --- /dev/null +++ b/tests/test_environment_context_tool.py @@ -0,0 +1,35 @@ +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from features.mcp.tools.environment.context_tool import EnvironmentContextTool + + +@pytest.mark.asyncio +async def test_environment_context_tool_requires_injected_user_id(): + tool = EnvironmentContextTool() + + result = await tool.execute({}) + + assert result["success"] is False + assert "_user_id" in result["error"] + + +@pytest.mark.asyncio +async def test_environment_context_tool_reads_real_user_context(): + tool = EnvironmentContextTool() + fetcher = AsyncMock( + return_value={ + "success": True, + "context": {"city": "Taipei", "tz": "Asia/Taipei"}, + } + ) + + with patch("core.database.get_user_env_current", fetcher): + result = await tool.execute({"_user_id": "user-1"}) + + assert result["success"] is True + assert result["data"] == {"city": "Taipei", "tz": "Asia/Taipei"} + assert json.loads(result["content"]) == {"city": "Taipei", "tz": "Asia/Taipei"} + fetcher.assert_awaited_once_with("user-1") diff --git a/tests/test_forward_geocode_ranking.py b/tests/test_forward_geocode_ranking.py new file mode 100644 index 0000000000000000000000000000000000000000..8426a2663ad76d7490a837503f9433349632c4e8 --- /dev/null +++ b/tests/test_forward_geocode_ranking.py @@ -0,0 +1,57 @@ +import pytest + +from features.mcp.tools.location.geocoding_tool import ForwardGeocodeTool + + +@pytest.mark.asyncio +async def test_forward_geocode_prefers_poi_match_for_station_queries(monkeypatch): + async def fake_tdx(query): + return [ + { + "lat": 25.0487, + "lon": 121.5143, + "display_name": "市民大道台北地下街出口Y12西面", + "label": "桃園機場捷運台北車站_A1", + "importance": 1.0, + "name": "桃園機場捷運台北車站_A1", + "road": "", + "house_number": "", + "suburb": "", + "city_district": "", + "city": "", + "admin": "", + "postcode": "", + "amenity": "", + "shop": "", + "building": "", + "detailed_address": "市民大道台北地下街出口Y12西面", + "_kind": "markname", + }, + { + "lat": 25.1424, + "lon": 121.5066, + "display_name": "台北市北投區開明里珠海路臨125之2號", + "label": "台北市北投區開明里珠海路臨125之2號", + "importance": 1.0, + "name": "", + "road": "", + "house_number": "", + "suburb": "", + "city_district": "", + "city": "", + "admin": "", + "postcode": "", + "amenity": "", + "shop": "", + "building": "", + "detailed_address": "台北市北投區開明里珠海路臨125之2號", + "_kind": "address", + }, + ] + + monkeypatch.setattr(ForwardGeocodeTool, "_forward_geocode_tdx", fake_tdx) + + result = await ForwardGeocodeTool.execute({"query": "台北車站", "limit": 1}) + + assert "台北車站" in result["best_match"]["label"] + assert result["best_match"]["lat"] == 25.0487 diff --git a/tests/test_frontend_agent_output_rendering.py b/tests/test_frontend_agent_output_rendering.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed5e29c83c788bd1a1593e4c649f3ba2ef13ba4 --- /dev/null +++ b/tests/test_frontend_agent_output_rendering.py @@ -0,0 +1,143 @@ +def test_typewriter_avoids_interval_and_full_markdown_rerender_per_tick(): + source = open("static/frontend/js/ui.js", encoding="utf-8").read() + start = source.index("function typewriterEffect") + block = source[start:source.index("const agentProgressSteps")] + + assert "setInterval(" not in block + assert "requestAnimationFrame(" in block + assert "renderAgentMarkdown(text.slice(0, index))" not in block + assert "typewriterTextNode" in source + assert "typewriterFrameId" in source + assert "setTypewriterContent" in source + + +def test_streaming_output_skips_redundant_full_repaint_when_text_unchanged(): + source = open("static/frontend/js/ui.js", encoding="utf-8").read() + + assert "let lastRenderedAgentText = '';" in source + assert "if (lastRenderedAgentText === nextText)" in source + assert "lastRenderedAgentText = nextText;" in source + + +def test_final_typewriter_waits_for_speech_completion_before_idle(): + ui_source = open("static/frontend/js/ui.js", encoding="utf-8").read() + tts_source = open("static/frontend/js/tts.js", encoding="utf-8").read() + + assert "const awaitSpeechCompletion = options.awaitSpeechCompletion === true || !!enableTTS;" in ui_source + assert "window.agentOutputAwaitingSpeechCompletion = awaitSpeechCompletion;" in ui_source + assert "if (window.agentOutputAwaitingSpeechCompletion) {" in ui_source + assert "function maybeFinalizeSpeechPlayback()" in tts_source + assert "window.agentOutputAwaitingSpeechCompletion = false;" in tts_source + assert "maybeFinalizeSpeechPlayback();" in tts_source + + +def test_progress_updates_do_not_override_active_answer_stream(): + ui_source = open("static/frontend/js/ui.js", encoding="utf-8").read() + ws_source = open("static/frontend/js/websocket.js", encoding="utf-8").read() + + assert "window.agentOutputHasAnswerStream = false;" in ui_source + assert "if (options.progress === true && window.agentOutputHasAnswerStream) {" in ui_source + assert "window.agentOutputHasAnswerStream = true;" in ui_source + assert "if (typeof currentState !== 'undefined' && currentState === 'speaking') {" in ws_source + assert "setState('speaking');" in ws_source + + +def test_stt_idle_does_not_reset_agent_while_speaking(): + ws_source = open("static/frontend/js/websocket.js", encoding="utf-8").read() + + assert "else if (data.status === 'idle' || data.status === 'error') {" in ws_source + assert "if (typeof currentState !== 'undefined' && currentState === 'speaking') {" in ws_source + assert "resetAgent();" in ws_source + + +def test_final_output_can_wait_for_streaming_speech_completion_without_full_tts(): + ui_source = open("static/frontend/js/ui.js", encoding="utf-8").read() + tts_source = open("static/frontend/js/tts.js", encoding="utf-8").read() + ws_source = open("static/frontend/js/websocket.js", encoding="utf-8").read() + + assert "const awaitSpeechCompletion = options.awaitSpeechCompletion === true || !!enableTTS;" in ui_source + assert "function hasPendingStreamingSpeech()" in tts_source + assert "|| !!ttsStreamSocket || _ttsActiveSources.length > 0" in tts_source + assert "hasPendingStreamingSpeech() || useFullTtsFallback" in ws_source + assert "finishAgentOutput(data.message, useFullTtsFallback, {" in ws_source + + +def test_tts_debug_logging_and_stop_reason_contracts_exist(): + tts_source = open("static/frontend/js/tts.js", encoding="utf-8").read() + agent_source = open("static/frontend/js/agent.js", encoding="utf-8").read() + + assert "function logTtsDebug(event, extra = {})" in tts_source + assert "if (!window.DEBUG_MODE) {" in tts_source + assert "logTtsDebug('socket_close'" in tts_source + assert "function stopSpeaking(clearPending = true, reason = 'unspecified')" in tts_source + assert "stopSpeaking(true, 'state_idle');" in agent_source + assert "stopSpeaking(true, 'reset_agent');" in agent_source + + +def test_complete_typewriter_refuses_idle_while_streaming_speech_still_pending(): + ui_source = open("static/frontend/js/ui.js", encoding="utf-8").read() + + assert "if (typeof hasPendingStreamingSpeech === 'function' && hasPendingStreamingSpeech()) {" in ui_source + assert "window.agentOutputAwaitingSpeechCompletion = true;" in ui_source + + +def test_tool_cards_container_is_an_interactive_fixed_overlay(): + source = open("static/frontend/index.html", encoding="utf-8").read() + + assert "#tool-cards-container {" in source + assert "position: fixed;" in source + assert "pointer-events: none;" in source + assert "#tool-cards-container .voice-tool-card {" in source + assert "pointer-events: auto;" in source + + +def test_tool_card_visual_dividers_and_scrollbar_are_refined(): + source = open("static/frontend/index.html", encoding="utf-8").read() + + assert "border: 1px solid rgba(15, 23, 42, 0.06);" in source + assert "border-bottom: 1px solid rgba(15, 23, 42, 0.05);" in source + assert ".tool-drawer-content::-webkit-scrollbar {" in source + assert "width: 3px;" in source + assert "background: rgba(15, 23, 42, 0.12);" in source + + +def test_speaking_petals_reuse_idle_rotation_and_layer_interleave(): + source = open("static/frontend/index.html", encoding="utf-8").read() + + assert ".voice-mic-container.speaking .bloom-petal-group.upper .bloom-petal:nth-child(1)" in source + assert "animation: petalBloom 1.6s ease-in-out infinite 0s;" in source + assert ".voice-mic-container.speaking .bloom-petal-group.lower .bloom-petal:nth-child(1)" in source + assert "animation: petalBloomInner 1.6s ease-in-out infinite 0.1s;" in source + assert "--speaking-angle" not in source + assert "speakingPetalFloat" not in source + assert "speakingPetalFloatInner" not in source + + +def test_websocket_updates_conversation_language_from_bot_messages(): + ws_source = open("static/frontend/js/websocket.js", encoding="utf-8").read() + + assert "if (data.language) {" in ws_source + assert "window.currentConversationLanguage = data.language;" in ws_source + assert "window.currentSpeechLanguage = data.language;" in ws_source + + +def test_new_thinking_cycle_clears_stale_tts_language_before_next_stream(): + ws_source = open("static/frontend/js/websocket.js", encoding="utf-8").read() + + assert "window.currentConversationLanguage = null;" in ws_source + assert "window.currentSpeechLanguage = null;" in ws_source + + +def test_bot_delta_carries_language_for_streaming_tts_switch(): + source = open("app.py", encoding="utf-8").read() + + assert "\"language\": _preferred_language_from_text(stream_accumulator[\"text\"]) or resolved_language" in source + + +def test_frontend_preloads_conversation_language_on_initial_input(): + ws_source = open("static/frontend/js/websocket.js", encoding="utf-8").read() + + assert "function inferConversationLanguage(text, fallback = null)" in ws_source + assert "const inferredLanguage = window.inferConversationLanguage(text, navigator.language || 'zh-TW');" in ws_source + assert "window.currentConversationLanguage = inferredLanguage;" in ws_source + assert "window.currentSpeechLanguage = inferredLanguage;" in ws_source diff --git a/tests/test_google_tts_service.py b/tests/test_google_tts_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8da1ccdc956ea894af8765ccd865af89ca72771a --- /dev/null +++ b/tests/test_google_tts_service.py @@ -0,0 +1,177 @@ +import base64 + +import pytest +from fastapi import WebSocketDisconnect + +import app +from services.tts_service import TTSService, get_emotion_rate + + +def test_google_tts_voice_aliases_are_multilingual(): + service = TTSService() + + assert service._voice_config("coral") == {"languageCode": "cmn-TW", "name": "cmn-TW-Wavenet-A"} + assert service._voice_config("ja-jp") == {"languageCode": "ja-JP", "name": "ja-JP-Neural2-B"} + assert service._voice_config("vi-vn") == {"languageCode": "vi-VN", "name": "vi-VN-Wavenet-A"} + + +def test_google_tts_emotion_rate_is_conservative_for_care(): + assert get_emotion_rate("happy") > 1.0 + assert get_emotion_rate("sad") < 1.0 + assert get_emotion_rate("neutral", care_mode=True) < 1.0 + + +@pytest.mark.asyncio +async def test_google_tts_requires_api_key(): + service = TTSService() + service.api_key = "" + + result = await service.synthesize("你好") + + assert result["success"] is False + assert "GOOGLE_TTS_API_KEY" in result["error"] + + +@pytest.mark.asyncio +async def test_google_tts_decodes_audio_content(monkeypatch): + service = TTSService() + service.api_key = "test-key" + captured = {} + + class FakeResponse: + status = 200 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + async def json(self, content_type=None): + return {"audioContent": base64.b64encode(b"mp3").decode("ascii")} + + class FakeSession: + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + def post(self, url, params=None, json=None, timeout=None): + captured["url"] = url + captured["params"] = params + captured["json"] = json + return FakeResponse() + + monkeypatch.setattr("services.tts_service.aiohttp.ClientSession", FakeSession) + + result = await service.synthesize("你好", voice="coral", emotion="happy") + + assert result["success"] is True + assert result["audio_data"] == b"mp3" + assert captured["params"] == {"key": "test-key"} + assert captured["json"]["voice"]["languageCode"] == "cmn-TW" + assert captured["json"]["audioConfig"]["speakingRate"] > 1.0 + + +@pytest.mark.asyncio +async def test_tts_websocket_client_disconnect_is_not_treated_as_server_error(monkeypatch): + events = [] + + class FakeWebSocket: + def __init__(self): + self.closed = False + self.send_count = 0 + + async def accept(self): + return None + + async def receive_json(self): + return { + "text": "你好", + "voice": "nova", + "language": "zh-TW", + "persona": "xiaohua", + "speaking_rate": 0.94, + } + + async def send_json(self, payload): + self.send_count += 1 + events.append(payload["type"]) + if payload["type"] == "tts_audio_chunk": + raise WebSocketDisconnect() + + async def close(self): + self.closed = True + + class FakeTTSService: + async def streaming_synthesize(self, **kwargs): + yield b"\x00\x01" + + monkeypatch.setattr(app, "logger", type("Logger", (), { + "info": lambda *args, **kwargs: events.append("log:info"), + "debug": lambda *args, **kwargs: events.append("log:debug"), + "error": lambda *args, **kwargs: events.append("log:error"), + "exception": lambda *args, **kwargs: events.append("log:exception"), + })()) + + import services.tts_service as tts_module + monkeypatch.setattr(tts_module, "tts_service", FakeTTSService()) + + websocket = FakeWebSocket() + await app.tts_stream_websocket(websocket) + + assert events[:2] == ["tts_stream_start", "tts_audio_chunk"] + assert "log:debug" in events + assert "log:error" not in events + assert "log:exception" not in events + assert websocket.closed is True + + +@pytest.mark.asyncio +async def test_tts_websocket_logs_chunk_stats_before_client_disconnect(monkeypatch): + events = [] + + class FakeWebSocket: + def __init__(self): + self.closed = False + + async def accept(self): + return None + + async def receive_json(self): + return { + "text": "你好", + "voice": "nova", + "language": "zh-TW", + "persona": "xiaohua", + "speaking_rate": 0.94, + } + + async def send_json(self, payload): + if payload["type"] == "tts_audio_chunk": + raise WebSocketDisconnect() + + async def close(self): + self.closed = True + + class FakeTTSService: + async def streaming_synthesize(self, **kwargs): + yield b"\x00\x01" + + monkeypatch.setattr(app, "logger", type("Logger", (), { + "info": lambda self, message, *args, **kwargs: events.append(message % args if args else message), + "debug": lambda self, message, *args, **kwargs: events.append(message % args if args else message), + "error": lambda self, *args, **kwargs: None, + "exception": lambda self, *args, **kwargs: None, + })()) + + import services.tts_service as tts_module + monkeypatch.setattr(tts_module, "tts_service", FakeTTSService()) + + websocket = FakeWebSocket() + await app.tts_stream_websocket(websocket) + + assert any("chunks=1" in event for event in events) + assert any("bytes=2" in event for event in events) + assert websocket.closed is True diff --git a/tests/test_mcp_client_contract.py b/tests/test_mcp_client_contract.py new file mode 100644 index 0000000000000000000000000000000000000000..3da16e238b6406115983dadbfa902e8394e9a69d --- /dev/null +++ b/tests/test_mcp_client_contract.py @@ -0,0 +1,74 @@ +import pytest + +from features.mcp.mcp_client import MCPClient + + +def test_create_tool_from_data_preserves_output_schema(): + client = MCPClient("server", {"command": "noop"}) + + tool = client._create_tool_from_data( + { + "name": "remote_tool", + "description": "Remote tool", + "inputSchema": {"type": "object", "properties": {}}, + "outputSchema": { + "type": "object", + "properties": {"success": {"type": "boolean"}}, + "required": ["success"], + }, + } + ) + + assert tool is not None + assert tool.outputSchema["properties"]["success"]["type"] == "boolean" + + +@pytest.mark.asyncio +async def test_call_tool_returns_structured_content(monkeypatch): + client = MCPClient("server", {"command": "noop"}) + + async def fake_send_request(method, params): + return { + "result": { + "content": [{"type": "text", "text": "fallback text"}], + "structuredContent": { + "success": True, + "value": "structured", + }, + } + } + + monkeypatch.setattr(client, "_send_request", fake_send_request) + + result = await client._call_tool("remote_tool", {}) + + assert result == { + "success": True, + "value": "structured", + } + + +@pytest.mark.asyncio +async def test_call_tool_preserves_error_semantics_with_structured_content(monkeypatch): + client = MCPClient("server", {"command": "noop"}) + + async def fake_send_request(method, params): + return { + "result": { + "content": [{"type": "text", "text": "failed"}], + "structuredContent": { + "error_code": "REMOTE_ERROR", + }, + "isError": True, + } + } + + monkeypatch.setattr(client, "_send_request", fake_send_request) + + result = await client._call_tool("remote_tool", {}) + + assert result == { + "error_code": "REMOTE_ERROR", + "success": False, + "error": "failed", + } diff --git a/tests/test_mcp_server_contract.py b/tests/test_mcp_server_contract.py new file mode 100644 index 0000000000000000000000000000000000000000..d5706d30e5f5c67ee8c2a03596f6ba4ccd326b96 --- /dev/null +++ b/tests/test_mcp_server_contract.py @@ -0,0 +1,139 @@ +import pytest + +from features.mcp.server import FeaturesMCPServer +from features.mcp.types import Tool, ToolCallResult + + +@pytest.mark.asyncio +async def test_tools_list_includes_output_schema(): + server = FeaturesMCPServer() + server.tools.clear() + server.register_tool( + Tool( + name="example", + description="Example tool", + inputSchema={"type": "object", "properties": {}}, + outputSchema={ + "type": "object", + "properties": {"success": {"type": "boolean"}}, + "required": ["success"], + }, + ) + ) + + result = await server._handle_tools_list({}) + + assert result["tools"][0]["outputSchema"]["properties"]["success"]["type"] == "boolean" + + +@pytest.mark.asyncio +async def test_tools_call_preserves_structured_content(): + async def handler(arguments): + return { + "success": True, + "content": "ok", + "value": arguments["value"], + } + + server = FeaturesMCPServer() + server.tools.clear() + server.register_tool( + Tool( + name="example", + description="Example tool", + inputSchema={ + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + }, + handler=handler, + outputSchema={ + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["success", "content", "value"], + }, + ) + ) + + result = await server._handle_tools_call({"name": "example", "arguments": {"value": "42"}}) + + assert result["content"] == [{"type": "text", "text": "ok"}] + assert result["structuredContent"]["value"] == "42" + assert result["isError"] is False + + +@pytest.mark.asyncio +async def test_tools_call_error_uses_safe_error_contract(): + async def handler(arguments): + raise RuntimeError("secret internal detail") + + server = FeaturesMCPServer() + server.tools.clear() + server.register_tool( + Tool( + name="bad_tool", + description="Bad tool", + inputSchema={"type": "object", "properties": {}}, + handler=handler, + ) + ) + + result = await server._handle_tools_call({"name": "bad_tool", "arguments": {}}) + + assert result["content"] == [{"type": "text", "text": "工具執行失敗"}] + assert result["structuredContent"]["error_code"] == "TOOL_EXECUTION_ERROR" + assert result["structuredContent"]["tool_name"] == "bad_tool" + assert result["isError"] is True + + +@pytest.mark.asyncio +async def test_tools_call_rejects_output_schema_violation(): + async def handler(arguments): + return { + "success": True, + "content": "ok", + } + + server = FeaturesMCPServer() + server.tools.clear() + server.register_tool( + Tool( + name="bad_output", + description="Bad output", + inputSchema={"type": "object", "properties": {}}, + handler=handler, + outputSchema={ + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["success", "content", "value"], + }, + ) + ) + + result = await server._handle_tools_call({"name": "bad_output", "arguments": {}}) + + assert result["content"] == [{"type": "text", "text": "工具輸出格式不符合契約"}] + assert result["structuredContent"]["error_code"] == "TOOL_OUTPUT_VALIDATION" + assert result["structuredContent"]["tool_name"] == "bad_output" + assert result["isError"] is True + + +def test_tool_call_result_serializes_mcp_fields(): + result = ToolCallResult( + content=[{"type": "text", "text": "ok"}], + structuredContent={"success": True}, + ) + + assert result.to_dict() == { + "content": [{"type": "text", "text": "ok"}], + "structuredContent": {"success": True}, + "isError": False, + } diff --git a/tests/test_memory_system_resilience.py b/tests/test_memory_system_resilience.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4fee90b877a3e50b98b1b6499d08cbc5e3f596 --- /dev/null +++ b/tests/test_memory_system_resilience.py @@ -0,0 +1,155 @@ +import logging +import json + +import pytest + +from core import memory_system + + +@pytest.mark.asyncio +async def test_memory_manager_skips_ai_analysis_for_transient_market_query(monkeypatch): + class FailingAnalyzer: + async def analyze_conversation(self, *args, **kwargs): + raise AssertionError("AI memory analysis should be skipped for transient queries") + + manager = memory_system.MemoryManager() + manager.analyzer = FailingAnalyzer() + + monkeypatch.setattr(memory_system, "_get_memory_client", lambda: object()) + monkeypatch.setattr(memory_system, "db_available", False) + + result = await manager.process_conversation( + user_id="u1", + user_message="今天台積電股價收盤價多少?", + assistant_response="查詢中。", + conversation_history=[], + ) + + assert result["extracted_memories"] == 0 + assert result["saved_memories"] == 0 + assert result["errors"] == [] + + +@pytest.mark.asyncio +async def test_memory_analyzer_quietly_degrades_on_transient_upstream_error(monkeypatch, caplog): + class Responses: + def create(self, **kwargs): + raise Exception("Error code: 502 - {'error': {'message': 'Upstream request failed'}}") + + class Client: + responses = Responses() + + class Settings: + OPENAI_USE_RESPONSES = True + GPT_INTENT_MODEL = "gpt-5.4" + OPENAI_MODEL = "gpt-5.4" + OPENAI_TIMEOUT = 30 + + monkeypatch.setattr(memory_system, "_get_memory_client", lambda: Client()) + monkeypatch.setattr(memory_system, "settings", Settings) + + analyzer = memory_system.MemoryAnalyzer() + + with caplog.at_level(logging.ERROR): + memories = await analyzer.analyze_conversation( + user_message="我喜歡安靜的咖啡店", + assistant_response="我記住了。", + conversation_history=[], + ) + + assert memories == [] + assert "AI記憶分析時發生錯誤" not in caplog.text + + +def test_memory_analyzer_uses_structured_responses_payload(monkeypatch): + captured = {} + + class Responses: + def create(self, **kwargs): + captured.update(kwargs) + + class Response: + output_text = '{"memories":[]}' + + output = [] + + return Response() + + class Client: + responses = Responses() + + class Settings: + OPENAI_USE_RESPONSES = True + GPT_INTENT_MODEL = "gpt-5.4-mini" + OPENAI_MODEL = "gpt-5.4-mini" + + monkeypatch.setattr(memory_system, "settings", Settings) + + analyzer = memory_system.MemoryAnalyzer() + response = analyzer._create_analysis_response( + Client(), + [{"role": "system", "content": "JSON"}, {"role": "user", "content": "JSON"}], + 500, + ) + + assert response.output_text == '{"memories":[]}' + assert captured["model"] == "gpt-5.4-mini" + assert captured["max_output_tokens"] == 500 + assert captured["store"] is False + assert captured["text"]["format"]["type"] == "json_schema" + assert captured["text"]["format"]["strict"] is True + assert captured["text"]["format"]["schema"]["required"] == ["memories"] + assert "reasoning" not in captured + + +@pytest.mark.asyncio +async def test_memory_analyzer_retries_transient_upstream_error_then_succeeds(monkeypatch): + calls = {"count": 0} + + class Responses: + def create(self, **kwargs): + calls["count"] += 1 + if calls["count"] == 1: + raise Exception("502 upstream request failed") + + class Response: + output_text = json.dumps( + { + "memories": [ + { + "type": "preferences", + "content": "使用者喜歡安靜的咖啡店", + "importance": 0.8, + } + ] + }, + ensure_ascii=False, + ) + + output = [] + + return Response() + + class Client: + responses = Responses() + + class Settings: + OPENAI_USE_RESPONSES = True + GPT_INTENT_MODEL = "gpt-5.4-mini" + OPENAI_MODEL = "gpt-5.4-mini" + OPENAI_TIMEOUT = 30 + + monkeypatch.setattr(memory_system, "_get_memory_client", lambda: Client()) + monkeypatch.setattr(memory_system, "settings", Settings) + monkeypatch.setattr(memory_system.MemoryAnalyzer, "_transient_backoff", staticmethod(lambda attempt: _noop())) + + analyzer = memory_system.MemoryAnalyzer() + memories = await analyzer.analyze_conversation("我喜歡安靜的咖啡店", "我記住了。", []) + + assert calls["count"] == 2 + assert memories[0]["type"] == "preferences" + assert memories[0]["source"] == "ai_analysis" + + +async def _noop(): + return None diff --git a/tests/test_openai_hosted_tools_config.py b/tests/test_openai_hosted_tools_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f2aa167af9609677971b0a1b964bc7f707b5795d --- /dev/null +++ b/tests/test_openai_hosted_tools_config.py @@ -0,0 +1,129 @@ +import json + +from features.mcp.openai_tools import build_openai_hosted_tools +from features.mcp.openai_tools import DEFAULT_CONFIG_PATH + + +def test_default_config_path_points_to_project_mcp_config(): + assert DEFAULT_CONFIG_PATH.name == "mcp_config.json" + assert DEFAULT_CONFIG_PATH.parent.name == "features" + assert DEFAULT_CONFIG_PATH.exists() + + +class _Settings: + OPENAI_ENABLE_WEB_SEARCH = True + OPENAI_ENABLE_REMOTE_MCP = True + OPENAI_REMOTE_MCP_SERVERS_JSON = "[]" + OPENAI_ENABLE_SKILLS = True + + +def test_openai_hosted_tools_reads_project_mcp_config(tmp_path, monkeypatch): + config_path = tmp_path / "mcp_config.json" + config_path.write_text( + json.dumps( + { + "openai_tools": { + "web_search": {"enabled": True}, + "remote_mcp": { + "enabled": True, + "approval_default": "always", + "items": [ + { + "enabled": True, + "server_label": "dmcp", + "server_url": "https://dmcp-server.deno.dev/sse", + "allowed_tools": ["roll"], + } + ], + }, + "skills": { + "enabled": True, + "mode": "system_context", + "skills_root": "features/mcp/skills", + }, + } + } + ), + encoding="utf-8", + ) + monkeypatch.setattr("features.mcp.openai_tools.settings", _Settings) + + specs = build_openai_hosted_tools(config_path) + + assert specs == [ + {"type": "web_search"}, + { + "type": "mcp", + "server_label": "dmcp", + "server_url": "https://dmcp-server.deno.dev/sse", + "allowed_tools": ["roll"], + "require_approval": "always", + } + ] + + +def test_openai_hosted_tools_keeps_remote_mcp_and_skills_disabled_by_env(tmp_path, monkeypatch): + class DisabledSettings(_Settings): + OPENAI_ENABLE_REMOTE_MCP = False + OPENAI_ENABLE_SKILLS = False + + config_path = tmp_path / "mcp_config.json" + config_path.write_text( + json.dumps( + { + "openai_tools": { + "web_search": {"enabled": True}, + "remote_mcp": { + "enabled": True, + "items": [{"server_label": "dmcp", "server_url": "https://example.com/mcp"}], + }, + "skills": {"enabled": True, "mode": "system_context"}, + } + } + ), + encoding="utf-8", + ) + monkeypatch.setattr("features.mcp.openai_tools.settings", DisabledSettings) + + assert build_openai_hosted_tools(config_path) == [{"type": "web_search"}] + + +def test_openai_hosted_tools_skips_local_mcp_without_remote_url(tmp_path, monkeypatch): + config_path = tmp_path / "mcp_config.json" + config_path.write_text( + json.dumps( + { + "openai_tools": { + "web_search": {"enabled": False}, + "remote_mcp": { + "enabled": True, + "items": [{"enabled": True, "server_label": "local-features"}], + }, + "skills": {"enabled": False}, + } + } + ), + encoding="utf-8", + ) + monkeypatch.setattr("features.mcp.openai_tools.settings", _Settings) + + assert build_openai_hosted_tools(config_path) == [] + + +def test_openai_hosted_tools_never_emits_executable_skill_adapter(tmp_path, monkeypatch): + config_path = tmp_path / "mcp_config.json" + config_path.write_text( + json.dumps( + { + "openai_tools": { + "web_search": {"enabled": False}, + "remote_mcp": {"enabled": False}, + "skills": {"enabled": True, "mode": "system_context"}, + } + } + ), + encoding="utf-8", + ) + monkeypatch.setattr("features.mcp.openai_tools.settings", _Settings) + + assert build_openai_hosted_tools(config_path) == [] diff --git a/tests/test_pipeline_confidence_loop.py b/tests/test_pipeline_confidence_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..368271b905d6ef2f1ee576577993dcb29a434030 --- /dev/null +++ b/tests/test_pipeline_confidence_loop.py @@ -0,0 +1,101 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock +from core.pipeline import ChatPipeline, PipelineResult + +@pytest.mark.asyncio +async def test_confidence_driven_loop_no_tool_calls(): + """Test when no tools are called, it immediately answers.""" + intent_detector = AsyncMock() + # Returns no feature + intent_detector.return_value = (False, {"emotion": "neutral"}) + + feature_processor = AsyncMock() + ai_generator = AsyncMock() + ai_generator.return_value = PipelineResult(text="Hello", is_fallback=False, meta={"emotion": "neutral"}) + + pipeline = ChatPipeline( + intent_detector=intent_detector, + feature_processor=feature_processor, + ai_generator=ai_generator + ) + + res = await pipeline.process("Hello") + assert res.text == "Hello" + assert intent_detector.call_count == 1 + assert feature_processor.call_count == 0 + assert ai_generator.call_count == 1 + +@pytest.mark.asyncio +async def test_confidence_driven_loop_single_tool_call(): + """Test when one tool is called, it iterates once and passes context.""" + intent_detector = AsyncMock() + # 1st call: Call tool + # 2nd call: No tool needed (satisfied) + intent_detector.side_effect = [ + (True, {"type": "mcp_tool", "confidence": 0.95, "emotion": "neutral"}), + (False, {"emotion": "neutral"}) + ] + + feature_processor = AsyncMock() + feature_processor.return_value = PipelineResult(text="Tool result payload", is_fallback=False, meta={}) + + ai_generator = AsyncMock() + ai_generator.return_value = "Based on the tool, the answer is Yes." + + pipeline = ChatPipeline( + intent_detector=intent_detector, + feature_processor=feature_processor, + ai_generator=ai_generator + ) + + res = await pipeline.process("What is the weather?") + + assert res.text == "Based on the tool, the answer is Yes." + assert intent_detector.call_count == 2 + assert feature_processor.call_count == 1 + assert ai_generator.call_count == 1 + + # Verify tool context is passed to ai_generator + call_kwargs = ai_generator.call_args.kwargs + assert "Tool result payload" in call_kwargs.get("tool_context", "") + +@pytest.mark.asyncio +async def test_confidence_driven_loop_multi_tool_call(): + """Test when information is incomplete, it calls multiple tools before answering.""" + intent_detector = AsyncMock() + # 1st call: Call tool 1 + # 2nd call: Call tool 2 + # 3rd call: Satisfied + intent_detector.side_effect = [ + (True, {"type": "mcp_tool", "confidence": 0.95, "emotion": "neutral"}), + (True, {"type": "mcp_tool", "confidence": 0.95, "emotion": "neutral"}), + (False, {"emotion": "neutral"}) + ] + + feature_processor = AsyncMock() + feature_processor.side_effect = [ + PipelineResult(text="Tool 1 result", is_fallback=False, meta={}), + PipelineResult(text="Tool 2 result", is_fallback=False, meta={}) + ] + + ai_generator = AsyncMock() + ai_generator.return_value = "Combined answer." + + pipeline = ChatPipeline( + intent_detector=intent_detector, + feature_processor=feature_processor, + ai_generator=ai_generator + ) + + res = await pipeline.process("Complex query") + + assert res.text == "Combined answer." + assert intent_detector.call_count == 3 + assert feature_processor.call_count == 2 + assert ai_generator.call_count == 1 + + call_kwargs = ai_generator.call_args.kwargs + context = call_kwargs.get("tool_context", "") + assert "Tool 1 result" in context + assert "Tool 2 result" in context diff --git a/tests/test_pipeline_tool_confidence.py b/tests/test_pipeline_tool_confidence.py new file mode 100644 index 0000000000000000000000000000000000000000..4afcc1a3ce91df1fd85de28297e220c6644f8d7c --- /dev/null +++ b/tests/test_pipeline_tool_confidence.py @@ -0,0 +1,81 @@ +import pytest + +from core.pipeline import ChatPipeline + + +async def noop_ai_generator(*args, **kwargs): + return "chat" + + +async def forbidden_feature_processor(*args, **kwargs): + raise AssertionError("feature processor must not be called when confidence is below threshold") + + +async def low_confidence_intent(_message): + return True, { + "type": "mcp_tool", + "tool_name": "weather_query", + "arguments": {"city": "Taipei"}, + "emotion": "neutral", + "confidence": 0.89, + } + + +async def high_confidence_intent(_message): + return True, { + "type": "mcp_tool", + "tool_name": "weather_query", + "arguments": {"city": "Taipei"}, + "emotion": "neutral", + "confidence": 0.90, + } + + +async def feature_processor(intent_data, user_id, original_message, chat_id): + return { + "message": "ok", + "tool_name": intent_data["tool_name"], + "tool_data": {"city": "Taipei"}, + } + + +def build_pipeline(intent_detector, processor): + return ChatPipeline( + intent_detector=intent_detector, + feature_processor=processor, + ai_generator=noop_ai_generator, + ) + + +@pytest.mark.asyncio +async def test_low_confidence_tool_call_is_blocked(): + pipeline = build_pipeline(low_confidence_intent, forbidden_feature_processor) + + result = await pipeline.process("天氣", "user1") + + assert result.reason == "tool-low-confidence" + assert result.meta["tool_blocked"] is True + assert result.meta["tool_confidence"] == 0.89 + assert "沒有可用工具" in result.text + assert "地點" in result.text + + +@pytest.mark.asyncio +async def test_low_confidence_tool_message_matches_user_language(): + pipeline = build_pipeline(low_confidence_intent, forbidden_feature_processor) + + result = await pipeline.process("weather", "user1") + + assert result.reason == "tool-low-confidence" + assert "No tool is available" in result.text + assert "location" in result.text + + +@pytest.mark.asyncio +async def test_threshold_confidence_allows_tool_call(): + pipeline = build_pipeline(high_confidence_intent, feature_processor) + + result = await pipeline.process("台北天氣", "user1") + + assert result.text == "ok" + assert result.meta["tool_name"] == "weather_query" diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 51372a559f657ce4f6e67b8d62d59281c74ffdde..b2d1c9c1bbfe785f6092a4223f386f792e1f5e59 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -5,6 +5,7 @@ import pytest from core.prompts.intent_detection import get_intent_prompt, TOOL_RULES from core.prompts.care_mode import get_care_prompt, CARE_MODE_PROMPT +from services.ai_service import _build_base_system_prompt, _compose_messages_with_context class TestIntentPrompt: @@ -16,6 +17,9 @@ class TestIntentPrompt: assert "意圖解析" in prompt assert "工具列表" in prompt assert "is_tool_call" in prompt + assert "反幻覺" in prompt + assert "環境優先" in prompt + assert "不得憑印象補答案" in prompt def test_get_intent_prompt_with_rules(self): """測試帶規則的 Prompt""" @@ -76,3 +80,39 @@ class TestCarePrompt: prompt = get_care_prompt(emotion="angry", user_name="小華") assert "angry" in prompt assert "小華" in prompt + + +class TestVoiceOutputPrompt: + def test_base_system_prompt_prefers_spoken_concise_answers(self): + prompt = _build_base_system_prompt( + use_care_mode=False, + care_emotion=None, + user_name="小明", + language="zh-TW", + ) + + assert "語音輸出風格" in prompt + assert "自然口語" in prompt + assert "不要輸出「資料來源」" in prompt + assert "不要輸出「資料來源」「來源如下」「參考連結」「URL」" in prompt + + def test_tool_context_is_grounding_not_mandatory_source_dump(self): + messages = _compose_messages_with_context( + base_prompt="base", + history_entries=[], + memory_context="", + env_context="", + time_context="", + emotion_context="", + current_request="今天台積電多少", + user_id="u1", + chat_id="c1", + use_care_mode=False, + care_emotion=None, + tool_context="Yahoo: 417.72 USD", + ) + + system_prompt = messages[0]["content"] + assert "這些資料主要用於查證與內部 grounding" in system_prompt + assert "不要在最終答案中列出來源、連結、URL" in system_prompt + assert "預設輸出是給人直接聽的口語答案" in system_prompt diff --git a/tests/test_responses_runtime.py b/tests/test_responses_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..5854eb44830df55bcb30fb7f0818d2f9a3c353b7 --- /dev/null +++ b/tests/test_responses_runtime.py @@ -0,0 +1,149 @@ +from core.environment.context_builder import EnvironmentInjection +from core.responses_runtime import ResponsesAgentRuntime, ResponsesRuntimeRequest + + +def test_responses_runtime_builds_payload_with_environment_block(): + runtime = ResponsesAgentRuntime() + + payload = runtime.build_request_payload( + ResponsesRuntimeRequest( + user_input="今天台北天氣如何?", + model="gpt-5.4", + instructions="Use hosted tools first.", + environment=EnvironmentInjection( + summary_text="timezone: Asia/Taipei\ncity: Taipei", + raw_context={"city": "Taipei"}, + metadata={"freshness": "latest_available"}, + ), + tools=[{"type": "web_search"}], + previous_response_id="resp_123", + ) + ) + + assert payload["model"] == "gpt-5.4" + assert payload["instructions"] == "Use hosted tools first." + assert payload["previous_response_id"] == "resp_123" + assert payload["tools"] == [{"type": "web_search"}] + assert payload["input"][0]["role"] == "system" + assert "Latest environment context" in payload["input"][0]["content"][0]["text"] + assert payload["input"][1]["role"] == "user" + + +def test_responses_runtime_converts_chat_messages_to_responses_payload(): + runtime = ResponsesAgentRuntime() + + payload = runtime.build_payload_from_messages( + messages=[ + {"role": "system", "content": "System policy"}, + {"role": "user", "content": "Hello"}, + ], + model="gpt-5.4", + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather.", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + }, + } + ], + reasoning_effort="high", + max_output_tokens=100, + ) + + assert payload["instructions"] == "System policy" + assert payload["input"][0]["role"] == "user" + assert payload["input"][0]["content"][0]["type"] == "input_text" + assert payload["reasoning"] == {"effort": "high"} + assert payload["max_output_tokens"] == 100 + assert payload["tools"] == [ + { + "type": "function", + "name": "get_weather", + "description": "Get weather.", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + } + ] + + +def test_responses_runtime_converts_image_url_content_to_responses_input(): + runtime = ResponsesAgentRuntime() + + payload = runtime.build_payload_from_messages( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Analyze image"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ], + } + ], + model="gpt-5.4", + ) + + assert payload["input"][0]["content"] == [ + {"type": "input_text", "text": "Analyze image"}, + {"type": "input_image", "image_url": "data:image/png;base64,abc"}, + ] + + +def test_responses_runtime_extracts_function_calls_from_output_items(): + class Item: + type = "function_call" + call_id = "call_1" + name = "get_weather" + arguments = '{"city":"Taipei"}' + + class Response: + output = [Item()] + + calls = ResponsesAgentRuntime.extract_function_calls(Response()) + + assert calls == [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city":"Taipei"}', + }, + } + ] + + +def test_responses_runtime_can_strip_hosted_tools_for_retry(): + payload = { + "model": "gpt-5.4", + "input": [], + "tools": [ + {"type": "web_search"}, + {"type": "mcp", "server_label": "remote", "server_url": "https://example.com/mcp"}, + {"type": "function", "name": "weather_query", "parameters": {"type": "object"}}, + ], + "tool_choice": "auto", + } + + stripped = ResponsesAgentRuntime.without_hosted_tools(payload) + + assert stripped["tools"] == [ + {"type": "function", "name": "weather_query", "parameters": {"type": "object"}}, + ] + assert stripped["tool_choice"] == "auto" + + +def test_responses_runtime_removes_tool_choice_when_no_functions_left(): + payload = { + "model": "gpt-5.4", + "input": [], + "tools": [{"type": "web_search"}], + "tool_choice": "auto", + } + + stripped = ResponsesAgentRuntime.without_hosted_tools(payload) + + assert stripped["tools"] == [] + assert "tool_choice" not in stripped diff --git a/tests/test_tdx_location_resolution.py b/tests/test_tdx_location_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..ee966c4abd59d224b5dff2a59fc926b3c8bfa134 --- /dev/null +++ b/tests/test_tdx_location_resolution.py @@ -0,0 +1,68 @@ +import pytest + +from features.mcp.tools.transportation import tdx_location + + +def test_resolve_city_code_normalizes_city_suffix(): + assert tdx_location.resolve_city_code("桃園市") == "Taoyuan" + assert tdx_location.resolve_city_code("臺中市") == "Taichung" + + +def test_resolve_metro_operator_from_city(): + assert tdx_location.resolve_metro_operator("高雄市") == "KRTC" + assert tdx_location.resolve_metro_operator("桃園") == "TYMC" + + +def test_resolve_city_candidates_include_neighbors(): + candidates = tdx_location.resolve_city_candidates( + city_like="新北市", + geo_city="新北市", + geo_admin="新北市", + allowed_city_codes={"Taipei", "NewTaipei", "Taoyuan", "Keelung"}, + ) + assert candidates[0] == "NewTaipei" + assert "Taipei" in candidates + + +def test_resolve_metro_operator_candidates_cover_taipei_living_circle(): + candidates = tdx_location.resolve_metro_operator_candidates( + city_like="新北市", + geo_city="新北市", + geo_admin="新北市", + ) + assert candidates == ["TRTC", "NTMC"] + + +@pytest.mark.asyncio +async def test_resolve_location_context_uses_location_query_when_coordinates_missing(monkeypatch): + async def fake_resolve_coordinates(*, lat, lon, location_query): + assert location_query == "桃園火車站" + return 24.989, 121.314, {"label": "桃園火車站"} + + async def fake_resolve_geo_context(*, lat, lon): + assert lat == 24.989 + assert lon == 121.314 + return { + "city": "桃園市", + "admin": "桃園市", + "label": "桃園火車站", + "detailed_address": "桃園火車站", + "city_code": "Taoyuan", + "metro_operator": "TYMC", + } + + monkeypatch.setattr(tdx_location, "resolve_coordinates", fake_resolve_coordinates) + monkeypatch.setattr(tdx_location, "resolve_geo_context", fake_resolve_geo_context) + + ctx = await tdx_location.resolve_location_context( + lat=None, + lon=None, + location_query="桃園火車站", + city_like=None, + allowed_city_codes={"Taoyuan", "Taipei"}, + ) + + assert ctx["lat"] == 24.989 + assert ctx["lon"] == 121.314 + assert ctx["city_code"] == "Taoyuan" + assert ctx["geo"]["label"] == "桃園火車站" diff --git a/tests/test_tdx_metro_refactor.py b/tests/test_tdx_metro_refactor.py new file mode 100644 index 0000000000000000000000000000000000000000..183bc0b6fc070b329733f4d79a58c8855f45ca84 --- /dev/null +++ b/tests/test_tdx_metro_refactor.py @@ -0,0 +1,42 @@ +import pytest + +from features.mcp.tools.transportation.tdx_metro import TDXMetroTool + + +@pytest.mark.asyncio +async def test_metro_nearest_station_queries_multiple_operators(monkeypatch): + calls = [] + + async def fake_call_api(endpoint, params, cache_ttl=3600): + calls.append(endpoint) + if endpoint.endswith("/TRTC"): + return [ + { + "StationUID": "TRTC-1", + "StationName": {"Zh_tw": "台北車站"}, + "StationPosition": {"PositionLat": 25.0478, "PositionLon": 121.5170}, + "StationAddress": "台北市中正區", + } + ] + if endpoint.endswith("/NTMC"): + return [ + { + "StationUID": "NTMC-1", + "StationName": {"Zh_tw": "頭前庄"}, + "StationPosition": {"PositionLat": 25.0390, "PositionLon": 121.4602}, + "StationAddress": "新北市新莊區", + } + ] + return [] + + monkeypatch.setattr( + "features.mcp.tools.transportation.tdx_metro.TDXBaseAPI.call_api", + fake_call_api, + ) + + result = await TDXMetroTool._query_nearest_station(25.04, 121.50, ["TRTC", "NTMC"]) + + assert result["success"] is True + assert "Rail/Metro/Station/TRTC" in calls + assert "Rail/Metro/Station/NTMC" in calls + assert len(result["stations"]) >= 1 diff --git a/tests/test_tdx_multicity_refactor.py b/tests/test_tdx_multicity_refactor.py new file mode 100644 index 0000000000000000000000000000000000000000..5b67874f63f9438c2396e3f315199f0b75114c09 --- /dev/null +++ b/tests/test_tdx_multicity_refactor.py @@ -0,0 +1,88 @@ +import pytest + +from features.mcp.tools.transportation.tdx_bus_arrival import TDXBusArrivalTool +from features.mcp.tools.transportation.tdx_parking import TDXParkingTool +from features.mcp.tools.transportation.tdx_youbike import TDXBikeTool + + +@pytest.mark.asyncio +async def test_youbike_nearby_queries_multiple_cities(monkeypatch): + calls = [] + + async def fake_call_api(endpoint, params, cache_ttl=1800): + calls.append(endpoint) + if "Bike/Station/City/Taipei" in endpoint: + return [{"StationUID": "T1", "StationName": {"Zh_tw": "台北站"}, "StationPosition": {"PositionLat": 25.04, "PositionLon": 121.52}}] + if "Bike/Station/City/NewTaipei" in endpoint: + return [{"StationUID": "N1", "StationName": {"Zh_tw": "新北站"}, "StationPosition": {"PositionLat": 25.03, "PositionLon": 121.49}}] + if "Bike/Availability/City/Taipei" in endpoint or "Bike/Availability/City/NewTaipei" in endpoint: + return [] + return [] + + monkeypatch.setattr("features.mcp.tools.transportation.tdx_youbike.TDXBaseAPI.call_api", fake_call_api) + + result = await TDXBikeTool._query_nearby_stations(25.04, 121.50, ["Taipei", "NewTaipei"], 500, 3) + + assert result["success"] is True + assert "Bike/Station/City/Taipei" in calls + assert "Bike/Station/City/NewTaipei" in calls + + +@pytest.mark.asyncio +async def test_bus_nearby_queries_multiple_cities(monkeypatch): + calls = [] + + async def fake_call_api(endpoint, params, cache_ttl=1800): + calls.append(endpoint) + return [] + + monkeypatch.setattr("features.mcp.tools.transportation.tdx_bus_arrival.TDXBaseAPI.call_api", fake_call_api) + + result = await TDXBusArrivalTool._query_nearby_stops(25.04, 121.50, ["Taipei", "NewTaipei"], 3) + + assert result["success"] is True + assert "Bus/Stop/City/Taipei" in calls + assert "Bus/Stop/City/NewTaipei" in calls + + +@pytest.mark.asyncio +async def test_parking_nearby_uses_advanced_nearby_endpoint(monkeypatch): + calls = [] + + async def fake_call_api(endpoint, params, cache_ttl=3600, api_version="v2", api_family="basic"): + calls.append((endpoint, api_version, api_family)) + return [] + + monkeypatch.setattr("features.mcp.tools.transportation.tdx_parking.TDXBaseAPI.call_api", fake_call_api) + + result = await TDXParkingTool._query_nearby_parkings(25.04, 121.50, None, 1000, 3) + + assert result["success"] is True + assert ("Parking/OffStreet/CarPark/NearBy", "v1", "advanced") in calls + + +@pytest.mark.asyncio +async def test_named_parking_uses_nearby_filter_instead_of_city_lookup(monkeypatch): + async def fake_nearby(lat, lon, parking_type, radius_m, limit): + return { + "success": True, + "content": "ok", + "parkings": [ + { + "parking_name": "台北車站停車場", + "available_spaces": 12, + "total_spaces": 100, + "fee_info": "每小時 60 元", + "charge_station": False, + "walking_time_min": 3, + "distance_m": 220, + } + ], + } + + monkeypatch.setattr(TDXParkingTool, "_query_nearby_parkings", fake_nearby) + + result = await TDXParkingTool._query_named_parking_nearby("台北車站", 25.04, 121.51, 1000, 5) + + assert result["success"] is True + assert result["parking"]["parking_name"] == "台北車站停車場" diff --git a/tests/test_tdx_official_geocode_strategy.py b/tests/test_tdx_official_geocode_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..23f21d3cf65d7284077559d74e3ca852a8ba5548 --- /dev/null +++ b/tests/test_tdx_official_geocode_strategy.py @@ -0,0 +1,9 @@ +from features.mcp.tools.location.geocode_tool import ReverseGeocodeTool + + +def test_reverse_geocode_output_schema_exposes_precision_fields(): + schema = ReverseGeocodeTool.get_output_schema() + props = schema["properties"] + assert "address_display" in props + assert "precision" in props + assert "poi_label" in props diff --git a/tests/test_tdx_youbike_refactor.py b/tests/test_tdx_youbike_refactor.py new file mode 100644 index 0000000000000000000000000000000000000000..d41c59ae7ee77da6ece34e153f75f6adbfa7a5f3 --- /dev/null +++ b/tests/test_tdx_youbike_refactor.py @@ -0,0 +1,44 @@ +import pytest + +from features.mcp.tools.transportation.tdx_youbike import TDXBikeTool + + +@pytest.mark.asyncio +async def test_youbike_nearby_supports_location_query(monkeypatch): + async def fake_location_context(**kwargs): + assert kwargs["location_query"] == "桃園火車站" + return { + "lat": 24.989, + "lon": 121.314, + "city_code": "Taoyuan", + "geo": {"city": "桃園市", "label": "桃園火車站"}, + "geocode_match": {"label": "桃園火車站"}, + } + + async def fake_nearby(lat, lon, cities, radius_m, limit): + assert lat == 24.989 + assert lon == 121.314 + assert cities[0] == "Taoyuan" + assert "Taipei" in cities + return { + "success": True, + "content": "ok", + "stations": [], + } + + monkeypatch.setattr( + "features.mcp.tools.transportation.tdx_youbike.resolve_location_context", + fake_location_context, + ) + monkeypatch.setattr(TDXBikeTool, "_query_nearby_stations", fake_nearby) + + result = await TDXBikeTool.execute( + { + "location_query": "桃園火車站", + "radius_m": 300, + "limit": 3, + } + ) + + assert result["success"] is True + assert result["stations"] == [] diff --git a/tests/test_tool_calling_policy.py b/tests/test_tool_calling_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce2115fcede2e11e12d3e3dc8ef42b9e6b7aee2 --- /dev/null +++ b/tests/test_tool_calling_policy.py @@ -0,0 +1,47 @@ +import inspect + +from core.intent_detector import IntentDetector +from core.prompts.tool_calling_policy import get_tool_calling_policy +from features.mcp.agent_bridge import MCPAgentBridge + + +def test_shared_tool_policy_contains_required_guards(): + policy = get_tool_calling_policy() + + assert "反幻覺" in policy + assert "環境優先" in policy + assert "參數紀律" in policy + assert "工具失敗" in policy + assert "不得憑印象補答案" in policy + assert "不要編造 city/lat/lon" in policy + assert "至少 90%" in policy + assert "語言一致" in policy + + +def test_agent_bridge_prompt_embeds_tool_policy(): + bridge = MCPAgentBridge() + + prompt = bridge._build_function_calling_prompt() + + assert get_tool_calling_policy() in prompt + assert "Weather/News/Exchange" in prompt + assert "附近" in prompt + assert "不要編造 city/lat/lon" in prompt + + +def test_tool_response_formatter_forbids_guessing(): + source = inspect.getsource(MCPAgentBridge._format_tool_response) + + assert "嚴禁推測" in source + assert "資料缺漏" in source + assert "不得把工具錯誤包裝成成功結果" in source + + +def test_intent_detector_prompt_embeds_tool_policy(): + detector = IntentDetector() + + prompt = detector._build_system_prompt() + + assert get_tool_calling_policy() in prompt + assert "無法確定的可選參數留空" in prompt + assert "不得憑印象補答案" in prompt diff --git a/tests/test_tool_coordinator_contract.py b/tests/test_tool_coordinator_contract.py new file mode 100644 index 0000000000000000000000000000000000000000..1111d053f527b097e76fe12efdfe1c28e464fb08 --- /dev/null +++ b/tests/test_tool_coordinator_contract.py @@ -0,0 +1,215 @@ +import pytest + +from features.mcp.coordinator import ToolCoordinator, ToolOutputValidationError +from features.mcp.tool_models import ToolMetadata + + +async def empty_env(user_id): + return {} + + +async def passthrough_formatter(tool_name, message, payload, original_message): + return message + + +@pytest.mark.asyncio +async def test_coordinator_validates_output_schema_on_main_path(): + async def bad_handler(arguments): + return { + "success": True, + "content": "ok", + } + + coordinator = ToolCoordinator( + env_provider=empty_env, + tool_lookup=lambda name: bad_handler, + formatter=passthrough_formatter, + output_schema_provider=lambda name: { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["success", "content", "value"], + }, + ) + + with pytest.raises(ToolOutputValidationError, match="輸出格式不符合契約"): + await coordinator.invoke( + "bad_tool", + {}, + user_id="user1", + original_message="test", + ) + + +@pytest.mark.asyncio +async def test_coordinator_accepts_valid_output_schema_on_main_path(): + async def good_handler(arguments): + return { + "success": True, + "content": "ok", + "value": "42", + } + + coordinator = ToolCoordinator( + env_provider=empty_env, + tool_lookup=lambda name: good_handler, + formatter=passthrough_formatter, + output_schema_provider=lambda name: { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["success", "content", "value"], + }, + ) + + result = await coordinator.invoke( + "good_tool", + {}, + user_id="user1", + original_message="test", + ) + + assert result.message == "ok" + assert result.data == {"value": "42"} + + +@pytest.mark.asyncio +async def test_coordinator_does_not_retry_output_schema_violation(): + calls = 0 + + async def bad_handler(arguments): + nonlocal calls + calls += 1 + return { + "success": True, + "content": "ok", + } + + coordinator = ToolCoordinator( + env_provider=empty_env, + tool_lookup=lambda name: bad_handler, + formatter=passthrough_formatter, + output_schema_provider=lambda name: { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["success", "content", "value"], + }, + ) + + with pytest.raises(ToolOutputValidationError): + await coordinator.invoke( + "bad_tool", + {}, + user_id="user1", + original_message="test", + ) + + assert calls == 1 + + +@pytest.mark.asyncio +async def test_coordinator_normalizes_city_from_env_fallback(): + captured_arguments = {} + + async def handler(arguments): + captured_arguments.update(arguments) + return { + "success": True, + "content": "ok", + "value": "done", + } + + async def env_provider(user_id): + return { + "detailed_address": "桃園市", + } + + coordinator = ToolCoordinator( + env_provider=env_provider, + tool_lookup=lambda name: handler, + formatter=passthrough_formatter, + output_schema_provider=lambda name: { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["success", "content", "value"], + }, + ) + coordinator.register( + ToolMetadata( + name="tdx_youbike", + requires_env={"city"}, + env_fallbacks={"city": ["detailed_address"]}, + ) + ) + + await coordinator.invoke( + "tdx_youbike", + {}, + user_id="user1", + original_message="最近的Ubike在哪裡", + ) + + assert captured_arguments["city"] == "桃園" + + +@pytest.mark.asyncio +async def test_coordinator_normalizes_city_from_label_fallback(): + captured_arguments = {} + + async def handler(arguments): + captured_arguments.update(arguments) + return { + "success": True, + "content": "ok", + "value": "done", + } + + async def env_provider(user_id): + return { + "label": "台中市", + } + + coordinator = ToolCoordinator( + env_provider=env_provider, + tool_lookup=lambda name: handler, + formatter=passthrough_formatter, + output_schema_provider=lambda name: { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "content": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["success", "content", "value"], + }, + ) + coordinator.register( + ToolMetadata( + name="tdx_metro", + requires_env={"city"}, + env_fallbacks={"city": ["label"]}, + ) + ) + + await coordinator.invoke( + "tdx_metro", + {}, + user_id="user1", + original_message="附近捷運在哪", + ) + + assert captured_arguments["city"] == "台中" diff --git a/tests/test_tool_registry_refactor.py b/tests/test_tool_registry_refactor.py index 4d97cb751b0e85ad2872ffaf856943406415f5d5..be1bb384a7be5c79f20fa313d6106d52ccd63f37 100644 --- a/tests/test_tool_registry_refactor.py +++ b/tests/test_tool_registry_refactor.py @@ -52,6 +52,34 @@ class TestToolSchema: params = openai_tool["function"]["parameters"] assert params["additionalProperties"] is False assert "query" in params["required"] + assert "limit" in params["required"] + + def test_strict_schema_applies_nested_object_rules(self): + """測試 strict schema 會遞迴處理 nested object""" + schema = ToolSchema( + metadata=ToolMetadata(name="nested_tool", description="Nested tool"), + input_schema={ + "type": "object", + "properties": { + "options": { + "type": "object", + "properties": { + "mode": {"type": "string"}, + "limit": {"type": "integer", "default": 5}, + }, + "required": ["mode"], + } + }, + "required": ["options"], + }, + ) + + params = schema.to_openai_tool(strict=True)["function"]["parameters"] + + assert params["additionalProperties"] is False + assert params["required"] == ["options"] + assert params["properties"]["options"]["additionalProperties"] is False + assert params["properties"]["options"]["required"] == ["mode", "limit"] def test_rich_description(self): """測試豐富描述生成""" @@ -97,6 +125,23 @@ class TestToolSchema: assert "route" in summary["params"] assert "stop" in summary["params"] + def test_schema_contract_rejects_missing_required_property(self): + """測試 schema contract 會拒絕不存在的 required 欄位""" + schema = ToolSchema( + metadata=ToolMetadata(name="broken", description="Broken"), + input_schema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["missing"], + }, + ) + + issues = schema.validate_schema_contract() + + assert any("missing" in issue for issue in issues) + class TestToolSchemaRegistry: """測試 ToolSchemaRegistry 類別""" @@ -115,6 +160,21 @@ class TestToolSchemaRegistry: retrieved = registry.get("test") assert retrieved is not None assert retrieved.metadata.name == "test" + + def test_register_rejects_invalid_contract(self): + """測試註冊中心拒絕壞掉的 schema contract""" + registry = ToolSchemaRegistry() + schema = ToolSchema( + metadata=ToolMetadata(name="bad", description="Bad"), + input_schema={ + "type": "object", + "properties": {}, + "required": ["query"], + }, + ) + + with pytest.raises(ValueError): + registry.register(schema) def test_disable_enable(self): """測試停用和啟用""" diff --git a/tests/test_voice_care_gate.py b/tests/test_voice_care_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..36f585e37f93c00fead73696d5cda9904949108f --- /dev/null +++ b/tests/test_voice_care_gate.py @@ -0,0 +1,160 @@ +import pytest + +from core.emotion_care_manager import EmotionCareManager +from core.pipeline import ChatPipeline + + +async def noop_feature_processor(*args, **kwargs): + return None + + +async def sad_text_intent(_message): + return False, {"emotion": "sad"} + + +async def neutral_text_intent(_message): + return False, {"emotion": "neutral"} + + +async def angry_text_intent(_message): + return False, {"emotion": "angry"} + + +async def noop_ai_generator(*args, **kwargs): + if kwargs.get("use_care_mode"): + return "care" + return "chat" + + +def build_pipeline(intent_detector): + return ChatPipeline( + intent_detector=intent_detector, + feature_processor=noop_feature_processor, + ai_generator=noop_ai_generator, + ) + + +@pytest.fixture(autouse=True) +def clear_care_state(): + EmotionCareManager._user_states.clear() + yield + EmotionCareManager._user_states.clear() + + +@pytest.mark.asyncio +async def test_voice_audio_extreme_does_not_enter_care_when_text_is_neutral(): + pipeline = build_pipeline(neutral_text_intent) + + result = await pipeline.process( + "幫我查明天天氣", + user_id="voice-user", + chat_id="chat-1", + audio_emotion={ + "success": True, + "source": "realtime_voice", + "emotion": "sad", + "confidence": 0.93, + }, + ) + + assert result.text == "chat" + assert result.meta["care_mode"] is False + assert result.meta["emotion"] == "neutral" + assert EmotionCareManager.is_in_care_mode("voice-user", "chat-1") is False + + +@pytest.mark.asyncio +async def test_voice_audio_extreme_enters_care_when_text_extreme_family_matches(): + pipeline = build_pipeline(sad_text_intent) + + result = await pipeline.process( + "我真的撐不下去了", + user_id="voice-user", + chat_id="chat-1", + audio_emotion={ + "success": True, + "source": "realtime_voice", + "emotion": "fear", + "confidence": 0.93, + }, + ) + + assert result.meta["care_mode"] is True + assert result.meta["emotion"] in {"sad", "fear"} + assert EmotionCareManager.is_in_care_mode("voice-user", "chat-1") is True + + +@pytest.mark.asyncio +async def test_voice_audio_low_confidence_does_not_override_text_emotion(): + pipeline = build_pipeline(sad_text_intent) + + result = await pipeline.process( + "我真的很難過", + user_id="voice-user", + chat_id="chat-1", + audio_emotion={ + "success": True, + "source": "realtime_voice", + "emotion": "angry", + "confidence": 0.42, + }, + ) + + assert result.meta["care_mode"] is True + assert result.meta["emotion"] == "sad" + + +@pytest.mark.asyncio +async def test_text_only_extreme_emotion_still_enters_care(): + pipeline = build_pipeline(angry_text_intent) + + result = await pipeline.process( + "我現在真的很生氣", + user_id="text-user", + chat_id="chat-1", + ) + + assert result.meta["care_mode"] is True + assert result.meta["emotion"] == "angry" + + +@pytest.mark.asyncio +async def test_voice_low_speech_confidence_blocks_care_even_when_emotions_match(): + pipeline = build_pipeline(sad_text_intent) + + result = await pipeline.process( + "我真的撐不下去了", + user_id="voice-user", + chat_id="chat-1", + audio_emotion={ + "success": True, + "source": "realtime_voice", + "emotion": "sad", + "confidence": 0.94, + "speech_confidence": 0.41, + }, + ) + + assert result.text == "chat" + assert result.meta["care_mode"] is False + assert EmotionCareManager.is_in_care_mode("voice-user", "chat-1") is False + + +@pytest.mark.asyncio +async def test_voice_context_without_usable_audio_emotion_blocks_text_only_care(): + pipeline = build_pipeline(sad_text_intent) + + result = await pipeline.process( + "我真的撐不下去了", + user_id="voice-user", + chat_id="chat-1", + audio_emotion={ + "success": False, + "source": "realtime_voice", + "error": "LOW_AUDIO_CONFIDENCE", + }, + ) + + assert result.text == "chat" + assert result.meta["care_mode"] is False + assert EmotionCareManager.is_in_care_mode("voice-user", "chat-1") is False diff --git a/tests/test_voice_login_quality.py b/tests/test_voice_login_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..6c850152ab5b3dae51bd72e37039d87e87f39133 --- /dev/null +++ b/tests/test_voice_login_quality.py @@ -0,0 +1,66 @@ +import tempfile +from pathlib import Path + +import numpy as np + +from services.voice_login import VoiceAuthService, VoiceLoginConfig + + +def _make_pcm_bytes(seconds: float, sample_rate: int, noise_amp: float, tone_amp: float) -> bytes: + samples = int(seconds * sample_rate) + t = np.arange(samples, dtype=np.float32) / float(sample_rate) + tone = tone_amp * np.sin(2.0 * np.pi * 220.0 * t) + noise = noise_amp * np.random.default_rng(7).normal(0.0, 1.0, samples).astype(np.float32) + signal = np.clip(tone + noise, -1.0, 1.0) + return (signal * 32767.0).astype(np.int16).tobytes() + + +def _build_service(tmp_path: Path) -> VoiceAuthService: + service = VoiceAuthService.__new__(VoiceAuthService) + service.base_dir = tmp_path + service.identity_dir = tmp_path + service.model_dir = tmp_path + service.temp_dir = tmp_path + service.config = VoiceLoginConfig( + window_seconds=3, + required_windows=1, + sample_rate=16000, + prob_threshold=0.50, + margin_threshold=0.05, + min_snr_db=12.0, + ) + service._buffers = {} + service._sr_overrides = {} + service._emo_predict = None + service._emo_id2class = None + service._predict_files = None + return service + + +def test_low_snr_is_warning_only_not_hard_fail(): + with tempfile.TemporaryDirectory() as tmpdir: + service = _build_service(Path(tmpdir)) + + low_snr_audio = _make_pcm_bytes( + seconds=3.1, + sample_rate=service.config.sample_rate, + noise_amp=0.05, + tone_amp=0.01, + ) + service._buffers["u1"] = bytearray(low_snr_audio) + + service._predict_one_wav = lambda wav_path: { + "label": "speaker_a", + "score": 0.93, + "margin": 0.31, + } + service._infer_emotion_from_bytes = lambda pcm_bytes, sr: {"label": "neutral"} + service._preprocess_bytes = lambda pcm_bytes, sr: pcm_bytes + + result = service.stop_and_authenticate("u1") + + assert result["success"] is True + assert result["label"] == "speaker_a" + assert "quality_warnings" in result + assert result["quality_warnings"] + assert result["quality_warnings"][0]["type"] == "LOW_SNR" diff --git a/websocket/manager.py b/websocket/manager.py index 0b885236ab98e5d7b6533ef3aabaff1cfe1f7249..35f0a4abd78c8b67198147f6f02596758c241169 100644 --- a/websocket/manager.py +++ b/websocket/manager.py @@ -23,6 +23,7 @@ class ConnectionManager: self.client_info: Dict[str, dict] = {} self.user_sessions: Dict[str, Dict[str, Any]] = {} self.last_env: Dict[str, Dict[str, Any]] = {} + self.active_tasks: Dict[str, Any] = {} # 🎯 追蹤每個用戶正在運行的非同步任務 async def connect( self, @@ -103,6 +104,22 @@ class ConnectionManager: """取得客戶端資訊""" return self.client_info.get(user_id, {}) + def register_task(self, user_id: str, task: Any) -> None: + """註冊用戶的非同步任務,以便後續取消""" + self.active_tasks[user_id] = task + + async def cancel_user_tasks(self, user_id: str) -> None: + """取消用戶所有正在運行的任務(用於中斷 Barge-in)""" + task = self.active_tasks.get(user_id) + if task and not task.done(): + task.cancel() + try: + await task + except Exception: + pass # 忽略取消時的異常 + logger.info(f"🛑 已成功中斷用戶 {user_id} 的正在執行任務 (Barge-in)") + self.active_tasks.pop(user_id, None) + def get_user_session(self, user_id: str) -> Optional[Dict[str, Any]]: """取得用戶會話資訊""" return self.user_sessions.get(user_id)