File size: 7,284 Bytes
300d567 2534744 511ff0c 851495c 2534744 300d567 b115e26 300d567 2534744 b115e26 300d567 b115e26 300d567 b115e26 300d567 b115e26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import asyncio
from contextlib import asynccontextmanager
from typing import Set, Dict
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from voice_dialogue.core.constants import websocket_message_queue, session_manager
from voice_dialogue.utils.logger import logger
ws = APIRouter()
class WebSocketConnectionManager:
"""WebSocket 连接管理器 - 管理所有活跃连接"""
def __init__(self):
# 使用 WeakSet 避免内存泄漏
self._connections: Set[WebSocket] = set()
# 会话ID到连接的映射
self._session_connections: Dict[str, Set[WebSocket]] = {}
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, session_id: str = None):
"""建立新连接"""
async with self._lock:
await websocket.accept()
self._connections.add(websocket)
# 如果指定了会话ID,建立映射关系
if session_id:
if session_id not in self._session_connections:
self._session_connections[session_id] = set()
self._session_connections[session_id].add(websocket)
logger.info(f"WebSocket连接已建立,当前活跃连接数: {len(self._connections)}")
async def disconnect(self, websocket: WebSocket):
"""断开连接"""
async with self._lock:
self._connections.discard(websocket)
# 从会话映射中移除
for session_id, connections in list(self._session_connections.items()):
connections.discard(websocket)
if not connections: # 如果该会话没有活跃连接,清理映射
del self._session_connections[session_id]
logger.info(f"WebSocket连接已断开,当前活跃连接数: {len(self._connections)}")
async def close_session_connections(self, session_id: str):
"""关闭指定会话的所有连接"""
async with self._lock:
if session_id in self._session_connections:
connections_to_close = list(self._session_connections[session_id])
for connection in connections_to_close:
try:
await connection.close()
logger.info(f"已关闭会话 {session_id} 的一个连接")
except Exception as e:
logger.warning(f"关闭连接时出错: {e}")
# 清理映射
del self._session_connections[session_id]
async def send_to_session(self, session_id: str, message: dict):
"""向指定会话的所有连接发送消息"""
async with self._lock:
if session_id in self._session_connections:
connections = list(self._session_connections[session_id])
disconnected_connections = []
for connection in connections:
try:
await connection.send_json(message)
except Exception as e:
logger.warning(f"发送消息失败,标记连接为断开: {e}")
disconnected_connections.append(connection)
# 清理断开的连接
for connection in disconnected_connections:
await self.disconnect(connection)
async def broadcast(self, message: dict):
"""向所有活跃连接广播消息"""
async with self._lock:
if not self._connections:
return
connections = list(self._connections)
disconnected_connections = []
for connection in connections:
try:
await connection.send_json(message)
except Exception as e:
logger.warning(f"广播消息失败,标记连接为断开: {e}")
disconnected_connections.append(connection)
# 清理断开的连接
for connection in disconnected_connections:
await self.disconnect(connection)
@property
def connection_count(self) -> int:
"""获取当前连接数"""
return len(self._connections)
def get_session_connection_count(self, session_id: str) -> int:
"""获取指定会话的连接数"""
return len(self._session_connections.get(session_id, set()))
# 全局连接管理器实例
connection_manager = WebSocketConnectionManager()
@asynccontextmanager
async def websocket_connection_context(websocket: WebSocket):
"""WebSocket连接上下文管理器"""
current_session_id = session_manager.current_id
# 关闭同一会话的旧连接
if connection_manager.get_session_connection_count(current_session_id) > 0:
logger.info(f"检测到会话 {current_session_id} 已有连接,关闭旧连接")
await connection_manager.close_session_connections(current_session_id)
try:
# 建立新连接
await connection_manager.connect(websocket, current_session_id)
yield websocket
finally:
# 确保连接被正确清理
await connection_manager.disconnect(websocket)
@ws.websocket("/api/v1/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket连接端点"""
async with websocket_connection_context(websocket) as websocket_connection:
disconnect_task = asyncio.create_task(websocket_connection.receive_text())
get_message_task = None
try:
# 保持连接活跃
while True:
get_message_task = asyncio.create_task(websocket_message_queue.get())
done, pending = await asyncio.wait(
{disconnect_task, get_message_task},
return_when=asyncio.FIRST_COMPLETED
)
# 如果是 disconnect_task 完成,说明客户端已断开连接。
if disconnect_task in done:
# 这将重新引发 WebSocketDisconnect 异常并跳出循环。
if get_message_task in done:
# 如果有,将消息放回队列,然后让连接断开
message = get_message_task.result()
websocket_message_queue.put_nowait(message)
logger.info("连接已关闭,将消息重新入队。")
disconnect_task.result()
# 如果是 queue_task 完成,说明有可用的消息。
if get_message_task in done:
message = get_message_task.result()
await connection_manager.send_to_session(message.session_id, message.model_dump())
# queue_task 现已完成,我们将在循环的下一次迭代中创建一个新的。
except WebSocketDisconnect:
logger.info("WebSocket连接已断开")
except Exception as e:
logger.error(f"WebSocket连接异常: {e}")
finally:
# 确保如果循环因其他原因退出,disconnect_task 会被取消。
disconnect_task.cancel()
if get_message_task and not get_message_task.done():
get_message_task.cancel()
|