shelfgot commited on
Commit
238f03e
·
verified ·
1 Parent(s): d5a48ff

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +265 -71
train.py CHANGED
@@ -1,17 +1,18 @@
1
  """
2
  Training module for Talmud language classifier
3
  Adapted from talmud_language_classifier.py for Hugging Face Spaces integration
 
4
  """
5
 
6
  import copy
7
  import torch
8
  import torch.nn as nn
9
  import torch.optim as optim
10
- from torch.utils.data import Dataset, DataLoader
11
  from collections import Counter
12
  from sklearn.model_selection import train_test_split, KFold
13
  from sklearn.preprocessing import LabelEncoder
14
- from sklearn.metrics import f1_score
15
  import numpy as np
16
  import io
17
  import os
@@ -21,10 +22,14 @@ import pickle
21
  MAX_LEN = 100
22
  VOCAB_SIZE = 10000
23
  EMBEDDING_DIM = 128
24
- HIDDEN_DIM = 128
25
- NUM_EPOCHS = 10 # Epochs per fold
26
  BATCH_SIZE = 16
27
  N_SPLITS = 5 # Number of folds for cross-validation
 
 
 
 
28
 
29
  # --- 1. Load and Parse Data ---
30
  def load_and_parse_data_from_string(training_data_text: str):
@@ -87,24 +92,132 @@ class TalmudDataset(Dataset):
87
 
88
  # --- 4. Model Definition ---
89
  class TalmudClassifierLSTM(nn.Module):
90
- def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
91
  super(TalmudClassifierLSTM, self).__init__()
92
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
93
- self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, dropout=0.3, num_layers=2)
94
- self.dropout = nn.Dropout(0.5)
95
- self.fc1 = nn.Linear(hidden_dim, 64)
 
 
 
 
 
 
 
 
96
  self.relu = nn.ReLU()
97
- self.fc2 = nn.Linear(64, output_dim)
 
98
 
99
  def forward(self, text):
100
  embedded = self.embedding(text)
101
- _, (hidden, _) = self.lstm(embedded)
102
- hidden = self.dropout(hidden[-1])
 
 
 
 
 
 
 
 
 
 
103
  out = self.fc1(hidden)
104
  out = self.relu(out)
 
105
  out = self.fc2(out)
106
  return out
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # --- 5. Training Function ---
109
  def train_model(training_data_text: str):
