shelfgot's picture
persistent prediction model
4b4ea56 verified
"""
FastAPI webhook server for Hugging Face Spaces
Handles training requests from Vercel and sends results back via callback
"""
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import requests
import os
import asyncio
import threading
import logging
from typing import Optional
from pydantic import BaseModel
import torch
import pickle
from train import train_model, TalmudClassifierLSTM, MAX_LEN, EMBEDDING_DIM, HIDDEN_DIM
from predict import generate_predictions_for_daf
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Talmud Language Classifier Training")
VERCEL_BASE_URL = os.getenv('VERCEL_BASE_URL')
VERCEL_PREVIEW_URL_REGEX = r'https://talmud-annotation-platform-[a-zA-Z0-9]{9}-shmans-projects\.vercel\.app'
# Enable CORS for Vercel callbacks
app.add_middleware(
CORSMiddleware,
allow_origins=[VERCEL_BASE_URL],
allow_origin_regex=VERCEL_PREVIEW_URL_REGEX,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Training state
training_in_progress = False
training_result = None
training_error = None
training_lock = threading.Lock() # Lock to prevent race conditions
class TrainingRequest(BaseModel):
training_data: str
callback_url: str
callback_auth_token: str
timestamp: Optional[str] = None
class PredictionRequest(BaseModel):
daf_text: str
auth_token: str
def run_training_async(training_data: str, callback_url: str, callback_auth_token: str):
"""
Run training in a separate thread to avoid blocking the request.
Trains the model on the provided training data and returns test results
on the ground truth (test set). Does not generate predictions for all dafim.
"""
global training_in_progress, training_result, training_error
try:
# Note: training_in_progress is already set to True by the endpoint
# Reset results for new training run
training_result = None
training_error = None
logger.info("Starting model training...")
# Train the model
result = train_model(training_data)
stats = result['stats']
logger.info(f"Training completed. Accuracy: {stats['accuracy']:.4f}")
logger.info(f"Test set results - Accuracy: {stats['accuracy']:.4f}, Loss: {stats['loss']:.4f}")
logger.info(f"F1 Scores: {stats['f1_scores']}")
# Prepare callback payload with only stats (test results on ground truth)
callback_payload = {
'stats': stats,
'auth_token': callback_auth_token
}
# Send callback to Vercel
logger.info(f"Sending callback to {callback_url}...")
logger.info(f"Callback payload stats: accuracy={stats['accuracy']:.4f}, loss={stats['loss']:.4f}, f1_scores={list(stats['f1_scores'].keys())}")
try:
response = requests.post(
callback_url,
json=callback_payload,
timeout=60, # Reduced timeout since we're not generating predictions
headers={'Content-Type': 'application/json'}
)
logger.info(f"Callback response status: {response.status_code}")
logger.info(f"Callback response headers: {dict(response.headers)}")
logger.info(f"Callback response body: {response.text}")
if response.status_code == 200:
try:
response_json = response.json()
logger.info(f"Callback response JSON: {response_json}")
if response_json.get('success'):
logger.info("Callback sent successfully and processed by Vercel")
training_result = {'success': True, 'message': 'Training completed'}
else:
logger.warning(f"Callback returned 200 but success=false: {response_json}")
training_error = f"Callback returned success=false: {response_json}"
training_result = {'success': False, 'error': training_error}
except Exception as e:
logger.warning(f"Callback returned 200 but couldn't parse JSON: {e}, body: {response.text}")
# If we can't parse the response, we can't confirm success, so mark as uncertain
training_error = f"Callback returned 200 but response couldn't be parsed: {str(e)}"
training_result = {'success': False, 'error': training_error, 'warning': 'Training completed but callback response unclear'}
else:
logger.error(f"Callback failed with status {response.status_code}: {response.text}")
training_error = f"Callback failed: {response.status_code}"
training_result = {'success': False, 'error': training_error}
except requests.exceptions.RequestException as e:
logger.error(f"Callback request exception: {str(e)}", exc_info=True)
training_error = f"Callback request failed: {str(e)}"
training_result = {'success': False, 'error': training_error}
except Exception as e:
logger.error(f"Training error: {str(e)}", exc_info=True)
training_error = str(e)
training_result = {'success': False, 'error': str(e)}
# Try to send error callback
try:
callback_payload = {
'error': str(e),
'auth_token': callback_auth_token
}
requests.post(
callback_url,
json=callback_payload,
timeout=10
)
except:
pass # Ignore callback errors if main training failed
finally:
training_in_progress = False
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"status": "running",
"service": "Talmud Language Classifier Training",
"training_in_progress": training_in_progress
}
@app.post("/train")
async def train_endpoint(request: TrainingRequest):
"""
Training webhook endpoint.
Accepts training data and callback URL, runs training in background.
"""
global training_in_progress
# Use lock to prevent race condition - check and set atomically
with training_lock:
if training_in_progress:
raise HTTPException(
status_code=429,
detail="Training already in progress. Please wait for current training to complete."
)
# Set flag immediately to prevent concurrent requests
training_in_progress = True
# Validate request after acquiring lock
# Use try-finally to ensure flag is reset if validation fails
try:
if not request.training_data or not request.callback_url:
raise HTTPException(
status_code=400,
detail="Missing training_data or callback_url"
)
if not request.callback_auth_token:
raise HTTPException(
status_code=400,
detail="Missing callback_auth_token"
)
except HTTPException:
# Reset flag if validation fails
with training_lock:
training_in_progress = False
raise
logger.info("Received training request")
# Start training in background thread
# Note: training_in_progress is already set to True above
training_thread = threading.Thread(
target=run_training_async,
args=(
request.training_data,
request.callback_url,
request.callback_auth_token
),
daemon=True
)
training_thread.start()
# Return immediately
return {
"success": True,
"message": "Training started",
"status": "processing"
}
@app.get("/status")
async def get_status():
"""Get current training status"""
return {
"training_in_progress": training_in_progress,
"result": training_result,
"error": training_error
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
def load_model_artifacts():
"""
Load model artifacts from /workspace directory (persistent storage).
Falls back to /tmp for backward compatibility.
Returns (model, word_to_idx, label_encoder) or (None, None, None) if not found.
"""
# Try /workspace first (persistent storage)
workspace_dir = '/workspace'
workspace_model_path = os.path.join(workspace_dir, 'latest_model.pt')
# Fallback to /tmp for backward compatibility
if os.path.exists(workspace_model_path):
model_path = workspace_model_path
word_to_idx_path = os.path.join(workspace_dir, 'word_to_idx.pt')
label_encoder_path = os.path.join(workspace_dir, 'label_encoder.pkl')
storage_location = '/workspace'
else:
logger.info("Model not found in /workspace, trying /tmp...")
model_path = '/tmp/latest_model.pt'
word_to_idx_path = '/tmp/word_to_idx.pt'
label_encoder_path = '/tmp/label_encoder.pkl'
storage_location = '/tmp'
try:
# Check if all files exist
if not os.path.exists(model_path) or not os.path.exists(word_to_idx_path) or not os.path.exists(label_encoder_path):
return None, None, None
# Load word_to_idx
word_to_idx = torch.load(word_to_idx_path)
# Validate vocabulary has required keys
if not isinstance(word_to_idx, dict):
raise ValueError(f"word_to_idx must be a dictionary, got {type(word_to_idx)}")
if '<UNK>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<UNK>' key")
if '<PAD>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<PAD>' key")
# Load label_encoder
with open(label_encoder_path, 'rb') as f:
label_encoder = pickle.load(f)
# Validate label_encoder
if not hasattr(label_encoder, 'classes_'):
raise ValueError("label_encoder must have 'classes_' attribute")
# Determine number of classes from label_encoder
num_classes = len(label_encoder.classes_)
if num_classes < 2:
raise ValueError(f"Model must have at least 2 classes, got {num_classes}")
# Create model and load state dict
# Explicitly load on CPU (HF Spaces typically use CPU)
model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
state_dict = torch.load(model_path, map_location='cpu')
# Check if state dict keys match model architecture
model_keys = set(model.state_dict().keys())
state_dict_keys = set(state_dict.keys())
if model_keys != state_dict_keys:
missing_keys = model_keys - state_dict_keys
unexpected_keys = state_dict_keys - model_keys
error_msg = f"Model architecture mismatch. Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}"
logger.error(error_msg)
raise RuntimeError(error_msg)
model.load_state_dict(state_dict, strict=True)
model.eval()
# Ensure model is on CPU
model = model.cpu()
logger.info(f"Successfully loaded model artifacts from {storage_location}")
return model, word_to_idx, label_encoder
except Exception as e:
logger.error(f"Error loading model artifacts: {e}", exc_info=True)
return None, None, None
@app.post("/predict")
async def predict_endpoint(request: PredictionRequest):
"""
On-demand prediction endpoint.
Accepts daf text and generates predictions using the latest trained model.
Authentication: Requires TRAINING_CALLBACK_TOKEN to be set in environment variables.
The token must match the auth_token sent in the request body.
"""
# Verify authentication token
# Security: Always require authentication token to match TRAINING_CALLBACK_TOKEN
expected_token = os.getenv('TRAINING_CALLBACK_TOKEN')
if not expected_token:
logger.error("TRAINING_CALLBACK_TOKEN not set in environment - prediction endpoint is insecure!")
raise HTTPException(
status_code=500,
detail="Server configuration error: TRAINING_CALLBACK_TOKEN not configured"
)
if not request.auth_token or request.auth_token != expected_token:
raise HTTPException(
status_code=401,
detail="Unauthorized: Invalid authentication token"
)
if not request.daf_text or not request.daf_text.strip():
raise HTTPException(
status_code=400,
detail="Missing or empty daf_text"
)
# Load model artifacts
model, word_to_idx, label_encoder = load_model_artifacts()
if model is None or word_to_idx is None or label_encoder is None:
raise HTTPException(
status_code=404,
detail="Model not found. Please train a model first by triggering training from your Vercel app."
)
try:
# Generate predictions
logger.info("Generating predictions for daf text...")
ranges = generate_predictions_for_daf(
model, request.daf_text, word_to_idx, label_encoder
)
logger.info(f"Generated {len(ranges)} prediction ranges")
return {
"success": True,
"ranges": ranges
}
except Exception as e:
logger.error(f"Error generating predictions: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Error generating predictions: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)