File size: 2,268 Bytes
fed4f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""WebSocket connection manager for real-time communication."""

import logging
from typing import Any

from fastapi import WebSocket

logger = logging.getLogger(__name__)


class ConnectionManager:
    """Manages WebSocket connections for multiple sessions."""

    def __init__(self) -> None:
        # session_id -> WebSocket
        self.active_connections: dict[str, WebSocket] = {}

    async def connect(self, websocket: WebSocket, session_id: str) -> None:
        """Accept a WebSocket connection and register it."""
        logger.info(f"Attempting to accept WebSocket for session {session_id}")
        await websocket.accept()
        self.active_connections[session_id] = websocket
        logger.info(f"WebSocket connected and registered for session {session_id}")

    def disconnect(self, session_id: str) -> None:
        """Remove a WebSocket connection."""
        if session_id in self.active_connections:
            del self.active_connections[session_id]
        logger.info(f"WebSocket disconnected for session {session_id}")

    async def send_event(
        self, session_id: str, event_type: str, data: dict[str, Any] | None = None
    ) -> None:
        """Send an event to a specific session's WebSocket."""
        if session_id not in self.active_connections:
            logger.warning(f"No active connection for session {session_id}")
            return

        message = {"event_type": event_type}
        if data is not None:
            message["data"] = data

        try:
            await self.active_connections[session_id].send_json(message)
        except Exception as e:
            logger.error(f"Error sending to session {session_id}: {e}")
            self.disconnect(session_id)

    async def broadcast(
        self, event_type: str, data: dict[str, Any] | None = None
    ) -> None:
        """Broadcast an event to all connected sessions."""
        for session_id in list(self.active_connections.keys()):
            await self.send_event(session_id, event_type, data)

    def is_connected(self, session_id: str) -> bool:
        """Check if a session has an active WebSocket connection."""
        return session_id in self.active_connections


# Global connection manager instance
manager = ConnectionManager()