from typing import Dict, Any, Optional, List from datetime import datetime import logging import traceback import httpx from app.core.config import Config logger = logging.getLogger(__name__) API_BASE_URL = Config.API_BASE_URL def format_task_completion_message( task_id: str, task_type: str, status: str, result: Optional[Dict[str, Any]] = None, error: Optional[str] = None ) -> Dict[str, Any]: """ Format a task completion message for WebSocket notifications. Args: task_id: The ID of the task task_type: Type of task (e.g., 'training', 'prediction', 'validation') status: Status of task ('completed', 'failed', etc.) result: Optional result data error: Optional error message Returns: Formatted message dictionary """ message = { "type": "task_completion", "task_id": task_id, "task_type": task_type, "status": status, "timestamp": datetime.utcnow().isoformat() } if result: message["result"] = result if error: message["error"] = error return message def format_progress_update( task_id: str, task_type: str, progress: float, message: Optional[str] = None ) -> Dict[str, Any]: """ Format a progress update message for WebSocket notifications. Args: task_id: The ID of the task task_type: Type of task progress: Progress percentage (0-100) message: Optional progress message Returns: Formatted progress update dictionary """ update = { "type": "progress_update", "task_id": task_id, "task_type": task_type, "progress": progress, "timestamp": datetime.utcnow().isoformat() } if message: update["message"] = message return update def format_training_result( model_id: str, metrics: Optional[Dict[str, Any]] = None, training_time: Optional[float] = None ) -> Dict[str, Any]: """ Format training task result data. Args: model_id: ID of the trained model metrics: Optional training metrics training_time: Optional training duration in seconds Returns: Formatted training result dictionary """ result = { "model_id": model_id } if metrics: result["metrics"] = metrics if training_time: result["training_time"] = training_time return result def format_prediction_result( predictions: List[Dict[str, Any]], total_predictions: int, average_probability: float, processing_time: float, predictions_per_second: float, results_s3_key: Optional[str] = None ) -> Dict[str, Any]: """ Format prediction task result data. Args: predictions: List of prediction dictionaries with details total_predictions: Total number of predictions made average_probability: Average probability across all predictions processing_time: Time taken for predictions in seconds predictions_per_second: Throughput metric results_s3_key: Optional S3 key where results are stored Returns: Formatted prediction result dictionary """ result = { "predictions": predictions, "total_predictions": total_predictions, "average_probability": average_probability, "processing_time": processing_time, "predictions_per_second": predictions_per_second } if results_s3_key: result["results_s3_key"] = results_s3_key return result def format_validation_result( metrics: Dict[str, Any], detailed_predictions: List[Dict[str, Any]], accuracy: float, total_samples: int, average_similarity_score: float, results_s3_key: Optional[str] = None ) -> Dict[str, Any]: """ Format validation task result data. Args: metrics: Validation metrics dictionary detailed_predictions: List of prediction details with reasoning accuracy: Overall accuracy score total_samples: Total number of samples validated average_similarity_score: Average similarity score results_s3_key: Optional S3 key where results are stored Returns: Formatted validation result dictionary """ result = { "metrics": metrics, "detailed_predictions": detailed_predictions, "accuracy": accuracy, "total_samples": total_samples, "average_similarity_score": average_similarity_score } if results_s3_key: result["results_s3_key"] = results_s3_key return result def format_connection_message(task_id: str) -> Dict[str, Any]: """ Format a connection established message. Args: task_id: The task ID that was connected to Returns: Formatted connection message dictionary """ return { "type": "connection", "message": f"Connected to task {task_id}", "task_id": task_id, "timestamp": datetime.utcnow().isoformat() } def format_error_message( task_id: str, error: str, error_type: Optional[str] = None ) -> Dict[str, Any]: """ Format an error message for WebSocket notifications. Args: task_id: The ID of the task error: Error message error_type: Optional error type classification Returns: Formatted error message dictionary """ message = { "type": "error", "task_id": task_id, "error": error, "timestamp": datetime.utcnow().isoformat() } if error_type: message["error_type"] = error_type return message def format_pong_message() -> Dict[str, Any]: """ Format a pong response message. Returns: Formatted pong message dictionary """ return { "type": "pong", "timestamp": datetime.utcnow().isoformat() } # ============================================================================= # Notification Helper Functions # ============================================================================= async def send_task_completion_notification( websocket_manager, task_id: str, task_type: str, status: str, result: Optional[Dict[str, Any]] = None, error: Optional[str] = None ): """ Send task completion notification to frontend via WebSocket. This function works in two contexts: 1. Called from FastAPI routes (same process as WebSocket manager) - direct send 2. Called from Celery tasks (different process) - needs HTTP callback Args: websocket_manager: WebSocketManager instance task_id: Unique task identifier (training_id, prediction_id, validation_id) task_type: Type of task ('training', 'prediction', 'validation') status: Task status ('completed', 'failed', 'success') result: Optional dictionary containing task results error: Optional error message if task failed """ try: # Use utility function to format message message = format_task_completion_message( task_id=task_id, task_type=task_type, status=status, result=result, error=error ) # Try to send directly to WebSocket manager await websocket_manager.send_message(task_id, message) logger.info(f"[WEBSOCKET] Sent {status} notification for {task_type} task: {task_id}") # CRITICAL: Send HTTP callback to FastAPI process logger.warning(f"[WEBSOCKET DEBUG] About to send HTTP callback for task: {task_id}") try: logger.warning(f"[WEBSOCKET DEBUG] httpx imported, creating client...") async with httpx.AsyncClient(timeout=5.0) as client: callback_url = f"{API_BASE_URL}/v1/websocket/notify" logger.warning(f"[WEBSOCKET DEBUG] Sending POST to {callback_url}") response = await client.post( callback_url, json=message, headers={"X-Internal-Call": "true"} ) logger.warning(f"[WEBSOCKET HTTP CALLBACK] SUCCESS! Response: {response.status_code}, task: {task_id}") except Exception as http_error: logger.error(f"[WEBSOCKET HTTP CALLBACK] FAILED for task {task_id}: {type(http_error).__name__}: {str(http_error)}") logger.error(f"[WEBSOCKET HTTP CALLBACK] Traceback: {traceback.format_exc()}") except Exception as e: logger.error(f"[WEBSOCKET] Failed to send notification for task {task_id}: {str(e)}") async def send_training_completion( websocket_manager, training_id: str, status: str, model_id: Optional[str] = None, metrics: Optional[Dict[str, Any]] = None, error: Optional[str] = None ): """ Send training completion notification. Args: websocket_manager: WebSocketManager instance training_id: Training task identifier status: Training status ('completed', 'failed') model_id: Optional model identifier metrics: Optional training metrics error: Optional error message """ result = None if model_id or metrics: result = format_training_result( model_id=model_id, metrics=metrics ) await send_task_completion_notification( websocket_manager=websocket_manager, task_id=training_id, task_type="training", status=status, result=result, error=error ) async def send_prediction_completion( websocket_manager, prediction_id: str, status: str, total_predictions: Optional[int] = None, average_probability: Optional[float] = None, results_s3_key: Optional[str] = None, predictions: Optional[list] = None, processing_time: Optional[float] = None, predictions_per_second: Optional[float] = None, error: Optional[str] = None ): """ Send prediction completion notification with complete results. Args: websocket_manager: WebSocketManager instance prediction_id: Prediction task identifier status: Prediction status ('completed', 'failed') total_predictions: Total number of predictions made average_probability: Average prediction probability results_s3_key: S3 key for results predictions: Detailed predictions with reasoning processing_time: Time taken to process predictions predictions_per_second: Throughput metric error: Optional error message """ result = None if any([total_predictions is not None, average_probability is not None, results_s3_key, predictions, processing_time is not None, predictions_per_second is not None]): result = format_prediction_result( predictions=predictions or [], total_predictions=total_predictions or 0, average_probability=average_probability or 0.0, processing_time=processing_time or 0.0, predictions_per_second=predictions_per_second or 0.0, results_s3_key=results_s3_key ) await send_task_completion_notification( websocket_manager=websocket_manager, task_id=prediction_id, task_type="prediction", status=status, result=result, error=error ) async def send_validation_completion( websocket_manager, validation_id: str, status: str, metrics: Optional[Dict[str, Any]] = None, results_s3_key: Optional[str] = None, accuracy: Optional[float] = None, total_samples: Optional[int] = None, average_similarity_score: Optional[float] = None, detailed_predictions: Optional[list] = None, error: Optional[str] = None ): """ Send validation completion notification with complete results. Args: websocket_manager: WebSocketManager instance validation_id: Validation task identifier status: Validation status ('completed', 'failed') metrics: Optional validation metrics (classification report) results_s3_key: Optional S3 key for results accuracy: Overall accuracy total_samples: Total number of samples validated average_similarity_score: Average similarity score for reasoning detailed_predictions: Detailed predictions with reasoning error: Optional error message """ result = None if any([metrics, accuracy is not None, total_samples is not None, average_similarity_score is not None, detailed_predictions, results_s3_key]): result = format_validation_result( metrics=metrics or {}, detailed_predictions=detailed_predictions or [], accuracy=accuracy or 0.0, total_samples=total_samples or 0, average_similarity_score=average_similarity_score or 0.0, results_s3_key=results_s3_key ) await send_task_completion_notification( websocket_manager=websocket_manager, task_id=validation_id, task_type="validation", status=status, result=result, error=error ) async def send_progress_update( websocket_manager, task_id: str, task_type: str, progress: float, message: Optional[str] = None ): """ Send task progress update. Args: websocket_manager: WebSocketManager instance task_id: Task identifier task_type: Type of task ('training', 'prediction', 'validation') progress: Progress percentage (0-100) message: Optional progress message """ try: notification = { "type": "progress_update", "task_id": task_id, "task_type": task_type, "progress": progress, "timestamp": datetime.utcnow().isoformat() } if message: notification["message"] = message await websocket_manager.send_message(task_id, notification) logger.debug(f"[WEBSOCKET] Sent progress update for task {task_id}: {progress}%") except Exception as e: logger.error(f"[WEBSOCKET] Failed to send progress update for task {task_id}: {str(e)}")