"""WebSocket connection manager and endpoints.""" import json from typing import Dict, Any from uuid import UUID import structlog from fastapi import APIRouter, WebSocket, WebSocketDisconnect logger = structlog.get_logger(__name__) router = APIRouter() class ConnectionManager: """Manages WebSocket connections for real-time updates.""" def __init__(self): """Initialize connection manager.""" # Map generation_id to list of active websockets # allowing multiple tabs/users to watch the same generation self.active_connections: Dict[str, list[WebSocket]] = {} async def connect(self, websocket: WebSocket, generation_id: str): """Accept connection and register it.""" await websocket.accept() if generation_id not in self.active_connections: self.active_connections[generation_id] = [] self.active_connections[generation_id].append(websocket) logger.info("websocket_connected", generation_id=generation_id) def disconnect(self, websocket: WebSocket, generation_id: str): """Remove connection.""" if generation_id in self.active_connections: if websocket in self.active_connections[generation_id]: self.active_connections[generation_id].remove(websocket) if not self.active_connections[generation_id]: del self.active_connections[generation_id] logger.info("websocket_disconnected", generation_id=generation_id) async def broadcast(self, generation_id: str, message: Dict[str, Any]): """Send message to all clients watching a generation.""" if generation_id in self.active_connections: for connection in self.active_connections[generation_id]: try: await connection.send_json(message) except Exception as e: logger.warning("websocket_send_failed", error=str(e)) # Cleanup dead connections lazily could happen here # but typically handled by disconnect on exception # Singleton instance manager = ConnectionManager() @router.websocket("/generations/{generation_id}") async def generation_websocket(websocket: WebSocket, generation_id: str): """ WebSocket endpoint for generation updates. Client connects to: /api/v1/ws/generations/{id} Server sends: { "status": "processing", "stage": "music_generation", "progress": 50 } """ await manager.connect(websocket, generation_id) try: while True: # Keep connection alive and listen for any client messages (optional) # We mostly push from server, but need to await receive to keep socket open await websocket.receive_text() except WebSocketDisconnect: manager.disconnect(websocket, generation_id) except Exception as e: logger.error("websocket_error", exc_info=e) manager.disconnect(websocket, generation_id)