Spaces:
Sleeping
Sleeping
| """ | |
| Training module for Talmud language classifier | |
| Adapted from talmud_language_classifier.py for Hugging Face Spaces integration | |
| Optimized for class imbalance and better performance | |
| """ | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler | |
| from collections import Counter | |
| from sklearn.model_selection import train_test_split, KFold | |
| from sklearn.preprocessing import LabelEncoder | |
| from sklearn.metrics import f1_score, classification_report | |
| import numpy as np | |
| import io | |
| import os | |
| import pickle | |
| # --- Configuration --- | |
| MAX_LEN = 100 | |
| VOCAB_SIZE = 10000 | |
| EMBEDDING_DIM = 128 | |
| HIDDEN_DIM = 256 # Increased for better capacity | |
| NUM_EPOCHS = 30 # Increased epochs with early stopping | |
| BATCH_SIZE = 16 | |
| N_SPLITS = 5 # Number of folds for cross-validation | |
| EARLY_STOPPING_PATIENCE = 5 # Stop if no improvement for 5 epochs | |
| LEARNING_RATE = 0.001 | |
| WEIGHT_DECAY = 1e-5 # L2 regularization | |
| GRADIENT_CLIP = 1.0 # Gradient clipping | |
| # --- 1. Load and Parse Data --- | |
| def load_and_parse_data_from_string(training_data_text: str): | |
| """Reads training data from string and separates text from labels.""" | |
| texts = [] | |
| labels = [] | |
| print("Loading training data from string...") | |
| for line in training_data_text.strip().split('\n'): | |
| if '::' in line: | |
| label, text = line.strip().split('::', 1) | |
| labels.append(label.strip()) | |
| # Simple tokenization by splitting on spaces | |
| texts.append(text.strip().split()) | |
| print(f"Loaded {len(texts)} samples.") | |
| return texts, labels | |
| # --- 2. Preprocessing & Vocabulary Building --- | |
| def build_vocab(texts, vocab_size): | |
| """Builds a vocabulary from the text.""" | |
| word_counts = Counter(word for text in texts for word in text) | |
| most_common_words = [word for word, count in word_counts.most_common(vocab_size - 2)] | |
| word_to_idx = {word: i+2 for i, word in enumerate(most_common_words)} | |
| word_to_idx['<PAD>'] = 0 | |
| word_to_idx['<UNK>'] = 1 | |
| print(f"Vocabulary size: {len(word_to_idx)}") | |
| return word_to_idx | |
| # --- 3. Custom PyTorch Dataset --- | |
| class TalmudDataset(Dataset): | |
| def __init__(self, texts, labels, word_to_idx, label_encoder, max_len): | |
| self.word_to_idx = word_to_idx | |
| self.max_len = max_len | |
| self.texts = texts | |
| self.labels = labels | |
| self.label_encoder = label_encoder | |
| def text_to_sequence(self, text): | |
| return [self.word_to_idx.get(word, self.word_to_idx['<UNK>']) for word in text] | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = self.texts[idx] | |
| label_str = self.labels[idx] | |
| seq = self.text_to_sequence(text) | |
| label = self.label_encoder.transform([label_str])[0] | |
| if len(seq) > self.max_len: | |
| seq = seq[:self.max_len] | |
| else: | |
| seq = seq + [self.word_to_idx['<PAD>']] * (self.max_len - len(seq)) | |
| return torch.tensor(seq, dtype=torch.long), torch.tensor(label, dtype=torch.long) | |
| # --- 4. Model Definition --- | |
| class TalmudClassifierLSTM(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2): | |
| super(TalmudClassifierLSTM, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) | |
| # Bidirectional LSTM - uses both forward and backward contexts | |
| self.lstm = nn.LSTM( | |
| embedding_dim, | |
| hidden_dim // 2, # Divide by 2 because bidirectional doubles the output | |
| batch_first=True, | |
| dropout=0.3 if num_layers > 1 else 0, | |
| num_layers=num_layers, | |
| bidirectional=True | |
| ) | |
| self.dropout1 = nn.Dropout(0.5) | |
| self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) | |
| self.relu = nn.ReLU() | |
| self.dropout2 = nn.Dropout(0.3) | |
| self.fc2 = nn.Linear(hidden_dim // 2, output_dim) | |
| def forward(self, text): | |
| embedded = self.embedding(text) | |
| # Get LSTM output - use both forward and backward hidden states | |
| lstm_out, (hidden, _) = self.lstm(embedded) | |
| # Concatenate forward and backward hidden states from last layer | |
| # hidden shape: (num_layers * num_directions, batch, hidden_size) | |
| if self.lstm.bidirectional: | |
| hidden_forward = hidden[-2] | |
| hidden_backward = hidden[-1] | |
| hidden = torch.cat([hidden_forward, hidden_backward], dim=1) | |
| else: | |
| hidden = hidden[-1] | |
| hidden = self.dropout1(hidden) | |
| out = self.fc1(hidden) | |
| out = self.relu(out) | |
| out = self.dropout2(out) | |
| out = self.fc2(out) | |
| return out | |
| # --- 4.5. Helper Functions --- | |
| def calculate_class_weights(labels, label_encoder): | |
| """Calculate class weights for weighted loss function.""" | |
| # Count occurrences of each class | |
| label_counts = Counter(labels) | |
| total_samples = len(labels) | |
| num_classes = len(label_encoder.classes_) | |
| # Calculate weights: inverse frequency, normalized | |
| weights = np.ones(num_classes) | |
| for i, class_name in enumerate(label_encoder.classes_): | |
| count = label_counts.get(class_name, 1) # Avoid division by zero | |
| # Weight is inversely proportional to frequency | |
| weights[i] = total_samples / (num_classes * count) | |
| # Normalize weights to sum to num_classes | |
| weights = weights / weights.sum() * num_classes | |
| return torch.FloatTensor(weights) | |
| def create_weighted_sampler(labels, label_encoder): | |
| """Create a weighted sampler for balanced batch sampling.""" | |
| # Convert string labels to encoded labels | |
| encoded_labels = label_encoder.transform(labels) | |
| # Calculate weights for each sample | |
| label_counts = Counter(encoded_labels) | |
| total_samples = len(encoded_labels) | |
| num_classes = len(label_encoder.classes_) | |
| sample_weights = np.ones(total_samples) | |
| for i, label in enumerate(encoded_labels): | |
| count = label_counts[label] | |
| # Weight inversely proportional to class frequency | |
| sample_weights[i] = total_samples / (num_classes * count) | |
| return WeightedRandomSampler( | |
| weights=sample_weights, | |
| num_samples=len(sample_weights), | |
| replacement=True | |
| ) | |
| def evaluate_model(model, data_loader, criterion, label_encoder, device='cpu'): | |
| """Evaluate model and return metrics.""" | |
| model.eval() | |
| all_predicted = [] | |
| all_labels = [] | |
| total_loss = 0.0 | |
| num_batches = 0 | |
| with torch.no_grad(): | |
| for sequences, labels in data_loader: | |
| sequences = sequences.to(device) | |
| labels = labels.to(device) | |
| outputs = model(sequences) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| num_batches += 1 | |
| _, predicted = torch.max(outputs.data, 1) | |
| all_predicted.extend(predicted.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 | |
| accuracy = 100 * np.mean(np.array(all_predicted) == np.array(all_labels)) | |
| # Calculate per-class F1 scores | |
| label_names = label_encoder.classes_ | |
| f1_scores_dict = {} | |
| for i, label_name in enumerate(label_names): | |
| binary_true = np.array(all_labels) == i | |
| binary_pred = np.array(all_predicted) == i | |
| f1 = f1_score(binary_true, binary_pred, zero_division=0) | |
| f1_scores_dict[label_name] = float(f1) | |
| # Calculate macro-averaged F1 score | |
| macro_f1 = np.mean(list(f1_scores_dict.values())) | |
| return { | |
| 'accuracy': accuracy, | |
| 'loss': avg_loss, | |
| 'f1_scores': f1_scores_dict, | |
| 'macro_f1': macro_f1, | |
| 'predictions': all_predicted, | |
| 'labels': all_labels | |
| } | |
| # --- 5. Training Function --- | |
| def train_model(training_data_text: str): | |
| """ | |
| Train the model on provided training data string. | |
| Returns training stats including accuracy, loss, and F1 scores. | |
| """ | |
| # Parse training data from string | |
| all_texts, all_labels = load_and_parse_data_from_string(training_data_text) | |
| if len(all_texts) == 0: | |
| raise ValueError("No training data provided") | |
| # Check for sufficient data and multiple classes | |
| unique_labels = set(all_labels) | |
| num_classes = len(unique_labels) | |
| if num_classes < 2: | |
| raise ValueError(f"Training data must contain at least 2 different classes. Found {num_classes} class(es).") | |
| if len(all_texts) < 10: | |
| raise ValueError(f"Training data must contain at least 10 samples. Found {len(all_texts)} samples.") | |
| # Check if we have enough samples per class for stratification | |
| # Stratification requires at least 2 samples per class for a 80/20 split | |
| min_samples_per_class = min(all_labels.count(label) for label in unique_labels) | |
| if min_samples_per_class < 2: | |
| raise ValueError(f"Each class must have at least 2 samples for train/test split. Minimum samples per class: {min_samples_per_class}") | |
| # Stratify ensures the split has a similar distribution of labels | |
| # Only use stratify if we have multiple classes and sufficient samples | |
| try: | |
| train_texts, test_texts, train_labels, test_labels = train_test_split( | |
| all_texts, all_labels, test_size=0.2, random_state=42, stratify=all_labels | |
| ) | |
| except ValueError as e: | |
| # If stratification fails (e.g., insufficient samples per class), fall back to non-stratified split | |
| if "least 2 samples" in str(e) or "class" in str(e).lower(): | |
| raise ValueError(f"Stratification failed: {str(e)}. Ensure each class has at least 2 samples.") | |
| raise | |
| print(f"\nTotal samples: {len(all_texts)}") | |
| print(f"Training set size: {len(train_texts)} (80%)") | |
| print(f"Test set size: {len(test_texts)} (20%)") | |
| # Print class distribution | |
| train_label_counts = Counter(train_labels) | |
| print("\nTraining set class distribution:") | |
| for label, count in sorted(train_label_counts.items()): | |
| print(f" {label}: {count} ({100*count/len(train_labels):.1f}%)") | |
| # Build vocabulary and label encoder ONLY on the training data | |
| word_to_idx = build_vocab(train_texts, VOCAB_SIZE) | |
| label_encoder = LabelEncoder() | |
| label_encoder.fit(train_labels) | |
| num_classes = len(label_encoder.classes_) | |
| # Calculate class weights for weighted loss | |
| class_weights = calculate_class_weights(train_labels, label_encoder) | |
| print(f"\nClass weights: {dict(zip(label_encoder.classes_, class_weights.numpy()))}") | |
| # Set device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Set up K-Fold Cross-Validation | |
| kfold = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42) | |
| best_val_macro_f1 = 0.0 | |
| best_model_state = None | |
| fold_results = [] | |
| print(f"\nStarting {N_SPLITS}-Fold Cross-Validation...") | |
| # Create the full training dataset once | |
| full_train_dataset = TalmudDataset(train_texts, train_labels, word_to_idx, label_encoder, MAX_LEN) | |
| for fold, (train_ids, val_ids) in enumerate(kfold.split(full_train_dataset)): | |
| print(f"\n----- FOLD {fold+1}/{N_SPLITS} -----") | |
| # Create data subsets for the current fold | |
| train_subset_texts = [train_texts[i] for i in train_ids] | |
| train_subset_labels = [train_labels[i] for i in train_ids] | |
| val_subset_texts = [train_texts[i] for i in val_ids] | |
| val_subset_labels = [train_labels[i] for i in val_ids] | |
| # Create datasets for this fold | |
| train_dataset = TalmudDataset(train_subset_texts, train_subset_labels, word_to_idx, label_encoder, MAX_LEN) | |
| val_dataset = TalmudDataset(val_subset_texts, val_subset_labels, word_to_idx, label_encoder, MAX_LEN) | |
| # Create weighted sampler for balanced training | |
| weighted_sampler = create_weighted_sampler(train_subset_labels, label_encoder) | |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=weighted_sampler) | |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) | |
| # Initialize a new model for each fold | |
| model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes) | |
| model = model.to(device) | |
| # Use weighted loss to handle class imbalance | |
| class_weights_device = class_weights.to(device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights_device) | |
| # Optimizer with weight decay for regularization | |
| optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) | |
| # Learning rate scheduler - reduce LR on plateau | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode='max', factor=0.5, patience=3 | |
| ) | |
| # Early stopping variables | |
| best_fold_macro_f1 = 0.0 | |
| best_fold_model_state = None | |
| patience_counter = 0 | |
| # Training loop with early stopping | |
| for epoch in range(NUM_EPOCHS): | |
| model.train() | |
| epoch_loss = 0.0 | |
| num_batches = 0 | |
| for sequences, labels in train_loader: | |
| sequences = sequences.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(sequences) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| # Gradient clipping to prevent exploding gradients | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP) | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| num_batches += 1 | |
| avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0.0 | |
| # Evaluate on validation set | |
| val_metrics = evaluate_model(model, val_loader, criterion, label_encoder, device) | |
| # Update learning rate based on validation macro F1 | |
| scheduler.step(val_metrics['macro_f1']) | |
| # Print progress | |
| print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {avg_epoch_loss:.4f}, " | |
| f"Val Acc: {val_metrics['accuracy']:.2f}%, " | |
| f"Val Macro F1: {val_metrics['macro_f1']:.4f}") | |
| print(f" Per-class F1: {', '.join([f'{k}: {v:.3f}' for k, v in val_metrics['f1_scores'].items()])}") | |
| # Early stopping based on macro F1 score | |
| if val_metrics['macro_f1'] > best_fold_macro_f1: | |
| best_fold_macro_f1 = val_metrics['macro_f1'] | |
| best_fold_model_state = copy.deepcopy(model.state_dict()) | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= EARLY_STOPPING_PATIENCE: | |
| print(f"Early stopping triggered at epoch {epoch+1}") | |
| break | |
| # Load best model for this fold | |
| if best_fold_model_state is not None: | |
| model.load_state_dict(best_fold_model_state) | |
| # Final evaluation on validation set | |
| val_metrics = evaluate_model(model, val_loader, criterion, label_encoder, device) | |
| fold_results.append({ | |
| 'accuracy': val_metrics['accuracy'], | |
| 'macro_f1': val_metrics['macro_f1'], | |
| 'f1_scores': val_metrics['f1_scores'] | |
| }) | |
| print(f"\nFold {fold+1} Results:") | |
| print(f" Validation Accuracy: {val_metrics['accuracy']:.2f}%") | |
| print(f" Validation Macro F1: {val_metrics['macro_f1']:.4f}") | |
| for label, f1 in val_metrics['f1_scores'].items(): | |
| print(f" {label} F1: {f1:.4f}") | |
| # Save the best model found across all folds (based on macro F1) | |
| if best_model_state is None or val_metrics['macro_f1'] >= best_val_macro_f1: | |
| best_val_macro_f1 = val_metrics['macro_f1'] | |
| best_model_state = copy.deepcopy(model.state_dict()) | |
| print("\n----- Cross-Validation Summary -----") | |
| acc_strs = [f"{r['accuracy']:.2f}%" for r in fold_results] | |
| f1_strs = [f"{r['macro_f1']:.4f}" for r in fold_results] | |
| print(f"Fold Accuracies: {acc_strs}") | |
| print(f"Fold Macro F1s: {f1_strs}") | |
| print(f"Average CV Accuracy: {np.mean([r['accuracy'] for r in fold_results]):.2f}%") | |
| print(f"Average CV Macro F1: {np.mean([r['macro_f1'] for r in fold_results]):.4f}") | |
| # Verify that we have a model state to load | |
| if best_model_state is None: | |
| raise RuntimeError("No model state was saved during cross-validation. This should not happen.") | |
| # Final Evaluation on the Held-Out Test Set | |
| print("\n----- Final Evaluation on Test Set -----") | |
| final_model = TalmudClassifierLSTM(len(word_to_idx), EMBEDDING_DIM, HIDDEN_DIM, num_classes) | |
| final_model.load_state_dict(best_model_state) | |
| final_model = final_model.to(device) | |
| test_dataset = TalmudDataset(test_texts, test_labels, word_to_idx, label_encoder, MAX_LEN) | |
| test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) | |
| # Use weighted loss for evaluation too | |
| class_weights_device = class_weights.to(device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights_device) | |
| # Evaluate on test set | |
| test_metrics = evaluate_model(final_model, test_loader, criterion, label_encoder, device) | |
| test_accuracy = test_metrics['accuracy'] | |
| avg_loss = test_metrics['loss'] | |
| f1_scores_dict = test_metrics['f1_scores'] | |
| macro_f1 = test_metrics['macro_f1'] | |
| print(f"Accuracy on the unseen test set: {test_accuracy:.2f}%") | |
| print(f"Average loss: {avg_loss:.4f}") | |
| print(f"Macro-averaged F1 score: {macro_f1:.4f}") | |
| print("\nPer-class F1 scores:") | |
| for label_name, f1 in f1_scores_dict.items(): | |
| print(f" {label_name}: {f1:.4f}") | |
| # Print detailed classification report | |
| print("\nClassification Report:") | |
| print(classification_report( | |
| test_metrics['labels'], | |
| test_metrics['predictions'], | |
| target_names=label_encoder.classes_, | |
| zero_division=0 | |
| )) | |
| # Convert accuracy to 0-1 range for callback | |
| accuracy_normalized = test_accuracy / 100.0 | |
| # Save model artifacts to /workspace for persistent storage | |
| # /workspace is persistent across Space restarts, unlike /tmp | |
| try: | |
| # Create /workspace directory if it doesn't exist | |
| workspace_dir = '/workspace' | |
| os.makedirs(workspace_dir, exist_ok=True) | |
| model_path = os.path.join(workspace_dir, 'latest_model.pt') | |
| word_to_idx_path = os.path.join(workspace_dir, 'word_to_idx.pt') | |
| label_encoder_path = os.path.join(workspace_dir, 'label_encoder.pkl') | |
| # Move model to CPU for saving (to ensure compatibility) | |
| final_model_cpu = final_model.cpu() | |
| # Save model state dict | |
| torch.save(final_model_cpu.state_dict(), model_path) | |
| print(f"Saved model to {model_path}") | |
| # Save word_to_idx dictionary | |
| torch.save(word_to_idx, word_to_idx_path) | |
| print(f"Saved word_to_idx to {word_to_idx_path}") | |
| # Save label_encoder | |
| with open(label_encoder_path, 'wb') as f: | |
| pickle.dump(label_encoder, f) | |
| print(f"Saved label_encoder to {label_encoder_path}") | |
| # Move model back to device for return | |
| final_model = final_model.to(device) | |
| print(f"Model artifacts saved to persistent storage in {workspace_dir}") | |
| except Exception as e: | |
| print(f"Warning: Failed to save model artifacts to /workspace: {e}") | |
| # Continue even if saving fails - model is still returned in result | |
| # Return model and stats | |
| return { | |
| 'model': final_model, | |
| 'word_to_idx': word_to_idx, | |
| 'label_encoder': label_encoder, | |
| 'stats': { | |
| 'accuracy': accuracy_normalized, | |
| 'loss': float(avg_loss), | |
| 'f1_scores': f1_scores_dict, | |
| 'macro_f1': float(macro_f1), | |
| 'model_path': '/workspace/latest_model.pt' # Path to saved model | |
| } | |
| } |