prediqai / app /utils /websocket_utils.py
ganesh-vilje's picture
Deploy to Hugging Face Main
f8f02c0
from typing import Dict, Any, Optional, List
from datetime import datetime
import logging
import traceback
import httpx
from app.core.config import Config
logger = logging.getLogger(__name__)
API_BASE_URL = Config.API_BASE_URL
def format_task_completion_message(
task_id: str,
task_type: str,
status: str,
result: Optional[Dict[str, Any]] = None,
error: Optional[str] = None
) -> Dict[str, Any]:
"""
Format a task completion message for WebSocket notifications.
Args:
task_id: The ID of the task
task_type: Type of task (e.g., 'training', 'prediction', 'validation')
status: Status of task ('completed', 'failed', etc.)
result: Optional result data
error: Optional error message
Returns:
Formatted message dictionary
"""
message = {
"type": "task_completion",
"task_id": task_id,
"task_type": task_type,
"status": status,
"timestamp": datetime.utcnow().isoformat()
}
if result:
message["result"] = result
if error:
message["error"] = error
return message
def format_progress_update(
task_id: str,
task_type: str,
progress: float,
message: Optional[str] = None
) -> Dict[str, Any]:
"""
Format a progress update message for WebSocket notifications.
Args:
task_id: The ID of the task
task_type: Type of task
progress: Progress percentage (0-100)
message: Optional progress message
Returns:
Formatted progress update dictionary
"""
update = {
"type": "progress_update",
"task_id": task_id,
"task_type": task_type,
"progress": progress,
"timestamp": datetime.utcnow().isoformat()
}
if message:
update["message"] = message
return update
def format_training_result(
model_id: str,
metrics: Optional[Dict[str, Any]] = None,
training_time: Optional[float] = None
) -> Dict[str, Any]:
"""
Format training task result data.
Args:
model_id: ID of the trained model
metrics: Optional training metrics
training_time: Optional training duration in seconds
Returns:
Formatted training result dictionary
"""
result = {
"model_id": model_id
}
if metrics:
result["metrics"] = metrics
if training_time:
result["training_time"] = training_time
return result
def format_prediction_result(
predictions: List[Dict[str, Any]],
total_predictions: int,
average_probability: float,
processing_time: float,
predictions_per_second: float,
results_s3_key: Optional[str] = None
) -> Dict[str, Any]:
"""
Format prediction task result data.
Args:
predictions: List of prediction dictionaries with details
total_predictions: Total number of predictions made
average_probability: Average probability across all predictions
processing_time: Time taken for predictions in seconds
predictions_per_second: Throughput metric
results_s3_key: Optional S3 key where results are stored
Returns:
Formatted prediction result dictionary
"""
result = {
"predictions": predictions,
"total_predictions": total_predictions,
"average_probability": average_probability,
"processing_time": processing_time,
"predictions_per_second": predictions_per_second
}
if results_s3_key:
result["results_s3_key"] = results_s3_key
return result
def format_validation_result(
metrics: Dict[str, Any],
detailed_predictions: List[Dict[str, Any]],
accuracy: float,
total_samples: int,
average_similarity_score: float,
results_s3_key: Optional[str] = None
) -> Dict[str, Any]:
"""
Format validation task result data.
Args:
metrics: Validation metrics dictionary
detailed_predictions: List of prediction details with reasoning
accuracy: Overall accuracy score
total_samples: Total number of samples validated
average_similarity_score: Average similarity score
results_s3_key: Optional S3 key where results are stored
Returns:
Formatted validation result dictionary
"""
result = {
"metrics": metrics,
"detailed_predictions": detailed_predictions,
"accuracy": accuracy,
"total_samples": total_samples,
"average_similarity_score": average_similarity_score
}
if results_s3_key:
result["results_s3_key"] = results_s3_key
return result
def format_connection_message(task_id: str) -> Dict[str, Any]:
"""
Format a connection established message.
Args:
task_id: The task ID that was connected to
Returns:
Formatted connection message dictionary
"""
return {
"type": "connection",
"message": f"Connected to task {task_id}",
"task_id": task_id,
"timestamp": datetime.utcnow().isoformat()
}
def format_error_message(
task_id: str,
error: str,
error_type: Optional[str] = None
) -> Dict[str, Any]:
"""
Format an error message for WebSocket notifications.
Args:
task_id: The ID of the task
error: Error message
error_type: Optional error type classification
Returns:
Formatted error message dictionary
"""
message = {
"type": "error",
"task_id": task_id,
"error": error,
"timestamp": datetime.utcnow().isoformat()
}
if error_type:
message["error_type"] = error_type
return message
def format_pong_message() -> Dict[str, Any]:
"""
Format a pong response message.
Returns:
Formatted pong message dictionary
"""
return {
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
}
# =============================================================================
# Notification Helper Functions
# =============================================================================
async def send_task_completion_notification(
websocket_manager,
task_id: str,
task_type: str,
status: str,
result: Optional[Dict[str, Any]] = None,
error: Optional[str] = None
):
"""
Send task completion notification to frontend via WebSocket.
This function works in two contexts:
1. Called from FastAPI routes (same process as WebSocket manager) - direct send
2. Called from Celery tasks (different process) - needs HTTP callback
Args:
websocket_manager: WebSocketManager instance
task_id: Unique task identifier (training_id, prediction_id, validation_id)
task_type: Type of task ('training', 'prediction', 'validation')
status: Task status ('completed', 'failed', 'success')
result: Optional dictionary containing task results
error: Optional error message if task failed
"""
try:
# Use utility function to format message
message = format_task_completion_message(
task_id=task_id,
task_type=task_type,
status=status,
result=result,
error=error
)
# Try to send directly to WebSocket manager
await websocket_manager.send_message(task_id, message)
logger.info(f"[WEBSOCKET] Sent {status} notification for {task_type} task: {task_id}")
# CRITICAL: Send HTTP callback to FastAPI process
logger.warning(f"[WEBSOCKET DEBUG] About to send HTTP callback for task: {task_id}")
try:
logger.warning(f"[WEBSOCKET DEBUG] httpx imported, creating client...")
async with httpx.AsyncClient(timeout=5.0) as client:
callback_url = f"{API_BASE_URL}/v1/websocket/notify"
logger.warning(f"[WEBSOCKET DEBUG] Sending POST to {callback_url}")
response = await client.post(
callback_url,
json=message,
headers={"X-Internal-Call": "true"}
)
logger.warning(f"[WEBSOCKET HTTP CALLBACK] SUCCESS! Response: {response.status_code}, task: {task_id}")
except Exception as http_error:
logger.error(f"[WEBSOCKET HTTP CALLBACK] FAILED for task {task_id}: {type(http_error).__name__}: {str(http_error)}")
logger.error(f"[WEBSOCKET HTTP CALLBACK] Traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"[WEBSOCKET] Failed to send notification for task {task_id}: {str(e)}")
async def send_training_completion(
websocket_manager,
training_id: str,
status: str,
model_id: Optional[str] = None,
metrics: Optional[Dict[str, Any]] = None,
error: Optional[str] = None
):
"""
Send training completion notification.
Args:
websocket_manager: WebSocketManager instance
training_id: Training task identifier
status: Training status ('completed', 'failed')
model_id: Optional model identifier
metrics: Optional training metrics
error: Optional error message
"""
result = None
if model_id or metrics:
result = format_training_result(
model_id=model_id,
metrics=metrics
)
await send_task_completion_notification(
websocket_manager=websocket_manager,
task_id=training_id,
task_type="training",
status=status,
result=result,
error=error
)
async def send_prediction_completion(
websocket_manager,
prediction_id: str,
status: str,
total_predictions: Optional[int] = None,
average_probability: Optional[float] = None,
results_s3_key: Optional[str] = None,
predictions: Optional[list] = None,
processing_time: Optional[float] = None,
predictions_per_second: Optional[float] = None,
error: Optional[str] = None
):
"""
Send prediction completion notification with complete results.
Args:
websocket_manager: WebSocketManager instance
prediction_id: Prediction task identifier
status: Prediction status ('completed', 'failed')
total_predictions: Total number of predictions made
average_probability: Average prediction probability
results_s3_key: S3 key for results
predictions: Detailed predictions with reasoning
processing_time: Time taken to process predictions
predictions_per_second: Throughput metric
error: Optional error message
"""
result = None
if any([total_predictions is not None, average_probability is not None,
results_s3_key, predictions, processing_time is not None,
predictions_per_second is not None]):
result = format_prediction_result(
predictions=predictions or [],
total_predictions=total_predictions or 0,
average_probability=average_probability or 0.0,
processing_time=processing_time or 0.0,
predictions_per_second=predictions_per_second or 0.0,
results_s3_key=results_s3_key
)
await send_task_completion_notification(
websocket_manager=websocket_manager,
task_id=prediction_id,
task_type="prediction",
status=status,
result=result,
error=error
)
async def send_validation_completion(
websocket_manager,
validation_id: str,
status: str,
metrics: Optional[Dict[str, Any]] = None,
results_s3_key: Optional[str] = None,
accuracy: Optional[float] = None,
total_samples: Optional[int] = None,
average_similarity_score: Optional[float] = None,
detailed_predictions: Optional[list] = None,
error: Optional[str] = None
):
"""
Send validation completion notification with complete results.
Args:
websocket_manager: WebSocketManager instance
validation_id: Validation task identifier
status: Validation status ('completed', 'failed')
metrics: Optional validation metrics (classification report)
results_s3_key: Optional S3 key for results
accuracy: Overall accuracy
total_samples: Total number of samples validated
average_similarity_score: Average similarity score for reasoning
detailed_predictions: Detailed predictions with reasoning
error: Optional error message
"""
result = None
if any([metrics, accuracy is not None, total_samples is not None,
average_similarity_score is not None, detailed_predictions, results_s3_key]):
result = format_validation_result(
metrics=metrics or {},
detailed_predictions=detailed_predictions or [],
accuracy=accuracy or 0.0,
total_samples=total_samples or 0,
average_similarity_score=average_similarity_score or 0.0,
results_s3_key=results_s3_key
)
await send_task_completion_notification(
websocket_manager=websocket_manager,
task_id=validation_id,
task_type="validation",
status=status,
result=result,
error=error
)
async def send_progress_update(
websocket_manager,
task_id: str,
task_type: str,
progress: float,
message: Optional[str] = None
):
"""
Send task progress update.
Args:
websocket_manager: WebSocketManager instance
task_id: Task identifier
task_type: Type of task ('training', 'prediction', 'validation')
progress: Progress percentage (0-100)
message: Optional progress message
"""
try:
notification = {
"type": "progress_update",
"task_id": task_id,
"task_type": task_type,
"progress": progress,
"timestamp": datetime.utcnow().isoformat()
}
if message:
notification["message"] = message
await websocket_manager.send_message(task_id, notification)
logger.debug(f"[WEBSOCKET] Sent progress update for task {task_id}: {progress}%")
except Exception as e:
logger.error(f"[WEBSOCKET] Failed to send progress update for task {task_id}: {str(e)}")