File size: 2,996 Bytes
6423ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""WebSocket connection manager and endpoints."""

import json
from typing import Dict, Any
from uuid import UUID
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect

logger = structlog.get_logger(__name__)
router = APIRouter()


class ConnectionManager:
    """Manages WebSocket connections for real-time updates."""

    def __init__(self):
        """Initialize connection manager."""
        # Map generation_id to list of active websockets
        # allowing multiple tabs/users to watch the same generation
        self.active_connections: Dict[str, list[WebSocket]] = {}

    async def connect(self, websocket: WebSocket, generation_id: str):
        """Accept connection and register it."""
        await websocket.accept()
        if generation_id not in self.active_connections:
            self.active_connections[generation_id] = []
        self.active_connections[generation_id].append(websocket)
        logger.info("websocket_connected", generation_id=generation_id)

    def disconnect(self, websocket: WebSocket, generation_id: str):
        """Remove connection."""
        if generation_id in self.active_connections:
            if websocket in self.active_connections[generation_id]:
                self.active_connections[generation_id].remove(websocket)
            if not self.active_connections[generation_id]:
                del self.active_connections[generation_id]
        logger.info("websocket_disconnected", generation_id=generation_id)

    async def broadcast(self, generation_id: str, message: Dict[str, Any]):
        """Send message to all clients watching a generation."""
        if generation_id in self.active_connections:
            for connection in self.active_connections[generation_id]:
                try:
                    await connection.send_json(message)
                except Exception as e:
                    logger.warning("websocket_send_failed", error=str(e))
                    # Cleanup dead connections lazily could happen here
                    # but typically handled by disconnect on exception


# Singleton instance
manager = ConnectionManager()


@router.websocket("/generations/{generation_id}")
async def generation_websocket(websocket: WebSocket, generation_id: str):
    """
    WebSocket endpoint for generation updates.
    
    Client connects to: /api/v1/ws/generations/{id}
    Server sends: { "status": "processing", "stage": "music_generation", "progress": 50 }
    """
    await manager.connect(websocket, generation_id)
    try:
        while True:
            # Keep connection alive and listen for any client messages (optional)
            # We mostly push from server, but need to await receive to keep socket open
            await websocket.receive_text()
    except WebSocketDisconnect:
        manager.disconnect(websocket, generation_id)
    except Exception as e:
        logger.error("websocket_error", exc_info=e)
        manager.disconnect(websocket, generation_id)