Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
|
@@ -117,10 +117,33 @@ def train_model(training_data_text: str):
|
|
| 117 |
if len(all_texts) == 0:
|
| 118 |
raise ValueError("No training data provided")
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
# Stratify ensures the split has a similar distribution of labels
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
print(f"\nTotal samples: {len(all_texts)}")
|
| 126 |
print(f"Training set size: {len(train_texts)} (80%)")
|
|
@@ -185,7 +208,8 @@ def train_model(training_data_text: str):
|
|
| 185 |
print(f"Validation Accuracy for Fold {fold+1}: {accuracy:.2f}%")
|
| 186 |
|
| 187 |
# Save the best model found across all folds
|
| 188 |
-
if
|
|
|
|
| 189 |
best_val_accuracy = accuracy
|
| 190 |
best_model_state = copy.deepcopy(model.state_dict())
|
| 191 |
|
|
@@ -193,6 +217,10 @@ def train_model(training_data_text: str):
|
|
| 193 |
print(f"Fold Accuracies: {[f'{acc:.2f}%' for acc in fold_results]}")
|
| 194 |
print(f"Average CV Accuracy: {np.mean(fold_results):.2f}%")
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
# Final Evaluation on the Held-Out Test Set
|
| 197 |
print("\n----- Final Evaluation on Test Set -----")
|
| 198 |
final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
|
|
|
|
| 117 |
if len(all_texts) == 0:
|
| 118 |
raise ValueError("No training data provided")
|
| 119 |
|
| 120 |
+
# Check for sufficient data and multiple classes
|
| 121 |
+
unique_labels = set(all_labels)
|
| 122 |
+
num_classes = len(unique_labels)
|
| 123 |
+
|
| 124 |
+
if num_classes < 2:
|
| 125 |
+
raise ValueError(f"Training data must contain at least 2 different classes. Found {num_classes} class(es).")
|
| 126 |
+
|
| 127 |
+
if len(all_texts) < 10:
|
| 128 |
+
raise ValueError(f"Training data must contain at least 10 samples. Found {len(all_texts)} samples.")
|
| 129 |
+
|
| 130 |
+
# Check if we have enough samples per class for stratification
|
| 131 |
+
# Stratification requires at least 2 samples per class for a 80/20 split
|
| 132 |
+
min_samples_per_class = min(all_labels.count(label) for label in unique_labels)
|
| 133 |
+
if min_samples_per_class < 2:
|
| 134 |
+
raise ValueError(f"Each class must have at least 2 samples for train/test split. Minimum samples per class: {min_samples_per_class}")
|
| 135 |
+
|
| 136 |
# Stratify ensures the split has a similar distribution of labels
|
| 137 |
+
# Only use stratify if we have multiple classes and sufficient samples
|
| 138 |
+
try:
|
| 139 |
+
train_texts, test_texts, train_labels, test_labels = train_test_split(
|
| 140 |
+
all_texts, all_labels, test_size=0.2, random_state=42, stratify=all_labels
|
| 141 |
+
)
|
| 142 |
+
except ValueError as e:
|
| 143 |
+
# If stratification fails (e.g., insufficient samples per class), fall back to non-stratified split
|
| 144 |
+
if "least 2 samples" in str(e) or "class" in str(e).lower():
|
| 145 |
+
raise ValueError(f"Stratification failed: {str(e)}. Ensure each class has at least 2 samples.")
|
| 146 |
+
raise
|
| 147 |
|
| 148 |
print(f"\nTotal samples: {len(all_texts)}")
|
| 149 |
print(f"Training set size: {len(train_texts)} (80%)")
|
|
|
|
| 208 |
print(f"Validation Accuracy for Fold {fold+1}: {accuracy:.2f}%")
|
| 209 |
|
| 210 |
# Save the best model found across all folds
|
| 211 |
+
# Always save at least the first fold's model, or if this fold is better
|
| 212 |
+
if best_model_state is None or accuracy >= best_val_accuracy:
|
| 213 |
best_val_accuracy = accuracy
|
| 214 |
best_model_state = copy.deepcopy(model.state_dict())
|
| 215 |
|
|
|
|
| 217 |
print(f"Fold Accuracies: {[f'{acc:.2f}%' for acc in fold_results]}")
|
| 218 |
print(f"Average CV Accuracy: {np.mean(fold_results):.2f}%")
|
| 219 |
|
| 220 |
+
# Verify that we have a model state to load
|
| 221 |
+
if best_model_state is None:
|
| 222 |
+
raise RuntimeError("No model state was saved during cross-validation. This should not happen.")
|
| 223 |
+
|
| 224 |
# Final Evaluation on the Held-Out Test Set
|
| 225 |
print("\n----- Final Evaluation on Test Set -----")
|
| 226 |
final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
|