File size: 4,543 Bytes
69fb140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
WebSocket 心跳機制
保持連線穩定,自動檢測斷線
"""

import asyncio
import time
from typing import Dict, Optional, Callable, Awaitable

from core.logging import get_logger
from core.config import settings

logger = get_logger("websocket.heartbeat")

# 心跳間隔(秒)
HEARTBEAT_INTERVAL = 30

# 心跳超時(秒)
HEARTBEAT_TIMEOUT = 10

# 最大重連次數
MAX_RECONNECT_ATTEMPTS = 5


class HeartbeatManager:
    """
    WebSocket 心跳管理器

    功能:
    1. 定期發送心跳包
    2. 檢測連線狀態
    3. 觸發斷線回調
    """

    def __init__(self):
        # 用戶最後心跳時間
        self._last_heartbeat: Dict[str, float] = {}
        # 心跳任務
        self._heartbeat_tasks: Dict[str, asyncio.Task] = {}
        # 斷線回調
        self._disconnect_callback: Optional[Callable[[str], Awaitable[None]]] = None

    def set_disconnect_callback(
        self, callback: Callable[[str], Awaitable[None]]
    ) -> None:
        """設定斷線回調函數"""
        self._disconnect_callback = callback

    def record_heartbeat(self, user_id: str) -> None:
        """記錄用戶心跳"""
        self._last_heartbeat[user_id] = time.time()
        logger.debug(f"💓 收到心跳: {user_id}")

    def get_last_heartbeat(self, user_id: str) -> Optional[float]:
        """取得用戶最後心跳時間"""
        return self._last_heartbeat.get(user_id)

    def is_alive(self, user_id: str, timeout: float = HEARTBEAT_TIMEOUT * 3) -> bool:
        """檢查用戶連線是否存活"""
        last = self._last_heartbeat.get(user_id)
        if last is None:
            return False
        return (time.time() - last) < timeout

    async def start_heartbeat(
        self,
        user_id: str,
        send_func: Callable[[dict], Awaitable[bool]],
    ) -> None:
        """
        啟動心跳任務

        Args:
            user_id: 用戶 ID
            send_func: 發送訊息的函數
        """
        # 取消舊任務
        if user_id in self._heartbeat_tasks:
            self._heartbeat_tasks[user_id].cancel()

        # 記錄初始心跳
        self.record_heartbeat(user_id)

        # 建立新任務
        task = asyncio.create_task(
            self._heartbeat_loop(user_id, send_func)
        )
        self._heartbeat_tasks[user_id] = task
        logger.info(f"💓 啟動心跳任務: {user_id}")

    async def stop_heartbeat(self, user_id: str) -> None:
        """停止心跳任務"""
        if user_id in self._heartbeat_tasks:
            self._heartbeat_tasks[user_id].cancel()
            del self._heartbeat_tasks[user_id]
            logger.info(f"💔 停止心跳任務: {user_id}")

        if user_id in self._last_heartbeat:
            del self._last_heartbeat[user_id]

    async def _heartbeat_loop(
        self,
        user_id: str,
        send_func: Callable[[dict], Awaitable[bool]],
    ) -> None:
        """心跳循環"""
        missed_count = 0

        while True:
            try:
                await asyncio.sleep(HEARTBEAT_INTERVAL)

                # 發送心跳
                success = await send_func({
                    "type": "heartbeat",
                    "timestamp": time.time(),
                })

                if success:
                    missed_count = 0
                else:
                    missed_count += 1
                    logger.warning(f"⚠️ 心跳發送失敗: {user_id} (連續 {missed_count} 次)")

                # 檢查是否超時
                if not self.is_alive(user_id):
                    missed_count += 1
                    logger.warning(f"⚠️ 心跳超時: {user_id}")

                # 連續失敗超過閾值,觸發斷線
                if missed_count >= 3:
                    logger.error(f"❌ 連線失效: {user_id}")
                    if self._disconnect_callback:
                        await self._disconnect_callback(user_id)
                    break

            except asyncio.CancelledError:
                logger.debug(f"心跳任務被取消: {user_id}")
                break
            except Exception as e:
                logger.error(f"心跳任務錯誤: {user_id} - {e}")
                await asyncio.sleep(5)

    def get_stats(self) -> Dict[str, int]:
        """取得心跳統計"""
        return {
            "active_heartbeats": len(self._heartbeat_tasks),
            "tracked_users": len(self._last_heartbeat),
        }


# 全域單例
heartbeat_manager = HeartbeatManager()