110
  """
@@ -148,17 +261,31 @@ def train_model(training_data_text: str):
148
  print(f"\nTotal samples: {len(all_texts)}")
149
  print(f"Training set size: {len(train_texts)} (80%)")
150
  print(f"Test set size: {len(test_texts)} (20%)")
 
 
 
 
 
 
151
 
152
  # Build vocabulary and label encoder ONLY on the training data
153
  word_to_idx = build_vocab(train_texts, VOCAB_SIZE)
154
  label_encoder = LabelEncoder()
155
  label_encoder.fit(train_labels)
156
  num_classes = len(label_encoder.classes_)
 
 
 
 
 
 
 
 
157
 
158
  # Set up K-Fold Cross-Validation
159
  kfold = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
160
 
161
- best_val_accuracy = 0.0
162
  best_model_state = None
163
  fold_results = []
164
 
@@ -171,51 +298,118 @@ def train_model(training_data_text: str):
171
  print(f"\n----- FOLD {fold+1}/{N_SPLITS} -----")
172
 
173
  # Create data subsets for the current fold
174
- train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
175
- val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)
176
-
177
- train_loader = DataLoader(full_train_dataset, batch_size=BATCH_SIZE, sampler=train_subsampler)
178
- val_loader = DataLoader(full_train_dataset, batch_size=BATCH_SIZE, sampler=val_subsampler)
 
 
 
 
 
 
 
 
179
 
180
  # Initialize a new model for each fold
181
  model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
182
- criterion = nn.CrossEntropyLoss()
183
- optimizer = optim.Adam(model.parameters(), lr=0.001)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
 
185
  for epoch in range(NUM_EPOCHS):
186
  model.train()
 
 
 
187
  for sequences, labels in train_loader:
 
 
 
188
  optimizer.zero_grad()
189
  outputs = model(sequences)
190
  loss = criterion(outputs, labels)
191
  loss.backward()
 
 
 
 
192
  optimizer.step()
193
-
194
- # Evaluate the model on the validation set for this fold
195
- model.eval()
196
- all_predicted = []
197
- all_labels_val = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- with torch.no_grad():
200
- for sequences, labels in val_loader:
201
- outputs = model(sequences)
202
- _, predicted = torch.max(outputs.data, 1)
203
- all_predicted.extend(predicted.cpu().numpy())
204
- all_labels_val.extend(labels.cpu().numpy())
205
-
206
- accuracy = 100 * np.mean(np.array(all_predicted) == np.array(all_labels_val))
207
- fold_results.append(accuracy)
208
- print(f"Validation Accuracy for Fold {fold+1}: {accuracy:.2f}%")
 
 
 
 
 
 
 
209
 
210
- # Save the best model found across all folds
211
- # Always save at least the first fold's model, or if this fold is better
212
- if best_model_state is None or accuracy >= best_val_accuracy:
213
- best_val_accuracy = accuracy
214
  best_model_state = copy.deepcopy(model.state_dict())
215
 
216
  print("\n----- Cross-Validation Summary -----")
217
- print(f"Fold Accuracies: {[f'{acc:.2f}%' for acc in fold_results]}")
218
- print(f"Average CV Accuracy: {np.mean(fold_results):.2f}%")
 
 
 
 
219
 
220
  # Verify that we have a model state to load
221
  if best_model_state is None:
@@ -225,45 +419,38 @@ def train_model(training_data_text: str):
225
  print("\n----- Final Evaluation on Test Set -----")
226
  final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
227
  final_model.load_state_dict(best_model_state)
228
- final_model.eval()
229
 
230
  test_dataset = TalmudDataset(test_texts, test_labels, word_to_idx, label_encoder, MAX_LEN)
231
- test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
232
-
233
- all_test_predicted = []
234
- all_test_labels = []
235
- test_losses = []
236
 
237
- criterion = nn.CrossEntropyLoss()
 
 
238
 
239
- with torch.no_grad():
240
- for sequences, labels in test_loader:
241
- outputs = final_model(sequences)
242
- loss = criterion(outputs, labels)
243
- test_losses.append(loss.item())
244
- _, predicted = torch.max(outputs.data, 1)
245
- all_test_predicted.extend(predicted.cpu().numpy())
246
- all_test_labels.extend(labels.cpu().numpy())
247
 
248
- test_accuracy = 100 * np.mean(np.array(all_test_predicted) == np.array(all_test_labels))
249
- avg_loss = np.mean(test_losses)
 
 
250
 
251
  print(f"Accuracy on the unseen test set: {test_accuracy:.2f}%")
252
  print(f"Average loss: {avg_loss:.4f}")
 
 
 
 
253
 
254
- # Calculate F1 scores per category
255
- label_names = label_encoder.classes_
256
- f1_scores_dict = {}
257
-
258
- for i, label_name in enumerate(label_names):
259
- # Create binary labels for this category
260
- binary_true = np.array(all_test_labels) == i
261
- binary_pred = np.array(all_test_predicted) == i
262
-
263
- # Calculate F1 score
264
- f1 = f1_score(binary_true, binary_pred, zero_division=0)
265
- f1_scores_dict[label_name] = float(f1)
266
- print(f"F1 Score for {label_name}: {f1:.4f}")
267
 
268
  # Convert accuracy to 0-1 range for callback
269
  accuracy_normalized = test_accuracy / 100.0
@@ -274,8 +461,11 @@ def train_model(training_data_text: str):
274
  word_to_idx_path = '/tmp/word_to_idx.pt'
275
  label_encoder_path = '/tmp/label_encoder.pkl'
276
 
 
 
 
277
  # Save model state dict
278
- torch.save(final_model.state_dict(), model_path)
279
  print(f"Saved model to {model_path}")
280
 
281
  # Save word_to_idx dictionary
@@ -287,6 +477,9 @@ def train_model(training_data_text: str):
287
  pickle.dump(label_encoder, f)
288
  print(f"Saved label_encoder to {label_encoder_path}")
289
 
 
 
 
290
  except Exception as e:
291
  print(f"Warning: Failed to save model artifacts to /tmp: {e}")
292
  # Continue even if saving fails - model is still returned in result
@@ -300,6 +493,7 @@ def train_model(training_data_text: str):
300
  'accuracy': accuracy_normalized,
301
  'loss': float(avg_loss),
302
  'f1_scores': f1_scores_dict,
 
303
  'model_path': '/tmp/latest_model.pt' # Path to saved model
304
  }
305
  }
 
1
  """
