File size: 14,239 Bytes
1f71502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172b660
 
1f71502
172b660
 
1f71502
 
 
 
 
 
4b4ea56
a4d498a
08a55bd
 
1f71502
 
 
a4d498a
08a55bd
1f71502
 
 
 
4b4ea56
1f71502
 
 
 
3319ff1
1f71502
 
 
 
 
 
 
172b660
 
 
 
1f71502
 
 
172b660
 
1f71502
 
 
 
3319ff1
 
1f71502
 
 
 
 
 
 
 
 
 
172b660
 
1f71502
172b660
1f71502
 
 
 
 
 
 
3883b9d
1f71502
3883b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3319ff1
3883b9d
 
3319ff1
 
3883b9d
 
3319ff1
 
 
3883b9d
 
 
 
 
 
 
1f71502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3319ff1
 
 
 
 
 
 
 
 
 
1f71502
3319ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f71502
 
 
 
3319ff1
1f71502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172b660
 
4b4ea56
 
172b660
 
4b4ea56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172b660
 
 
 
 
 
 
 
 
3319ff1
 
 
 
 
 
 
 
172b660
 
 
 
3319ff1
 
 
 
172b660
 
 
3319ff1
 
 
172b660
 
 
3319ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
172b660
 
 
 
4b4ea56
172b660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f71502
 
 
4b4ea56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
"""
FastAPI webhook server for Hugging Face Spaces
Handles training requests from Vercel and sends results back via callback
"""

from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import requests
import os
import asyncio
import threading
import logging
from typing import Optional
from pydantic import BaseModel
import torch
import pickle

from train import train_model, TalmudClassifierLSTM, MAX_LEN, EMBEDDING_DIM, HIDDEN_DIM
from predict import generate_predictions_for_daf

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Talmud Language Classifier Training")

VERCEL_BASE_URL = os.getenv('VERCEL_BASE_URL')
VERCEL_PREVIEW_URL_REGEX = r'https://talmud-annotation-platform-[a-zA-Z0-9]{9}-shmans-projects\.vercel\.app'

