Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
|
@@ -14,6 +14,8 @@ from sklearn.preprocessing import LabelEncoder
|
|
| 14 |
from sklearn.metrics import f1_score
|
| 15 |
import numpy as np
|
| 16 |
import io
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# --- Configuration ---
|
| 19 |
MAX_LEN = 100
|
|
@@ -238,6 +240,29 @@ def train_model(training_data_text: str):
|
|
| 238 |
# Convert accuracy to 0-1 range for callback
|
| 239 |
accuracy_normalized = test_accuracy / 100.0
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
# Return model and stats
|
| 242 |
return {
|
| 243 |
'model': final_model,
|
|
@@ -247,7 +272,7 @@ def train_model(training_data_text: str):
|
|
| 247 |
'accuracy': accuracy_normalized,
|
| 248 |
'loss': float(avg_loss),
|
| 249 |
'f1_scores': f1_scores_dict,
|
| 250 |
-
'model_path': '
|
| 251 |
}
|
| 252 |
}
|
| 253 |
|
|
|
|
| 14 |
from sklearn.metrics import f1_score
|
| 15 |
import numpy as np
|
| 16 |
import io
|
| 17 |
+
import os
|
| 18 |
+
import pickle
|
| 19 |
|
| 20 |
# --- Configuration ---
|
| 21 |
MAX_LEN = 100
|
|
|
|
| 240 |
# Convert accuracy to 0-1 range for callback
|
| 241 |
accuracy_normalized = test_accuracy / 100.0
|
| 242 |
|
| 243 |
+
# Save model artifacts to /tmp for on-demand predictions
|
| 244 |
+
try:
|
| 245 |
+
model_path = '/tmp/latest_model.pt'
|
| 246 |
+
word_to_idx_path = '/tmp/word_to_idx.pt'
|
| 247 |
+
label_encoder_path = '/tmp/label_encoder.pkl'
|
| 248 |
+
|
| 249 |
+
# Save model state dict
|
| 250 |
+
torch.save(final_model.state_dict(), model_path)
|
| 251 |
+
print(f"Saved model to {model_path}")
|
| 252 |
+
|
| 253 |
+
# Save word_to_idx dictionary
|
| 254 |
+
torch.save(word_to_idx, word_to_idx_path)
|
| 255 |
+
print(f"Saved word_to_idx to {word_to_idx_path}")
|
| 256 |
+
|
| 257 |
+
# Save label_encoder
|
| 258 |
+
with open(label_encoder_path, 'wb') as f:
|
| 259 |
+
pickle.dump(label_encoder, f)
|
| 260 |
+
print(f"Saved label_encoder to {label_encoder_path}")
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
print(f"Warning: Failed to save model artifacts to /tmp: {e}")
|
| 264 |
+
# Continue even if saving fails - model is still returned in result
|
| 265 |
+
|
| 266 |
# Return model and stats
|
| 267 |
return {
|
| 268 |
'model': final_model,
|
|
|
|
| 272 |
'accuracy': accuracy_normalized,
|
| 273 |
'loss': float(avg_loss),
|
| 274 |
'f1_scores': f1_scores_dict,
|
| 275 |
+
'model_path': '/tmp/latest_model.pt' # Path to saved model
|
| 276 |
}
|
| 277 |
}
|
| 278 |
|