File size: 3,483 Bytes
9bed109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import annotations

import asyncio
import json
import logging
from typing import Any

import redis.asyncio as redis
from fastapi import WebSocket

from backend.core.config import settings

logger = logging.getLogger(__name__)


class ConnectionManager:
    """Manages active WebSocket connections and broadcasts Redis Pub/Sub events."""

    def __init__(self) -> None:
        self.active_connections: dict[str, list[WebSocket]] = {}
        self._redis_url = settings.redis_url
        self._pubsub_task: asyncio.Task[None] | None = None

    async def connect(self, websocket: WebSocket, scene_id: str) -> None:
        await websocket.accept()
        if scene_id not in self.active_connections:
            self.active_connections[scene_id] = []
        self.active_connections[scene_id].append(websocket)

        # Start pubsub listener if not already running
        if not self._pubsub_task:
            self._pubsub_task = asyncio.create_task(self._listen_to_redis())

    def disconnect(self, websocket: WebSocket, scene_id: str) -> None:
        if scene_id in self.active_connections:
            if websocket in self.active_connections[scene_id]:
                self.active_connections[scene_id].remove(websocket)
            if not self.active_connections[scene_id]:
                del self.active_connections[scene_id]

    async def _listen_to_redis(self) -> None:
        """Background task to listen to Redis 'manim_agent:events' and broadcast."""
        logger.info("Starting Redis Pub/Sub listener for WebSockets")
        r = redis.from_url(self._redis_url, decode_responses=True)  # type: ignore
        pubsub = r.pubsub()
        await pubsub.subscribe("manim_agent:events")

        try:
            async for message in pubsub.listen():
                if message["type"] == "message":
                    data = message["data"]
                    try:
                        payload = json.loads(data)
                        scene_id = payload.get("scene_id")
                        # Debug log to see what's coming in
                        logger.info(
                            f"WS Listener received: scene_id={scene_id}, "
                            f"active={list(self.active_connections.keys())}"
                        )

                        if scene_id and scene_id in self.active_connections:
                            logger.info(
                                f"Broadcasting to scene {scene_id}: {payload.get('message')}"
                            )
                            await self.broadcast_to_scene(scene_id, payload)
                    except Exception:
                        logger.exception("Failed to broadcast WebSocket message")
        finally:
            await pubsub.unsubscribe("manim_agent:events")
            await r.close()

    async def broadcast_to_scene(self, scene_id: str, message: Any) -> None:
        if scene_id in self.active_connections:
            # Create tasks for all sends to avoid blocking
            dead_connections = []
            for connection in self.active_connections[scene_id]:
                try:
                    await connection.send_json(message)
                except Exception:
                    dead_connections.append(connection)

            # Cleanup dead connections
            for dead in dead_connections:
                self.active_connections[scene_id].remove(dead)


# Global manager instance
manager = ConnectionManager()