shelfgot commited on
Commit
172b660
·
verified ·
1 Parent(s): 6441bee

no predictions, save model

Browse files
Files changed (1) hide show
  1. app.py +114 -25
app.py CHANGED
@@ -13,9 +13,11 @@ import threading
13
  import logging
14
  from typing import Optional
15
  from pydantic import BaseModel
 
 
16
 
17
- from train import train_model
18
- from predict import generate_all_predictions
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
@@ -47,10 +49,15 @@ class TrainingRequest(BaseModel):
47
  callback_auth_token: str
48
  timestamp: Optional[str] = None
49
 
 
 
 
 
50
  def run_training_async(training_data: str, callback_url: str, callback_auth_token: str):
51
  """
52
  Run training in a separate thread to avoid blocking the request.
53
- This function runs the full training and prediction pipeline.
 
54
  """
55
  global training_in_progress, training_result, training_error
56
 
@@ -63,32 +70,15 @@ def run_training_async(training_data: str, callback_url: str, callback_auth_toke
63
 
64
  # Train the model
65
  result = train_model(training_data)
66
- model = result['model']
67
- word_to_idx = result['word_to_idx']
68
- label_encoder = result['label_encoder']
69
  stats = result['stats']
70
 
71
  logger.info(f"Training completed. Accuracy: {stats['accuracy']:.4f}")
 
 
72
 
73
- # Get Vercel base URL from environment
74
- vercel_base_url = os.getenv('VERCEL_BASE_URL')
75
- if not vercel_base_url:
76
- raise ValueError("VERCEL_BASE_URL environment variable not set")
77
-
78
- logger.info("Generating predictions for all dafim...")
79
-
80
- # Generate predictions for all dafim
81
- # Use the callback_auth_token to authenticate requests to Vercel endpoints
82
- predictions = generate_all_predictions(
83
- model, word_to_idx, label_encoder, vercel_base_url, callback_auth_token
84
- )
85
-
86
- logger.info(f"Generated {len(predictions)} predictions")
87
-
88
- # Prepare callback payload
89
  callback_payload = {
90
  'stats': stats,
91
- 'predictions': predictions,
92
  'auth_token': callback_auth_token
93
  }
94
 
@@ -97,7 +87,7 @@ def run_training_async(training_data: str, callback_url: str, callback_auth_toke
97
  response = requests.post(
98
  callback_url,
99
  json=callback_payload,
100
- timeout=300, # 5 minute timeout
101
  headers={'Content-Type': 'application/json'}
102
  )
103
 
@@ -201,8 +191,107 @@ async def health_check():
201
  """Health check endpoint"""
202
  return {"status": "healthy"}
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  if __name__ == "__main__":
205
  import uvicorn
206
  port = int(os.getenv("PORT", 7860))
207
  uvicorn.run(app, host="0.0.0.0", port=port)
208
-
 
13
  import logging
14
  from typing import Optional
15
  from pydantic import BaseModel
16
+ import torch
17
+ import pickle
18
 
19
+ from train import train_model, TalmudClassifierLSTM, MAX_LEN, EMBEDDING_DIM, HIDDEN_DIM
20
+ from predict import generate_predictions_for_daf
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
 
49
  callback_auth_token: str
50
  timestamp: Optional[str] = None
51
 
52
+ class PredictionRequest(BaseModel):
53
+ daf_text: str
54
+ auth_token: str
55
+
56
  def run_training_async(training_data: str, callback_url: str, callback_auth_token: str):
57
  """
58
  Run training in a separate thread to avoid blocking the request.
59
+ Trains the model on the provided training data and returns test results
60
+ on the ground truth (test set). Does not generate predictions for all dafim.
61
  """
62
  global training_in_progress, training_result, training_error
63
 
 
70
 
71
  # Train the model
72
  result = train_model(training_data)
 
 
 
73
  stats = result['stats']
74
 
75
  logger.info(f"Training completed. Accuracy: {stats['accuracy']:.4f}")
76
+ logger.info(f"Test set results - Accuracy: {stats['accuracy']:.4f}, Loss: {stats['loss']:.4f}")
77
+ logger.info(f"F1 Scores: {stats['f1_scores']}")
78
 
79
+ # Prepare callback payload with only stats (test results on ground truth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  callback_payload = {
81
  'stats': stats,
 
82
  'auth_token': callback_auth_token
83
  }
84
 
 
87
  response = requests.post(
88
  callback_url,
89
  json=callback_payload,
90
+ timeout=60, # Reduced timeout since we're not generating predictions
91
  headers={'Content-Type': 'application/json'}
92
  )
93
 
 
191
  """Health check endpoint"""
192
  return {"status": "healthy"}
193
 
194
+ def load_model_artifacts():
195
+ """
196
+ Load model artifacts from /tmp directory.
197
+ Returns (model, word_to_idx, label_encoder) or (None, None, None) if not found.
198
+ """
199
+ model_path = '/tmp/latest_model.pt'
200
+ word_to_idx_path = '/tmp/word_to_idx.pt'
201
+ label_encoder_path = '/tmp/label_encoder.pkl'
202
+
203
+ try:
204
+ # Check if all files exist
205
+ if not os.path.exists(model_path) or not os.path.exists(word_to_idx_path) or not os.path.exists(label_encoder_path):
206
+ return None, None, None
207
+
208
+ # Load word_to_idx
209
+ word_to_idx = torch.load(word_to_idx_path)
210
+
211
+ # Load label_encoder
212
+ with open(label_encoder_path, 'rb') as f:
213
+ label_encoder = pickle.load(f)
214
+
215
+ # Determine number of classes from label_encoder
216
+ num_classes = len(label_encoder.classes_)
217
+
218
+ # Create model and load state dict
219
+ # Explicitly load on CPU (HF Spaces typically use CPU)
220
+ model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
221
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
222
+ model.eval()
223
+ # Ensure model is on CPU
224
+ model = model.cpu()
225
+
226
+ logger.info("Successfully loaded model artifacts from /tmp")
227
+ return model, word_to_idx, label_encoder
228
+
229
+ except Exception as e:
230
+ logger.error(f"Error loading model artifacts: {e}", exc_info=True)
231
+ return None, None, None
232
+
233
+ @app.post("/predict")
234
+ async def predict_endpoint(request: PredictionRequest):
235
+ """
236
+ On-demand prediction endpoint.
237
+ Accepts daf text and generates predictions using the latest trained model.
238
+
239
+ Authentication: Requires TRAINING_CALLBACK_TOKEN to be set in environment variables.
240
+ The token must match the auth_token sent in the request body.
241
+ """
242
+ # Verify authentication token
243
+ # Security: Always require authentication token to match TRAINING_CALLBACK_TOKEN
244
+ expected_token = os.getenv('TRAINING_CALLBACK_TOKEN')
245
+ if not expected_token:
246
+ logger.error("TRAINING_CALLBACK_TOKEN not set in environment - prediction endpoint is insecure!")
247
+ raise HTTPException(
248
+ status_code=500,
249
+ detail="Server configuration error: TRAINING_CALLBACK_TOKEN not configured"
250
+ )
251
+
252
+ if not request.auth_token or request.auth_token != expected_token:
253
+ raise HTTPException(
254
+ status_code=401,
255
+ detail="Unauthorized: Invalid authentication token"
256
+ )
257
+
258
+ if not request.daf_text or not request.daf_text.strip():
259
+ raise HTTPException(
260
+ status_code=400,
261
+ detail="Missing or empty daf_text"
262
+ )
263
+
264
+ # Load model artifacts
265
+ model, word_to_idx, label_encoder = load_model_artifacts()
266
+
267
+ if model is None or word_to_idx is None or label_encoder is None:
268
+ raise HTTPException(
269
+ status_code=404,
270
+ detail="Model not found. Please train a model first by triggering training from your Vercel app."
271
+ )
272
+
273
+ try:
274
+ # Generate predictions
275
+ logger.info("Generating predictions for daf text...")
276
+ ranges = generate_predictions_for_daf(
277
+ model, request.daf_text, word_to_idx, label_encoder
278
+ )
279
+
280
+ logger.info(f"Generated {len(ranges)} prediction ranges")
281
+
282
+ return {
283
+ "success": True,
284
+ "ranges": ranges
285
+ }
286
+
287
+ except Exception as e:
288
+ logger.error(f"Error generating predictions: {e}", exc_info=True)
289
+ raise HTTPException(
290
+ status_code=500,
291
+ detail=f"Error generating predictions: {str(e)}"
292
+ )
293
+
294
  if __name__ == "__main__":
295
  import uvicorn
296
  port = int(os.getenv("PORT", 7860))
297
  uvicorn.run(app, host="0.0.0.0", port=port)