|
|
import asyncio |
|
|
from contextlib import asynccontextmanager |
|
|
from typing import Set, Dict |
|
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect |
|
|
|
|
|
from voice_dialogue.core.constants import websocket_message_queue, session_manager |
|
|
from voice_dialogue.utils.logger import logger |
|
|
|
|
|
ws = APIRouter() |
|
|
|
|
|
|
|
|
class WebSocketConnectionManager: |
|
|
"""WebSocket 连接管理器 - 管理所有活跃连接""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self._connections: Set[WebSocket] = set() |
|
|
|
|
|
self._session_connections: Dict[str, Set[WebSocket]] = {} |
|
|
self._lock = asyncio.Lock() |
|
|
|
|
|
async def connect(self, websocket: WebSocket, session_id: str = None): |
|
|
"""建立新连接""" |
|
|
async with self._lock: |
|
|
await websocket.accept() |
|
|
self._connections.add(websocket) |
|
|
|
|
|
|
|
|
if session_id: |
|
|
if session_id not in self._session_connections: |
|
|
self._session_connections[session_id] = set() |
|
|
self._session_connections[session_id].add(websocket) |
|
|
|
|
|
logger.info(f"WebSocket连接已建立,当前活跃连接数: {len(self._connections)}") |
|
|
|
|
|
async def disconnect(self, websocket: WebSocket): |
|
|
"""断开连接""" |
|
|
async with self._lock: |
|
|
self._connections.discard(websocket) |
|
|
|
|
|
|
|
|
for session_id, connections in list(self._session_connections.items()): |
|
|
connections.discard(websocket) |
|
|
if not connections: |
|
|
del self._session_connections[session_id] |
|
|
|
|
|
logger.info(f"WebSocket连接已断开,当前活跃连接数: {len(self._connections)}") |
|
|
|
|
|
async def close_session_connections(self, session_id: str): |
|
|
"""关闭指定会话的所有连接""" |
|
|
async with self._lock: |
|
|
if session_id in self._session_connections: |
|
|
connections_to_close = list(self._session_connections[session_id]) |
|
|
for connection in connections_to_close: |
|
|
try: |
|
|
await connection.close() |
|
|
logger.info(f"已关闭会话 {session_id} 的一个连接") |
|
|
except Exception as e: |
|
|
logger.warning(f"关闭连接时出错: {e}") |
|
|
|
|
|
|
|
|
del self._session_connections[session_id] |
|
|
|
|
|
async def send_to_session(self, session_id: str, message: dict): |
|
|
"""向指定会话的所有连接发送消息""" |
|
|
async with self._lock: |
|
|
if session_id in self._session_connections: |
|
|
connections = list(self._session_connections[session_id]) |
|
|
disconnected_connections = [] |
|
|
|
|
|
for connection in connections: |
|
|
try: |
|
|
await connection.send_json(message) |
|
|
except Exception as e: |
|
|
logger.warning(f"发送消息失败,标记连接为断开: {e}") |
|
|
disconnected_connections.append(connection) |
|
|
|
|
|
|
|
|
for connection in disconnected_connections: |
|
|
await self.disconnect(connection) |
|
|
|
|
|
async def broadcast(self, message: dict): |
|
|
"""向所有活跃连接广播消息""" |
|
|
async with self._lock: |
|
|
if not self._connections: |
|
|
return |
|
|
|
|
|
connections = list(self._connections) |
|
|
disconnected_connections = [] |
|
|
|
|
|
for connection in connections: |
|
|
try: |
|
|
await connection.send_json(message) |
|
|
except Exception as e: |
|
|
logger.warning(f"广播消息失败,标记连接为断开: {e}") |
|
|
disconnected_connections.append(connection) |
|
|
|
|
|
|
|
|
for connection in disconnected_connections: |
|
|
await self.disconnect(connection) |
|
|
|
|
|
@property |
|
|
def connection_count(self) -> int: |
|
|
"""获取当前连接数""" |
|
|
return len(self._connections) |
|
|
|
|
|
def get_session_connection_count(self, session_id: str) -> int: |
|
|
"""获取指定会话的连接数""" |
|
|
return len(self._session_connections.get(session_id, set())) |
|
|
|
|
|
|
|
|
|
|
|
connection_manager = WebSocketConnectionManager() |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def websocket_connection_context(websocket: WebSocket): |
|
|
"""WebSocket连接上下文管理器""" |
|
|
current_session_id = session_manager.current_id |
|
|
|
|
|
|
|
|
if connection_manager.get_session_connection_count(current_session_id) > 0: |
|
|
logger.info(f"检测到会话 {current_session_id} 已有连接,关闭旧连接") |
|
|
await connection_manager.close_session_connections(current_session_id) |
|
|
|
|
|
try: |
|
|
|
|
|
await connection_manager.connect(websocket, current_session_id) |
|
|
yield websocket |
|
|
finally: |
|
|
|
|
|
await connection_manager.disconnect(websocket) |
|
|
|
|
|
|
|
|
@ws.websocket("/api/v1/ws") |
|
|
async def websocket_endpoint(websocket: WebSocket): |
|
|
"""WebSocket连接端点""" |
|
|
async with websocket_connection_context(websocket) as websocket_connection: |
|
|
disconnect_task = asyncio.create_task(websocket_connection.receive_text()) |
|
|
get_message_task = None |
|
|
|
|
|
try: |
|
|
|
|
|
while True: |
|
|
get_message_task = asyncio.create_task(websocket_message_queue.get()) |
|
|
|
|
|
done, pending = await asyncio.wait( |
|
|
{disconnect_task, get_message_task}, |
|
|
return_when=asyncio.FIRST_COMPLETED |
|
|
) |
|
|
|
|
|
|
|
|
if disconnect_task in done: |
|
|
|
|
|
if get_message_task in done: |
|
|
|
|
|
message = get_message_task.result() |
|
|
websocket_message_queue.put_nowait(message) |
|
|
logger.info("连接已关闭,将消息重新入队。") |
|
|
|
|
|
disconnect_task.result() |
|
|
|
|
|
|
|
|
if get_message_task in done: |
|
|
message = get_message_task.result() |
|
|
await connection_manager.send_to_session(message.session_id, message.model_dump()) |
|
|
|
|
|
|
|
|
except WebSocketDisconnect: |
|
|
logger.info("WebSocket连接已断开") |
|
|
except Exception as e: |
|
|
logger.error(f"WebSocket连接异常: {e}") |
|
|
finally: |
|
|
|
|
|
disconnect_task.cancel() |
|
|
if get_message_task and not get_message_task.done(): |
|
|
get_message_task.cancel() |
|
|
|