Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI webhook server for Hugging Face Spaces | |
| Handles training requests from Vercel and sends results back via callback | |
| """ | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import requests | |
| import os | |
| import asyncio | |
| import threading | |
| import logging | |
| from typing import Optional | |
| from pydantic import BaseModel | |
| import torch | |
| import pickle | |
| from train import train_model, TalmudClassifierLSTM, MAX_LEN, EMBEDDING_DIM, HIDDEN_DIM | |
| from predict import generate_predictions_for_daf | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Talmud Language Classifier Training") | |
| VERCEL_BASE_URL = os.getenv('VERCEL_BASE_URL') | |
| VERCEL_PREVIEW_URL_REGEX = r'https://talmud-annotation-platform-[a-zA-Z0-9]{9}-shmans-projects\.vercel\.app' | |
| # Enable CORS for Vercel callbacks | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[VERCEL_BASE_URL], | |
| allow_origin_regex=VERCEL_PREVIEW_URL_REGEX, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Training state | |
| training_in_progress = False | |
| training_result = None | |
| training_error = None | |
| training_lock = threading.Lock() # Lock to prevent race conditions | |
| class TrainingRequest(BaseModel): | |
| training_data: str | |
| callback_url: str | |
| callback_auth_token: str | |
| timestamp: Optional[str] = None | |
| class PredictionRequest(BaseModel): | |
| daf_text: str | |
| auth_token: str | |
| def run_training_async(training_data: str, callback_url: str, callback_auth_token: str): | |
| """ | |
| Run training in a separate thread to avoid blocking the request. | |
| Trains the model on the provided training data and returns test results | |
| on the ground truth (test set). Does not generate predictions for all dafim. | |
| """ | |
| global training_in_progress, training_result, training_error | |
| try: | |
| # Note: training_in_progress is already set to True by the endpoint | |
| # Reset results for new training run | |
| training_result = None | |
| training_error = None | |
| logger.info("Starting model training...") | |
| # Train the model | |
| result = train_model(training_data) | |
| stats = result['stats'] | |
| logger.info(f"Training completed. Accuracy: {stats['accuracy']:.4f}") | |
| logger.info(f"Test set results - Accuracy: {stats['accuracy']:.4f}, Loss: {stats['loss']:.4f}") | |
| logger.info(f"F1 Scores: {stats['f1_scores']}") | |
| # Prepare callback payload with only stats (test results on ground truth) | |
| callback_payload = { | |
| 'stats': stats, | |
| 'auth_token': callback_auth_token | |
| } | |
| # Send callback to Vercel | |
| logger.info(f"Sending callback to {callback_url}...") | |
| logger.info(f"Callback payload stats: accuracy={stats['accuracy']:.4f}, loss={stats['loss']:.4f}, f1_scores={list(stats['f1_scores'].keys())}") | |
| try: | |
| response = requests.post( | |
| callback_url, | |
| json=callback_payload, | |
| timeout=60, # Reduced timeout since we're not generating predictions | |
| headers={'Content-Type': 'application/json'} | |
| ) | |
| logger.info(f"Callback response status: {response.status_code}") | |
| logger.info(f"Callback response headers: {dict(response.headers)}") | |
| logger.info(f"Callback response body: {response.text}") | |
| if response.status_code == 200: | |
| try: | |
| response_json = response.json() | |
| logger.info(f"Callback response JSON: {response_json}") | |
| if response_json.get('success'): | |
| logger.info("Callback sent successfully and processed by Vercel") | |
| training_result = {'success': True, 'message': 'Training completed'} | |
| else: | |
| logger.warning(f"Callback returned 200 but success=false: {response_json}") | |
| training_error = f"Callback returned success=false: {response_json}" | |
| training_result = {'success': False, 'error': training_error} | |
| except Exception as e: | |
| logger.warning(f"Callback returned 200 but couldn't parse JSON: {e}, body: {response.text}") | |
| # If we can't parse the response, we can't confirm success, so mark as uncertain | |
| training_error = f"Callback returned 200 but response couldn't be parsed: {str(e)}" | |
| training_result = {'success': False, 'error': training_error, 'warning': 'Training completed but callback response unclear'} | |
| else: | |
| logger.error(f"Callback failed with status {response.status_code}: {response.text}") | |
| training_error = f"Callback failed: {response.status_code}" | |
| training_result = {'success': False, 'error': training_error} | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Callback request exception: {str(e)}", exc_info=True) | |
| training_error = f"Callback request failed: {str(e)}" | |
| training_result = {'success': False, 'error': training_error} | |
| except Exception as e: | |
| logger.error(f"Training error: {str(e)}", exc_info=True) | |
| training_error = str(e) | |
| training_result = {'success': False, 'error': str(e)} | |
| # Try to send error callback | |
| try: | |
| callback_payload = { | |
| 'error': str(e), | |
| 'auth_token': callback_auth_token | |
| } | |
| requests.post( | |
| callback_url, | |
| json=callback_payload, | |
| timeout=10 | |
| ) | |
| except: | |
| pass # Ignore callback errors if main training failed | |
| finally: | |
| training_in_progress = False | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "running", | |
| "service": "Talmud Language Classifier Training", | |
| "training_in_progress": training_in_progress | |
| } | |
| async def train_endpoint(request: TrainingRequest): | |
| """ | |
| Training webhook endpoint. | |
| Accepts training data and callback URL, runs training in background. | |
| """ | |
| global training_in_progress | |
| # Use lock to prevent race condition - check and set atomically | |
| with training_lock: | |
| if training_in_progress: | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Training already in progress. Please wait for current training to complete." | |
| ) | |
| # Set flag immediately to prevent concurrent requests | |
| training_in_progress = True | |
| # Validate request after acquiring lock | |
| # Use try-finally to ensure flag is reset if validation fails | |
| try: | |
| if not request.training_data or not request.callback_url: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing training_data or callback_url" | |
| ) | |
| if not request.callback_auth_token: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing callback_auth_token" | |
| ) | |
| except HTTPException: | |
| # Reset flag if validation fails | |
| with training_lock: | |
| training_in_progress = False | |
| raise | |
| logger.info("Received training request") | |
| # Start training in background thread | |
| # Note: training_in_progress is already set to True above | |
| training_thread = threading.Thread( | |
| target=run_training_async, | |
| args=( | |
| request.training_data, | |
| request.callback_url, | |
| request.callback_auth_token | |
| ), | |
| daemon=True | |
| ) | |
| training_thread.start() | |
| # Return immediately | |
| return { | |
| "success": True, | |
| "message": "Training started", | |
| "status": "processing" | |
| } | |
| async def get_status(): | |
| """Get current training status""" | |
| return { | |
| "training_in_progress": training_in_progress, | |
| "result": training_result, | |
| "error": training_error | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy"} | |
| def load_model_artifacts(): | |
| """ | |
| Load model artifacts from /workspace directory (persistent storage). | |
| Falls back to /tmp for backward compatibility. | |
| Returns (model, word_to_idx, label_encoder) or (None, None, None) if not found. | |
| """ | |
| # Try /workspace first (persistent storage) | |
| workspace_dir = '/workspace' | |
| workspace_model_path = os.path.join(workspace_dir, 'latest_model.pt') | |
| # Fallback to /tmp for backward compatibility | |
| if os.path.exists(workspace_model_path): | |
| model_path = workspace_model_path | |
| word_to_idx_path = os.path.join(workspace_dir, 'word_to_idx.pt') | |
| label_encoder_path = os.path.join(workspace_dir, 'label_encoder.pkl') | |
| storage_location = '/workspace' | |
| else: | |
| logger.info("Model not found in /workspace, trying /tmp...") | |
| model_path = '/tmp/latest_model.pt' | |
| word_to_idx_path = '/tmp/word_to_idx.pt' | |
| label_encoder_path = '/tmp/label_encoder.pkl' | |
| storage_location = '/tmp' | |
| try: | |
| # Check if all files exist | |
| if not os.path.exists(model_path) or not os.path.exists(word_to_idx_path) or not os.path.exists(label_encoder_path): | |
| return None, None, None | |
| # Load word_to_idx | |
| word_to_idx = torch.load(word_to_idx_path) | |
| # Validate vocabulary has required keys | |
| if not isinstance(word_to_idx, dict): | |
| raise ValueError(f"word_to_idx must be a dictionary, got {type(word_to_idx)}") | |
| if '<UNK>' not in word_to_idx: | |
| raise ValueError("Vocabulary must contain '<UNK>' key") | |
| if '<PAD>' not in word_to_idx: | |
| raise ValueError("Vocabulary must contain '<PAD>' key") | |
| # Load label_encoder | |
| with open(label_encoder_path, 'rb') as f: | |
| label_encoder = pickle.load(f) | |
| # Validate label_encoder | |
| if not hasattr(label_encoder, 'classes_'): | |
| raise ValueError("label_encoder must have 'classes_' attribute") | |
| # Determine number of classes from label_encoder | |
| num_classes = len(label_encoder.classes_) | |
| if num_classes < 2: | |
| raise ValueError(f"Model must have at least 2 classes, got {num_classes}") | |
| # Create model and load state dict | |
| # Explicitly load on CPU (HF Spaces typically use CPU) | |
| model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| # Check if state dict keys match model architecture | |
| model_keys = set(model.state_dict().keys()) | |
| state_dict_keys = set(state_dict.keys()) | |
| if model_keys != state_dict_keys: | |
| missing_keys = model_keys - state_dict_keys | |
| unexpected_keys = state_dict_keys - model_keys | |
| error_msg = f"Model architecture mismatch. Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) | |
| model.load_state_dict(state_dict, strict=True) | |
| model.eval() | |
| # Ensure model is on CPU | |
| model = model.cpu() | |
| logger.info(f"Successfully loaded model artifacts from {storage_location}") | |
| return model, word_to_idx, label_encoder | |
| except Exception as e: | |
| logger.error(f"Error loading model artifacts: {e}", exc_info=True) | |
| return None, None, None | |
| async def predict_endpoint(request: PredictionRequest): | |
| """ | |
| On-demand prediction endpoint. | |
| Accepts daf text and generates predictions using the latest trained model. | |
| Authentication: Requires TRAINING_CALLBACK_TOKEN to be set in environment variables. | |
| The token must match the auth_token sent in the request body. | |
| """ | |
| # Verify authentication token | |
| # Security: Always require authentication token to match TRAINING_CALLBACK_TOKEN | |
| expected_token = os.getenv('TRAINING_CALLBACK_TOKEN') | |
| if not expected_token: | |
| logger.error("TRAINING_CALLBACK_TOKEN not set in environment - prediction endpoint is insecure!") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Server configuration error: TRAINING_CALLBACK_TOKEN not configured" | |
| ) | |
| if not request.auth_token or request.auth_token != expected_token: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Unauthorized: Invalid authentication token" | |
| ) | |
| if not request.daf_text or not request.daf_text.strip(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing or empty daf_text" | |
| ) | |
| # Load model artifacts | |
| model, word_to_idx, label_encoder = load_model_artifacts() | |
| if model is None or word_to_idx is None or label_encoder is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail="Model not found. Please train a model first by triggering training from your Vercel app." | |
| ) | |
| try: | |
| # Generate predictions | |
| logger.info("Generating predictions for daf text...") | |
| ranges = generate_predictions_for_daf( | |
| model, request.daf_text, word_to_idx, label_encoder | |
| ) | |
| logger.info(f"Generated {len(ranges)} prediction ranges") | |
| return { | |
| "success": True, | |
| "ranges": ranges | |
| } | |
| except Exception as e: | |
| logger.error(f"Error generating predictions: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error generating predictions: {str(e)}" | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |