| | """WebSocket connection manager for real-time communication."""
|
| |
|
| | import logging
|
| | from typing import Any
|
| |
|
| | from fastapi import WebSocket
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | class ConnectionManager:
|
| | """Manages WebSocket connections for multiple sessions."""
|
| |
|
| | def __init__(self) -> None:
|
| |
|
| | self.active_connections: dict[str, WebSocket] = {}
|
| |
|
| | async def connect(self, websocket: WebSocket, session_id: str) -> None:
|
| | """Accept a WebSocket connection and register it."""
|
| | logger.info(f"Attempting to accept WebSocket for session {session_id}")
|
| | await websocket.accept()
|
| | self.active_connections[session_id] = websocket
|
| | logger.info(f"WebSocket connected and registered for session {session_id}")
|
| |
|
| | def disconnect(self, session_id: str) -> None:
|
| | """Remove a WebSocket connection."""
|
| | if session_id in self.active_connections:
|
| | del self.active_connections[session_id]
|
| | logger.info(f"WebSocket disconnected for session {session_id}")
|
| |
|
| | async def send_event(
|
| | self, session_id: str, event_type: str, data: dict[str, Any] | None = None
|
| | ) -> None:
|
| | """Send an event to a specific session's WebSocket."""
|
| | if session_id not in self.active_connections:
|
| | logger.warning(f"No active connection for session {session_id}")
|
| | return
|
| |
|
| | message = {"event_type": event_type}
|
| | if data is not None:
|
| | message["data"] = data
|
| |
|
| | try:
|
| | await self.active_connections[session_id].send_json(message)
|
| | except Exception as e:
|
| | logger.error(f"Error sending to session {session_id}: {e}")
|
| | self.disconnect(session_id)
|
| |
|
| | async def broadcast(
|
| | self, event_type: str, data: dict[str, Any] | None = None
|
| | ) -> None:
|
| | """Broadcast an event to all connected sessions."""
|
| | for session_id in list(self.active_connections.keys()):
|
| | await self.send_event(session_id, event_type, data)
|
| |
|
| | def is_connected(self, session_id: str) -> bool:
|
| | """Check if a session has an active WebSocket connection."""
|
| | return session_id in self.active_connections
|
| |
|
| |
|
| |
|
| | manager = ConnectionManager()
|
| |
|