Spaces:
Build error
Build error
| """WebSocket connection manager and endpoints.""" | |
| import json | |
| from typing import Dict, Any | |
| from uuid import UUID | |
| import structlog | |
| from fastapi import APIRouter, WebSocket, WebSocketDisconnect | |
| logger = structlog.get_logger(__name__) | |
| router = APIRouter() | |
| class ConnectionManager: | |
| """Manages WebSocket connections for real-time updates.""" | |
| def __init__(self): | |
| """Initialize connection manager.""" | |
| # Map generation_id to list of active websockets | |
| # allowing multiple tabs/users to watch the same generation | |
| self.active_connections: Dict[str, list[WebSocket]] = {} | |
| async def connect(self, websocket: WebSocket, generation_id: str): | |
| """Accept connection and register it.""" | |
| await websocket.accept() | |
| if generation_id not in self.active_connections: | |
| self.active_connections[generation_id] = [] | |
| self.active_connections[generation_id].append(websocket) | |
| logger.info("websocket_connected", generation_id=generation_id) | |
| def disconnect(self, websocket: WebSocket, generation_id: str): | |
| """Remove connection.""" | |
| if generation_id in self.active_connections: | |
| if websocket in self.active_connections[generation_id]: | |
| self.active_connections[generation_id].remove(websocket) | |
| if not self.active_connections[generation_id]: | |
| del self.active_connections[generation_id] | |
| logger.info("websocket_disconnected", generation_id=generation_id) | |
| async def broadcast(self, generation_id: str, message: Dict[str, Any]): | |
| """Send message to all clients watching a generation.""" | |
| if generation_id in self.active_connections: | |
| for connection in self.active_connections[generation_id]: | |
| try: | |
| await connection.send_json(message) | |
| except Exception as e: | |
| logger.warning("websocket_send_failed", error=str(e)) | |
| # Cleanup dead connections lazily could happen here | |
| # but typically handled by disconnect on exception | |
| # Singleton instance | |
| manager = ConnectionManager() | |
| async def generation_websocket(websocket: WebSocket, generation_id: str): | |
| """ | |
| WebSocket endpoint for generation updates. | |
| Client connects to: /api/v1/ws/generations/{id} | |
| Server sends: { "status": "processing", "stage": "music_generation", "progress": 50 } | |
| """ | |
| await manager.connect(websocket, generation_id) | |
| try: | |
| while True: | |
| # Keep connection alive and listen for any client messages (optional) | |
| # We mostly push from server, but need to await receive to keep socket open | |
| await websocket.receive_text() | |
| except WebSocketDisconnect: | |
| manager.disconnect(websocket, generation_id) | |
| except Exception as e: | |
| logger.error("websocket_error", exc_info=e) | |
| manager.disconnect(websocket, generation_id) | |