# Enable CORS for Vercel callbacks
app.add_middleware(
    CORSMiddleware,
    allow_origins=[VERCEL_BASE_URL], 
    allow_origin_regex=VERCEL_PREVIEW_URL_REGEX,  
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Training state
training_in_progress = False
training_result = None
training_error = None
training_lock = threading.Lock()  # Lock to prevent race conditions

class TrainingRequest(BaseModel):
    training_data: str
    callback_url: str
    callback_auth_token: str
    timestamp: Optional[str] = None

class PredictionRequest(BaseModel):
    daf_text: str
    auth_token: str

def run_training_async(training_data: str, callback_url: str, callback_auth_token: str):
    """
    Run training in a separate thread to avoid blocking the request.
    Trains the model on the provided training data and returns test results
    on the ground truth (test set). Does not generate predictions for all dafim.
    """
    global training_in_progress, training_result, training_error
    
    try:
        # Note: training_in_progress is already set to True by the endpoint
        # Reset results for new training run
        training_result = None
        training_error = None
        
        logger.info("Starting model training...")
        
        # Train the model
        result = train_model(training_data)
        stats = result['stats']
        
        logger.info(f"Training completed. Accuracy: {stats['accuracy']:.4f}")
        logger.info(f"Test set results - Accuracy: {stats['accuracy']:.4f}, Loss: {stats['loss']:.4f}")
        logger.info(f"F1 Scores: {stats['f1_scores']}")
        
        # Prepare callback payload with only stats (test results on ground truth)
        callback_payload = {
            'stats': stats,
            'auth_token': callback_auth_token
        }
        
        # Send callback to Vercel
        logger.info(f"Sending callback to {callback_url}...")
        logger.info(f"Callback payload stats: accuracy={stats['accuracy']:.4f}, loss={stats['loss']:.4f}, f1_scores={list(stats['f1_scores'].keys())}")
        
        try:
            response = requests.post(
                callback_url,
                json=callback_payload,
                timeout=60,  # Reduced timeout since we're not generating predictions
                headers={'Content-Type': 'application/json'}
            )
            
            logger.info(f"Callback response status: {response.status_code}")
            logger.info(f"Callback response headers: {dict(response.headers)}")
            logger.info(f"Callback response body: {response.text}")
            
            if response.status_code == 200:
                try:
                    response_json = response.json()
                    logger.info(f"Callback response JSON: {response_json}")
                    if response_json.get('success'):
                        logger.info("Callback sent successfully and processed by Vercel")
                        training_result = {'success': True, 'message': 'Training completed'}
                    else:
                        logger.warning(f"Callback returned 200 but success=false: {response_json}")
                        training_error = f"Callback returned success=false: {response_json}"
                        training_result = {'success': False, 'error': training_error}
                except Exception as e:
                    logger.warning(f"Callback returned 200 but couldn't parse JSON: {e}, body: {response.text}")
                    # If we can't parse the response, we can't confirm success, so mark as uncertain
                    training_error = f"Callback returned 200 but response couldn't be parsed: {str(e)}"
                    training_result = {'success': False, 'error': training_error, 'warning': 'Training completed but callback response unclear'}
            else:
                logger.error(f"Callback failed with status {response.status_code}: {response.text}")
                training_error = f"Callback failed: {response.status_code}"
                training_result = {'success': False, 'error': training_error}
        except requests.exceptions.RequestException as e:
            logger.error(f"Callback request exception: {str(e)}", exc_info=True)
            training_error = f"Callback request failed: {str(e)}"
            training_result = {'success': False, 'error': training_error}
        
    except Exception as e:
        logger.error(f"Training error: {str(e)}", exc_info=True)
        training_error = str(e)
        training_result = {'success': False, 'error': str(e)}
        
        # Try to send error callback
        try:
            callback_payload = {
                'error': str(e),
                'auth_token': callback_auth_token
            }
            requests.post(
                callback_url,
                json=callback_payload,
                timeout=10
            )
        except:
            pass  # Ignore callback errors if main training failed
    
    finally:
        training_in_progress = False

@app.get("/")
async def root():
    """Health check endpoint"""
    return {
        "status": "running",
        "service": "Talmud Language Classifier Training",
        "training_in_progress": training_in_progress
    }

@app.post("/train")
async def train_endpoint(request: TrainingRequest):
    """
    Training webhook endpoint.
    Accepts training data and callback URL, runs training in background.
    """
    global training_in_progress
    
    # Use lock to prevent race condition - check and set atomically
    with training_lock:
        if training_in_progress:
            raise HTTPException(
                status_code=429,
                detail="Training already in progress. Please wait for current training to complete."
            )
        
        # Set flag immediately to prevent concurrent requests
        training_in_progress = True
    
    # Validate request after acquiring lock
    # Use try-finally to ensure flag is reset if validation fails
    try:
        if not request.training_data or not request.callback_url:
            raise HTTPException(
                status_code=400,
                detail="Missing training_data or callback_url"
            )
        
        if not request.callback_auth_token:
            raise HTTPException(
                status_code=400,
                detail="Missing callback_auth_token"
            )
    except HTTPException:
        # Reset flag if validation fails
        with training_lock:
            training_in_progress = False
        raise
    
    logger.info("Received training request")
    
    # Start training in background thread
    # Note: training_in_progress is already set to True above
    training_thread = threading.Thread(
        target=run_training_async,
        args=(
            request.training_data,
            request.callback_url,
            request.callback_auth_token
        ),
        daemon=True
    )
    training_thread.start()
    
    # Return immediately
    return {
        "success": True,
        "message": "Training started",
        "status": "processing"
    }

@app.get("/status")
async def get_status():
    """Get current training status"""
    return {
        "training_in_progress": training_in_progress,
        "result": training_result,
        "error": training_error
    }

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy"}

def load_model_artifacts():
    """
    Load model artifacts from /workspace directory (persistent storage).
    Falls back to /tmp for backward compatibility.
    Returns (model, word_to_idx, label_encoder) or (None, None, None) if not found.
    """
    # Try /workspace first (persistent storage)
    workspace_dir = '/workspace'
    workspace_model_path = os.path.join(workspace_dir, 'latest_model.pt')
    
    # Fallback to /tmp for backward compatibility
    if os.path.exists(workspace_model_path):
        model_path = workspace_model_path
        word_to_idx_path = os.path.join(workspace_dir, 'word_to_idx.pt')
        label_encoder_path = os.path.join(workspace_dir, 'label_encoder.pkl')
        storage_location = '/workspace'
    else:
        logger.info("Model not found in /workspace, trying /tmp...")
        model_path = '/tmp/latest_model.pt'
        word_to_idx_path = '/tmp/word_to_idx.pt'
        label_encoder_path = '/tmp/label_encoder.pkl'
        storage_location = '/tmp'
    
    try:
        # Check if all files exist
        if not os.path.exists(model_path) or not os.path.exists(word_to_idx_path) or not os.path.exists(label_encoder_path):
            return None, None, None
        
        # Load word_to_idx
        word_to_idx = torch.load(word_to_idx_path)
        
        # Validate vocabulary has required keys
        if not isinstance(word_to_idx, dict):
            raise ValueError(f"word_to_idx must be a dictionary, got {type(word_to_idx)}")
        if '<UNK>' not in word_to_idx:
            raise ValueError("Vocabulary must contain '<UNK>' key")
        if '<PAD>' not in word_to_idx:
            raise ValueError("Vocabulary must contain '<PAD>' key")
        
        # Load label_encoder
        with open(label_encoder_path, 'rb') as f:
            label_encoder = pickle.load(f)
        
        # Validate label_encoder
        if not hasattr(label_encoder, 'classes_'):
            raise ValueError("label_encoder must have 'classes_' attribute")
        
        # Determine number of classes from label_encoder
        num_classes = len(label_encoder.classes_)
        
        if num_classes < 2:
            raise ValueError(f"Model must have at least 2 classes, got {num_classes}")
        
        # Create model and load state dict
        # Explicitly load on CPU (HF Spaces typically use CPU)
        model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
        state_dict = torch.load(model_path, map_location='cpu')
        
        # Check if state dict keys match model architecture
        model_keys = set(model.state_dict().keys())
        state_dict_keys = set(state_dict.keys())
        
        if model_keys != state_dict_keys:
            missing_keys = model_keys - state_dict_keys
            unexpected_keys = state_dict_keys - model_keys
            error_msg = f"Model architecture mismatch. Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}"
            logger.error(error_msg)
            raise RuntimeError(error_msg)
        
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        # Ensure model is on CPU
        model = model.cpu()
        
        logger.info(f"Successfully loaded model artifacts from {storage_location}")
        return model, word_to_idx, label_encoder
        
    except Exception as e:
        logger.error(f"Error loading model artifacts: {e}", exc_info=True)
        return None, None, None

@app.post("/predict")
async def predict_endpoint(request: PredictionRequest):
    """
    On-demand prediction endpoint.
    Accepts daf text and generates predictions using the latest trained model.
    
    Authentication: Requires TRAINING_CALLBACK_TOKEN to be set in environment variables.
    The token must match the auth_token sent in the request body.
    """
    # Verify authentication token
    # Security: Always require authentication token to match TRAINING_CALLBACK_TOKEN
    expected_token = os.getenv('TRAINING_CALLBACK_TOKEN')
    if not expected_token:
        logger.error("TRAINING_CALLBACK_TOKEN not set in environment - prediction endpoint is insecure!")
        raise HTTPException(
            status_code=500,
            detail="Server configuration error: TRAINING_CALLBACK_TOKEN not configured"
        )
    
    if not request.auth_token or request.auth_token != expected_token:
        raise HTTPException(
            status_code=401,
            detail="Unauthorized: Invalid authentication token"
        )
    
    if not request.daf_text or not request.daf_text.strip():
        raise HTTPException(
            status_code=400,
            detail="Missing or empty daf_text"
        )
    
    # Load model artifacts
    model, word_to_idx, label_encoder = load_model_artifacts()
    
    if model is None or word_to_idx is None or label_encoder is None:
        raise HTTPException(
            status_code=404,
            detail="Model not found. Please train a model first by triggering training from your Vercel app."
        )
    
    try:
        # Generate predictions
        logger.info("Generating predictions for daf text...")
        ranges = generate_predictions_for_daf(
            model, request.daf_text, word_to_idx, label_encoder
        )
        
        logger.info(f"Generated {len(ranges)} prediction ranges")
        
        return {
            "success": True,
            "ranges": ranges
        }
        
    except Exception as e:
        logger.error(f"Error generating predictions: {e}", exc_info=True)
        raise HTTPException(
            status_code=500,
            detail=f"Error generating predictions: {str(e)}"
        )

if __name__ == "__main__":
    import uvicorn
    port = int(os.getenv("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)