Spaces:
Running
Running
| """ | |
| WebSocket 連線管理器 | |
| 統一管理 WebSocket 連線、會話狀態和訊息發送 | |
| """ | |
| import logging | |
| import time | |
| from datetime import datetime | |
| from typing import Dict, Any, Optional | |
| from fastapi import WebSocket | |
| from core.logging import get_logger | |
| logger = get_logger("websocket.manager") | |
| class ConnectionManager: | |
| """WebSocket 連線管理器""" | |
| def __init__(self): | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| self.client_info: Dict[str, dict] = {} | |
| self.user_sessions: Dict[str, Dict[str, Any]] = {} | |
| self.last_env: Dict[str, Dict[str, Any]] = {} | |
| async def connect( | |
| self, | |
| websocket: WebSocket, | |
| user_id: str, | |
| user_info: Dict[str, Any] | |
| ) -> None: | |
| """建立 WebSocket 連線""" | |
| await websocket.accept() | |
| self.active_connections[user_id] = websocket | |
| self.user_sessions[user_id] = user_info | |
| logger.info(f"新的 WebSocket 連接: {user_id}") | |
| def disconnect(self, user_id: str) -> None: | |
| """關閉 WebSocket 連線""" | |
| if user_id in self.active_connections: | |
| del self.active_connections[user_id] | |
| if user_id in self.user_sessions: | |
| del self.user_sessions[user_id] | |
| if user_id in self.client_info: | |
| del self.client_info[user_id] | |
| logger.info(f"WebSocket 連接關閉: {user_id}") | |
| async def send_message( | |
| self, | |
| message: str, | |
| user_id: str, | |
| message_type: str = "bot_message" | |
| ) -> bool: | |
| """發送訊息給指定用戶""" | |
| if user_id not in self.active_connections: | |
| logger.warning(f"用戶 {user_id} 不在線,無法發送訊息") | |
| return False | |
| try: | |
| payload = { | |
| "type": message_type, | |
| "message": message, | |
| "timestamp": time.time() | |
| } | |
| await self.active_connections[user_id].send_json(payload) | |
| # 日誌記錄(截斷過長訊息) | |
| preview = (str(message) or "").strip().replace("\n", " ") | |
| if len(preview) > 120: | |
| preview = preview[:120] + "..." | |
| logger.debug( | |
| f"WebSocket 已發送 → client={user_id} " | |
| f"type={message_type} preview=\"{preview}\"" | |
| ) | |
| return True | |
| except Exception as e: | |
| logger.error(f"發送訊息到客戶端 {user_id} 時出錯: {e}") | |
| return False | |
| async def send_json( | |
| self, | |
| data: Dict[str, Any], | |
| user_id: str | |
| ) -> bool: | |
| """發送 JSON 資料給指定用戶""" | |
| if user_id not in self.active_connections: | |
| return False | |
| try: | |
| await self.active_connections[user_id].send_json(data) | |
| return True | |
| except Exception as e: | |
| logger.error(f"發送 JSON 到客戶端 {user_id} 時出錯: {e}") | |
| return False | |
| def set_client_info(self, user_id: str, info: dict) -> None: | |
| """設定客戶端資訊""" | |
| self.client_info[user_id] = info | |
| def get_client_info(self, user_id: str) -> dict: | |
| """取得客戶端資訊""" | |
| return self.client_info.get(user_id, {}) | |
| def get_user_session(self, user_id: str) -> Optional[Dict[str, Any]]: | |
| """取得用戶會話資訊""" | |
| return self.user_sessions.get(user_id) | |
| def update_last_activity(self, user_id: str) -> None: | |
| """更新用戶最後活動時間""" | |
| if user_id in self.user_sessions: | |
| self.user_sessions[user_id]["last_activity"] = datetime.now() | |
| def is_connected(self, user_id: str) -> bool: | |
| """檢查用戶是否在線""" | |
| return user_id in self.active_connections | |
| def get_active_user_count(self) -> int: | |
| """取得在線用戶數量""" | |
| return len(self.active_connections) | |
| async def cleanup_expired_sessions(self, timeout_seconds: int = None) -> int: | |
| """ | |
| 清理過期的用戶會話 | |
| Args: | |
| timeout_seconds: 超時時間(秒),預設使用配置值 | |
| Returns: | |
| 清理的會話數量 | |
| """ | |
| if timeout_seconds is None: | |
| from core.config import settings | |
| timeout_seconds = settings.WEBSOCKET_SESSION_TIMEOUT | |
| current_time = datetime.now() | |
| expired_users = [] | |
| for user_id, session_info in self.user_sessions.items(): | |
| last_activity = session_info.get("last_activity", current_time) | |
| if (current_time - last_activity).total_seconds() > timeout_seconds: | |
| expired_users.append(user_id) | |
| for user_id in expired_users: | |
| logger.info(f"清理過期會話: {user_id}") | |
| self.disconnect(user_id) | |
| return len(expired_users) | |
| async def broadcast( | |
| self, | |
| message: str, | |
| message_type: str = "system", | |
| exclude_users: Optional[list] = None | |
| ) -> int: | |
| """ | |
| 廣播訊息給所有在線用戶 | |
| Args: | |
| message: 訊息內容 | |
| message_type: 訊息類型 | |
| exclude_users: 排除的用戶列表 | |
| Returns: | |
| 成功發送的數量 | |
| """ | |
| exclude_users = exclude_users or [] | |
| success_count = 0 | |
| for user_id in list(self.active_connections.keys()): | |
| if user_id not in exclude_users: | |
| if await self.send_message(message, user_id, message_type): | |
| success_count += 1 | |
| return success_count | |
| # 全域單例 | |
| manager = ConnectionManager() | |