2
  Training module for Talmud language classifier
3
  Adapted from talmud_language_classifier.py for Hugging Face Spaces integration
4
+ Optimized for class imbalance and better performance
5
  """
6
 
7
  import copy
8
  import torch
9
  import torch.nn as nn
10
  import torch.optim as optim
11
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
12
  from collections import Counter
13
  from sklearn.model_selection import train_test_split, KFold
14
  from sklearn.preprocessing import LabelEncoder
15
+ from sklearn.metrics import f1_score, classification_report
16
  import numpy as np
17
  import io
18
  import os
 
22
  MAX_LEN = 100
23
  VOCAB_SIZE = 10000
24
  EMBEDDING_DIM = 128
25
+ HIDDEN_DIM = 256 # Increased for better capacity
26
+ NUM_EPOCHS = 30 # Increased epochs with early stopping
27
  BATCH_SIZE = 16
28
  N_SPLITS = 5 # Number of folds for cross-validation
29
+ EARLY_STOPPING_PATIENCE = 5 # Stop if no improvement for 5 epochs
30
+ LEARNING_RATE = 0.001
31
+ WEIGHT_DECAY = 1e-5 # L2 regularization
32
+ GRADIENT_CLIP = 1.0 # Gradient clipping
33
 
34
  # --- 1. Load and Parse Data ---
35
  def load_and_parse_data_from_string(training_data_text: str):
 
92
 
93
  # --- 4. Model Definition ---
94
  class TalmudClassifierLSTM(nn.Module):
95
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2):
96
  super(TalmudClassifierLSTM, self).__init__()
97
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
98
+ # Bidirectional LSTM - uses both forward and backward contexts
99
+ self.lstm = nn.LSTM(
100
+ embedding_dim,
101
+ hidden_dim // 2, # Divide by 2 because bidirectional doubles the output
102
+ batch_first=True,
103
+ dropout=0.3 if num_layers > 1 else 0,
104
+ num_layers=num_layers,
105
+ bidirectional=True
106
+ )
107
+ self.dropout1 = nn.Dropout(0.5)
108
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
109
  self.relu = nn.ReLU()
110
+ self.dropout2 = nn.Dropout(0.3)
111
+ self.fc2 = nn.Linear(hidden_dim // 2, output_dim)
112
 
113
  def forward(self, text):
114
  embedded = self.embedding(text)
115
+ # Get LSTM output - use both forward and backward hidden states
116
+ lstm_out, (hidden, _) = self.lstm(embedded)
117
+ # Concatenate forward and backward hidden states from last layer
118
+ # hidden shape: (num_layers * num_directions, batch, hidden_size)
119
+ if self.lstm.bidirectional:
120
+ hidden_forward = hidden[-2]
121
+ hidden_backward = hidden[-1]
122
+ hidden = torch.cat([hidden_forward, hidden_backward], dim=1)
123
+ else:
124
+ hidden = hidden[-1]
125
+
126
+ hidden = self.dropout1(hidden)
127
  out = self.fc1(hidden)
128
  out = self.relu(out)
129
+ out = self.dropout2(out)
130
  out = self.fc2(out)
131
  return out
132
 
133
+ # --- 4.5. Helper Functions ---
134
+ def calculate_class_weights(labels, label_encoder):
135
+ """Calculate class weights for weighted loss function."""
136
+ # Count occurrences of each class
137
+ label_counts = Counter(labels)
138
+ total_samples = len(labels)
139
+ num_classes = len(label_encoder.classes_)
140
+
141
+ # Calculate weights: inverse frequency, normalized
142
+ weights = np.ones(num_classes)
143
+ for i, class_name in enumerate(label_encoder.classes_):
144
+ count = label_counts.get(class_name, 1) # Avoid division by zero
145
+ # Weight is inversely proportional to frequency
146
+ weights[i] = total_samples / (num_classes * count)
147
+
148
+ # Normalize weights to sum to num_classes
149
+ weights = weights / weights.sum() * num_classes
150
+ return torch.FloatTensor(weights)
151
+
152
+ def create_weighted_sampler(labels, label_encoder):
153
+ """Create a weighted sampler for balanced batch sampling."""
154
+ # Convert string labels to encoded labels
155
+ encoded_labels = label_encoder.transform(labels)
156
+
157
+ # Calculate weights for each sample
158
+ label_counts = Counter(encoded_labels)
159
+ total_samples = len(encoded_labels)
160
+ num_classes = len(label_encoder.classes_)
161
+
162
+ sample_weights = np.ones(total_samples)
163
+ for i, label in enumerate(encoded_labels):
164
+ count = label_counts[label]
165
+ # Weight inversely proportional to class frequency
166
+ sample_weights[i] = total_samples / (num_classes * count)
167
+
168
+ return WeightedRandomSampler(
169
+ weights=sample_weights,
170
+ num_samples=len(sample_weights),
171
+ replacement=True
172
+ )
173
+
174
+ def evaluate_model(model, data_loader, criterion, label_encoder, device='cpu'):
175
+ """Evaluate model and return metrics."""
176
+ model.eval()
177
+ all_predicted = []
178
+ all_labels = []
179
+ total_loss = 0.0
180
+ num_batches = 0
181
+
182
+ with torch.no_grad():
183
+ for sequences, labels in data_loader:
184
+ sequences = sequences.to(device)
185
+ labels = labels.to(device)
186
+
187
+ outputs = model(sequences)
188
+ loss = criterion(outputs, labels)
189
+
190
+ total_loss += loss.item()
191
+ num_batches += 1
192
+
193
+ _, predicted = torch.max(outputs.data, 1)
194
+ all_predicted.extend(predicted.cpu().numpy())
195
+ all_labels.extend(labels.cpu().numpy())
196
+
197
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
198
+ accuracy = 100 * np.mean(np.array(all_predicted) == np.array(all_labels))
199
+
200
+ # Calculate per-class F1 scores
201
+ label_names = label_encoder.classes_
202
+ f1_scores_dict = {}
203
+ for i, label_name in enumerate(label_names):
204
+ binary_true = np.array(all_labels) == i
205
+ binary_pred = np.array(all_predicted) == i
206
+ f1 = f1_score(binary_true, binary_pred, zero_division=0)
207
+ f1_scores_dict[label_name] = float(f1)
208
+
209
+ # Calculate macro-averaged F1 score
210
+ macro_f1 = np.mean(list(f1_scores_dict.values()))
211
+
212
+ return {
213
+ 'accuracy': accuracy,
214
+ 'loss': avg_loss,
215
+ 'f1_scores': f1_scores_dict,
216
+ 'macro_f1': macro_f1,
217
+ 'predictions': all_predicted,
218
+ 'labels': all_labels
219
+ }
220
+
221
  # --- 5. Training Function ---
222
  def train_model(training_data_text: str):
223
  """
 
