liumaolin commited on
Commit
300d567
·
1 Parent(s): 351fe7b

Refactor WebSocket handling with connection manager

Browse files

- Introduce `WebSocketConnectionManager` for centralized connection tracking and session management.
- Add support for session-specific connection handling and message broadcasting.
- Implement `websocket_connection_context` to streamline connection lifecycle management.
- Update WebSocket endpoint to utilize the connection manager for improved reliability and efficiency.

src/voice_dialogue/api/routes/websocket_routes.py CHANGED
@@ -1,6 +1,10 @@
 
 
1
  from queue import Empty
 
2
 
3
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
 
4
 
5
  from voice_dialogue.core.constants import websocket_message_queue, session_manager
6
  from voice_dialogue.utils.logger import logger
@@ -8,25 +12,124 @@ from voice_dialogue.utils.logger import logger
8
  ws = APIRouter()
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @ws.websocket("/api/v1/ws")
12
  async def websocket_endpoint(websocket: WebSocket):
13
  """WebSocket连接端点"""
14
- try:
15
- # 建立连接
16
- await websocket.accept()
17
- # 保持连接活跃
18
- while True:
19
- try:
20
- message = await websocket_message_queue.get()
21
- except Empty:
22
- continue
23
-
24
- if message.session_id != session_manager.current_id:
25
- continue
26
-
27
- await websocket.send_json(message.model_dump())
28
-
29
- except WebSocketDisconnect:
30
- pass
31
- except Exception as e:
32
- logger.error(f"WebSocket连接异常: {e}")
 
1
+ import asyncio
2
+ from contextlib import asynccontextmanager
3
  from queue import Empty
4
+ from typing import Set, Dict
5
 
6
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
7
+ from fastapi.websockets import WebSocketState
8
 
9
  from voice_dialogue.core.constants import websocket_message_queue, session_manager
10
  from voice_dialogue.utils.logger import logger
 
12
  ws = APIRouter()
13
 
14
 
15
+ class WebSocketConnectionManager:
16
+ """WebSocket 连接管理器 - 管理所有活跃连接"""
17
+
18
+ def __init__(self):
19
+ # 使用 WeakSet 避免内存泄漏
20
+ self._connections: Set[WebSocket] = set()
21
+ # 会话ID到连接的映射
22
+ self._session_connections: Dict[str, Set[WebSocket]] = {}
23
+ self._lock = asyncio.Lock()
24
+
25
+ async def connect(self, websocket: WebSocket, session_id: str = None):
26
+ """建立新连接"""
27
+ async with self._lock:
28
+ await websocket.accept()
29
+ self._connections.add(websocket)
30
+
31
+ # 如果指定了会话ID,建立映射关系
32
+ if session_id:
33
+ if session_id not in self._session_connections:
34
+ self._session_connections[session_id] = set()
35
+ self._session_connections[session_id].add(websocket)
36
+
37
+ logger.info(f"WebSocket连接已建立,当前活跃连接数: {len(self._connections)}")
38
+
39
+ async def disconnect(self, websocket: WebSocket):
40
+ """断开连接"""
41
+ async with self._lock:
42
+ self._connections.discard(websocket)
43
+
44
+ # 从会话映射中移除
45
+ for session_id, connections in list(self._session_connections.items()):
46
+ connections.discard(websocket)
47
+ if not connections: # 如果该会话没有活跃连接,清理映射
48
+ del self._session_connections[session_id]
49
+
50
+ logger.info(f"WebSocket连接已断开,当前活跃连接数: {len(self._connections)}")
51
+
52
+ async def close_session_connections(self, session_id: str):
53
+ """关闭指定会话的所有连接"""
54
+ async with self._lock:
55
+ if session_id in self._session_connections:
56
+ connections_to_close = list(self._session_connections[session_id])
57
+ for connection in connections_to_close:
58
+ try:
59
+ await connection.close()
60
+ logger.info(f"已关闭会话 {session_id} 的一个连接")
61
+ except Exception as e:
62
+ logger.warning(f"关闭连接时出错: {e}")
63
+
64
+ # 清理映射
65
+ del self._session_connections[session_id]
66
+
67
+ async def send_to_session(self, session_id: str, message: dict):
68
+ """向指定会话的所有连接发送消息"""
69
+ async with self._lock:
70
+ if session_id in self._session_connections:
71
+ connections = list(self._session_connections[session_id])
72
+ disconnected_connections = []
73
+
74
+ for connection in connections:
75
+ try:
76
+ await connection.send_json(message)
77
+ except Exception as e:
78
+ logger.warning(f"发送消息失败,标记连接为断开: {e}")
79
+ disconnected_connections.append(connection)
80
+
81
+ # 清理断开的连接
82
+ for connection in disconnected_connections:
83
+ await self.disconnect(connection)
84
+
85
+ @property
86
+ def connection_count(self) -> int:
87
+ """获取当前连接数"""
88
+ return len(self._connections)
89
+
90
+ def get_session_connection_count(self, session_id: str) -> int:
91
+ """获取指定会话的连接数"""
92
+ return len(self._session_connections.get(session_id, set()))
93
+
94
+
95
+ # 全局连接管理器实例
96
+ connection_manager = WebSocketConnectionManager()
97
+
98
+
99
+ @asynccontextmanager
100
+ async def websocket_connection_context(websocket: WebSocket):
101
+ """WebSocket连接上下文管理器"""
102
+ current_session_id = session_manager.current_id
103
+
104
+ # 关闭同一会话的旧连接
105
+ if connection_manager.get_session_connection_count(current_session_id) > 0:
106
+ logger.info(f"检测到会话 {current_session_id} 已有连接,关闭旧连接")
107
+ await connection_manager.close_session_connections(current_session_id)
108
+
109
+ try:
110
+ # 建立新连接
111
+ await connection_manager.connect(websocket, current_session_id)
112
+ yield websocket
113
+ finally:
114
+ # 确保连接被正确清理
115
+ await connection_manager.disconnect(websocket)
116
+
117
+
118
  @ws.websocket("/api/v1/ws")
119
  async def websocket_endpoint(websocket: WebSocket):
120
  """WebSocket连接端点"""
121
+ async with websocket_connection_context(websocket):
122
+ try:
123
+ # 保持连接活跃
124
+ while websocket.client_state == WebSocketState.CONNECTED:
125
+ try:
126
+ message = await websocket_message_queue.get()
127
+ except Empty:
128
+ continue
129
+
130
+ await connection_manager.send_to_session(message.session_id, message.model_dump())
131
+
132
+ except WebSocketDisconnect:
133
+ logger.info("WebSocket连接断开")
134
+ except Exception as e:
135
+ logger.error(f"WebSocket连接异常: {e}")