""" WebSocket Manager for AegisLM Real-time Updates. Manages WebSocket connections for real-time evaluation status updates, progress tracking, and live notifications. """ import json import logging from typing import Dict, List, Set, Any from fastapi import WebSocket, WebSocketDisconnect from datetime import datetime import asyncio logger = logging.getLogger(__name__) class ConnectionManager: """ Manages WebSocket connections for real-time updates. """ def __init__(self): # Store active connections by user_id self.active_connections: Dict[int, Set[WebSocket]] = {} # Store connection metadata self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {} # Store evaluation subscriptions self.evaluation_subscribers: Dict[str, Set[int]] = {} async def connect(self, websocket: WebSocket, user_id: int): """Connect a WebSocket client.""" await websocket.accept() if user_id not in self.active_connections: self.active_connections[user_id] = set() self.active_connections[user_id].add(websocket) self.connection_metadata[websocket] = { 'user_id': user_id, 'connected_at': datetime.utcnow(), 'subscriptions': set() } logger.info(f"WebSocket connected for user {user_id}. Total connections: {len(self.active_connections)}") def disconnect(self, websocket: WebSocket): """Disconnect a WebSocket client.""" metadata = self.connection_metadata.get(websocket) if not metadata: return user_id = metadata['user_id'] # Remove from active connections if user_id in self.active_connections: self.active_connections[user_id].discard(websocket) if not self.active_connections[user_id]: del self.active_connections[user_id] # Remove from evaluation subscriptions for evaluation_id in metadata.get('subscriptions', set()): if evaluation_id in self.evaluation_subscribers: self.evaluation_subscribers[evaluation_id].discard(user_id) if not self.evaluation_subscribers[evaluation_id]: del self.evaluation_subscribers[evaluation_id] # Remove metadata del self.connection_metadata[websocket] logger.info(f"WebSocket disconnected for user {user_id}. Total connections: {len(self.active_connections)}") async def send_personal_message(self, message: str, user_id: int): """Send a message to a specific user.""" if user_id not in self.active_connections: return disconnected = set() for connection in self.active_connections[user_id]: try: await connection.send_text(message) except Exception as e: logger.error(f"Failed to send message to user {user_id}: {str(e)}") disconnected.add(connection) # Clean up disconnected connections for connection in disconnected: self.disconnect(connection) async def send_evaluation_update(self, evaluation_id: str, update_data: Dict[str, Any]): """Send evaluation update to subscribed users.""" if evaluation_id not in self.evaluation_subscribers: return message = { 'type': 'evaluation_update', 'evaluation_id': evaluation_id, 'data': update_data, 'timestamp': datetime.utcnow().isoformat() } message_str = json.dumps(message) for user_id in self.evaluation_subscribers[evaluation_id].copy(): await self.send_personal_message(message_str, user_id) async def broadcast_message(self, message: str): """Broadcast a message to all connected users.""" message_data = { 'type': 'broadcast', 'data': message, 'timestamp': datetime.utcnow().isoformat() } message_str = json.dumps(message_data) all_users = list(self.active_connections.keys()) for user_id in all_users: await self.send_personal_message(message_str, user_id) def subscribe_to_evaluation(self, websocket: WebSocket, evaluation_id: str): """Subscribe a connection to evaluation updates.""" metadata = self.connection_metadata.get(websocket) if not metadata: return user_id = metadata['user_id'] # Add to evaluation subscribers if evaluation_id not in self.evaluation_subscribers: self.evaluation_subscribers[evaluation_id] = set() self.evaluation_subscribers[evaluation_id].add(user_id) # Add to connection metadata metadata['subscriptions'].add(evaluation_id) logger.info(f"User {user_id} subscribed to evaluation {evaluation_id}") def unsubscribe_from_evaluation(self, websocket: WebSocket, evaluation_id: str): """Unsubscribe a connection from evaluation updates.""" metadata = self.connection_metadata.get(websocket) if not metadata: return user_id = metadata['user_id'] # Remove from evaluation subscribers if evaluation_id in self.evaluation_subscribers: self.evaluation_subscribers[evaluation_id].discard(user_id) if not self.evaluation_subscribers[evaluation_id]: del self.evaluation_subscribers[evaluation_id] # Remove from connection metadata metadata['subscriptions'].discard(evaluation_id) logger.info(f"User {user_id} unsubscribed from evaluation {evaluation_id}") def get_connection_stats(self) -> Dict[str, Any]: """Get connection statistics.""" return { 'total_connections': sum(len(connections) for connections in self.active_connections.values()), 'unique_users': len(self.active_connections), 'evaluation_subscriptions': len(self.evaluation_subscribers), 'active_evaluations': list(self.evaluation_subscribers.keys()) } # Global connection manager instance manager = ConnectionManager() async def handle_websocket_message(websocket: WebSocket, user_id: int, message: str): """Handle incoming WebSocket messages.""" try: data = json.loads(message) message_type = data.get('type') if message_type == 'subscribe_evaluation': evaluation_id = data.get('evaluation_id') if evaluation_id: manager.subscribe_to_evaluation(websocket, evaluation_id) # Send current status if available await send_evaluation_status(websocket, evaluation_id) elif message_type == 'unsubscribe_evaluation': evaluation_id = data.get('evaluation_id') if evaluation_id: manager.unsubscribe_from_evaluation(websocket, evaluation_id) elif message_type == 'ping': # Respond with pong for keep-alive await websocket.send_text(json.dumps({ 'type': 'pong', 'timestamp': datetime.utcnow().isoformat() })) else: logger.warning(f"Unknown WebSocket message type: {message_type}") except json.JSONDecodeError: logger.error(f"Invalid JSON message from user {user_id}: {message}") except Exception as e: logger.error(f"Error handling WebSocket message from user {user_id}: {str(e)}") async def send_evaluation_status(websocket: WebSocket, evaluation_id: str): """Send current evaluation status to a WebSocket client.""" try: # This would typically fetch from database # For now, send a placeholder status_message = { 'type': 'evaluation_status', 'evaluation_id': evaluation_id, 'status': 'loading', 'message': 'Fetching current status...' } await websocket.send_text(json.dumps(status_message)) except Exception as e: logger.error(f"Failed to send evaluation status: {str(e)}") # Utility functions for sending specific update types async def send_evaluation_started(evaluation_id: str, user_id: int): """Send notification that evaluation started.""" await manager.send_evaluation_update(evaluation_id, { 'status': 'running', 'message': 'Evaluation started', 'progress': 0.0 }) async def send_evaluation_progress(evaluation_id: str, progress: float, current_phase: str = None): """Send progress update.""" await manager.send_evaluation_update(evaluation_id, { 'progress': progress, 'current_phase': current_phase }) async def send_evaluation_completed(evaluation_id: str, results: Dict[str, Any]): """Send notification that evaluation completed.""" await manager.send_evaluation_update(evaluation_id, { 'status': 'completed', 'message': 'Evaluation completed successfully', 'progress': 1.0, 'results': results }) async def send_evaluation_failed(evaluation_id: str, error_message: str): """Send notification that evaluation failed.""" await manager.send_evaluation_update(evaluation_id, { 'status': 'failed', 'message': 'Evaluation failed', 'error': error_message }) async def send_system_notification(message: str, level: str = 'info'): """Send system-wide notification.""" await manager.broadcast_message({ 'type': 'system_notification', 'message': message, 'level': level }) def get_connection_manager() -> ConnectionManager: """Get the global connection manager instance.""" return manager