""" Connection management for WebSocket communications. This module manages WebSocket connections, tracks connection status, and provides utilities for connection lifecycle management. """ import logging from datetime import datetime, timedelta from typing import Dict, Any, Optional, List, Set from threading import Lock import json import redis logger = logging.getLogger(__name__) class ConnectionManagerError(Exception): """Base exception for connection manager errors.""" pass class ConnectionManager: """Manages WebSocket connections and their lifecycle.""" def __init__(self, redis_client: redis.Redis, connection_timeout: int = 300): """ Initialize the connection manager. Args: redis_client: Redis client for connection persistence connection_timeout: Connection timeout in seconds (default: 5 minutes) """ self.redis_client = redis_client self.connection_timeout = connection_timeout self.connections_prefix = "ws_connection:" self.session_connections_prefix = "session_connections:" self.user_connections_prefix = "user_connections:" # In-memory cache for active connections (faster access) self._active_connections: Dict[str, Dict[str, Any]] = {} self._connections_lock = Lock() def add_connection(self, client_id: str, connection_info: Dict[str, Any]) -> None: """ Add a new WebSocket connection. Args: client_id: Unique client identifier (socket ID) connection_info: Connection information dictionary Raises: ConnectionManagerError: If connection cannot be added """ try: with self._connections_lock: # Add to in-memory cache self._active_connections[client_id] = connection_info.copy() # Persist to Redis connection_key = f"{self.connections_prefix}{client_id}" self.redis_client.setex( connection_key, self.connection_timeout, json.dumps(connection_info) ) # Add to session connections set session_id = connection_info['session_id'] session_connections_key = f"{self.session_connections_prefix}{session_id}" self.redis_client.sadd(session_connections_key, client_id) self.redis_client.expire(session_connections_key, self.connection_timeout) # Add to user connections set user_id = connection_info['user_id'] user_connections_key = f"{self.user_connections_prefix}{user_id}" self.redis_client.sadd(user_connections_key, client_id) self.redis_client.expire(user_connections_key, self.connection_timeout) logger.info(f"Added connection {client_id} for session {session_id}") except redis.RedisError as e: logger.error(f"Redis error adding connection {client_id}: {e}") # Keep in memory even if Redis fails except Exception as e: logger.error(f"Error adding connection {client_id}: {e}") raise ConnectionManagerError(f"Failed to add connection: {e}") def remove_connection(self, client_id: str) -> Optional[Dict[str, Any]]: """ Remove a WebSocket connection. Args: client_id: Client identifier to remove Returns: Optional[Dict[str, Any]]: Connection info if found, None otherwise """ try: with self._connections_lock: # Get connection info before removal connection_info = self._active_connections.get(client_id) if not connection_info: # Try to get from Redis connection_info = self._get_connection_from_redis(client_id) if connection_info: # Remove from in-memory cache self._active_connections.pop(client_id, None) # Remove from Redis connection_key = f"{self.connections_prefix}{client_id}" self.redis_client.delete(connection_key) # Remove from session connections set session_id = connection_info['session_id'] session_connections_key = f"{self.session_connections_prefix}{session_id}" self.redis_client.srem(session_connections_key, client_id) # Remove from user connections set user_id = connection_info['user_id'] user_connections_key = f"{self.user_connections_prefix}{user_id}" self.redis_client.srem(user_connections_key, client_id) logger.info(f"Removed connection {client_id}") return connection_info except redis.RedisError as e: logger.error(f"Redis error removing connection {client_id}: {e}") # Still remove from memory with self._connections_lock: return self._active_connections.pop(client_id, None) except Exception as e: logger.error(f"Error removing connection {client_id}: {e}") return None def get_connection(self, client_id: str) -> Optional[Dict[str, Any]]: """ Get connection information. Args: client_id: Client identifier Returns: Optional[Dict[str, Any]]: Connection info if found, None otherwise """ try: with self._connections_lock: # Check in-memory cache first connection_info = self._active_connections.get(client_id) if connection_info: return connection_info.copy() # Check Redis connection_info = self._get_connection_from_redis(client_id) if connection_info: # Cache in memory self._active_connections[client_id] = connection_info.copy() return connection_info return None except Exception as e: logger.error(f"Error getting connection {client_id}: {e}") return None def update_connection(self, client_id: str, connection_info: Dict[str, Any]) -> bool: """ Update connection information. Args: client_id: Client identifier connection_info: Updated connection information Returns: bool: True if updated successfully, False otherwise """ try: with self._connections_lock: # Update in-memory cache if client_id in self._active_connections: self._active_connections[client_id] = connection_info.copy() # Update in Redis connection_key = f"{self.connections_prefix}{client_id}" self.redis_client.setex( connection_key, self.connection_timeout, json.dumps(connection_info) ) logger.debug(f"Updated connection {client_id}") return True return False except redis.RedisError as e: logger.error(f"Redis error updating connection {client_id}: {e}") # Update in memory only with self._connections_lock: if client_id in self._active_connections: self._active_connections[client_id] = connection_info.copy() return True return False except Exception as e: logger.error(f"Error updating connection {client_id}: {e}") return False def update_connection_activity(self, client_id: str) -> bool: """ Update connection activity timestamp. Args: client_id: Client identifier Returns: bool: True if updated successfully, False otherwise """ try: connection_info = self.get_connection(client_id) if connection_info: connection_info['last_activity'] = datetime.utcnow().isoformat() return self.update_connection(client_id, connection_info) return False except Exception as e: logger.error(f"Error updating connection activity {client_id}: {e}") return False def get_session_connections(self, session_id: str) -> List[str]: """ Get all connection IDs for a session. Args: session_id: Session identifier Returns: List[str]: List of client IDs connected to the session """ try: session_connections_key = f"{self.session_connections_prefix}{session_id}" client_ids = self.redis_client.smembers(session_connections_key) # Convert bytes to strings and filter active connections active_client_ids = [] for client_id_bytes in client_ids: client_id = client_id_bytes.decode('utf-8') if self.get_connection(client_id): active_client_ids.append(client_id) else: # Clean up stale reference self.redis_client.srem(session_connections_key, client_id) return active_client_ids except redis.RedisError as e: logger.error(f"Redis error getting session connections: {e}") return [] except Exception as e: logger.error(f"Error getting session connections: {e}") return [] def get_user_connections(self, user_id: str) -> List[str]: """ Get all connection IDs for a user. Args: user_id: User identifier Returns: List[str]: List of client IDs connected for the user """ try: user_connections_key = f"{self.user_connections_prefix}{user_id}" client_ids = self.redis_client.smembers(user_connections_key) # Convert bytes to strings and filter active connections active_client_ids = [] for client_id_bytes in client_ids: client_id = client_id_bytes.decode('utf-8') if self.get_connection(client_id): active_client_ids.append(client_id) else: # Clean up stale reference self.redis_client.srem(user_connections_key, client_id) return active_client_ids except redis.RedisError as e: logger.error(f"Redis error getting user connections: {e}") return [] except Exception as e: logger.error(f"Error getting user connections: {e}") return [] def get_all_connections(self) -> Dict[str, Dict[str, Any]]: """ Get all active connections. Returns: Dict[str, Dict[str, Any]]: Dictionary of client_id -> connection_info """ try: with self._connections_lock: return self._active_connections.copy() except Exception as e: logger.error(f"Error getting all connections: {e}") return {} def cleanup_expired_connections(self) -> int: """ Clean up expired connections. Returns: int: Number of connections cleaned up """ try: cleaned_count = 0 cutoff_time = datetime.utcnow() - timedelta(seconds=self.connection_timeout) with self._connections_lock: expired_client_ids = [] for client_id, connection_info in self._active_connections.items(): try: connected_at = datetime.fromisoformat(connection_info['connected_at']) last_activity = connection_info.get('last_activity') if last_activity: last_activity_time = datetime.fromisoformat(last_activity) if last_activity_time < cutoff_time: expired_client_ids.append(client_id) elif connected_at < cutoff_time: expired_client_ids.append(client_id) except (ValueError, KeyError): # Invalid timestamp, mark for cleanup expired_client_ids.append(client_id) # Remove expired connections for client_id in expired_client_ids: self.remove_connection(client_id) cleaned_count += 1 if cleaned_count > 0: logger.info(f"Cleaned up {cleaned_count} expired connections") return cleaned_count except Exception as e: logger.error(f"Error cleaning up expired connections: {e}") return 0 def get_connection_stats(self) -> Dict[str, Any]: """ Get connection statistics. Returns: Dict[str, Any]: Connection statistics """ try: with self._connections_lock: total_connections = len(self._active_connections) # Count connections by session and user sessions = set() users = set() languages = {} for connection_info in self._active_connections.values(): sessions.add(connection_info['session_id']) users.add(connection_info['user_id']) language = connection_info.get('language', 'unknown') languages[language] = languages.get(language, 0) + 1 return { 'total_connections': total_connections, 'unique_sessions': len(sessions), 'unique_users': len(users), 'languages': languages, 'timestamp': datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Error getting connection stats: {e}") return { 'total_connections': 0, 'unique_sessions': 0, 'unique_users': 0, 'languages': {}, 'timestamp': datetime.utcnow().isoformat() } def _get_connection_from_redis(self, client_id: str) -> Optional[Dict[str, Any]]: """Get connection info from Redis.""" try: connection_key = f"{self.connections_prefix}{client_id}" connection_data = self.redis_client.get(connection_key) if connection_data: return json.loads(connection_data) return None except (redis.RedisError, json.JSONDecodeError) as e: logger.warning(f"Error getting connection from Redis: {e}") return None def create_connection_manager(redis_client: redis.Redis, connection_timeout: int = 300) -> ConnectionManager: """ Factory function to create a ConnectionManager instance. Args: redis_client: Redis client instance connection_timeout: Connection timeout in seconds Returns: ConnectionManager: Configured connection manager """ return ConnectionManager(redis_client, connection_timeout)