| """ |
| WebSocket Manager for AegisLM Real-time Updates. |
| |
| Manages WebSocket connections for real-time evaluation status updates, |
| progress tracking, and live notifications. |
| """ |
|
|
| import json |
| import logging |
| from typing import Dict, List, Set, Any |
| from fastapi import WebSocket, WebSocketDisconnect |
| from datetime import datetime |
| import asyncio |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ConnectionManager: |
| """ |
| Manages WebSocket connections for real-time updates. |
| """ |
| |
| def __init__(self): |
| |
| self.active_connections: Dict[int, Set[WebSocket]] = {} |
| |
| self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {} |
| |
| self.evaluation_subscribers: Dict[str, Set[int]] = {} |
| |
| async def connect(self, websocket: WebSocket, user_id: int): |
| """Connect a WebSocket client.""" |
| await websocket.accept() |
| |
| if user_id not in self.active_connections: |
| self.active_connections[user_id] = set() |
| |
| self.active_connections[user_id].add(websocket) |
| self.connection_metadata[websocket] = { |
| 'user_id': user_id, |
| 'connected_at': datetime.utcnow(), |
| 'subscriptions': set() |
| } |
| |
| logger.info(f"WebSocket connected for user {user_id}. Total connections: {len(self.active_connections)}") |
| |
| def disconnect(self, websocket: WebSocket): |
| """Disconnect a WebSocket client.""" |
| metadata = self.connection_metadata.get(websocket) |
| if not metadata: |
| return |
| |
| user_id = metadata['user_id'] |
| |
| |
| if user_id in self.active_connections: |
| self.active_connections[user_id].discard(websocket) |
| if not self.active_connections[user_id]: |
| del self.active_connections[user_id] |
| |
| |
| for evaluation_id in metadata.get('subscriptions', set()): |
| if evaluation_id in self.evaluation_subscribers: |
| self.evaluation_subscribers[evaluation_id].discard(user_id) |
| if not self.evaluation_subscribers[evaluation_id]: |
| del self.evaluation_subscribers[evaluation_id] |
| |
| |
| del self.connection_metadata[websocket] |
| |
| logger.info(f"WebSocket disconnected for user {user_id}. Total connections: {len(self.active_connections)}") |
| |
| async def send_personal_message(self, message: str, user_id: int): |
| """Send a message to a specific user.""" |
| if user_id not in self.active_connections: |
| return |
| |
| disconnected = set() |
| for connection in self.active_connections[user_id]: |
| try: |
| await connection.send_text(message) |
| except Exception as e: |
| logger.error(f"Failed to send message to user {user_id}: {str(e)}") |
| disconnected.add(connection) |
| |
| |
| for connection in disconnected: |
| self.disconnect(connection) |
| |
| async def send_evaluation_update(self, evaluation_id: str, update_data: Dict[str, Any]): |
| """Send evaluation update to subscribed users.""" |
| if evaluation_id not in self.evaluation_subscribers: |
| return |
| |
| message = { |
| 'type': 'evaluation_update', |
| 'evaluation_id': evaluation_id, |
| 'data': update_data, |
| 'timestamp': datetime.utcnow().isoformat() |
| } |
| |
| message_str = json.dumps(message) |
| |
| for user_id in self.evaluation_subscribers[evaluation_id].copy(): |
| await self.send_personal_message(message_str, user_id) |
| |
| async def broadcast_message(self, message: str): |
| """Broadcast a message to all connected users.""" |
| message_data = { |
| 'type': 'broadcast', |
| 'data': message, |
| 'timestamp': datetime.utcnow().isoformat() |
| } |
| |
| message_str = json.dumps(message_data) |
| |
| all_users = list(self.active_connections.keys()) |
| for user_id in all_users: |
| await self.send_personal_message(message_str, user_id) |
| |
| def subscribe_to_evaluation(self, websocket: WebSocket, evaluation_id: str): |
| """Subscribe a connection to evaluation updates.""" |
| metadata = self.connection_metadata.get(websocket) |
| if not metadata: |
| return |
| |
| user_id = metadata['user_id'] |
| |
| |
| if evaluation_id not in self.evaluation_subscribers: |
| self.evaluation_subscribers[evaluation_id] = set() |
| self.evaluation_subscribers[evaluation_id].add(user_id) |
| |
| |
| metadata['subscriptions'].add(evaluation_id) |
| |
| logger.info(f"User {user_id} subscribed to evaluation {evaluation_id}") |
| |
| def unsubscribe_from_evaluation(self, websocket: WebSocket, evaluation_id: str): |
| """Unsubscribe a connection from evaluation updates.""" |
| metadata = self.connection_metadata.get(websocket) |
| if not metadata: |
| return |
| |
| user_id = metadata['user_id'] |
| |
| |
| if evaluation_id in self.evaluation_subscribers: |
| self.evaluation_subscribers[evaluation_id].discard(user_id) |
| if not self.evaluation_subscribers[evaluation_id]: |
| del self.evaluation_subscribers[evaluation_id] |
| |
| |
| metadata['subscriptions'].discard(evaluation_id) |
| |
| logger.info(f"User {user_id} unsubscribed from evaluation {evaluation_id}") |
| |
| def get_connection_stats(self) -> Dict[str, Any]: |
| """Get connection statistics.""" |
| return { |
| 'total_connections': sum(len(connections) for connections in self.active_connections.values()), |
| 'unique_users': len(self.active_connections), |
| 'evaluation_subscriptions': len(self.evaluation_subscribers), |
| 'active_evaluations': list(self.evaluation_subscribers.keys()) |
| } |
|
|
|
|
| |
| manager = ConnectionManager() |
|
|
|
|
| async def handle_websocket_message(websocket: WebSocket, user_id: int, message: str): |
| """Handle incoming WebSocket messages.""" |
| try: |
| data = json.loads(message) |
| message_type = data.get('type') |
| |
| if message_type == 'subscribe_evaluation': |
| evaluation_id = data.get('evaluation_id') |
| if evaluation_id: |
| manager.subscribe_to_evaluation(websocket, evaluation_id) |
| |
| |
| await send_evaluation_status(websocket, evaluation_id) |
| |
| elif message_type == 'unsubscribe_evaluation': |
| evaluation_id = data.get('evaluation_id') |
| if evaluation_id: |
| manager.unsubscribe_from_evaluation(websocket, evaluation_id) |
| |
| elif message_type == 'ping': |
| |
| await websocket.send_text(json.dumps({ |
| 'type': 'pong', |
| 'timestamp': datetime.utcnow().isoformat() |
| })) |
| |
| else: |
| logger.warning(f"Unknown WebSocket message type: {message_type}") |
| |
| except json.JSONDecodeError: |
| logger.error(f"Invalid JSON message from user {user_id}: {message}") |
| except Exception as e: |
| logger.error(f"Error handling WebSocket message from user {user_id}: {str(e)}") |
|
|
|
|
| async def send_evaluation_status(websocket: WebSocket, evaluation_id: str): |
| """Send current evaluation status to a WebSocket client.""" |
| try: |
| |
| |
| status_message = { |
| 'type': 'evaluation_status', |
| 'evaluation_id': evaluation_id, |
| 'status': 'loading', |
| 'message': 'Fetching current status...' |
| } |
| await websocket.send_text(json.dumps(status_message)) |
| except Exception as e: |
| logger.error(f"Failed to send evaluation status: {str(e)}") |
|
|
|
|
| |
| async def send_evaluation_started(evaluation_id: str, user_id: int): |
| """Send notification that evaluation started.""" |
| await manager.send_evaluation_update(evaluation_id, { |
| 'status': 'running', |
| 'message': 'Evaluation started', |
| 'progress': 0.0 |
| }) |
|
|
|
|
| async def send_evaluation_progress(evaluation_id: str, progress: float, current_phase: str = None): |
| """Send progress update.""" |
| await manager.send_evaluation_update(evaluation_id, { |
| 'progress': progress, |
| 'current_phase': current_phase |
| }) |
|
|
|
|
| async def send_evaluation_completed(evaluation_id: str, results: Dict[str, Any]): |
| """Send notification that evaluation completed.""" |
| await manager.send_evaluation_update(evaluation_id, { |
| 'status': 'completed', |
| 'message': 'Evaluation completed successfully', |
| 'progress': 1.0, |
| 'results': results |
| }) |
|
|
|
|
| async def send_evaluation_failed(evaluation_id: str, error_message: str): |
| """Send notification that evaluation failed.""" |
| await manager.send_evaluation_update(evaluation_id, { |
| 'status': 'failed', |
| 'message': 'Evaluation failed', |
| 'error': error_message |
| }) |
|
|
|
|
| async def send_system_notification(message: str, level: str = 'info'): |
| """Send system-wide notification.""" |
| await manager.broadcast_message({ |
| 'type': 'system_notification', |
| 'message': message, |
| 'level': level |
| }) |
|
|
|
|
| def get_connection_manager() -> ConnectionManager: |
| """Get the global connection manager instance.""" |
| return manager |
|
|