Spaces:
Sleeping
Sleeping
| """ | |
| Progress Broadcaster - Manages WebSocket broadcasting of knowledge base progress | |
| """ | |
| import asyncio | |
| from typing import Optional | |
| from fastapi import WebSocket | |
| from deeptutor.logging import get_logger | |
| logger = get_logger("ProgressBroadcaster") | |
| class ProgressBroadcaster: | |
| """Manages WebSocket broadcasting of knowledge base progress""" | |
| _instance: Optional["ProgressBroadcaster"] = None | |
| _connections: dict[str, set[WebSocket]] = {} # kb_name -> Set[WebSocket] | |
| _lock = asyncio.Lock() | |
| def get_instance(cls) -> "ProgressBroadcaster": | |
| """Get singleton instance""" | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| async def connect(self, kb_name: str, websocket: WebSocket): | |
| """Connect WebSocket to specified knowledge base""" | |
| async with self._lock: | |
| if kb_name not in self._connections: | |
| self._connections[kb_name] = set() | |
| self._connections[kb_name].add(websocket) | |
| logger.debug( | |
| f"Connected WebSocket for KB '{kb_name}' (total: {len(self._connections[kb_name])})" | |
| ) | |
| async def disconnect(self, kb_name: str, websocket: WebSocket): | |
| """Disconnect WebSocket connection""" | |
| async with self._lock: | |
| if kb_name in self._connections: | |
| self._connections[kb_name].discard(websocket) | |
| if not self._connections[kb_name]: | |
| del self._connections[kb_name] | |
| logger.debug(f"Disconnected WebSocket for KB '{kb_name}'") | |
| async def broadcast(self, kb_name: str, progress: dict): | |
| """Broadcast progress update to all WebSocket connections for specified knowledge base""" | |
| async with self._lock: | |
| if kb_name not in self._connections: | |
| return | |
| # Create list of connections to remove (closed connections) | |
| to_remove = [] | |
| for websocket in self._connections[kb_name]: | |
| try: | |
| await websocket.send_json({"type": "progress", "data": progress}) | |
| except Exception as e: | |
| # Connection closed or error, mark for removal | |
| logger.debug(f"Error sending to WebSocket for KB '{kb_name}': {e}") | |
| to_remove.append(websocket) | |
| # Remove closed connections | |
| for ws in to_remove: | |
| self._connections[kb_name].discard(ws) | |
| if not self._connections[kb_name]: | |
| del self._connections[kb_name] | |
| def get_connection_count(self, kb_name: str) -> int: | |
| """Get connection count for specified knowledge base""" | |
| return len(self._connections.get(kb_name, set())) | |