Spaces:
Sleeping
Sleeping
File size: 2,767 Bytes
5df8a73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | """
Progress Broadcaster - Manages WebSocket broadcasting of knowledge base progress
"""
import asyncio
from typing import Optional
from fastapi import WebSocket
from deeptutor.logging import get_logger
logger = get_logger("ProgressBroadcaster")
class ProgressBroadcaster:
"""Manages WebSocket broadcasting of knowledge base progress"""
_instance: Optional["ProgressBroadcaster"] = None
_connections: dict[str, set[WebSocket]] = {} # kb_name -> Set[WebSocket]
_lock = asyncio.Lock()
@classmethod
def get_instance(cls) -> "ProgressBroadcaster":
"""Get singleton instance"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def connect(self, kb_name: str, websocket: WebSocket):
"""Connect WebSocket to specified knowledge base"""
async with self._lock:
if kb_name not in self._connections:
self._connections[kb_name] = set()
self._connections[kb_name].add(websocket)
logger.debug(
f"Connected WebSocket for KB '{kb_name}' (total: {len(self._connections[kb_name])})"
)
async def disconnect(self, kb_name: str, websocket: WebSocket):
"""Disconnect WebSocket connection"""
async with self._lock:
if kb_name in self._connections:
self._connections[kb_name].discard(websocket)
if not self._connections[kb_name]:
del self._connections[kb_name]
logger.debug(f"Disconnected WebSocket for KB '{kb_name}'")
async def broadcast(self, kb_name: str, progress: dict):
"""Broadcast progress update to all WebSocket connections for specified knowledge base"""
async with self._lock:
if kb_name not in self._connections:
return
# Create list of connections to remove (closed connections)
to_remove = []
for websocket in self._connections[kb_name]:
try:
await websocket.send_json({"type": "progress", "data": progress})
except Exception as e:
# Connection closed or error, mark for removal
logger.debug(f"Error sending to WebSocket for KB '{kb_name}': {e}")
to_remove.append(websocket)
# Remove closed connections
for ws in to_remove:
self._connections[kb_name].discard(ws)
if not self._connections[kb_name]:
del self._connections[kb_name]
def get_connection_count(self, kb_name: str) -> int:
"""Get connection count for specified knowledge base"""
return len(self._connections.get(kb_name, set()))
|