| """ |
| WebSocket Handler for AegisLM Real-time Updates. |
| |
| Provides real-time evaluation status updates, progress tracking, |
| and live result streaming to connected frontend clients. |
| """ |
|
|
| import json |
| import asyncio |
| import logging |
| from typing import Dict, Set, Optional, Any |
| from datetime import datetime |
| from fastapi import WebSocket, WebSocketDisconnect |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from sqlalchemy import select |
|
|
| from db_models.evaluation import Evaluation, EvaluationStatus |
| from core.security import verify_websocket_token |
| from services.evaluation_service import EvaluationService |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class WebSocketManager: |
| """Manages WebSocket connections and room-based broadcasting.""" |
| |
| def __init__(self): |
| self.active_connections: Dict[str, WebSocket] = {} |
| self.rooms: Dict[str, Set[str]] = {} |
| self.connection_rooms: Dict[str, str] = {} |
| |
| async def connect(self, websocket: WebSocket, connection_id: str): |
| """Accept and register WebSocket connection.""" |
| await websocket.accept() |
| self.active_connections[connection_id] = websocket |
| logger.info(f"✅ WebSocket connected: {connection_id}") |
| |
| async def disconnect(self, connection_id: str): |
| """Remove WebSocket connection.""" |
| if connection_id in self.active_connections: |
| |
| if connection_id in self.connection_rooms: |
| room_id = self.connection_rooms[connection_id] |
| if room_id in self.rooms and connection_id in self.rooms[room_id]: |
| self.rooms[room_id].remove(connection_id) |
| if not self.rooms[room_id]: |
| del self.rooms[room_id] |
| del self.connection_rooms[connection_id] |
| |
| |
| del self.active_connections[connection_id] |
| logger.info(f"❌ WebSocket disconnected: {connection_id}") |
| |
| async def join_room(self, connection_id: str, room_id: str): |
| """Add connection to a room.""" |
| if room_id not in self.rooms: |
| self.rooms[room_id] = set() |
| |
| |
| if connection_id in self.connection_rooms: |
| old_room = self.connection_rooms[connection_id] |
| if old_room in self.rooms and connection_id in self.rooms[old_room]: |
| self.rooms[old_room].remove(connection_id) |
| |
| |
| self.rooms[room_id].add(connection_id) |
| self.connection_rooms[connection_id] = room_id |
| logger.info(f"📡 Connection {connection_id} joined room {room_id}") |
| |
| async def leave_room(self, connection_id: str): |
| """Remove connection from current room.""" |
| if connection_id in self.connection_rooms: |
| room_id = self.connection_rooms[connection_id] |
| if room_id in self.rooms and connection_id in self.rooms[room_id]: |
| self.rooms[room_id].remove(connection_id) |
| if not self.rooms[room_id]: |
| del self.rooms[room_id] |
| del self.connection_rooms[connection_id] |
| |
| async def send_to_connection(self, connection_id: str, message: dict): |
| """Send message to specific connection.""" |
| if connection_id in self.active_connections: |
| websocket = self.active_connections[connection_id] |
| try: |
| await websocket.send_text(json.dumps(message)) |
| except Exception as e: |
| logger.error(f"❌ Failed to send to {connection_id}: {e}") |
| await self.disconnect(connection_id) |
| |
| async def broadcast_to_room(self, room_id: str, message: dict): |
| """Broadcast message to all connections in a room.""" |
| if room_id in self.rooms: |
| disconnected = set() |
| for connection_id in self.rooms[room_id]: |
| if connection_id in self.active_connections: |
| websocket = self.active_connections[connection_id] |
| try: |
| await websocket.send_text(json.dumps(message)) |
| except Exception as e: |
| logger.error(f"❌ Failed to send to {connection_id}: {e}") |
| disconnected.add(connection_id) |
| else: |
| disconnected.add(connection_id) |
| |
| |
| for connection_id in disconnected: |
| await self.disconnect(connection_id) |
|
|
|
|
| |
| websocket_manager = WebSocketManager() |
|
|
|
|
| class EvaluationWebSocketHandler: |
| """Handles WebSocket connections for evaluation updates.""" |
| |
| def __init__(self, db: AsyncSession): |
| self.db = db |
| self.evaluation_service = EvaluationService(db) |
| |
| async def handle_connection( |
| self, |
| websocket: WebSocket, |
| job_id: str, |
| token: Optional[str] = None |
| ): |
| """Handle WebSocket connection for evaluation updates.""" |
| connection_id = f"eval_{job_id}_{datetime.utcnow().timestamp()}" |
| |
| try: |
| |
| user = await self._authenticate_websocket(websocket, token) |
| if not user: |
| await websocket.close(code=4001, reason="Authentication failed") |
| return |
| |
| |
| evaluation = await self._get_evaluation(job_id, user.id) |
| if not evaluation: |
| await websocket.close(code=4004, reason="Evaluation not found") |
| return |
| |
| |
| await websocket_manager.connect(websocket, connection_id) |
| await websocket_manager.join_room(connection_id, f"evaluation_{job_id}") |
| |
| |
| initial_status = await self._get_evaluation_status(job_id) |
| await websocket_manager.send_to_connection(connection_id, { |
| "type": "status_update", |
| "data": initial_status |
| }) |
| |
| |
| await self._status_monitoring_loop(connection_id, job_id) |
| |
| except WebSocketDisconnect: |
| logger.info(f"🔌 WebSocket disconnected for evaluation {job_id}") |
| except Exception as e: |
| logger.error(f"❌ WebSocket error for evaluation {job_id}: {e}") |
| await websocket.close(code=4000, reason="Internal server error") |
| finally: |
| await websocket_manager.disconnect(connection_id) |
| |
| async def _authenticate_websocket(self, websocket: WebSocket, token: Optional[str]) -> Optional[Any]: |
| """Authenticate WebSocket connection.""" |
| if not token: |
| |
| token = websocket.query_params.get("token") if hasattr(websocket, 'query_params') else None |
| |
| if not token: |
| return None |
| |
| try: |
| from db_models.user import User |
| user_id = verify_websocket_token(token) |
| if user_id: |
| user = await self.db.get(User, user_id) |
| return user |
| except Exception as e: |
| logger.error(f"❌ WebSocket authentication failed: {e}") |
| |
| return None |
| |
| async def _get_evaluation(self, job_id: str, user_id: int) -> Optional[Evaluation]: |
| """Get evaluation and verify user access.""" |
| try: |
| result = await self.db.execute( |
| select(Evaluation).where( |
| Evaluation.job_id == job_id, |
| Evaluation.user_id == user_id |
| ) |
| ) |
| return result.scalar_one_or_none() |
| except Exception as e: |
| logger.error(f"❌ Failed to get evaluation {job_id}: {e}") |
| return None |
| |
| async def _get_evaluation_status(self, job_id: str) -> dict: |
| """Get current evaluation status.""" |
| try: |
| evaluation = await self.evaluation_service.get_evaluation_by_job_id(job_id, 0) |
| |
| if not evaluation: |
| return { |
| "job_id": job_id, |
| "status": "not_found", |
| "progress": 0.0, |
| "error_message": "Evaluation not found" |
| } |
| |
| |
| progress = 0.0 |
| if evaluation.status == EvaluationStatus.COMPLETED: |
| progress = 100.0 |
| elif evaluation.status == EvaluationStatus.RUNNING: |
| |
| if evaluation.started_at and evaluation.execution_time_ms: |
| elapsed = (datetime.utcnow() - evaluation.started_at).total_seconds() * 1000 |
| progress = min(95.0, (elapsed / evaluation.execution_time_ms) * 100) |
| else: |
| progress = 25.0 |
| elif evaluation.status == EvaluationStatus.FAILED: |
| progress = 0.0 |
| |
| return { |
| "job_id": evaluation.job_id, |
| "status": evaluation.status.value, |
| "progress": progress, |
| "started_at": evaluation.started_at.isoformat() if evaluation.started_at else None, |
| "completed_at": evaluation.completed_at.isoformat() if evaluation.completed_at else None, |
| "error_message": evaluation.error_message, |
| "total_attacks": evaluation.total_attacks, |
| "successful_attacks": evaluation.successful_attacks, |
| "current_phase": self._get_current_phase(evaluation.status) |
| } |
| except Exception as e: |
| logger.error(f"❌ Failed to get evaluation status {job_id}: {e}") |
| return { |
| "job_id": job_id, |
| "status": "error", |
| "progress": 0.0, |
| "error_message": str(e) |
| } |
| |
| def _get_current_phase(self, status: EvaluationStatus) -> str: |
| """Get current execution phase based on status.""" |
| phase_mapping = { |
| EvaluationStatus.PENDING: "queued", |
| EvaluationStatus.RUNNING: "executing_attacks", |
| EvaluationStatus.COMPLETED: "completed", |
| EvaluationStatus.FAILED: "failed", |
| EvaluationStatus.CANCELLED: "cancelled" |
| } |
| return phase_mapping.get(status, "unknown") |
| |
| async def _status_monitoring_loop(self, connection_id: str, job_id: str): |
| """Monitor evaluation status and send updates.""" |
| last_status = None |
| |
| while connection_id in websocket_manager.active_connections: |
| try: |
| current_status = await self._get_evaluation_status(job_id) |
| |
| |
| if current_status != last_status: |
| await websocket_manager.send_to_connection(connection_id, { |
| "type": "status_update", |
| "data": current_status |
| }) |
| last_status = current_status |
| |
| |
| if current_status["status"] in ["completed", "failed", "cancelled"]: |
| |
| await websocket_manager.send_to_connection(connection_id, { |
| "type": "evaluation_complete", |
| "data": current_status |
| }) |
| break |
| |
| |
| await asyncio.sleep(5) |
| |
| except Exception as e: |
| logger.error(f"❌ Status monitoring error for {job_id}: {e}") |
| break |
| |
| async def broadcast_evaluation_update(self, job_id: str, update_data: dict): |
| """Broadcast evaluation update to all connected clients.""" |
| room_id = f"evaluation_{job_id}" |
| await websocket_manager.broadcast_to_room(room_id, { |
| "type": "evaluation_update", |
| "data": update_data |
| }) |
| |
| async def broadcast_progress_update(self, job_id: str, progress: float, phase: str): |
| """Broadcast progress update.""" |
| room_id = f"evaluation_{job_id}" |
| await websocket_manager.broadcast_to_room(room_id, { |
| "type": "progress_update", |
| "data": { |
| "job_id": job_id, |
| "progress": progress, |
| "current_phase": phase, |
| "timestamp": datetime.utcnow().isoformat() |
| } |
| }) |
|
|
|
|
| |
| def get_websocket_handler(db: AsyncSession) -> EvaluationWebSocketHandler: |
| """Get WebSocket handler instance.""" |
| return EvaluationWebSocketHandler(db) |
|
|
|
|
| |
| async def broadcast_evaluation_status(job_id: str, status: str, progress: float = None): |
| """Broadcast evaluation status update.""" |
| update_data = { |
| "job_id": job_id, |
| "status": status, |
| "timestamp": datetime.utcnow().isoformat() |
| } |
| |
| if progress is not None: |
| update_data["progress"] = progress |
| |
| room_id = f"evaluation_{job_id}" |
| await websocket_manager.broadcast_to_room(room_id, { |
| "type": "status_update", |
| "data": update_data |
| }) |
|
|
|
|
| async def broadcast_evaluation_completed(job_id: str, results: dict): |
| """Broadcast evaluation completion with results.""" |
| room_id = f"evaluation_{job_id}" |
| await websocket_manager.broadcast_to_room(room_id, { |
| "type": "evaluation_completed", |
| "data": { |
| "job_id": job_id, |
| "status": "completed", |
| "results": results, |
| "timestamp": datetime.utcnow().isoformat() |
| } |
| }) |
|
|