ALM-2 / backend /websocket /manager.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
WebSocket Manager for AegisLM Real-time Updates.
Manages WebSocket connections for real-time evaluation status updates,
progress tracking, and live notifications.
"""
import json
import logging
from typing import Dict, List, Set, Any
from fastapi import WebSocket, WebSocketDisconnect
from datetime import datetime
import asyncio
logger = logging.getLogger(__name__)
class ConnectionManager:
"""
Manages WebSocket connections for real-time updates.
"""
def __init__(self):
# Store active connections by user_id
self.active_connections: Dict[int, Set[WebSocket]] = {}
# Store connection metadata
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
# Store evaluation subscriptions
self.evaluation_subscribers: Dict[str, Set[int]] = {}
async def connect(self, websocket: WebSocket, user_id: int):
"""Connect a WebSocket client."""
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
self.connection_metadata[websocket] = {
'user_id': user_id,
'connected_at': datetime.utcnow(),
'subscriptions': set()
}
logger.info(f"WebSocket connected for user {user_id}. Total connections: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
"""Disconnect a WebSocket client."""
metadata = self.connection_metadata.get(websocket)
if not metadata:
return
user_id = metadata['user_id']
# Remove from active connections
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
# Remove from evaluation subscriptions
for evaluation_id in metadata.get('subscriptions', set()):
if evaluation_id in self.evaluation_subscribers:
self.evaluation_subscribers[evaluation_id].discard(user_id)
if not self.evaluation_subscribers[evaluation_id]:
del self.evaluation_subscribers[evaluation_id]
# Remove metadata
del self.connection_metadata[websocket]
logger.info(f"WebSocket disconnected for user {user_id}. Total connections: {len(self.active_connections)}")
async def send_personal_message(self, message: str, user_id: int):
"""Send a message to a specific user."""
if user_id not in self.active_connections:
return
disconnected = set()
for connection in self.active_connections[user_id]:
try:
await connection.send_text(message)
except Exception as e:
logger.error(f"Failed to send message to user {user_id}: {str(e)}")
disconnected.add(connection)
# Clean up disconnected connections
for connection in disconnected:
self.disconnect(connection)
async def send_evaluation_update(self, evaluation_id: str, update_data: Dict[str, Any]):
"""Send evaluation update to subscribed users."""
if evaluation_id not in self.evaluation_subscribers:
return
message = {
'type': 'evaluation_update',
'evaluation_id': evaluation_id,
'data': update_data,
'timestamp': datetime.utcnow().isoformat()
}
message_str = json.dumps(message)
for user_id in self.evaluation_subscribers[evaluation_id].copy():
await self.send_personal_message(message_str, user_id)
async def broadcast_message(self, message: str):
"""Broadcast a message to all connected users."""
message_data = {
'type': 'broadcast',
'data': message,
'timestamp': datetime.utcnow().isoformat()
}
message_str = json.dumps(message_data)
all_users = list(self.active_connections.keys())
for user_id in all_users:
await self.send_personal_message(message_str, user_id)
def subscribe_to_evaluation(self, websocket: WebSocket, evaluation_id: str):
"""Subscribe a connection to evaluation updates."""
metadata = self.connection_metadata.get(websocket)
if not metadata:
return
user_id = metadata['user_id']
# Add to evaluation subscribers
if evaluation_id not in self.evaluation_subscribers:
self.evaluation_subscribers[evaluation_id] = set()
self.evaluation_subscribers[evaluation_id].add(user_id)
# Add to connection metadata
metadata['subscriptions'].add(evaluation_id)
logger.info(f"User {user_id} subscribed to evaluation {evaluation_id}")
def unsubscribe_from_evaluation(self, websocket: WebSocket, evaluation_id: str):
"""Unsubscribe a connection from evaluation updates."""
metadata = self.connection_metadata.get(websocket)
if not metadata:
return
user_id = metadata['user_id']
# Remove from evaluation subscribers
if evaluation_id in self.evaluation_subscribers:
self.evaluation_subscribers[evaluation_id].discard(user_id)
if not self.evaluation_subscribers[evaluation_id]:
del self.evaluation_subscribers[evaluation_id]
# Remove from connection metadata
metadata['subscriptions'].discard(evaluation_id)
logger.info(f"User {user_id} unsubscribed from evaluation {evaluation_id}")
def get_connection_stats(self) -> Dict[str, Any]:
"""Get connection statistics."""
return {
'total_connections': sum(len(connections) for connections in self.active_connections.values()),
'unique_users': len(self.active_connections),
'evaluation_subscriptions': len(self.evaluation_subscribers),
'active_evaluations': list(self.evaluation_subscribers.keys())
}
# Global connection manager instance
manager = ConnectionManager()
async def handle_websocket_message(websocket: WebSocket, user_id: int, message: str):
"""Handle incoming WebSocket messages."""
try:
data = json.loads(message)
message_type = data.get('type')
if message_type == 'subscribe_evaluation':
evaluation_id = data.get('evaluation_id')
if evaluation_id:
manager.subscribe_to_evaluation(websocket, evaluation_id)
# Send current status if available
await send_evaluation_status(websocket, evaluation_id)
elif message_type == 'unsubscribe_evaluation':
evaluation_id = data.get('evaluation_id')
if evaluation_id:
manager.unsubscribe_from_evaluation(websocket, evaluation_id)
elif message_type == 'ping':
# Respond with pong for keep-alive
await websocket.send_text(json.dumps({
'type': 'pong',
'timestamp': datetime.utcnow().isoformat()
}))
else:
logger.warning(f"Unknown WebSocket message type: {message_type}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON message from user {user_id}: {message}")
except Exception as e:
logger.error(f"Error handling WebSocket message from user {user_id}: {str(e)}")
async def send_evaluation_status(websocket: WebSocket, evaluation_id: str):
"""Send current evaluation status to a WebSocket client."""
try:
# This would typically fetch from database
# For now, send a placeholder
status_message = {
'type': 'evaluation_status',
'evaluation_id': evaluation_id,
'status': 'loading',
'message': 'Fetching current status...'
}
await websocket.send_text(json.dumps(status_message))
except Exception as e:
logger.error(f"Failed to send evaluation status: {str(e)}")
# Utility functions for sending specific update types
async def send_evaluation_started(evaluation_id: str, user_id: int):
"""Send notification that evaluation started."""
await manager.send_evaluation_update(evaluation_id, {
'status': 'running',
'message': 'Evaluation started',
'progress': 0.0
})
async def send_evaluation_progress(evaluation_id: str, progress: float, current_phase: str = None):
"""Send progress update."""
await manager.send_evaluation_update(evaluation_id, {
'progress': progress,
'current_phase': current_phase
})
async def send_evaluation_completed(evaluation_id: str, results: Dict[str, Any]):
"""Send notification that evaluation completed."""
await manager.send_evaluation_update(evaluation_id, {
'status': 'completed',
'message': 'Evaluation completed successfully',
'progress': 1.0,
'results': results
})
async def send_evaluation_failed(evaluation_id: str, error_message: str):
"""Send notification that evaluation failed."""
await manager.send_evaluation_update(evaluation_id, {
'status': 'failed',
'message': 'Evaluation failed',
'error': error_message
})
async def send_system_notification(message: str, level: str = 'info'):
"""Send system-wide notification."""
await manager.broadcast_message({
'type': 'system_notification',
'message': message,
'level': level
})
def get_connection_manager() -> ConnectionManager:
"""Get the global connection manager instance."""
return manager