""" FastAPI webhook server for Hugging Face Spaces Handles training requests from Vercel and sends results back via callback """ from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import requests import os import asyncio import threading import logging from typing import Optional from pydantic import BaseModel import torch import pickle from train import train_model, TalmudClassifierLSTM, MAX_LEN, EMBEDDING_DIM, HIDDEN_DIM from predict import generate_predictions_for_daf # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Talmud Language Classifier Training") VERCEL_BASE_URL = os.getenv('VERCEL_BASE_URL') VERCEL_PREVIEW_URL_REGEX = r'https://talmud-annotation-platform-[a-zA-Z0-9]{9}-shmans-projects\.vercel\.app' # Enable CORS for Vercel callbacks app.add_middleware( CORSMiddleware, allow_origins=[VERCEL_BASE_URL], allow_origin_regex=VERCEL_PREVIEW_URL_REGEX, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Training state training_in_progress = False training_result = None training_error = None training_lock = threading.Lock() # Lock to prevent race conditions class TrainingRequest(BaseModel): training_data: str callback_url: str callback_auth_token: str timestamp: Optional[str] = None class PredictionRequest(BaseModel): daf_text: str auth_token: str def run_training_async(training_data: str, callback_url: str, callback_auth_token: str): """ Run training in a separate thread to avoid blocking the request. Trains the model on the provided training data and returns test results on the ground truth (test set). Does not generate predictions for all dafim. """ global training_in_progress, training_result, training_error try: # Note: training_in_progress is already set to True by the endpoint # Reset results for new training run training_result = None training_error = None logger.info("Starting model training...") # Train the model result = train_model(training_data) stats = result['stats'] logger.info(f"Training completed. Accuracy: {stats['accuracy']:.4f}") logger.info(f"Test set results - Accuracy: {stats['accuracy']:.4f}, Loss: {stats['loss']:.4f}") logger.info(f"F1 Scores: {stats['f1_scores']}") # Prepare callback payload with only stats (test results on ground truth) callback_payload = { 'stats': stats, 'auth_token': callback_auth_token } # Send callback to Vercel logger.info(f"Sending callback to {callback_url}...") logger.info(f"Callback payload stats: accuracy={stats['accuracy']:.4f}, loss={stats['loss']:.4f}, f1_scores={list(stats['f1_scores'].keys())}") try: response = requests.post( callback_url, json=callback_payload, timeout=60, # Reduced timeout since we're not generating predictions headers={'Content-Type': 'application/json'} ) logger.info(f"Callback response status: {response.status_code}") logger.info(f"Callback response headers: {dict(response.headers)}") logger.info(f"Callback response body: {response.text}") if response.status_code == 200: try: response_json = response.json() logger.info(f"Callback response JSON: {response_json}") if response_json.get('success'): logger.info("Callback sent successfully and processed by Vercel") training_result = {'success': True, 'message': 'Training completed'} else: logger.warning(f"Callback returned 200 but success=false: {response_json}") training_error = f"Callback returned success=false: {response_json}" training_result = {'success': False, 'error': training_error} except Exception as e: logger.warning(f"Callback returned 200 but couldn't parse JSON: {e}, body: {response.text}") # If we can't parse the response, we can't confirm success, so mark as uncertain training_error = f"Callback returned 200 but response couldn't be parsed: {str(e)}" training_result = {'success': False, 'error': training_error, 'warning': 'Training completed but callback response unclear'} else: logger.error(f"Callback failed with status {response.status_code}: {response.text}") training_error = f"Callback failed: {response.status_code}" training_result = {'success': False, 'error': training_error} except requests.exceptions.RequestException as e: logger.error(f"Callback request exception: {str(e)}", exc_info=True) training_error = f"Callback request failed: {str(e)}" training_result = {'success': False, 'error': training_error} except Exception as e: logger.error(f"Training error: {str(e)}", exc_info=True) training_error = str(e) training_result = {'success': False, 'error': str(e)} # Try to send error callback try: callback_payload = { 'error': str(e), 'auth_token': callback_auth_token } requests.post( callback_url, json=callback_payload, timeout=10 ) except: pass # Ignore callback errors if main training failed finally: training_in_progress = False @app.get("/") async def root(): """Health check endpoint""" return { "status": "running", "service": "Talmud Language Classifier Training", "training_in_progress": training_in_progress } @app.post("/train") async def train_endpoint(request: TrainingRequest): """ Training webhook endpoint. Accepts training data and callback URL, runs training in background. """ global training_in_progress # Use lock to prevent race condition - check and set atomically with training_lock: if training_in_progress: raise HTTPException( status_code=429, detail="Training already in progress. Please wait for current training to complete." ) # Set flag immediately to prevent concurrent requests training_in_progress = True # Validate request after acquiring lock # Use try-finally to ensure flag is reset if validation fails try: if not request.training_data or not request.callback_url: raise HTTPException( status_code=400, detail="Missing training_data or callback_url" ) if not request.callback_auth_token: raise HTTPException( status_code=400, detail="Missing callback_auth_token" ) except HTTPException: # Reset flag if validation fails with training_lock: training_in_progress = False raise logger.info("Received training request") # Start training in background thread # Note: training_in_progress is already set to True above training_thread = threading.Thread( target=run_training_async, args=( request.training_data, request.callback_url, request.callback_auth_token ), daemon=True ) training_thread.start() # Return immediately return { "success": True, "message": "Training started", "status": "processing" } @app.get("/status") async def get_status(): """Get current training status""" return { "training_in_progress": training_in_progress, "result": training_result, "error": training_error } @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} def load_model_artifacts(): """ Load model artifacts from /workspace directory (persistent storage). Falls back to /tmp for backward compatibility. Returns (model, word_to_idx, label_encoder) or (None, None, None) if not found. """ # Try /workspace first (persistent storage) workspace_dir = '/workspace' workspace_model_path = os.path.join(workspace_dir, 'latest_model.pt') # Fallback to /tmp for backward compatibility if os.path.exists(workspace_model_path): model_path = workspace_model_path word_to_idx_path = os.path.join(workspace_dir, 'word_to_idx.pt') label_encoder_path = os.path.join(workspace_dir, 'label_encoder.pkl') storage_location = '/workspace' else: logger.info("Model not found in /workspace, trying /tmp...") model_path = '/tmp/latest_model.pt' word_to_idx_path = '/tmp/word_to_idx.pt' label_encoder_path = '/tmp/label_encoder.pkl' storage_location = '/tmp' try: # Check if all files exist if not os.path.exists(model_path) or not os.path.exists(word_to_idx_path) or not os.path.exists(label_encoder_path): return None, None, None # Load word_to_idx word_to_idx = torch.load(word_to_idx_path) # Validate vocabulary has required keys if not isinstance(word_to_idx, dict): raise ValueError(f"word_to_idx must be a dictionary, got {type(word_to_idx)}") if '' not in word_to_idx: raise ValueError("Vocabulary must contain '' key") if '' not in word_to_idx: raise ValueError("Vocabulary must contain '' key") # Load label_encoder with open(label_encoder_path, 'rb') as f: label_encoder = pickle.load(f) # Validate label_encoder if not hasattr(label_encoder, 'classes_'): raise ValueError("label_encoder must have 'classes_' attribute") # Determine number of classes from label_encoder num_classes = len(label_encoder.classes_) if num_classes < 2: raise ValueError(f"Model must have at least 2 classes, got {num_classes}") # Create model and load state dict # Explicitly load on CPU (HF Spaces typically use CPU) model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes) state_dict = torch.load(model_path, map_location='cpu') # Check if state dict keys match model architecture model_keys = set(model.state_dict().keys()) state_dict_keys = set(state_dict.keys()) if model_keys != state_dict_keys: missing_keys = model_keys - state_dict_keys unexpected_keys = state_dict_keys - model_keys error_msg = f"Model architecture mismatch. Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}" logger.error(error_msg) raise RuntimeError(error_msg) model.load_state_dict(state_dict, strict=True) model.eval() # Ensure model is on CPU model = model.cpu() logger.info(f"Successfully loaded model artifacts from {storage_location}") return model, word_to_idx, label_encoder except Exception as e: logger.error(f"Error loading model artifacts: {e}", exc_info=True) return None, None, None @app.post("/predict") async def predict_endpoint(request: PredictionRequest): """ On-demand prediction endpoint. Accepts daf text and generates predictions using the latest trained model. Authentication: Requires TRAINING_CALLBACK_TOKEN to be set in environment variables. The token must match the auth_token sent in the request body. """ # Verify authentication token # Security: Always require authentication token to match TRAINING_CALLBACK_TOKEN expected_token = os.getenv('TRAINING_CALLBACK_TOKEN') if not expected_token: logger.error("TRAINING_CALLBACK_TOKEN not set in environment - prediction endpoint is insecure!") raise HTTPException( status_code=500, detail="Server configuration error: TRAINING_CALLBACK_TOKEN not configured" ) if not request.auth_token or request.auth_token != expected_token: raise HTTPException( status_code=401, detail="Unauthorized: Invalid authentication token" ) if not request.daf_text or not request.daf_text.strip(): raise HTTPException( status_code=400, detail="Missing or empty daf_text" ) # Load model artifacts model, word_to_idx, label_encoder = load_model_artifacts() if model is None or word_to_idx is None or label_encoder is None: raise HTTPException( status_code=404, detail="Model not found. Please train a model first by triggering training from your Vercel app." ) try: # Generate predictions logger.info("Generating predictions for daf text...") ranges = generate_predictions_for_daf( model, request.daf_text, word_to_idx, label_encoder ) logger.info(f"Generated {len(ranges)} prediction ranges") return { "success": True, "ranges": ranges } except Exception as e: logger.error(f"Error generating predictions: {e}", exc_info=True) raise HTTPException( status_code=500, detail=f"Error generating predictions: {str(e)}" ) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)