261
  print(f"\nTotal samples: {len(all_texts)}")
262
  print(f"Training set size: {len(train_texts)} (80%)")
263
  print(f"Test set size: {len(test_texts)} (20%)")
264
+
265
+ # Print class distribution
266
+ train_label_counts = Counter(train_labels)
267
+ print("\nTraining set class distribution:")
268
+ for label, count in sorted(train_label_counts.items()):
269
+ print(f" {label}: {count} ({100*count/len(train_labels):.1f}%)")
270
 
271
  # Build vocabulary and label encoder ONLY on the training data
272
  word_to_idx = build_vocab(train_texts, VOCAB_SIZE)
273
  label_encoder = LabelEncoder()
274
  label_encoder.fit(train_labels)
275
  num_classes = len(label_encoder.classes_)
276
+
277
+ # Calculate class weights for weighted loss
278
+ class_weights = calculate_class_weights(train_labels, label_encoder)
279
+ print(f"\nClass weights: {dict(zip(label_encoder.classes_, class_weights.numpy()))}")
280
+
281
+ # Set device
282
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
283
+ print(f"Using device: {device}")
284
 
285
  # Set up K-Fold Cross-Validation
286
  kfold = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
287
 
288
+ best_val_macro_f1 = 0.0
289
  best_model_state = None
