liumaolin
commited on
Commit
·
b115e26
1
Parent(s):
29766c6
Enhance WebSocket handling for connection management and reliability
Browse files- Add message broadcasting support to `WebSocketConnectionManager`.
- Improve connection cleanup for disconnected clients.
- Refine WebSocket endpoint to handle client disconnects and queue tasks more gracefully.
src/voice_dialogue/api/routes/websocket_routes.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
import asyncio
|
| 2 |
from contextlib import asynccontextmanager
|
| 3 |
-
from queue import Empty
|
| 4 |
from typing import Set, Dict
|
| 5 |
|
| 6 |
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
| 7 |
-
from fastapi.websockets import WebSocketState
|
| 8 |
|
| 9 |
from voice_dialogue.core.constants import websocket_message_queue, session_manager
|
| 10 |
from voice_dialogue.utils.logger import logger
|
|
@@ -82,6 +80,26 @@ class WebSocketConnectionManager:
|
|
| 82 |
for connection in disconnected_connections:
|
| 83 |
await self.disconnect(connection)
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
@property
|
| 86 |
def connection_count(self) -> int:
|
| 87 |
"""获取当前连接数"""
|
|
@@ -118,18 +136,43 @@ async def websocket_connection_context(websocket: WebSocket):
|
|
| 118 |
@ws.websocket("/api/v1/ws")
|
| 119 |
async def websocket_endpoint(websocket: WebSocket):
|
| 120 |
"""WebSocket连接端点"""
|
| 121 |
-
async with websocket_connection_context(websocket):
|
|
|
|
|
|
|
|
|
|
| 122 |
try:
|
| 123 |
# 保持连接活跃
|
| 124 |
-
while
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
except WebSocketDisconnect:
|
| 133 |
-
logger.info("WebSocket
|
| 134 |
except Exception as e:
|
| 135 |
logger.error(f"WebSocket连接异常: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
from contextlib import asynccontextmanager
|
|
|
|
| 3 |
from typing import Set, Dict
|
| 4 |
|
| 5 |
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
|
|
| 6 |
|
| 7 |
from voice_dialogue.core.constants import websocket_message_queue, session_manager
|
| 8 |
from voice_dialogue.utils.logger import logger
|
|
|
|
| 80 |
for connection in disconnected_connections:
|
| 81 |
await self.disconnect(connection)
|
| 82 |
|
| 83 |
+
async def broadcast(self, message: dict):
|
| 84 |
+
"""向所有活跃连接广播消息"""
|
| 85 |
+
async with self._lock:
|
| 86 |
+
if not self._connections:
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
connections = list(self._connections)
|
| 90 |
+
disconnected_connections = []
|
| 91 |
+
|
| 92 |
+
for connection in connections:
|
| 93 |
+
try:
|
| 94 |
+
await connection.send_json(message)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.warning(f"广播消息失败,标记连接为断开: {e}")
|
| 97 |
+
disconnected_connections.append(connection)
|
| 98 |
+
|
| 99 |
+
# 清理断开的连接
|
| 100 |
+
for connection in disconnected_connections:
|
| 101 |
+
await self.disconnect(connection)
|
| 102 |
+
|
| 103 |
@property
|
| 104 |
def connection_count(self) -> int:
|
| 105 |
"""获取当前连接数"""
|
|
|
|
| 136 |
@ws.websocket("/api/v1/ws")
|
| 137 |
async def websocket_endpoint(websocket: WebSocket):
|
| 138 |
"""WebSocket连接端点"""
|
| 139 |
+
async with websocket_connection_context(websocket) as websocket_connection:
|
| 140 |
+
disconnect_task = asyncio.create_task(websocket_connection.receive_text())
|
| 141 |
+
get_message_task = None
|
| 142 |
+
|
| 143 |
try:
|
| 144 |
# 保持连接活跃
|
| 145 |
+
while True:
|
| 146 |
+
get_message_task = asyncio.create_task(websocket_message_queue.get())
|
| 147 |
+
|
| 148 |
+
done, pending = await asyncio.wait(
|
| 149 |
+
{disconnect_task, get_message_task},
|
| 150 |
+
return_when=asyncio.FIRST_COMPLETED
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# 如果是 disconnect_task 完成,说明客户端已断开连接。
|
| 154 |
+
if disconnect_task in done:
|
| 155 |
+
# 这将重新引发 WebSocketDisconnect 异常并跳出循环。
|
| 156 |
+
if get_message_task in done:
|
| 157 |
+
# 如果有,将消息放回队列,然后让连接断开
|
| 158 |
+
message = get_message_task.result()
|
| 159 |
+
websocket_message_queue.put_nowait(message)
|
| 160 |
+
logger.info("连接已关闭,将消息重新入队。")
|
| 161 |
+
|
| 162 |
+
disconnect_task.result()
|
| 163 |
+
|
| 164 |
+
# 如果是 queue_task 完成,说明有可用的消息。
|
| 165 |
+
if get_message_task in done:
|
| 166 |
+
message = get_message_task.result()
|
| 167 |
+
await connection_manager.send_to_session(message.session_id, message.model_dump())
|
| 168 |
+
# queue_task 现已完成,我们将在循环的下一次迭代中创建一个新的。
|
| 169 |
|
| 170 |
except WebSocketDisconnect:
|
| 171 |
+
logger.info("WebSocket连接已断开")
|
| 172 |
except Exception as e:
|
| 173 |
logger.error(f"WebSocket连接异常: {e}")
|
| 174 |
+
finally:
|
| 175 |
+
# 确保如果循环因其他原因退出,disconnect_task 会被取消。
|
| 176 |
+
disconnect_task.cancel()
|
| 177 |
+
if get_message_task and not get_message_task.done():
|
| 178 |
+
get_message_task.cancel()
|