File size: 5,610 Bytes
69fb140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
WebSocket 連線管理器
統一管理 WebSocket 連線、會話狀態和訊息發送
"""

import logging
import time
from datetime import datetime
from typing import Dict, Any, Optional

from fastapi import WebSocket

from core.logging import get_logger

logger = get_logger("websocket.manager")


class ConnectionManager:
    """WebSocket 連線管理器"""

    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}
        self.client_info: Dict[str, dict] = {}
        self.user_sessions: Dict[str, Dict[str, Any]] = {}
        self.last_env: Dict[str, Dict[str, Any]] = {}

    async def connect(
        self,
        websocket: WebSocket,
        user_id: str,
        user_info: Dict[str, Any]
    ) -> None:
        """建立 WebSocket 連線"""
        await websocket.accept()
        self.active_connections[user_id] = websocket
        self.user_sessions[user_id] = user_info
        logger.info(f"新的 WebSocket 連接: {user_id}")

    def disconnect(self, user_id: str) -> None:
        """關閉 WebSocket 連線"""
        if user_id in self.active_connections:
            del self.active_connections[user_id]
        if user_id in self.user_sessions:
            del self.user_sessions[user_id]
        if user_id in self.client_info:
            del self.client_info[user_id]
        logger.info(f"WebSocket 連接關閉: {user_id}")

    async def send_message(
        self,
        message: str,
        user_id: str,
        message_type: str = "bot_message"
    ) -> bool:
        """發送訊息給指定用戶"""
        if user_id not in self.active_connections:
            logger.warning(f"用戶 {user_id} 不在線,無法發送訊息")
            return False

        try:
            payload = {
                "type": message_type,
                "message": message,
                "timestamp": time.time()
            }
            await self.active_connections[user_id].send_json(payload)

            # 日誌記錄(截斷過長訊息)
            preview = (str(message) or "").strip().replace("\n", " ")
            if len(preview) > 120:
                preview = preview[:120] + "..."
            logger.debug(
                f"WebSocket 已發送 → client={user_id} "
                f"type={message_type} preview=\"{preview}\""
            )
            return True

        except Exception as e:
            logger.error(f"發送訊息到客戶端 {user_id} 時出錯: {e}")
            return False

    async def send_json(
        self,
        data: Dict[str, Any],
        user_id: str
    ) -> bool:
        """發送 JSON 資料給指定用戶"""
        if user_id not in self.active_connections:
            return False

        try:
            await self.active_connections[user_id].send_json(data)
            return True
        except Exception as e:
            logger.error(f"發送 JSON 到客戶端 {user_id} 時出錯: {e}")
            return False

    def set_client_info(self, user_id: str, info: dict) -> None:
        """設定客戶端資訊"""
        self.client_info[user_id] = info

    def get_client_info(self, user_id: str) -> dict:
        """取得客戶端資訊"""
        return self.client_info.get(user_id, {})

    def get_user_session(self, user_id: str) -> Optional[Dict[str, Any]]:
        """取得用戶會話資訊"""
        return self.user_sessions.get(user_id)

    def update_last_activity(self, user_id: str) -> None:
        """更新用戶最後活動時間"""
        if user_id in self.user_sessions:
            self.user_sessions[user_id]["last_activity"] = datetime.now()

    def is_connected(self, user_id: str) -> bool:
        """檢查用戶是否在線"""
        return user_id in self.active_connections

    def get_active_user_count(self) -> int:
        """取得在線用戶數量"""
        return len(self.active_connections)

    async def cleanup_expired_sessions(self, timeout_seconds: int = None) -> int:
        """
        清理過期的用戶會話

        Args:
            timeout_seconds: 超時時間(秒),預設使用配置值

        Returns:
            清理的會話數量
        """
        if timeout_seconds is None:
            from core.config import settings
            timeout_seconds = settings.WEBSOCKET_SESSION_TIMEOUT
        current_time = datetime.now()
        expired_users = []

        for user_id, session_info in self.user_sessions.items():
            last_activity = session_info.get("last_activity", current_time)
            if (current_time - last_activity).total_seconds() > timeout_seconds:
                expired_users.append(user_id)

        for user_id in expired_users:
            logger.info(f"清理過期會話: {user_id}")
            self.disconnect(user_id)

        return len(expired_users)

    async def broadcast(
        self,
        message: str,
        message_type: str = "system",
        exclude_users: Optional[list] = None
    ) -> int:
        """
        廣播訊息給所有在線用戶

        Args:
            message: 訊息內容
            message_type: 訊息類型
            exclude_users: 排除的用戶列表

        Returns:
            成功發送的數量
        """
        exclude_users = exclude_users or []
        success_count = 0

        for user_id in list(self.active_connections.keys()):
            if user_id not in exclude_users:
                if await self.send_message(message, user_id, message_type):
                    success_count += 1

        return success_count


# 全域單例
manager = ConnectionManager()