|
|
""" |
|
|
WebSocket 路由 |
|
|
提供实时通信功能 |
|
|
""" |
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect |
|
|
from typing import Set |
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
active_connections: Set[WebSocket] = set() |
|
|
|
|
|
|
|
|
@router.websocket("/ws") |
|
|
async def websocket_endpoint( |
|
|
websocket: WebSocket, |
|
|
): |
|
|
"""WebSocket 端点""" |
|
|
|
|
|
await websocket.accept() |
|
|
active_connections.add(websocket) |
|
|
|
|
|
try: |
|
|
|
|
|
while True: |
|
|
|
|
|
data = await websocket.receive_text() |
|
|
|
|
|
|
|
|
except WebSocketDisconnect: |
|
|
active_connections.remove(websocket) |
|
|
except Exception as e: |
|
|
print(f"WebSocket 错误: {e}") |
|
|
if websocket in active_connections: |
|
|
active_connections.remove(websocket) |
|
|
|
|
|
|
|
|
async def broadcast_message(message_type: str, data: dict): |
|
|
"""向所有连接的客户端广播消息""" |
|
|
message = { |
|
|
"type": message_type, |
|
|
"data": data |
|
|
} |
|
|
|
|
|
|
|
|
disconnected = set() |
|
|
|
|
|
for connection in active_connections: |
|
|
try: |
|
|
await connection.send_json(message) |
|
|
except Exception: |
|
|
disconnected.add(connection) |
|
|
|
|
|
|
|
|
for connection in disconnected: |
|
|
active_connections.discard(connection) |
|
|
|