290
  fold_results = []
291
 
 
298
  print(f"\n----- FOLD {fold+1}/{N_SPLITS} -----")
299
 
300
  # Create data subsets for the current fold
301
+ train_subset_texts = [train_texts[i] for i in train_ids]
302
+ train_subset_labels = [train_labels[i] for i in train_ids]
303
+ val_subset_texts = [train_texts[i] for i in val_ids]
304
+ val_subset_labels = [train_labels[i] for i in val_ids]
305
+
306
+ # Create datasets for this fold
307
+ train_dataset = TalmudDataset(train_subset_texts, train_subset_labels, word_to_idx, label_encoder, MAX_LEN)
308
+ val_dataset = TalmudDataset(val_subset_texts, val_subset_labels, word_to_idx, label_encoder, MAX_LEN)
309
+
310
+ # Create weighted sampler for balanced training
311
+ weighted_sampler = create_weighted_sampler(train_subset_labels, label_encoder)
312
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=weighted_sampler)
313
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
314
 
315
  # Initialize a new model for each fold
316
  model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
317
+ model = model.to(device)
318
+
319
+ # Use weighted loss to handle class imbalance
320
+ class_weights_device = class_weights.to(device)
321
+ criterion = nn.CrossEntropyLoss(weight=class_weights_device)
322
+
323
+ # Optimizer with weight decay for regularization
324
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
325
+
326
+ # Learning rate scheduler - reduce LR on plateau
327
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
328
+ optimizer, mode='max', factor=0.5, patience=3
329
+ )
330
+
331
+ # Early stopping variables
332
+ best_fold_macro_f1 = 0.0
333
+ best_fold_model_state = None
334
+ patience_counter = 0
335
 
336
+ # Training loop with early stopping
337
  for epoch in range(NUM_EPOCHS):
338
  model.train()
339
+ epoch_loss = 0.0
340
+ num_batches = 0
341
+
342
  for sequences, labels in train_loader:
343
+ sequences = sequences.to(device)
344
+ labels = labels.to(device)
345
+
346
  optimizer.zero_grad()
347
  outputs = model(sequences)
348
  loss = criterion(outputs, labels)
349
  loss.backward()
350
+
351
+ # Gradient clipping to prevent exploding gradients
352
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
353
+
354
  optimizer.step()
355
+ epoch_loss += loss.item()
356
+ num_batches += 1
357
+
358
+ avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
359
+
360
+ # Evaluate on validation set
361
+ val_metrics = evaluate_model(model, val_loader, criterion, label_encoder, device)
362
+
363
+ # Update learning rate based on validation macro F1
364
+ scheduler.step(val_metrics['macro_f1'])
365
+
366
+ # Print progress
367
+ print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {avg_epoch_loss:.4f}, "
368
+ f"Val Acc: {val_metrics['accuracy']:.2f}%, "
369
+ f"Val Macro F1: {val_metrics['macro_f1']:.4f}")
370
+ print(f" Per-class F1: {', '.join([f'{k}: {v:.3f}' for k, v in val_metrics['f1_scores'].items()])}")
371
+
372
+ # Early stopping based on macro F1 score
373
+ if val_metrics['macro_f1'] > best_fold_macro_f1:
374
+ best_fold_macro_f1 = val_metrics['macro_f1']
375
+ best_fold_model_state = copy.deepcopy(model.state_dict())
376
+ patience_counter = 0
377
+ else:
378
+ patience_counter += 1
379
+ if patience_counter >= EARLY_STOPPING_PATIENCE:
380
+ print(f"Early stopping triggered at epoch {epoch+1}")
381
+ break
382
 
383
+ # Load best model for this fold
384
+ if best_fold_model_state is not None:
385
+ model.load_state_dict(best_fold_model_state)
386
+
387
+ # Final evaluation on validation set
388
+ val_metrics = evaluate_model(model, val_loader, criterion, label_encoder, device)
389
+ fold_results.append({
390
+ 'accuracy': val_metrics['accuracy'],
391
+ 'macro_f1': val_metrics['macro_f1'],
392
+ 'f1_scores': val_metrics['f1_scores']
393
+ })
394
+
395
+ print(f"\nFold {fold+1} Results:")
396
+ print(f" Validation Accuracy: {val_metrics['accuracy']:.2f}%")
397
+ print(f" Validation Macro F1: {val_metrics['macro_f1']:.4f}")
398
+ for label, f1 in val_metrics['f1_scores'].items():
399
+ print(f" {label} F1: {f1:.4f}")
400
 
