shelfgot commited on
Commit
e268d40
·
verified ·
1 Parent(s): 8176e08

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +26 -1
train.py CHANGED
@@ -14,6 +14,8 @@ from sklearn.preprocessing import LabelEncoder
14
  from sklearn.metrics import f1_score
15
  import numpy as np
16
  import io
 
 
17
 
18
  # --- Configuration ---
19
  MAX_LEN = 100
@@ -238,6 +240,29 @@ def train_model(training_data_text: str):
238
  # Convert accuracy to 0-1 range for callback
239
  accuracy_normalized = test_accuracy / 100.0
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # Return model and stats
242
  return {
243
  'model': final_model,
@@ -247,7 +272,7 @@ def train_model(training_data_text: str):
247
  'accuracy': accuracy_normalized,
248
  'loss': float(avg_loss),
249
  'f1_scores': f1_scores_dict,
250
- 'model_path': 'trained_model' # Not saved to disk in HF Space
251
  }
252
  }
253
 
 
14
  from sklearn.metrics import f1_score
15
  import numpy as np
16
  import io
17
+ import os
18
+ import pickle
19
 
20
  # --- Configuration ---
21
  MAX_LEN = 100
 
240
  # Convert accuracy to 0-1 range for callback
241
  accuracy_normalized = test_accuracy / 100.0
242
 
243
+ # Save model artifacts to /tmp for on-demand predictions
244
+ try:
245
+ model_path = '/tmp/latest_model.pt'
246
+ word_to_idx_path = '/tmp/word_to_idx.pt'
247
+ label_encoder_path = '/tmp/label_encoder.pkl'
248
+
249
+ # Save model state dict
250
+ torch.save(final_model.state_dict(), model_path)
251
+ print(f"Saved model to {model_path}")
252
+
253
+ # Save word_to_idx dictionary
254
+ torch.save(word_to_idx, word_to_idx_path)
255
+ print(f"Saved word_to_idx to {word_to_idx_path}")
256
+
257
+ # Save label_encoder
258
+ with open(label_encoder_path, 'wb') as f:
259
+ pickle.dump(label_encoder, f)
260
+ print(f"Saved label_encoder to {label_encoder_path}")
261
+
262
+ except Exception as e:
263
+ print(f"Warning: Failed to save model artifacts to /tmp: {e}")
264
+ # Continue even if saving fails - model is still returned in result
265
+
266
  # Return model and stats
267
  return {
268
  'model': final_model,
 
272
  'accuracy': accuracy_normalized,
273
  'loss': float(avg_loss),
274
  'f1_scores': f1_scores_dict,
275
+ 'model_path': '/tmp/latest_model.pt' # Path to saved model
276
  }
277
  }
278