shelfgot commited on
Commit
d5a48ff
·
verified ·
1 Parent(s): be1bc6c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +32 -4
train.py CHANGED
@@ -117,10 +117,33 @@ def train_model(training_data_text: str):
117
  if len(all_texts) == 0:
118
  raise ValueError("No training data provided")
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # Stratify ensures the split has a similar distribution of labels
121
- train_texts, test_texts, train_labels, test_labels = train_test_split(
122
- all_texts, all_labels, test_size=0.2, random_state=42, stratify=all_labels
123
- )
 
 
 
 
 
 
 
124
 
125
  print(f"\nTotal samples: {len(all_texts)}")
126
  print(f"Training set size: {len(train_texts)} (80%)")
@@ -185,7 +208,8 @@ def train_model(training_data_text: str):
185
  print(f"Validation Accuracy for Fold {fold+1}: {accuracy:.2f}%")
186
 
187
  # Save the best model found across all folds
188
- if accuracy > best_val_accuracy:
 
189
  best_val_accuracy = accuracy
190
  best_model_state = copy.deepcopy(model.state_dict())
191
 
@@ -193,6 +217,10 @@ def train_model(training_data_text: str):
193
  print(f"Fold Accuracies: {[f'{acc:.2f}%' for acc in fold_results]}")
194
  print(f"Average CV Accuracy: {np.mean(fold_results):.2f}%")
195
 
 
 
 
 
196
  # Final Evaluation on the Held-Out Test Set
197
  print("\n----- Final Evaluation on Test Set -----")
198
  final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
 
117
  if len(all_texts) == 0:
118
  raise ValueError("No training data provided")
119
 
120
+ # Check for sufficient data and multiple classes
121
+ unique_labels = set(all_labels)
122
+ num_classes = len(unique_labels)
123
+
124
+ if num_classes < 2:
125
+ raise ValueError(f"Training data must contain at least 2 different classes. Found {num_classes} class(es).")
126
+
127
+ if len(all_texts) < 10:
128
+ raise ValueError(f"Training data must contain at least 10 samples. Found {len(all_texts)} samples.")
129
+
130
+ # Check if we have enough samples per class for stratification
131
+ # Stratification requires at least 2 samples per class for a 80/20 split
132
+ min_samples_per_class = min(all_labels.count(label) for label in unique_labels)
133
+ if min_samples_per_class < 2:
134
+ raise ValueError(f"Each class must have at least 2 samples for train/test split. Minimum samples per class: {min_samples_per_class}")
135
+
136
  # Stratify ensures the split has a similar distribution of labels
137
+ # Only use stratify if we have multiple classes and sufficient samples
138
+ try:
139
+ train_texts, test_texts, train_labels, test_labels = train_test_split(
140
+ all_texts, all_labels, test_size=0.2, random_state=42, stratify=all_labels
141
+ )
142
+ except ValueError as e:
143
+ # If stratification fails (e.g., insufficient samples per class), fall back to non-stratified split
144
+ if "least 2 samples" in str(e) or "class" in str(e).lower():
145
+ raise ValueError(f"Stratification failed: {str(e)}. Ensure each class has at least 2 samples.")
146
+ raise
147
 
148
  print(f"\nTotal samples: {len(all_texts)}")
149
  print(f"Training set size: {len(train_texts)} (80%)")
 
208
  print(f"Validation Accuracy for Fold {fold+1}: {accuracy:.2f}%")
209
 
210
  # Save the best model found across all folds
211
+ # Always save at least the first fold's model, or if this fold is better
212
+ if best_model_state is None or accuracy >= best_val_accuracy:
213
  best_val_accuracy = accuracy
214
  best_model_state = copy.deepcopy(model.state_dict())
215
 
 
217
  print(f"Fold Accuracies: {[f'{acc:.2f}%' for acc in fold_results]}")
218
  print(f"Average CV Accuracy: {np.mean(fold_results):.2f}%")
219
 
220
+ # Verify that we have a model state to load
221
+ if best_model_state is None:
222
+ raise RuntimeError("No model state was saved during cross-validation. This should not happen.")
223
+
224
  # Final Evaluation on the Held-Out Test Set
225
  print("\n----- Final Evaluation on Test Set -----")
226
  final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)