401
+ # Save the best model found across all folds (based on macro F1)
402
+ if best_model_state is None or val_metrics['macro_f1'] >= best_val_macro_f1:
403
+ best_val_macro_f1 = val_metrics['macro_f1']
 
404
  best_model_state = copy.deepcopy(model.state_dict())
405
 
406
  print("\n----- Cross-Validation Summary -----")
407
+ acc_strs = [f"{r['accuracy']:.2f}%" for r in fold_results]
408
+ f1_strs = [f"{r['macro_f1']:.4f}" for r in fold_results]
409
+ print(f"Fold Accuracies: {acc_strs}")
410
+ print(f"Fold Macro F1s: {f1_strs}")
411
+ print(f"Average CV Accuracy: {np.mean([r['accuracy'] for r in fold_results]):.2f}%")
412
+ print(f"Average CV Macro F1: {np.mean([r['macro_f1'] for r in fold_results]):.4f}")
413
 
414
  # Verify that we have a model state to load
415
  if best_model_state is None:
 
419
  print("\n----- Final Evaluation on Test Set -----")
420
  final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
421
  final_model.load_state_dict(best_model_state)
422
+ final_model = final_model.to(device)
423
 
424
  test_dataset = TalmudDataset(test_texts, test_labels, word_to_idx, label_encoder, MAX_LEN)
425
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
 
 
 
 
426
 
427
+ # Use weighted loss for evaluation too
428
+ class_weights_device = class_weights.to(device)
429
+ criterion = nn.CrossEntropyLoss(weight=class_weights_device)
430
 
431
+ # Evaluate on test set
432
+ test_metrics = evaluate_model(final_model, test_loader, criterion, label_encoder, device)
 
 
 
 
 
 
433
 
434
+ test_accuracy = test_metrics['accuracy']
435
+ avg_loss = test_metrics['loss']
436
+ f1_scores_dict = test_metrics['f1_scores']
437
+ macro_f1 = test_metrics['macro_f1']
438
 
439
  print(f"Accuracy on the unseen test set: {test_accuracy:.2f}%")
440
  print(f"Average loss: {avg_loss:.4f}")
441
+ print(f"Macro-averaged F1 score: {macro_f1:.4f}")
442
+ print("\nPer-class F1 scores:")
443
+ for label_name, f1 in f1_scores_dict.items():
444
+ print(f" {label_name}: {f1:.4f}")
445
 
446
+ # Print detailed classification report
447
+ print("\nClassification Report:")
448
+ print(classification_report(
449
+ test_metrics['labels'],
450
+ test_metrics['predictions'],
451
+ target_names=label_encoder.classes_,
452
+ zero_division=0
453
+ ))
 
 
 
 
 
454
 
455
  # Convert accuracy to 0-1 range for callback
456
  accuracy_normalized = test_accuracy / 100.0
 
461
  word_to_idx_path = '/tmp/word_to_idx.pt'
462
  label_encoder_path = '/tmp/label_encoder.pkl'
463
 
464
+ # Move model to CPU for saving (to ensure compatibility)
465
+ final_model_cpu = final_model.cpu()
466
+
467
  # Save model state dict
468
+ torch.save(final_model_cpu.state_dict(), model_path)
469
  print(f"Saved model to {model_path}")
470
 
471
  # Save word_to_idx dictionary
 
477
  pickle.dump(label_encoder, f)
478
  print(f"Saved label_encoder to {label_encoder_path}")
479
 
480
+ # Move model back to device for return
481
+ final_model = final_model.to(device)
482
+
483
  except Exception as e:
484
  print(f"Warning: Failed to save model artifacts to /tmp: {e}")
485
  # Continue even if saving fails - model is still returned in result
 
493
  'accuracy': accuracy_normalized,
494
  'loss': float(avg_loss),
495
  'f1_scores': f1_scores_dict,
496
+ 'macro_f1': float(macro_f1),
497
  'model_path': '/tmp/latest_model.pt' # Path to saved model
498
  }
499
  }