AudioForge / backend /app /api /v1 /websockets.py
OnyxlMunkey's picture
c618549
"""WebSocket connection manager and endpoints."""
import json
from typing import Dict, Any
from uuid import UUID
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
logger = structlog.get_logger(__name__)
router = APIRouter()
class ConnectionManager:
"""Manages WebSocket connections for real-time updates."""
def __init__(self):
"""Initialize connection manager."""
# Map generation_id to list of active websockets
# allowing multiple tabs/users to watch the same generation
self.active_connections: Dict[str, list[WebSocket]] = {}
async def connect(self, websocket: WebSocket, generation_id: str):
"""Accept connection and register it."""
await websocket.accept()
if generation_id not in self.active_connections:
self.active_connections[generation_id] = []
self.active_connections[generation_id].append(websocket)
logger.info("websocket_connected", generation_id=generation_id)
def disconnect(self, websocket: WebSocket, generation_id: str):
"""Remove connection."""
if generation_id in self.active_connections:
if websocket in self.active_connections[generation_id]:
self.active_connections[generation_id].remove(websocket)
if not self.active_connections[generation_id]:
del self.active_connections[generation_id]
logger.info("websocket_disconnected", generation_id=generation_id)
async def broadcast(self, generation_id: str, message: Dict[str, Any]):
"""Send message to all clients watching a generation."""
if generation_id in self.active_connections:
for connection in self.active_connections[generation_id]:
try:
await connection.send_json(message)
except Exception as e:
logger.warning("websocket_send_failed", error=str(e))
# Cleanup dead connections lazily could happen here
# but typically handled by disconnect on exception
# Singleton instance
manager = ConnectionManager()
@router.websocket("/generations/{generation_id}")
async def generation_websocket(websocket: WebSocket, generation_id: str):
"""
WebSocket endpoint for generation updates.
Client connects to: /api/v1/ws/generations/{id}
Server sends: { "status": "processing", "stage": "music_generation", "progress": 50 }
"""
await manager.connect(websocket, generation_id)
try:
while True:
# Keep connection alive and listen for any client messages (optional)
# We mostly push from server, but need to await receive to keep socket open
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(websocket, generation_id)
except Exception as e:
logger.error("websocket_error", exc_info=e)
manager.disconnect(websocket, generation_id)