Spaces:
Runtime error
Runtime error
| from typing import Dict, Any, Optional, List | |
| from datetime import datetime | |
| import logging | |
| import traceback | |
| import httpx | |
| from app.core.config import Config | |
| logger = logging.getLogger(__name__) | |
| API_BASE_URL = Config.API_BASE_URL | |
| def format_task_completion_message( | |
| task_id: str, | |
| task_type: str, | |
| status: str, | |
| result: Optional[Dict[str, Any]] = None, | |
| error: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Format a task completion message for WebSocket notifications. | |
| Args: | |
| task_id: The ID of the task | |
| task_type: Type of task (e.g., 'training', 'prediction', 'validation') | |
| status: Status of task ('completed', 'failed', etc.) | |
| result: Optional result data | |
| error: Optional error message | |
| Returns: | |
| Formatted message dictionary | |
| """ | |
| message = { | |
| "type": "task_completion", | |
| "task_id": task_id, | |
| "task_type": task_type, | |
| "status": status, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| if result: | |
| message["result"] = result | |
| if error: | |
| message["error"] = error | |
| return message | |
| def format_progress_update( | |
| task_id: str, | |
| task_type: str, | |
| progress: float, | |
| message: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Format a progress update message for WebSocket notifications. | |
| Args: | |
| task_id: The ID of the task | |
| task_type: Type of task | |
| progress: Progress percentage (0-100) | |
| message: Optional progress message | |
| Returns: | |
| Formatted progress update dictionary | |
| """ | |
| update = { | |
| "type": "progress_update", | |
| "task_id": task_id, | |
| "task_type": task_type, | |
| "progress": progress, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| if message: | |
| update["message"] = message | |
| return update | |
| def format_training_result( | |
| model_id: str, | |
| metrics: Optional[Dict[str, Any]] = None, | |
| training_time: Optional[float] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Format training task result data. | |
| Args: | |
| model_id: ID of the trained model | |
| metrics: Optional training metrics | |
| training_time: Optional training duration in seconds | |
| Returns: | |
| Formatted training result dictionary | |
| """ | |
| result = { | |
| "model_id": model_id | |
| } | |
| if metrics: | |
| result["metrics"] = metrics | |
| if training_time: | |
| result["training_time"] = training_time | |
| return result | |
| def format_prediction_result( | |
| predictions: List[Dict[str, Any]], | |
| total_predictions: int, | |
| average_probability: float, | |
| processing_time: float, | |
| predictions_per_second: float, | |
| results_s3_key: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Format prediction task result data. | |
| Args: | |
| predictions: List of prediction dictionaries with details | |
| total_predictions: Total number of predictions made | |
| average_probability: Average probability across all predictions | |
| processing_time: Time taken for predictions in seconds | |
| predictions_per_second: Throughput metric | |
| results_s3_key: Optional S3 key where results are stored | |
| Returns: | |
| Formatted prediction result dictionary | |
| """ | |
| result = { | |
| "predictions": predictions, | |
| "total_predictions": total_predictions, | |
| "average_probability": average_probability, | |
| "processing_time": processing_time, | |
| "predictions_per_second": predictions_per_second | |
| } | |
| if results_s3_key: | |
| result["results_s3_key"] = results_s3_key | |
| return result | |
| def format_validation_result( | |
| metrics: Dict[str, Any], | |
| detailed_predictions: List[Dict[str, Any]], | |
| accuracy: float, | |
| total_samples: int, | |
| average_similarity_score: float, | |
| results_s3_key: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Format validation task result data. | |
| Args: | |
| metrics: Validation metrics dictionary | |
| detailed_predictions: List of prediction details with reasoning | |
| accuracy: Overall accuracy score | |
| total_samples: Total number of samples validated | |
| average_similarity_score: Average similarity score | |
| results_s3_key: Optional S3 key where results are stored | |
| Returns: | |
| Formatted validation result dictionary | |
| """ | |
| result = { | |
| "metrics": metrics, | |
| "detailed_predictions": detailed_predictions, | |
| "accuracy": accuracy, | |
| "total_samples": total_samples, | |
| "average_similarity_score": average_similarity_score | |
| } | |
| if results_s3_key: | |
| result["results_s3_key"] = results_s3_key | |
| return result | |
| def format_connection_message(task_id: str) -> Dict[str, Any]: | |
| """ | |
| Format a connection established message. | |
| Args: | |
| task_id: The task ID that was connected to | |
| Returns: | |
| Formatted connection message dictionary | |
| """ | |
| return { | |
| "type": "connection", | |
| "message": f"Connected to task {task_id}", | |
| "task_id": task_id, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| def format_error_message( | |
| task_id: str, | |
| error: str, | |
| error_type: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Format an error message for WebSocket notifications. | |
| Args: | |
| task_id: The ID of the task | |
| error: Error message | |
| error_type: Optional error type classification | |
| Returns: | |
| Formatted error message dictionary | |
| """ | |
| message = { | |
| "type": "error", | |
| "task_id": task_id, | |
| "error": error, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| if error_type: | |
| message["error_type"] = error_type | |
| return message | |
| def format_pong_message() -> Dict[str, Any]: | |
| """ | |
| Format a pong response message. | |
| Returns: | |
| Formatted pong message dictionary | |
| """ | |
| return { | |
| "type": "pong", | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| # ============================================================================= | |
| # Notification Helper Functions | |
| # ============================================================================= | |
| async def send_task_completion_notification( | |
| websocket_manager, | |
| task_id: str, | |
| task_type: str, | |
| status: str, | |
| result: Optional[Dict[str, Any]] = None, | |
| error: Optional[str] = None | |
| ): | |
| """ | |
| Send task completion notification to frontend via WebSocket. | |
| This function works in two contexts: | |
| 1. Called from FastAPI routes (same process as WebSocket manager) - direct send | |
| 2. Called from Celery tasks (different process) - needs HTTP callback | |
| Args: | |
| websocket_manager: WebSocketManager instance | |
| task_id: Unique task identifier (training_id, prediction_id, validation_id) | |
| task_type: Type of task ('training', 'prediction', 'validation') | |
| status: Task status ('completed', 'failed', 'success') | |
| result: Optional dictionary containing task results | |
| error: Optional error message if task failed | |
| """ | |
| try: | |
| # Use utility function to format message | |
| message = format_task_completion_message( | |
| task_id=task_id, | |
| task_type=task_type, | |
| status=status, | |
| result=result, | |
| error=error | |
| ) | |
| # Try to send directly to WebSocket manager | |
| await websocket_manager.send_message(task_id, message) | |
| logger.info(f"[WEBSOCKET] Sent {status} notification for {task_type} task: {task_id}") | |
| # CRITICAL: Send HTTP callback to FastAPI process | |
| logger.warning(f"[WEBSOCKET DEBUG] About to send HTTP callback for task: {task_id}") | |
| try: | |
| logger.warning(f"[WEBSOCKET DEBUG] httpx imported, creating client...") | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| callback_url = f"{API_BASE_URL}/v1/websocket/notify" | |
| logger.warning(f"[WEBSOCKET DEBUG] Sending POST to {callback_url}") | |
| response = await client.post( | |
| callback_url, | |
| json=message, | |
| headers={"X-Internal-Call": "true"} | |
| ) | |
| logger.warning(f"[WEBSOCKET HTTP CALLBACK] SUCCESS! Response: {response.status_code}, task: {task_id}") | |
| except Exception as http_error: | |
| logger.error(f"[WEBSOCKET HTTP CALLBACK] FAILED for task {task_id}: {type(http_error).__name__}: {str(http_error)}") | |
| logger.error(f"[WEBSOCKET HTTP CALLBACK] Traceback: {traceback.format_exc()}") | |
| except Exception as e: | |
| logger.error(f"[WEBSOCKET] Failed to send notification for task {task_id}: {str(e)}") | |
| async def send_training_completion( | |
| websocket_manager, | |
| training_id: str, | |
| status: str, | |
| model_id: Optional[str] = None, | |
| metrics: Optional[Dict[str, Any]] = None, | |
| error: Optional[str] = None | |
| ): | |
| """ | |
| Send training completion notification. | |
| Args: | |
| websocket_manager: WebSocketManager instance | |
| training_id: Training task identifier | |
| status: Training status ('completed', 'failed') | |
| model_id: Optional model identifier | |
| metrics: Optional training metrics | |
| error: Optional error message | |
| """ | |
| result = None | |
| if model_id or metrics: | |
| result = format_training_result( | |
| model_id=model_id, | |
| metrics=metrics | |
| ) | |
| await send_task_completion_notification( | |
| websocket_manager=websocket_manager, | |
| task_id=training_id, | |
| task_type="training", | |
| status=status, | |
| result=result, | |
| error=error | |
| ) | |
| async def send_prediction_completion( | |
| websocket_manager, | |
| prediction_id: str, | |
| status: str, | |
| total_predictions: Optional[int] = None, | |
| average_probability: Optional[float] = None, | |
| results_s3_key: Optional[str] = None, | |
| predictions: Optional[list] = None, | |
| processing_time: Optional[float] = None, | |
| predictions_per_second: Optional[float] = None, | |
| error: Optional[str] = None | |
| ): | |
| """ | |
| Send prediction completion notification with complete results. | |
| Args: | |
| websocket_manager: WebSocketManager instance | |
| prediction_id: Prediction task identifier | |
| status: Prediction status ('completed', 'failed') | |
| total_predictions: Total number of predictions made | |
| average_probability: Average prediction probability | |
| results_s3_key: S3 key for results | |
| predictions: Detailed predictions with reasoning | |
| processing_time: Time taken to process predictions | |
| predictions_per_second: Throughput metric | |
| error: Optional error message | |
| """ | |
| result = None | |
| if any([total_predictions is not None, average_probability is not None, | |
| results_s3_key, predictions, processing_time is not None, | |
| predictions_per_second is not None]): | |
| result = format_prediction_result( | |
| predictions=predictions or [], | |
| total_predictions=total_predictions or 0, | |
| average_probability=average_probability or 0.0, | |
| processing_time=processing_time or 0.0, | |
| predictions_per_second=predictions_per_second or 0.0, | |
| results_s3_key=results_s3_key | |
| ) | |
| await send_task_completion_notification( | |
| websocket_manager=websocket_manager, | |
| task_id=prediction_id, | |
| task_type="prediction", | |
| status=status, | |
| result=result, | |
| error=error | |
| ) | |
| async def send_validation_completion( | |
| websocket_manager, | |
| validation_id: str, | |
| status: str, | |
| metrics: Optional[Dict[str, Any]] = None, | |
| results_s3_key: Optional[str] = None, | |
| accuracy: Optional[float] = None, | |
| total_samples: Optional[int] = None, | |
| average_similarity_score: Optional[float] = None, | |
| detailed_predictions: Optional[list] = None, | |
| error: Optional[str] = None | |
| ): | |
| """ | |
| Send validation completion notification with complete results. | |
| Args: | |
| websocket_manager: WebSocketManager instance | |
| validation_id: Validation task identifier | |
| status: Validation status ('completed', 'failed') | |
| metrics: Optional validation metrics (classification report) | |
| results_s3_key: Optional S3 key for results | |
| accuracy: Overall accuracy | |
| total_samples: Total number of samples validated | |
| average_similarity_score: Average similarity score for reasoning | |
| detailed_predictions: Detailed predictions with reasoning | |
| error: Optional error message | |
| """ | |
| result = None | |
| if any([metrics, accuracy is not None, total_samples is not None, | |
| average_similarity_score is not None, detailed_predictions, results_s3_key]): | |
| result = format_validation_result( | |
| metrics=metrics or {}, | |
| detailed_predictions=detailed_predictions or [], | |
| accuracy=accuracy or 0.0, | |
| total_samples=total_samples or 0, | |
| average_similarity_score=average_similarity_score or 0.0, | |
| results_s3_key=results_s3_key | |
| ) | |
| await send_task_completion_notification( | |
| websocket_manager=websocket_manager, | |
| task_id=validation_id, | |
| task_type="validation", | |
| status=status, | |
| result=result, | |
| error=error | |
| ) | |
| async def send_progress_update( | |
| websocket_manager, | |
| task_id: str, | |
| task_type: str, | |
| progress: float, | |
| message: Optional[str] = None | |
| ): | |
| """ | |
| Send task progress update. | |
| Args: | |
| websocket_manager: WebSocketManager instance | |
| task_id: Task identifier | |
| task_type: Type of task ('training', 'prediction', 'validation') | |
| progress: Progress percentage (0-100) | |
| message: Optional progress message | |
| """ | |
| try: | |
| notification = { | |
| "type": "progress_update", | |
| "task_id": task_id, | |
| "task_type": task_type, | |
| "progress": progress, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| if message: | |
| notification["message"] = message | |
| await websocket_manager.send_message(task_id, notification) | |
| logger.debug(f"[WEBSOCKET] Sent progress update for task {task_id}: {progress}%") | |
| except Exception as e: | |
| logger.error(f"[WEBSOCKET] Failed to send progress update for task {task_id}: {str(e)}") | |