Spaces:
Runtime error
Runtime error
| """ | |
| Connection management for WebSocket communications. | |
| This module manages WebSocket connections, tracks connection status, | |
| and provides utilities for connection lifecycle management. | |
| """ | |
| import logging | |
| from datetime import datetime, timedelta | |
| from typing import Dict, Any, Optional, List, Set | |
| from threading import Lock | |
| import json | |
| import redis | |
| logger = logging.getLogger(__name__) | |
| class ConnectionManagerError(Exception): | |
| """Base exception for connection manager errors.""" | |
| pass | |
| class ConnectionManager: | |
| """Manages WebSocket connections and their lifecycle.""" | |
| def __init__(self, redis_client: redis.Redis, connection_timeout: int = 300): | |
| """ | |
| Initialize the connection manager. | |
| Args: | |
| redis_client: Redis client for connection persistence | |
| connection_timeout: Connection timeout in seconds (default: 5 minutes) | |
| """ | |
| self.redis_client = redis_client | |
| self.connection_timeout = connection_timeout | |
| self.connections_prefix = "ws_connection:" | |
| self.session_connections_prefix = "session_connections:" | |
| self.user_connections_prefix = "user_connections:" | |
| # In-memory cache for active connections (faster access) | |
| self._active_connections: Dict[str, Dict[str, Any]] = {} | |
| self._connections_lock = Lock() | |
| def add_connection(self, client_id: str, connection_info: Dict[str, Any]) -> None: | |
| """ | |
| Add a new WebSocket connection. | |
| Args: | |
| client_id: Unique client identifier (socket ID) | |
| connection_info: Connection information dictionary | |
| Raises: | |
| ConnectionManagerError: If connection cannot be added | |
| """ | |
| try: | |
| with self._connections_lock: | |
| # Add to in-memory cache | |
| self._active_connections[client_id] = connection_info.copy() | |
| # Persist to Redis | |
| connection_key = f"{self.connections_prefix}{client_id}" | |
| self.redis_client.setex( | |
| connection_key, | |
| self.connection_timeout, | |
| json.dumps(connection_info) | |
| ) | |
| # Add to session connections set | |
| session_id = connection_info['session_id'] | |
| session_connections_key = f"{self.session_connections_prefix}{session_id}" | |
| self.redis_client.sadd(session_connections_key, client_id) | |
| self.redis_client.expire(session_connections_key, self.connection_timeout) | |
| # Add to user connections set | |
| user_id = connection_info['user_id'] | |
| user_connections_key = f"{self.user_connections_prefix}{user_id}" | |
| self.redis_client.sadd(user_connections_key, client_id) | |
| self.redis_client.expire(user_connections_key, self.connection_timeout) | |
| logger.info(f"Added connection {client_id} for session {session_id}") | |
| except redis.RedisError as e: | |
| logger.error(f"Redis error adding connection {client_id}: {e}") | |
| # Keep in memory even if Redis fails | |
| except Exception as e: | |
| logger.error(f"Error adding connection {client_id}: {e}") | |
| raise ConnectionManagerError(f"Failed to add connection: {e}") | |
| def remove_connection(self, client_id: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Remove a WebSocket connection. | |
| Args: | |
| client_id: Client identifier to remove | |
| Returns: | |
| Optional[Dict[str, Any]]: Connection info if found, None otherwise | |
| """ | |
| try: | |
| with self._connections_lock: | |
| # Get connection info before removal | |
| connection_info = self._active_connections.get(client_id) | |
| if not connection_info: | |
| # Try to get from Redis | |
| connection_info = self._get_connection_from_redis(client_id) | |
| if connection_info: | |
| # Remove from in-memory cache | |
| self._active_connections.pop(client_id, None) | |
| # Remove from Redis | |
| connection_key = f"{self.connections_prefix}{client_id}" | |
| self.redis_client.delete(connection_key) | |
| # Remove from session connections set | |
| session_id = connection_info['session_id'] | |
| session_connections_key = f"{self.session_connections_prefix}{session_id}" | |
| self.redis_client.srem(session_connections_key, client_id) | |
| # Remove from user connections set | |
| user_id = connection_info['user_id'] | |
| user_connections_key = f"{self.user_connections_prefix}{user_id}" | |
| self.redis_client.srem(user_connections_key, client_id) | |
| logger.info(f"Removed connection {client_id}") | |
| return connection_info | |
| except redis.RedisError as e: | |
| logger.error(f"Redis error removing connection {client_id}: {e}") | |
| # Still remove from memory | |
| with self._connections_lock: | |
| return self._active_connections.pop(client_id, None) | |
| except Exception as e: | |
| logger.error(f"Error removing connection {client_id}: {e}") | |
| return None | |
| def get_connection(self, client_id: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get connection information. | |
| Args: | |
| client_id: Client identifier | |
| Returns: | |
| Optional[Dict[str, Any]]: Connection info if found, None otherwise | |
| """ | |
| try: | |
| with self._connections_lock: | |
| # Check in-memory cache first | |
| connection_info = self._active_connections.get(client_id) | |
| if connection_info: | |
| return connection_info.copy() | |
| # Check Redis | |
| connection_info = self._get_connection_from_redis(client_id) | |
| if connection_info: | |
| # Cache in memory | |
| self._active_connections[client_id] = connection_info.copy() | |
| return connection_info | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error getting connection {client_id}: {e}") | |
| return None | |
| def update_connection(self, client_id: str, connection_info: Dict[str, Any]) -> bool: | |
| """ | |
| Update connection information. | |
| Args: | |
| client_id: Client identifier | |
| connection_info: Updated connection information | |
| Returns: | |
| bool: True if updated successfully, False otherwise | |
| """ | |
| try: | |
| with self._connections_lock: | |
| # Update in-memory cache | |
| if client_id in self._active_connections: | |
| self._active_connections[client_id] = connection_info.copy() | |
| # Update in Redis | |
| connection_key = f"{self.connections_prefix}{client_id}" | |
| self.redis_client.setex( | |
| connection_key, | |
| self.connection_timeout, | |
| json.dumps(connection_info) | |
| ) | |
| logger.debug(f"Updated connection {client_id}") | |
| return True | |
| return False | |
| except redis.RedisError as e: | |
| logger.error(f"Redis error updating connection {client_id}: {e}") | |
| # Update in memory only | |
| with self._connections_lock: | |
| if client_id in self._active_connections: | |
| self._active_connections[client_id] = connection_info.copy() | |
| return True | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error updating connection {client_id}: {e}") | |
| return False | |
| def update_connection_activity(self, client_id: str) -> bool: | |
| """ | |
| Update connection activity timestamp. | |
| Args: | |
| client_id: Client identifier | |
| Returns: | |
| bool: True if updated successfully, False otherwise | |
| """ | |
| try: | |
| connection_info = self.get_connection(client_id) | |
| if connection_info: | |
| connection_info['last_activity'] = datetime.utcnow().isoformat() | |
| return self.update_connection(client_id, connection_info) | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error updating connection activity {client_id}: {e}") | |
| return False | |
| def get_session_connections(self, session_id: str) -> List[str]: | |
| """ | |
| Get all connection IDs for a session. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| List[str]: List of client IDs connected to the session | |
| """ | |
| try: | |
| session_connections_key = f"{self.session_connections_prefix}{session_id}" | |
| client_ids = self.redis_client.smembers(session_connections_key) | |
| # Convert bytes to strings and filter active connections | |
| active_client_ids = [] | |
| for client_id_bytes in client_ids: | |
| client_id = client_id_bytes.decode('utf-8') | |
| if self.get_connection(client_id): | |
| active_client_ids.append(client_id) | |
| else: | |
| # Clean up stale reference | |
| self.redis_client.srem(session_connections_key, client_id) | |
| return active_client_ids | |
| except redis.RedisError as e: | |
| logger.error(f"Redis error getting session connections: {e}") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error getting session connections: {e}") | |
| return [] | |
| def get_user_connections(self, user_id: str) -> List[str]: | |
| """ | |
| Get all connection IDs for a user. | |
| Args: | |
| user_id: User identifier | |
| Returns: | |
| List[str]: List of client IDs connected for the user | |
| """ | |
| try: | |
| user_connections_key = f"{self.user_connections_prefix}{user_id}" | |
| client_ids = self.redis_client.smembers(user_connections_key) | |
| # Convert bytes to strings and filter active connections | |
| active_client_ids = [] | |
| for client_id_bytes in client_ids: | |
| client_id = client_id_bytes.decode('utf-8') | |
| if self.get_connection(client_id): | |
| active_client_ids.append(client_id) | |
| else: | |
| # Clean up stale reference | |
| self.redis_client.srem(user_connections_key, client_id) | |
| return active_client_ids | |
| except redis.RedisError as e: | |
| logger.error(f"Redis error getting user connections: {e}") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error getting user connections: {e}") | |
| return [] | |
| def get_all_connections(self) -> Dict[str, Dict[str, Any]]: | |
| """ | |
| Get all active connections. | |
| Returns: | |
| Dict[str, Dict[str, Any]]: Dictionary of client_id -> connection_info | |
| """ | |
| try: | |
| with self._connections_lock: | |
| return self._active_connections.copy() | |
| except Exception as e: | |
| logger.error(f"Error getting all connections: {e}") | |
| return {} | |
| def cleanup_expired_connections(self) -> int: | |
| """ | |
| Clean up expired connections. | |
| Returns: | |
| int: Number of connections cleaned up | |
| """ | |
| try: | |
| cleaned_count = 0 | |
| cutoff_time = datetime.utcnow() - timedelta(seconds=self.connection_timeout) | |
| with self._connections_lock: | |
| expired_client_ids = [] | |
| for client_id, connection_info in self._active_connections.items(): | |
| try: | |
| connected_at = datetime.fromisoformat(connection_info['connected_at']) | |
| last_activity = connection_info.get('last_activity') | |
| if last_activity: | |
| last_activity_time = datetime.fromisoformat(last_activity) | |
| if last_activity_time < cutoff_time: | |
| expired_client_ids.append(client_id) | |
| elif connected_at < cutoff_time: | |
| expired_client_ids.append(client_id) | |
| except (ValueError, KeyError): | |
| # Invalid timestamp, mark for cleanup | |
| expired_client_ids.append(client_id) | |
| # Remove expired connections | |
| for client_id in expired_client_ids: | |
| self.remove_connection(client_id) | |
| cleaned_count += 1 | |
| if cleaned_count > 0: | |
| logger.info(f"Cleaned up {cleaned_count} expired connections") | |
| return cleaned_count | |
| except Exception as e: | |
| logger.error(f"Error cleaning up expired connections: {e}") | |
| return 0 | |
| def get_connection_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get connection statistics. | |
| Returns: | |
| Dict[str, Any]: Connection statistics | |
| """ | |
| try: | |
| with self._connections_lock: | |
| total_connections = len(self._active_connections) | |
| # Count connections by session and user | |
| sessions = set() | |
| users = set() | |
| languages = {} | |
| for connection_info in self._active_connections.values(): | |
| sessions.add(connection_info['session_id']) | |
| users.add(connection_info['user_id']) | |
| language = connection_info.get('language', 'unknown') | |
| languages[language] = languages.get(language, 0) + 1 | |
| return { | |
| 'total_connections': total_connections, | |
| 'unique_sessions': len(sessions), | |
| 'unique_users': len(users), | |
| 'languages': languages, | |
| 'timestamp': datetime.utcnow().isoformat() | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting connection stats: {e}") | |
| return { | |
| 'total_connections': 0, | |
| 'unique_sessions': 0, | |
| 'unique_users': 0, | |
| 'languages': {}, | |
| 'timestamp': datetime.utcnow().isoformat() | |
| } | |
| def _get_connection_from_redis(self, client_id: str) -> Optional[Dict[str, Any]]: | |
| """Get connection info from Redis.""" | |
| try: | |
| connection_key = f"{self.connections_prefix}{client_id}" | |
| connection_data = self.redis_client.get(connection_key) | |
| if connection_data: | |
| return json.loads(connection_data) | |
| return None | |
| except (redis.RedisError, json.JSONDecodeError) as e: | |
| logger.warning(f"Error getting connection from Redis: {e}") | |
| return None | |
| def create_connection_manager(redis_client: redis.Redis, | |
| connection_timeout: int = 300) -> ConnectionManager: | |
| """ | |
| Factory function to create a ConnectionManager instance. | |
| Args: | |
| redis_client: Redis client instance | |
| connection_timeout: Connection timeout in seconds | |
| Returns: | |
| ConnectionManager: Configured connection manager | |
| """ | |
| return ConnectionManager(redis_client, connection_timeout) |