liumaolin commited on
Commit
b115e26
·
1 Parent(s): 29766c6

Enhance WebSocket handling for connection management and reliability

Browse files

- Add message broadcasting support to `WebSocketConnectionManager`.
- Improve connection cleanup for disconnected clients.
- Refine WebSocket endpoint to handle client disconnects and queue tasks more gracefully.

src/voice_dialogue/api/routes/websocket_routes.py CHANGED
@@ -1,10 +1,8 @@
1
  import asyncio
2
  from contextlib import asynccontextmanager
3
- from queue import Empty
4
  from typing import Set, Dict
5
 
6
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
7
- from fastapi.websockets import WebSocketState
8
 
9
  from voice_dialogue.core.constants import websocket_message_queue, session_manager
10
  from voice_dialogue.utils.logger import logger
@@ -82,6 +80,26 @@ class WebSocketConnectionManager:
82
  for connection in disconnected_connections:
83
  await self.disconnect(connection)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  @property
86
  def connection_count(self) -> int:
87
  """获取当前连接数"""
@@ -118,18 +136,43 @@ async def websocket_connection_context(websocket: WebSocket):
118
  @ws.websocket("/api/v1/ws")
119
  async def websocket_endpoint(websocket: WebSocket):
120
  """WebSocket连接端点"""
121
- async with websocket_connection_context(websocket):
 
 
 
122
  try:
123
  # 保持连接活跃
124
- while websocket.client_state == WebSocketState.CONNECTED:
125
- try:
126
- message = await websocket_message_queue.get()
127
- except Empty:
128
- continue
129
-
130
- await connection_manager.send_to_session(message.session_id, message.model_dump())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  except WebSocketDisconnect:
133
- logger.info("WebSocket连接断开")
134
  except Exception as e:
135
  logger.error(f"WebSocket连接异常: {e}")
 
 
 
 
 
 
1
  import asyncio
2
  from contextlib import asynccontextmanager
 
3
  from typing import Set, Dict
4
 
5
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
 
6
 
7
  from voice_dialogue.core.constants import websocket_message_queue, session_manager
8
  from voice_dialogue.utils.logger import logger
 
80
  for connection in disconnected_connections:
81
  await self.disconnect(connection)
82
 
83
+ async def broadcast(self, message: dict):
84
+ """向所有活跃连接广播消息"""
85
+ async with self._lock:
86
+ if not self._connections:
87
+ return
88
+
89
+ connections = list(self._connections)
90
+ disconnected_connections = []
91
+
92
+ for connection in connections:
93
+ try:
94
+ await connection.send_json(message)
95
+ except Exception as e:
96
+ logger.warning(f"广播消息失败,标记连接为断开: {e}")
97
+ disconnected_connections.append(connection)
98
+
99
+ # 清理断开的连接
100
+ for connection in disconnected_connections:
101
+ await self.disconnect(connection)
102
+
103
  @property
104
  def connection_count(self) -> int:
105
  """获取当前连接数"""
 
136
  @ws.websocket("/api/v1/ws")
137
  async def websocket_endpoint(websocket: WebSocket):
138
  """WebSocket连接端点"""
139
+ async with websocket_connection_context(websocket) as websocket_connection:
140
+ disconnect_task = asyncio.create_task(websocket_connection.receive_text())
141
+ get_message_task = None
142
+
143
  try:
144
  # 保持连接活跃
145
+ while True:
146
+ get_message_task = asyncio.create_task(websocket_message_queue.get())
147
+
148
+ done, pending = await asyncio.wait(
149
+ {disconnect_task, get_message_task},
150
+ return_when=asyncio.FIRST_COMPLETED
151
+ )
152
+
153
+ # 如果是 disconnect_task 完成,说明客户端已断开连接。
154
+ if disconnect_task in done:
155
+ # 这将重新引发 WebSocketDisconnect 异常并跳出循环。
156
+ if get_message_task in done:
157
+ # 如果有,将消息放回队列,然后让连接断开
158
+ message = get_message_task.result()
159
+ websocket_message_queue.put_nowait(message)
160
+ logger.info("连接已关闭,将消息重新入队。")
161
+
162
+ disconnect_task.result()
163
+
164
+ # 如果是 queue_task 完成,说明有可用的消息。
165
+ if get_message_task in done:
166
+ message = get_message_task.result()
167
+ await connection_manager.send_to_session(message.session_id, message.model_dump())
168
+ # queue_task 现已完成,我们将在循环的下一次迭代中创建一个新的。
169
 
170
  except WebSocketDisconnect:
171
+ logger.info("WebSocket连接已断开")
172
  except Exception as e:
173
  logger.error(f"WebSocket连接异常: {e}")
174
+ finally:
175
+ # 确保如果循环因其他原因退出,disconnect_task 会被取消。
176
+ disconnect_task.cancel()
177
+ if get_message_task and not get_message_task.done():
178
+ get_message_task.cancel()