|
|
""" |
|
|
WebSocket Connection Manager - Phase 5 |
|
|
Manages real-time WebSocket connections for multi-client sync |
|
|
""" |
|
|
|
|
|
import json |
|
|
import asyncio |
|
|
from typing import Dict, Set, Optional |
|
|
from fastapi import WebSocket, WebSocketDisconnect |
|
|
from uuid import UUID |
|
|
|
|
|
from src.utils.logger import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class ConnectionManager: |
|
|
""" |
|
|
Manages WebSocket connections for real-time updates. |
|
|
|
|
|
Features: |
|
|
- Track active connections per user |
|
|
- Broadcast updates to specific user's connections |
|
|
- Handle connection/disconnection gracefully |
|
|
- Support multiple devices per user |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.active_connections: Dict[str, Set[WebSocket]] = {} |
|
|
|
|
|
self.connection_to_user: Dict[WebSocket, str] = {} |
|
|
|
|
|
async def connect(self, websocket: WebSocket, user_id: str): |
|
|
""" |
|
|
Accept a new WebSocket connection and track it. |
|
|
|
|
|
Args: |
|
|
websocket: The WebSocket connection |
|
|
user_id: ID of the user connecting |
|
|
""" |
|
|
await websocket.accept() |
|
|
|
|
|
|
|
|
if user_id not in self.active_connections: |
|
|
self.active_connections[user_id] = set() |
|
|
|
|
|
self.active_connections[user_id].add(websocket) |
|
|
self.connection_to_user[websocket] = user_id |
|
|
|
|
|
logger.info( |
|
|
"websocket_connected", |
|
|
user_id=user_id, |
|
|
total_connections_for_user=len(self.active_connections[user_id]), |
|
|
total_users=len(self.active_connections) |
|
|
) |
|
|
|
|
|
|
|
|
await websocket.send_json({ |
|
|
"type": "connected", |
|
|
"message": "Real-time sync activated", |
|
|
"user_id": user_id |
|
|
}) |
|
|
|
|
|
async def disconnect(self, websocket: WebSocket): |
|
|
""" |
|
|
Remove a WebSocket connection. |
|
|
|
|
|
Args: |
|
|
websocket: The WebSocket connection to remove |
|
|
""" |
|
|
user_id = self.connection_to_user.get(websocket) |
|
|
|
|
|
if user_id and user_id in self.active_connections: |
|
|
self.active_connections[user_id].discard(websocket) |
|
|
|
|
|
|
|
|
if not self.active_connections[user_id]: |
|
|
del self.active_connections[user_id] |
|
|
|
|
|
del self.connection_to_user[websocket] |
|
|
|
|
|
logger.info( |
|
|
"websocket_disconnected", |
|
|
user_id=user_id, |
|
|
remaining_connections=len(self.active_connections.get(user_id, [])) |
|
|
) |
|
|
|
|
|
async def send_personal_message(self, message: dict, user_id: str): |
|
|
""" |
|
|
Send a message to all connections for a specific user. |
|
|
|
|
|
Args: |
|
|
message: The message to send (will be JSON serialized) |
|
|
user_id: ID of the user to send to |
|
|
""" |
|
|
if user_id not in self.active_connections: |
|
|
logger.debug("No active connections for user", user_id=user_id) |
|
|
return |
|
|
|
|
|
|
|
|
disconnected = set() |
|
|
for connection in self.active_connections[user_id]: |
|
|
try: |
|
|
await connection.send_json(message) |
|
|
except Exception as e: |
|
|
logger.warning( |
|
|
"failed_to_send_to_connection", |
|
|
user_id=user_id, |
|
|
error=str(e) |
|
|
) |
|
|
disconnected.add(connection) |
|
|
|
|
|
|
|
|
for connection in disconnected: |
|
|
await self.disconnect(connection) |
|
|
|
|
|
logger.info( |
|
|
"message_broadcast_to_user", |
|
|
user_id=user_id, |
|
|
recipient_count=len(self.active_connections.get(user_id, [])), |
|
|
message_type=message.get("type") |
|
|
) |
|
|
|
|
|
async def broadcast_to_all(self, message: dict): |
|
|
""" |
|
|
Broadcast a message to all connected users. |
|
|
|
|
|
Args: |
|
|
message: The message to broadcast |
|
|
""" |
|
|
all_users = list(self.active_connections.keys()) |
|
|
|
|
|
for user_id in all_users: |
|
|
await self.send_personal_message(message, user_id) |
|
|
|
|
|
logger.info( |
|
|
"message_broadcast_to_all", |
|
|
total_users=len(all_users), |
|
|
message_type=message.get("type") |
|
|
) |
|
|
|
|
|
async def broadcast_task_update( |
|
|
self, |
|
|
user_id: str, |
|
|
update_type: str, |
|
|
task_data: dict |
|
|
): |
|
|
""" |
|
|
Broadcast a task update to all of a user's connected devices. |
|
|
|
|
|
Args: |
|
|
user_id: ID of the user who owns the task |
|
|
update_type: Type of update (created, updated, completed, deleted) |
|
|
task_data: The task data |
|
|
""" |
|
|
message = { |
|
|
"type": "task_update", |
|
|
"update_type": update_type, |
|
|
"data": task_data, |
|
|
"timestamp": asyncio.get_event_loop().time() |
|
|
} |
|
|
|
|
|
await self.send_personal_message(message, user_id) |
|
|
|
|
|
async def broadcast_reminder_created( |
|
|
self, |
|
|
user_id: str, |
|
|
reminder_data: dict |
|
|
): |
|
|
""" |
|
|
Broadcast a new reminder to all of a user's connected devices. |
|
|
|
|
|
Args: |
|
|
user_id: ID of the user who owns the reminder |
|
|
reminder_data: The reminder data |
|
|
""" |
|
|
message = { |
|
|
"type": "reminder_created", |
|
|
"data": reminder_data, |
|
|
"timestamp": asyncio.get_event_loop().time() |
|
|
} |
|
|
|
|
|
await self.send_personal_message(message, user_id) |
|
|
|
|
|
def get_connection_count(self, user_id: Optional[str] = None) -> int: |
|
|
""" |
|
|
Get the number of active connections. |
|
|
|
|
|
Args: |
|
|
user_id: If provided, get count for specific user only |
|
|
|
|
|
Returns: |
|
|
Number of active connections |
|
|
""" |
|
|
if user_id: |
|
|
return len(self.active_connections.get(user_id, [])) |
|
|
return sum(len(conns) for conns in self.active_connections.values()) |
|
|
|
|
|
def get_connected_users(self) -> list[str]: |
|
|
""" |
|
|
Get list of all connected user IDs. |
|
|
|
|
|
Returns: |
|
|
List of user IDs with active connections |
|
|
""" |
|
|
return list(self.active_connections.keys()) |
|
|
|
|
|
|
|
|
|
|
|
manager: Optional[ConnectionManager] = None |
|
|
|
|
|
|
|
|
def get_websocket_manager() -> ConnectionManager: |
|
|
"""Get the global WebSocket connection manager instance.""" |
|
|
global manager |
|
|
if manager is None: |
|
|
manager = ConnectionManager() |
|
|
return manager |
|
|
|