from __future__ import annotations import asyncio import json import logging from typing import Any import redis.asyncio as redis from fastapi import WebSocket from backend.core.config import settings logger = logging.getLogger(__name__) class ConnectionManager: """Manages active WebSocket connections and broadcasts Redis Pub/Sub events.""" def __init__(self) -> None: self.active_connections: dict[str, list[WebSocket]] = {} self._redis_url = settings.redis_url self._pubsub_task: asyncio.Task[None] | None = None async def connect(self, websocket: WebSocket, scene_id: str) -> None: await websocket.accept() if scene_id not in self.active_connections: self.active_connections[scene_id] = [] self.active_connections[scene_id].append(websocket) # Start pubsub listener if not already running if not self._pubsub_task: self._pubsub_task = asyncio.create_task(self._listen_to_redis()) def disconnect(self, websocket: WebSocket, scene_id: str) -> None: if scene_id in self.active_connections: if websocket in self.active_connections[scene_id]: self.active_connections[scene_id].remove(websocket) if not self.active_connections[scene_id]: del self.active_connections[scene_id] async def _listen_to_redis(self) -> None: """Background task to listen to Redis 'manim_agent:events' and broadcast.""" logger.info("Starting Redis Pub/Sub listener for WebSockets") r = redis.from_url(self._redis_url, decode_responses=True) # type: ignore pubsub = r.pubsub() await pubsub.subscribe("manim_agent:events") try: async for message in pubsub.listen(): if message["type"] == "message": data = message["data"] try: payload = json.loads(data) scene_id = payload.get("scene_id") # Debug log to see what's coming in logger.info( f"WS Listener received: scene_id={scene_id}, " f"active={list(self.active_connections.keys())}" ) if scene_id and scene_id in self.active_connections: logger.info( f"Broadcasting to scene {scene_id}: {payload.get('message')}" ) await self.broadcast_to_scene(scene_id, payload) except Exception: logger.exception("Failed to broadcast WebSocket message") finally: await pubsub.unsubscribe("manim_agent:events") await r.close() async def broadcast_to_scene(self, scene_id: str, message: Any) -> None: if scene_id in self.active_connections: # Create tasks for all sends to avoid blocking dead_connections = [] for connection in self.active_connections[scene_id]: try: await connection.send_json(message) except Exception: dead_connections.append(connection) # Cleanup dead connections for dead in dead_connections: self.active_connections[scene_id].remove(dead) # Global manager instance manager = ConnectionManager()