Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
Training module for Talmud language classifier
|
| 3 |
Adapted from talmud_language_classifier.py for Hugging Face Spaces integration
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import copy
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
import torch.optim as optim
|
| 10 |
-
from torch.utils.data import Dataset, DataLoader
|
| 11 |
from collections import Counter
|
| 12 |
from sklearn.model_selection import train_test_split, KFold
|
| 13 |
from sklearn.preprocessing import LabelEncoder
|
| 14 |
-
from sklearn.metrics import f1_score
|
| 15 |
import numpy as np
|
| 16 |
import io
|
| 17 |
import os
|
|
@@ -21,10 +22,14 @@ import pickle
|
|
| 21 |
MAX_LEN = 100
|
| 22 |
VOCAB_SIZE = 10000
|
| 23 |
EMBEDDING_DIM = 128
|
| 24 |
-
HIDDEN_DIM =
|
| 25 |
-
NUM_EPOCHS =
|
| 26 |
BATCH_SIZE = 16
|
| 27 |
N_SPLITS = 5 # Number of folds for cross-validation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# --- 1. Load and Parse Data ---
|
| 30 |
def load_and_parse_data_from_string(training_data_text: str):
|
|
@@ -87,24 +92,132 @@ class TalmudDataset(Dataset):
|
|
| 87 |
|
| 88 |
# --- 4. Model Definition ---
|
| 89 |
class TalmudClassifierLSTM(nn.Module):
|
| 90 |
-
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
|
| 91 |
super(TalmudClassifierLSTM, self).__init__()
|
| 92 |
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
|
| 93 |
-
|
| 94 |
-
self.
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
self.relu = nn.ReLU()
|
| 97 |
-
self.
|
|
|
|
| 98 |
|
| 99 |
def forward(self, text):
|
| 100 |
embedded = self.embedding(text)
|
| 101 |
-
|
| 102 |
-
hidden = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
out = self.fc1(hidden)
|
| 104 |
out = self.relu(out)
|
|
|
|
| 105 |
out = self.fc2(out)
|
| 106 |
return out
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
# --- 5. Training Function ---
|
| 109 |
def train_model(training_data_text: str):
|
| 110 |
"""
|
|
@@ -148,17 +261,31 @@ def train_model(training_data_text: str):
|
|
| 148 |
print(f"\nTotal samples: {len(all_texts)}")
|
| 149 |
print(f"Training set size: {len(train_texts)} (80%)")
|
| 150 |
print(f"Test set size: {len(test_texts)} (20%)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# Build vocabulary and label encoder ONLY on the training data
|
| 153 |
word_to_idx = build_vocab(train_texts, VOCAB_SIZE)
|
| 154 |
label_encoder = LabelEncoder()
|
| 155 |
label_encoder.fit(train_labels)
|
| 156 |
num_classes = len(label_encoder.classes_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
# Set up K-Fold Cross-Validation
|
| 159 |
kfold = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
|
| 160 |
|
| 161 |
-
|
| 162 |
best_model_state = None
|
| 163 |
fold_results = []
|
| 164 |
|
|
@@ -171,51 +298,118 @@ def train_model(training_data_text: str):
|
|
| 171 |
print(f"\n----- FOLD {fold+1}/{N_SPLITS} -----")
|
| 172 |
|
| 173 |
# Create data subsets for the current fold
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
# Initialize a new model for each fold
|
| 181 |
model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
|
|
|
| 185 |
for epoch in range(NUM_EPOCHS):
|
| 186 |
model.train()
|
|
|
|
|
|
|
|
|
|
| 187 |
for sequences, labels in train_loader:
|
|
|
|
|
|
|
|
|
|
| 188 |
optimizer.zero_grad()
|
| 189 |
outputs = model(sequences)
|
| 190 |
loss = criterion(outputs, labels)
|
| 191 |
loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
optimizer.step()
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
-
# Save the best model found across all folds
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
best_val_accuracy = accuracy
|
| 214 |
best_model_state = copy.deepcopy(model.state_dict())
|
| 215 |
|
| 216 |
print("\n----- Cross-Validation Summary -----")
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# Verify that we have a model state to load
|
| 221 |
if best_model_state is None:
|
|
@@ -225,45 +419,38 @@ def train_model(training_data_text: str):
|
|
| 225 |
print("\n----- Final Evaluation on Test Set -----")
|
| 226 |
final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
|
| 227 |
final_model.load_state_dict(best_model_state)
|
| 228 |
-
final_model.
|
| 229 |
|
| 230 |
test_dataset = TalmudDataset(test_texts, test_labels, word_to_idx, label_encoder, MAX_LEN)
|
| 231 |
-
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
|
| 232 |
-
|
| 233 |
-
all_test_predicted = []
|
| 234 |
-
all_test_labels = []
|
| 235 |
-
test_losses = []
|
| 236 |
|
| 237 |
-
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
outputs = final_model(sequences)
|
| 242 |
-
loss = criterion(outputs, labels)
|
| 243 |
-
test_losses.append(loss.item())
|
| 244 |
-
_, predicted = torch.max(outputs.data, 1)
|
| 245 |
-
all_test_predicted.extend(predicted.cpu().numpy())
|
| 246 |
-
all_test_labels.extend(labels.cpu().numpy())
|
| 247 |
|
| 248 |
-
test_accuracy =
|
| 249 |
-
avg_loss =
|
|
|
|
|
|
|
| 250 |
|
| 251 |
print(f"Accuracy on the unseen test set: {test_accuracy:.2f}%")
|
| 252 |
print(f"Average loss: {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
# Calculate F1 score
|
| 264 |
-
f1 = f1_score(binary_true, binary_pred, zero_division=0)
|
| 265 |
-
f1_scores_dict[label_name] = float(f1)
|
| 266 |
-
print(f"F1 Score for {label_name}: {f1:.4f}")
|
| 267 |
|
| 268 |
# Convert accuracy to 0-1 range for callback
|
| 269 |
accuracy_normalized = test_accuracy / 100.0
|
|
@@ -274,8 +461,11 @@ def train_model(training_data_text: str):
|
|
| 274 |
word_to_idx_path = '/tmp/word_to_idx.pt'
|
| 275 |
label_encoder_path = '/tmp/label_encoder.pkl'
|
| 276 |
|
|
|
|
|
|
|
|
|
|
| 277 |
# Save model state dict
|
| 278 |
-
torch.save(
|
| 279 |
print(f"Saved model to {model_path}")
|
| 280 |
|
| 281 |
# Save word_to_idx dictionary
|
|
@@ -287,6 +477,9 @@ def train_model(training_data_text: str):
|
|
| 287 |
pickle.dump(label_encoder, f)
|
| 288 |
print(f"Saved label_encoder to {label_encoder_path}")
|
| 289 |
|
|
|
|
|
|
|
|
|
|
| 290 |
except Exception as e:
|
| 291 |
print(f"Warning: Failed to save model artifacts to /tmp: {e}")
|
| 292 |
# Continue even if saving fails - model is still returned in result
|
|
@@ -300,6 +493,7 @@ def train_model(training_data_text: str):
|
|
| 300 |
'accuracy': accuracy_normalized,
|
| 301 |
'loss': float(avg_loss),
|
| 302 |
'f1_scores': f1_scores_dict,
|
|
|
|
| 303 |
'model_path': '/tmp/latest_model.pt' # Path to saved model
|
| 304 |
}
|
| 305 |
}
|
|
|
|
| 1 |
"""
|
| 2 |
Training module for Talmud language classifier
|
| 3 |
Adapted from talmud_language_classifier.py for Hugging Face Spaces integration
|
| 4 |
+
Optimized for class imbalance and better performance
|
| 5 |
"""
|
| 6 |
|
| 7 |
import copy
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.optim as optim
|
| 11 |
+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
| 12 |
from collections import Counter
|
| 13 |
from sklearn.model_selection import train_test_split, KFold
|
| 14 |
from sklearn.preprocessing import LabelEncoder
|
| 15 |
+
from sklearn.metrics import f1_score, classification_report
|
| 16 |
import numpy as np
|
| 17 |
import io
|
| 18 |
import os
|
|
|
|
| 22 |
MAX_LEN = 100
|
| 23 |
VOCAB_SIZE = 10000
|
| 24 |
EMBEDDING_DIM = 128
|
| 25 |
+
HIDDEN_DIM = 256 # Increased for better capacity
|
| 26 |
+
NUM_EPOCHS = 30 # Increased epochs with early stopping
|
| 27 |
BATCH_SIZE = 16
|
| 28 |
N_SPLITS = 5 # Number of folds for cross-validation
|
| 29 |
+
EARLY_STOPPING_PATIENCE = 5 # Stop if no improvement for 5 epochs
|
| 30 |
+
LEARNING_RATE = 0.001
|
| 31 |
+
WEIGHT_DECAY = 1e-5 # L2 regularization
|
| 32 |
+
GRADIENT_CLIP = 1.0 # Gradient clipping
|
| 33 |
|
| 34 |
# --- 1. Load and Parse Data ---
|
| 35 |
def load_and_parse_data_from_string(training_data_text: str):
|
|
|
|
| 92 |
|
| 93 |
# --- 4. Model Definition ---
|
| 94 |
class TalmudClassifierLSTM(nn.Module):
|
| 95 |
+
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2):
|
| 96 |
super(TalmudClassifierLSTM, self).__init__()
|
| 97 |
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
|
| 98 |
+
# Bidirectional LSTM - uses both forward and backward contexts
|
| 99 |
+
self.lstm = nn.LSTM(
|
| 100 |
+
embedding_dim,
|
| 101 |
+
hidden_dim // 2, # Divide by 2 because bidirectional doubles the output
|
| 102 |
+
batch_first=True,
|
| 103 |
+
dropout=0.3 if num_layers > 1 else 0,
|
| 104 |
+
num_layers=num_layers,
|
| 105 |
+
bidirectional=True
|
| 106 |
+
)
|
| 107 |
+
self.dropout1 = nn.Dropout(0.5)
|
| 108 |
+
self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
|
| 109 |
self.relu = nn.ReLU()
|
| 110 |
+
self.dropout2 = nn.Dropout(0.3)
|
| 111 |
+
self.fc2 = nn.Linear(hidden_dim // 2, output_dim)
|
| 112 |
|
| 113 |
def forward(self, text):
|
| 114 |
embedded = self.embedding(text)
|
| 115 |
+
# Get LSTM output - use both forward and backward hidden states
|
| 116 |
+
lstm_out, (hidden, _) = self.lstm(embedded)
|
| 117 |
+
# Concatenate forward and backward hidden states from last layer
|
| 118 |
+
# hidden shape: (num_layers * num_directions, batch, hidden_size)
|
| 119 |
+
if self.lstm.bidirectional:
|
| 120 |
+
hidden_forward = hidden[-2]
|
| 121 |
+
hidden_backward = hidden[-1]
|
| 122 |
+
hidden = torch.cat([hidden_forward, hidden_backward], dim=1)
|
| 123 |
+
else:
|
| 124 |
+
hidden = hidden[-1]
|
| 125 |
+
|
| 126 |
+
hidden = self.dropout1(hidden)
|
| 127 |
out = self.fc1(hidden)
|
| 128 |
out = self.relu(out)
|
| 129 |
+
out = self.dropout2(out)
|
| 130 |
out = self.fc2(out)
|
| 131 |
return out
|
| 132 |
|
| 133 |
+
# --- 4.5. Helper Functions ---
|
| 134 |
+
def calculate_class_weights(labels, label_encoder):
|
| 135 |
+
"""Calculate class weights for weighted loss function."""
|
| 136 |
+
# Count occurrences of each class
|
| 137 |
+
label_counts = Counter(labels)
|
| 138 |
+
total_samples = len(labels)
|
| 139 |
+
num_classes = len(label_encoder.classes_)
|
| 140 |
+
|
| 141 |
+
# Calculate weights: inverse frequency, normalized
|
| 142 |
+
weights = np.ones(num_classes)
|
| 143 |
+
for i, class_name in enumerate(label_encoder.classes_):
|
| 144 |
+
count = label_counts.get(class_name, 1) # Avoid division by zero
|
| 145 |
+
# Weight is inversely proportional to frequency
|
| 146 |
+
weights[i] = total_samples / (num_classes * count)
|
| 147 |
+
|
| 148 |
+
# Normalize weights to sum to num_classes
|
| 149 |
+
weights = weights / weights.sum() * num_classes
|
| 150 |
+
return torch.FloatTensor(weights)
|
| 151 |
+
|
| 152 |
+
def create_weighted_sampler(labels, label_encoder):
|
| 153 |
+
"""Create a weighted sampler for balanced batch sampling."""
|
| 154 |
+
# Convert string labels to encoded labels
|
| 155 |
+
encoded_labels = label_encoder.transform(labels)
|
| 156 |
+
|
| 157 |
+
# Calculate weights for each sample
|
| 158 |
+
label_counts = Counter(encoded_labels)
|
| 159 |
+
total_samples = len(encoded_labels)
|
| 160 |
+
num_classes = len(label_encoder.classes_)
|
| 161 |
+
|
| 162 |
+
sample_weights = np.ones(total_samples)
|
| 163 |
+
for i, label in enumerate(encoded_labels):
|
| 164 |
+
count = label_counts[label]
|
| 165 |
+
# Weight inversely proportional to class frequency
|
| 166 |
+
sample_weights[i] = total_samples / (num_classes * count)
|
| 167 |
+
|
| 168 |
+
return WeightedRandomSampler(
|
| 169 |
+
weights=sample_weights,
|
| 170 |
+
num_samples=len(sample_weights),
|
| 171 |
+
replacement=True
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def evaluate_model(model, data_loader, criterion, label_encoder, device='cpu'):
|
| 175 |
+
"""Evaluate model and return metrics."""
|
| 176 |
+
model.eval()
|
| 177 |
+
all_predicted = []
|
| 178 |
+
all_labels = []
|
| 179 |
+
total_loss = 0.0
|
| 180 |
+
num_batches = 0
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
for sequences, labels in data_loader:
|
| 184 |
+
sequences = sequences.to(device)
|
| 185 |
+
labels = labels.to(device)
|
| 186 |
+
|
| 187 |
+
outputs = model(sequences)
|
| 188 |
+
loss = criterion(outputs, labels)
|
| 189 |
+
|
| 190 |
+
total_loss += loss.item()
|
| 191 |
+
num_batches += 1
|
| 192 |
+
|
| 193 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 194 |
+
all_predicted.extend(predicted.cpu().numpy())
|
| 195 |
+
all_labels.extend(labels.cpu().numpy())
|
| 196 |
+
|
| 197 |
+
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
| 198 |
+
accuracy = 100 * np.mean(np.array(all_predicted) == np.array(all_labels))
|
| 199 |
+
|
| 200 |
+
# Calculate per-class F1 scores
|
| 201 |
+
label_names = label_encoder.classes_
|
| 202 |
+
f1_scores_dict = {}
|
| 203 |
+
for i, label_name in enumerate(label_names):
|
| 204 |
+
binary_true = np.array(all_labels) == i
|
| 205 |
+
binary_pred = np.array(all_predicted) == i
|
| 206 |
+
f1 = f1_score(binary_true, binary_pred, zero_division=0)
|
| 207 |
+
f1_scores_dict[label_name] = float(f1)
|
| 208 |
+
|
| 209 |
+
# Calculate macro-averaged F1 score
|
| 210 |
+
macro_f1 = np.mean(list(f1_scores_dict.values()))
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
'accuracy': accuracy,
|
| 214 |
+
'loss': avg_loss,
|
| 215 |
+
'f1_scores': f1_scores_dict,
|
| 216 |
+
'macro_f1': macro_f1,
|
| 217 |
+
'predictions': all_predicted,
|
| 218 |
+
'labels': all_labels
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
# --- 5. Training Function ---
|
| 222 |
def train_model(training_data_text: str):
|
| 223 |
"""
|
|
|
|
| 261 |
print(f"\nTotal samples: {len(all_texts)}")
|
| 262 |
print(f"Training set size: {len(train_texts)} (80%)")
|
| 263 |
print(f"Test set size: {len(test_texts)} (20%)")
|
| 264 |
+
|
| 265 |
+
# Print class distribution
|
| 266 |
+
train_label_counts = Counter(train_labels)
|
| 267 |
+
print("\nTraining set class distribution:")
|
| 268 |
+
for label, count in sorted(train_label_counts.items()):
|
| 269 |
+
print(f" {label}: {count} ({100*count/len(train_labels):.1f}%)")
|
| 270 |
|
| 271 |
# Build vocabulary and label encoder ONLY on the training data
|
| 272 |
word_to_idx = build_vocab(train_texts, VOCAB_SIZE)
|
| 273 |
label_encoder = LabelEncoder()
|
| 274 |
label_encoder.fit(train_labels)
|
| 275 |
num_classes = len(label_encoder.classes_)
|
| 276 |
+
|
| 277 |
+
# Calculate class weights for weighted loss
|
| 278 |
+
class_weights = calculate_class_weights(train_labels, label_encoder)
|
| 279 |
+
print(f"\nClass weights: {dict(zip(label_encoder.classes_, class_weights.numpy()))}")
|
| 280 |
+
|
| 281 |
+
# Set device
|
| 282 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 283 |
+
print(f"Using device: {device}")
|
| 284 |
|
| 285 |
# Set up K-Fold Cross-Validation
|
| 286 |
kfold = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
|
| 287 |
|
| 288 |
+
best_val_macro_f1 = 0.0
|
| 289 |
best_model_state = None
|
| 290 |
fold_results = []
|
| 291 |
|
|
|
|
| 298 |
print(f"\n----- FOLD {fold+1}/{N_SPLITS} -----")
|
| 299 |
|
| 300 |
# Create data subsets for the current fold
|
| 301 |
+
train_subset_texts = [train_texts[i] for i in train_ids]
|
| 302 |
+
train_subset_labels = [train_labels[i] for i in train_ids]
|
| 303 |
+
val_subset_texts = [train_texts[i] for i in val_ids]
|
| 304 |
+
val_subset_labels = [train_labels[i] for i in val_ids]
|
| 305 |
+
|
| 306 |
+
# Create datasets for this fold
|
| 307 |
+
train_dataset = TalmudDataset(train_subset_texts, train_subset_labels, word_to_idx, label_encoder, MAX_LEN)
|
| 308 |
+
val_dataset = TalmudDataset(val_subset_texts, val_subset_labels, word_to_idx, label_encoder, MAX_LEN)
|
| 309 |
+
|
| 310 |
+
# Create weighted sampler for balanced training
|
| 311 |
+
weighted_sampler = create_weighted_sampler(train_subset_labels, label_encoder)
|
| 312 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=weighted_sampler)
|
| 313 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| 314 |
|
| 315 |
# Initialize a new model for each fold
|
| 316 |
model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
|
| 317 |
+
model = model.to(device)
|
| 318 |
+
|
| 319 |
+
# Use weighted loss to handle class imbalance
|
| 320 |
+
class_weights_device = class_weights.to(device)
|
| 321 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights_device)
|
| 322 |
+
|
| 323 |
+
# Optimizer with weight decay for regularization
|
| 324 |
+
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
| 325 |
+
|
| 326 |
+
# Learning rate scheduler - reduce LR on plateau
|
| 327 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 328 |
+
optimizer, mode='max', factor=0.5, patience=3
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Early stopping variables
|
| 332 |
+
best_fold_macro_f1 = 0.0
|
| 333 |
+
best_fold_model_state = None
|
| 334 |
+
patience_counter = 0
|
| 335 |
|
| 336 |
+
# Training loop with early stopping
|
| 337 |
for epoch in range(NUM_EPOCHS):
|
| 338 |
model.train()
|
| 339 |
+
epoch_loss = 0.0
|
| 340 |
+
num_batches = 0
|
| 341 |
+
|
| 342 |
for sequences, labels in train_loader:
|
| 343 |
+
sequences = sequences.to(device)
|
| 344 |
+
labels = labels.to(device)
|
| 345 |
+
|
| 346 |
optimizer.zero_grad()
|
| 347 |
outputs = model(sequences)
|
| 348 |
loss = criterion(outputs, labels)
|
| 349 |
loss.backward()
|
| 350 |
+
|
| 351 |
+
# Gradient clipping to prevent exploding gradients
|
| 352 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
|
| 353 |
+
|
| 354 |
optimizer.step()
|
| 355 |
+
epoch_loss += loss.item()
|
| 356 |
+
num_batches += 1
|
| 357 |
+
|
| 358 |
+
avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
|
| 359 |
+
|
| 360 |
+
# Evaluate on validation set
|
| 361 |
+
val_metrics = evaluate_model(model, val_loader, criterion, label_encoder, device)
|
| 362 |
+
|
| 363 |
+
# Update learning rate based on validation macro F1
|
| 364 |
+
scheduler.step(val_metrics['macro_f1'])
|
| 365 |
+
|
| 366 |
+
# Print progress
|
| 367 |
+
print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {avg_epoch_loss:.4f}, "
|
| 368 |
+
f"Val Acc: {val_metrics['accuracy']:.2f}%, "
|
| 369 |
+
f"Val Macro F1: {val_metrics['macro_f1']:.4f}")
|
| 370 |
+
print(f" Per-class F1: {', '.join([f'{k}: {v:.3f}' for k, v in val_metrics['f1_scores'].items()])}")
|
| 371 |
+
|
| 372 |
+
# Early stopping based on macro F1 score
|
| 373 |
+
if val_metrics['macro_f1'] > best_fold_macro_f1:
|
| 374 |
+
best_fold_macro_f1 = val_metrics['macro_f1']
|
| 375 |
+
best_fold_model_state = copy.deepcopy(model.state_dict())
|
| 376 |
+
patience_counter = 0
|
| 377 |
+
else:
|
| 378 |
+
patience_counter += 1
|
| 379 |
+
if patience_counter >= EARLY_STOPPING_PATIENCE:
|
| 380 |
+
print(f"Early stopping triggered at epoch {epoch+1}")
|
| 381 |
+
break
|
| 382 |
|
| 383 |
+
# Load best model for this fold
|
| 384 |
+
if best_fold_model_state is not None:
|
| 385 |
+
model.load_state_dict(best_fold_model_state)
|
| 386 |
+
|
| 387 |
+
# Final evaluation on validation set
|
| 388 |
+
val_metrics = evaluate_model(model, val_loader, criterion, label_encoder, device)
|
| 389 |
+
fold_results.append({
|
| 390 |
+
'accuracy': val_metrics['accuracy'],
|
| 391 |
+
'macro_f1': val_metrics['macro_f1'],
|
| 392 |
+
'f1_scores': val_metrics['f1_scores']
|
| 393 |
+
})
|
| 394 |
+
|
| 395 |
+
print(f"\nFold {fold+1} Results:")
|
| 396 |
+
print(f" Validation Accuracy: {val_metrics['accuracy']:.2f}%")
|
| 397 |
+
print(f" Validation Macro F1: {val_metrics['macro_f1']:.4f}")
|
| 398 |
+
for label, f1 in val_metrics['f1_scores'].items():
|
| 399 |
+
print(f" {label} F1: {f1:.4f}")
|
| 400 |
|
| 401 |
+
# Save the best model found across all folds (based on macro F1)
|
| 402 |
+
if best_model_state is None or val_metrics['macro_f1'] >= best_val_macro_f1:
|
| 403 |
+
best_val_macro_f1 = val_metrics['macro_f1']
|
|
|
|
| 404 |
best_model_state = copy.deepcopy(model.state_dict())
|
| 405 |
|
| 406 |
print("\n----- Cross-Validation Summary -----")
|
| 407 |
+
acc_strs = [f"{r['accuracy']:.2f}%" for r in fold_results]
|
| 408 |
+
f1_strs = [f"{r['macro_f1']:.4f}" for r in fold_results]
|
| 409 |
+
print(f"Fold Accuracies: {acc_strs}")
|
| 410 |
+
print(f"Fold Macro F1s: {f1_strs}")
|
| 411 |
+
print(f"Average CV Accuracy: {np.mean([r['accuracy'] for r in fold_results]):.2f}%")
|
| 412 |
+
print(f"Average CV Macro F1: {np.mean([r['macro_f1'] for r in fold_results]):.4f}")
|
| 413 |
|
| 414 |
# Verify that we have a model state to load
|
| 415 |
if best_model_state is None:
|
|
|
|
| 419 |
print("\n----- Final Evaluation on Test Set -----")
|
| 420 |
final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes)
|
| 421 |
final_model.load_state_dict(best_model_state)
|
| 422 |
+
final_model = final_model.to(device)
|
| 423 |
|
| 424 |
test_dataset = TalmudDataset(test_texts, test_labels, word_to_idx, label_encoder, MAX_LEN)
|
| 425 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
+
# Use weighted loss for evaluation too
|
| 428 |
+
class_weights_device = class_weights.to(device)
|
| 429 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights_device)
|
| 430 |
|
| 431 |
+
# Evaluate on test set
|
| 432 |
+
test_metrics = evaluate_model(final_model, test_loader, criterion, label_encoder, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
+
test_accuracy = test_metrics['accuracy']
|
| 435 |
+
avg_loss = test_metrics['loss']
|
| 436 |
+
f1_scores_dict = test_metrics['f1_scores']
|
| 437 |
+
macro_f1 = test_metrics['macro_f1']
|
| 438 |
|
| 439 |
print(f"Accuracy on the unseen test set: {test_accuracy:.2f}%")
|
| 440 |
print(f"Average loss: {avg_loss:.4f}")
|
| 441 |
+
print(f"Macro-averaged F1 score: {macro_f1:.4f}")
|
| 442 |
+
print("\nPer-class F1 scores:")
|
| 443 |
+
for label_name, f1 in f1_scores_dict.items():
|
| 444 |
+
print(f" {label_name}: {f1:.4f}")
|
| 445 |
|
| 446 |
+
# Print detailed classification report
|
| 447 |
+
print("\nClassification Report:")
|
| 448 |
+
print(classification_report(
|
| 449 |
+
test_metrics['labels'],
|
| 450 |
+
test_metrics['predictions'],
|
| 451 |
+
target_names=label_encoder.classes_,
|
| 452 |
+
zero_division=0
|
| 453 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
# Convert accuracy to 0-1 range for callback
|
| 456 |
accuracy_normalized = test_accuracy / 100.0
|
|
|
|
| 461 |
word_to_idx_path = '/tmp/word_to_idx.pt'
|
| 462 |
label_encoder_path = '/tmp/label_encoder.pkl'
|
| 463 |
|
| 464 |
+
# Move model to CPU for saving (to ensure compatibility)
|
| 465 |
+
final_model_cpu = final_model.cpu()
|
| 466 |
+
|
| 467 |
# Save model state dict
|
| 468 |
+
torch.save(final_model_cpu.state_dict(), model_path)
|
| 469 |
print(f"Saved model to {model_path}")
|
| 470 |
|
| 471 |
# Save word_to_idx dictionary
|
|
|
|
| 477 |
pickle.dump(label_encoder, f)
|
| 478 |
print(f"Saved label_encoder to {label_encoder_path}")
|
| 479 |
|
| 480 |
+
# Move model back to device for return
|
| 481 |
+
final_model = final_model.to(device)
|
| 482 |
+
|
| 483 |
except Exception as e:
|
| 484 |
print(f"Warning: Failed to save model artifacts to /tmp: {e}")
|
| 485 |
# Continue even if saving fails - model is still returned in result
|
|
|
|
| 493 |
'accuracy': accuracy_normalized,
|
| 494 |
'loss': float(avg_loss),
|
| 495 |
'f1_scores': f1_scores_dict,
|
| 496 |
+
'macro_f1': float(macro_f1),
|
| 497 |
'model_path': '/tmp/latest_model.pt' # Path to saved model
|
| 498 |
}
|
| 499 |
}
|