Spaces:
Build error
Build error
File size: 2,996 Bytes
6423ff2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | """WebSocket connection manager and endpoints."""
import json
from typing import Dict, Any
from uuid import UUID
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
logger = structlog.get_logger(__name__)
router = APIRouter()
class ConnectionManager:
"""Manages WebSocket connections for real-time updates."""
def __init__(self):
"""Initialize connection manager."""
# Map generation_id to list of active websockets
# allowing multiple tabs/users to watch the same generation
self.active_connections: Dict[str, list[WebSocket]] = {}
async def connect(self, websocket: WebSocket, generation_id: str):
"""Accept connection and register it."""
await websocket.accept()
if generation_id not in self.active_connections:
self.active_connections[generation_id] = []
self.active_connections[generation_id].append(websocket)
logger.info("websocket_connected", generation_id=generation_id)
def disconnect(self, websocket: WebSocket, generation_id: str):
"""Remove connection."""
if generation_id in self.active_connections:
if websocket in self.active_connections[generation_id]:
self.active_connections[generation_id].remove(websocket)
if not self.active_connections[generation_id]:
del self.active_connections[generation_id]
logger.info("websocket_disconnected", generation_id=generation_id)
async def broadcast(self, generation_id: str, message: Dict[str, Any]):
"""Send message to all clients watching a generation."""
if generation_id in self.active_connections:
for connection in self.active_connections[generation_id]:
try:
await connection.send_json(message)
except Exception as e:
logger.warning("websocket_send_failed", error=str(e))
# Cleanup dead connections lazily could happen here
# but typically handled by disconnect on exception
# Singleton instance
manager = ConnectionManager()
@router.websocket("/generations/{generation_id}")
async def generation_websocket(websocket: WebSocket, generation_id: str):
"""
WebSocket endpoint for generation updates.
Client connects to: /api/v1/ws/generations/{id}
Server sends: { "status": "processing", "stage": "music_generation", "progress": 50 }
"""
await manager.connect(websocket, generation_id)
try:
while True:
# Keep connection alive and listen for any client messages (optional)
# We mostly push from server, but need to await receive to keep socket open
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(websocket, generation_id)
except Exception as e:
logger.error("websocket_error", exc_info=e)
manager.disconnect(websocket, generation_id)
|