Spaces:
Sleeping
Sleeping
File size: 14,239 Bytes
1f71502 172b660 1f71502 172b660 1f71502 4b4ea56 a4d498a 08a55bd 1f71502 a4d498a 08a55bd 1f71502 4b4ea56 1f71502 3319ff1 1f71502 172b660 1f71502 172b660 1f71502 3319ff1 1f71502 172b660 1f71502 172b660 1f71502 3883b9d 1f71502 3883b9d 3319ff1 3883b9d 3319ff1 3883b9d 3319ff1 3883b9d 1f71502 3319ff1 1f71502 3319ff1 1f71502 3319ff1 1f71502 172b660 4b4ea56 172b660 4b4ea56 172b660 3319ff1 172b660 3319ff1 172b660 3319ff1 172b660 3319ff1 172b660 4b4ea56 172b660 1f71502 4b4ea56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 | """
FastAPI webhook server for Hugging Face Spaces
Handles training requests from Vercel and sends results back via callback
"""
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import requests
import os
import asyncio
import threading
import logging
from typing import Optional
from pydantic import BaseModel
import torch
import pickle
from train import train_model, TalmudClassifierLSTM, MAX_LEN, EMBEDDING_DIM, HIDDEN_DIM
from predict import generate_predictions_for_daf
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Talmud Language Classifier Training")
VERCEL_BASE_URL = os.getenv('VERCEL_BASE_URL')
VERCEL_PREVIEW_URL_REGEX = r'https://talmud-annotation-platform-[a-zA-Z0-9]{9}-shmans-projects\.vercel\.app'
# Enable CORS for Vercel callbacks
app.add_middleware(
CORSMiddleware,
allow_origins=[VERCEL_BASE_URL],
allow_origin_regex=VERCEL_PREVIEW_URL_REGEX,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Training state
training_in_progress = False
training_result = None
training_error = None
training_lock = threading.Lock() # Lock to prevent race conditions
class TrainingRequest(BaseModel):
training_data: str
callback_url: str
callback_auth_token: str
timestamp: Optional[str] = None
class PredictionRequest(BaseModel):
daf_text: str
auth_token: str
def run_training_async(training_data: str, callback_url: str, callback_auth_token: str):
"""
Run training in a separate thread to avoid blocking the request.
Trains the model on the provided training data and returns test results
on the ground truth (test set). Does not generate predictions for all dafim.
"""
global training_in_progress, training_result, training_error
try:
# Note: training_in_progress is already set to True by the endpoint
# Reset results for new training run
training_result = None
training_error = None
logger.info("Starting model training...")
# Train the model
result = train_model(training_data)
stats = result['stats']
logger.info(f"Training completed. Accuracy: {stats['accuracy']:.4f}")
logger.info(f"Test set results - Accuracy: {stats['accuracy']:.4f}, Loss: {stats['loss']:.4f}")
logger.info(f"F1 Scores: {stats['f1_scores']}")
# Prepare callback payload with only stats (test results on ground truth)
callback_payload = {
'stats': stats,
'auth_token': callback_auth_token
}
# Send callback to Vercel
logger.info(f"Sending callback to {callback_url}...")
logger.info(f"Callback payload stats: accuracy={stats['accuracy']:.4f}, loss={stats['loss']:.4f}, f1_scores={list(stats['f1_scores'].keys())}")
try:
response = requests.post(
callback_url,
json=callback_payload,
timeout=60, # Reduced timeout since we're not generating predictions
headers={'Content-Type': 'application/json'}
)
logger.info(f"Callback response status: {response.status_code}")
logger.info(f"Callback response headers: {dict(response.headers)}")
logger.info(f"Callback response body: {response.text}")
if response.status_code == 200:
try:
response_json = response.json()
logger.info(f"Callback response JSON: {response_json}")
if response_json.get('success'):
logger.info("Callback sent successfully and processed by Vercel")
training_result = {'success': True, 'message': 'Training completed'}
else:
logger.warning(f"Callback returned 200 but success=false: {response_json}")
training_error = f"Callback returned success=false: {response_json}"
training_result = {'success': False, 'error': training_error}
except Exception as e:
logger.warning(f"Callback returned 200 but couldn't parse JSON: {e}, body: {response.text}")
# If we can't parse the response, we can't confirm success, so mark as uncertain
training_error = f"Callback returned 200 but response couldn't be parsed: {str(e)}"
training_result = {'success': False, 'error': training_error, 'warning': 'Training completed but callback response unclear'}
else:
logger.error(f"Callback failed with status {response.status_code}: {response.text}")
training_error = f"Callback failed: {response.status_code}"
training_result = {'success': False, 'error': training_error}
except requests.exceptions.RequestException as e:
logger.error(f"Callback request exception: {str(e)}", exc_info=True)
training_error = f"Callback request failed: {str(e)}"
training_result = {'success': False, 'error': training_error}
except Exception as e:
logger.error(f"Training error: {str(e)}", exc_info=True)
training_error = str(e)
training_result = {'success': False, 'error': str(e)}
# Try to send error callback
try:
callback_payload = {
'error': str(e),
'auth_token': callback_auth_token
}
requests.post(
callback_url,
json=callback_payload,
timeout=10
)
except:
pass # Ignore callback errors if main training failed
finally:
training_in_progress = False
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"status": "running",
"service": "Talmud Language Classifier Training",
"training_in_progress": training_in_progress
}
@app.post("/train")
async def train_endpoint(request: TrainingRequest):
"""
Training webhook endpoint.
Accepts training data and callback URL, runs training in background.
"""
global training_in_progress
# Use lock to prevent race condition - check and set atomically
with training_lock:
if training_in_progress:
raise HTTPException(
status_code=429,
detail="Training already in progress. Please wait for current training to complete."
)
# Set flag immediately to prevent concurrent requests
training_in_progress = True
# Validate request after acquiring lock
# Use try-finally to ensure flag is reset if validation fails
try:
if not request.training_data or not request.callback_url:
raise HTTPException(
status_code=400,
detail="Missing training_data or callback_url"
)
if not request.callback_auth_token:
raise HTTPException(
status_code=400,
detail="Missing callback_auth_token"
)
except HTTPException:
# Reset flag if validation fails
with training_lock:
training_in_progress = False
raise
logger.info("Received training request")
# Start training in background thread
# Note: training_in_progress is already set to True above
training_thread = threading.Thread(
target=run_training_async,
args=(
request.training_data,
request.callback_url,
request.callback_auth_token
),
daemon=True
)
training_thread.start()
# Return immediately
return {
"success": True,
"message": "Training started",
"status": "processing"
}
@app.get("/status")
async def get_status():
"""Get current training status"""
return {
"training_in_progress": training_in_progress,
"result": training_result,
"error": training_error
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
def load_model_artifacts():
"""
Load model artifacts from /workspace directory (persistent storage).
Falls back to /tmp for backward compatibility.
Returns (model, word_to_idx, label_encoder) or (None, None, None) if not found.
"""
# Try /workspace first (persistent storage)
workspace_dir = '/workspace'
workspace_model_path = os.path.join(workspace_dir, 'latest_model.pt')
# Fallback to /tmp for backward compatibility
if os.path.exists(workspace_model_path):
model_path = workspace_model_path
word_to_idx_path = os.path.join(workspace_dir, 'word_to_idx.pt')
label_encoder_path = os.path.join(workspace_dir, 'label_encoder.pkl')
storage_location = '/workspace'
else:
logger.info("Model not found in /workspace, trying /tmp...")
model_path = '/tmp/latest_model.pt'
word_to_idx_path = '/tmp/word_to_idx.pt'
label_encoder_path = '/tmp/label_encoder.pkl'
storage_location = '/tmp'
try:
# Check if all files exist
if not os.path.exists(model_path) or not os.path.exists(word_to_idx_path) or not os.path.exists(label_encoder_path):
return None, None, None
# Load word_to_idx
word_to_idx = torch.load(word_to_idx_path)
# Validate vocabulary has required keys
if not isinstance(word_to_idx, dict):
raise ValueError(f"word_to_idx must be a dictionary, got {type(word_to_idx)}")
if '<UNK>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<UNK>' key")
if '<PAD>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<PAD>' key")
# Load label_encoder
with open(label_encoder_path, 'rb') as f:
label_encoder = pickle.load(f)
# Validate label_encoder
if not hasattr(label_encoder, 'classes_'):
raise ValueError("label_encoder must have 'classes_' attribute")
# Determine number of classes from label_encoder
num_classes = len(label_encoder.classes_)
if num_classes < 2:
raise ValueError(f"Model must have at least 2 classes, got {num_classes}")
# Create model and load state dict
# Explicitly load on CPU (HF Spaces typically use CPU)
model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
state_dict = torch.load(model_path, map_location='cpu')
# Check if state dict keys match model architecture
model_keys = set(model.state_dict().keys())
state_dict_keys = set(state_dict.keys())
if model_keys != state_dict_keys:
missing_keys = model_keys - state_dict_keys
unexpected_keys = state_dict_keys - model_keys
error_msg = f"Model architecture mismatch. Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}"
logger.error(error_msg)
raise RuntimeError(error_msg)
model.load_state_dict(state_dict, strict=True)
model.eval()
# Ensure model is on CPU
model = model.cpu()
logger.info(f"Successfully loaded model artifacts from {storage_location}")
return model, word_to_idx, label_encoder
except Exception as e:
logger.error(f"Error loading model artifacts: {e}", exc_info=True)
return None, None, None
@app.post("/predict")
async def predict_endpoint(request: PredictionRequest):
"""
On-demand prediction endpoint.
Accepts daf text and generates predictions using the latest trained model.
Authentication: Requires TRAINING_CALLBACK_TOKEN to be set in environment variables.
The token must match the auth_token sent in the request body.
"""
# Verify authentication token
# Security: Always require authentication token to match TRAINING_CALLBACK_TOKEN
expected_token = os.getenv('TRAINING_CALLBACK_TOKEN')
if not expected_token:
logger.error("TRAINING_CALLBACK_TOKEN not set in environment - prediction endpoint is insecure!")
raise HTTPException(
status_code=500,
detail="Server configuration error: TRAINING_CALLBACK_TOKEN not configured"
)
if not request.auth_token or request.auth_token != expected_token:
raise HTTPException(
status_code=401,
detail="Unauthorized: Invalid authentication token"
)
if not request.daf_text or not request.daf_text.strip():
raise HTTPException(
status_code=400,
detail="Missing or empty daf_text"
)
# Load model artifacts
model, word_to_idx, label_encoder = load_model_artifacts()
if model is None or word_to_idx is None or label_encoder is None:
raise HTTPException(
status_code=404,
detail="Model not found. Please train a model first by triggering training from your Vercel app."
)
try:
# Generate predictions
logger.info("Generating predictions for daf text...")
ranges = generate_predictions_for_daf(
model, request.daf_text, word_to_idx, label_encoder
)
logger.info(f"Generated {len(ranges)} prediction ranges")
return {
"success": True,
"ranges": ranges
}
except Exception as e:
logger.error(f"Error generating predictions: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Error generating predictions: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port) |