File size: 6,561 Bytes
edcd2ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """
WebSocket Connection Manager - Phase 5
Manages real-time WebSocket connections for multi-client sync
"""
import json
import asyncio
from typing import Dict, Set, Optional
from fastapi import WebSocket, WebSocketDisconnect
from uuid import UUID
from src.utils.logger import get_logger
logger = get_logger(__name__)
class ConnectionManager:
"""
Manages WebSocket connections for real-time updates.
Features:
- Track active connections per user
- Broadcast updates to specific user's connections
- Handle connection/disconnection gracefully
- Support multiple devices per user
"""
def __init__(self):
# user_id -> set of WebSocket connections
self.active_connections: Dict[str, Set[WebSocket]] = {}
# WebSocket -> user_id mapping (for reverse lookup)
self.connection_to_user: Dict[WebSocket, str] = {}
async def connect(self, websocket: WebSocket, user_id: str):
"""
Accept a new WebSocket connection and track it.
Args:
websocket: The WebSocket connection
user_id: ID of the user connecting
"""
await websocket.accept()
# Add to user's connection set
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
self.connection_to_user[websocket] = user_id
logger.info(
"websocket_connected",
user_id=user_id,
total_connections_for_user=len(self.active_connections[user_id]),
total_users=len(self.active_connections)
)
# Send welcome message
await websocket.send_json({
"type": "connected",
"message": "Real-time sync activated",
"user_id": user_id
})
async def disconnect(self, websocket: WebSocket):
"""
Remove a WebSocket connection.
Args:
websocket: The WebSocket connection to remove
"""
user_id = self.connection_to_user.get(websocket)
if user_id and user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
# Clean up empty user entries
if not self.active_connections[user_id]:
del self.active_connections[user_id]
del self.connection_to_user[websocket]
logger.info(
"websocket_disconnected",
user_id=user_id,
remaining_connections=len(self.active_connections.get(user_id, []))
)
async def send_personal_message(self, message: dict, user_id: str):
"""
Send a message to all connections for a specific user.
Args:
message: The message to send (will be JSON serialized)
user_id: ID of the user to send to
"""
if user_id not in self.active_connections:
logger.debug("No active connections for user", user_id=user_id)
return
# Send to all of user's connected devices
disconnected = set()
for connection in self.active_connections[user_id]:
try:
await connection.send_json(message)
except Exception as e:
logger.warning(
"failed_to_send_to_connection",
user_id=user_id,
error=str(e)
)
disconnected.add(connection)
# Clean up disconnected sockets
for connection in disconnected:
await self.disconnect(connection)
logger.info(
"message_broadcast_to_user",
user_id=user_id,
recipient_count=len(self.active_connections.get(user_id, [])),
message_type=message.get("type")
)
async def broadcast_to_all(self, message: dict):
"""
Broadcast a message to all connected users.
Args:
message: The message to broadcast
"""
all_users = list(self.active_connections.keys())
for user_id in all_users:
await self.send_personal_message(message, user_id)
logger.info(
"message_broadcast_to_all",
total_users=len(all_users),
message_type=message.get("type")
)
async def broadcast_task_update(
self,
user_id: str,
update_type: str,
task_data: dict
):
"""
Broadcast a task update to all of a user's connected devices.
Args:
user_id: ID of the user who owns the task
update_type: Type of update (created, updated, completed, deleted)
task_data: The task data
"""
message = {
"type": "task_update",
"update_type": update_type,
"data": task_data,
"timestamp": asyncio.get_event_loop().time()
}
await self.send_personal_message(message, user_id)
async def broadcast_reminder_created(
self,
user_id: str,
reminder_data: dict
):
"""
Broadcast a new reminder to all of a user's connected devices.
Args:
user_id: ID of the user who owns the reminder
reminder_data: The reminder data
"""
message = {
"type": "reminder_created",
"data": reminder_data,
"timestamp": asyncio.get_event_loop().time()
}
await self.send_personal_message(message, user_id)
def get_connection_count(self, user_id: Optional[str] = None) -> int:
"""
Get the number of active connections.
Args:
user_id: If provided, get count for specific user only
Returns:
Number of active connections
"""
if user_id:
return len(self.active_connections.get(user_id, []))
return sum(len(conns) for conns in self.active_connections.values())
def get_connected_users(self) -> list[str]:
"""
Get list of all connected user IDs.
Returns:
List of user IDs with active connections
"""
return list(self.active_connections.keys())
# Global connection manager instance
manager: Optional[ConnectionManager] = None
def get_websocket_manager() -> ConnectionManager:
"""Get the global WebSocket connection manager instance."""
global manager
if manager is None:
manager = ConnectionManager()
return manager
|