Spaces:
Sleeping
Sleeping
File size: 3,483 Bytes
9bed109 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
import redis.asyncio as redis
from fastapi import WebSocket
from backend.core.config import settings
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages active WebSocket connections and broadcasts Redis Pub/Sub events."""
def __init__(self) -> None:
self.active_connections: dict[str, list[WebSocket]] = {}
self._redis_url = settings.redis_url
self._pubsub_task: asyncio.Task[None] | None = None
async def connect(self, websocket: WebSocket, scene_id: str) -> None:
await websocket.accept()
if scene_id not in self.active_connections:
self.active_connections[scene_id] = []
self.active_connections[scene_id].append(websocket)
# Start pubsub listener if not already running
if not self._pubsub_task:
self._pubsub_task = asyncio.create_task(self._listen_to_redis())
def disconnect(self, websocket: WebSocket, scene_id: str) -> None:
if scene_id in self.active_connections:
if websocket in self.active_connections[scene_id]:
self.active_connections[scene_id].remove(websocket)
if not self.active_connections[scene_id]:
del self.active_connections[scene_id]
async def _listen_to_redis(self) -> None:
"""Background task to listen to Redis 'manim_agent:events' and broadcast."""
logger.info("Starting Redis Pub/Sub listener for WebSockets")
r = redis.from_url(self._redis_url, decode_responses=True) # type: ignore
pubsub = r.pubsub()
await pubsub.subscribe("manim_agent:events")
try:
async for message in pubsub.listen():
if message["type"] == "message":
data = message["data"]
try:
payload = json.loads(data)
scene_id = payload.get("scene_id")
# Debug log to see what's coming in
logger.info(
f"WS Listener received: scene_id={scene_id}, "
f"active={list(self.active_connections.keys())}"
)
if scene_id and scene_id in self.active_connections:
logger.info(
f"Broadcasting to scene {scene_id}: {payload.get('message')}"
)
await self.broadcast_to_scene(scene_id, payload)
except Exception:
logger.exception("Failed to broadcast WebSocket message")
finally:
await pubsub.unsubscribe("manim_agent:events")
await r.close()
async def broadcast_to_scene(self, scene_id: str, message: Any) -> None:
if scene_id in self.active_connections:
# Create tasks for all sends to avoid blocking
dead_connections = []
for connection in self.active_connections[scene_id]:
try:
await connection.send_json(message)
except Exception:
dead_connections.append(connection)
# Cleanup dead connections
for dead in dead_connections:
self.active_connections[scene_id].remove(dead)
# Global manager instance
manager = ConnectionManager()
|