Spaces:
Sleeping
Sleeping
| """ | |
| Training Pipeline for MNIST CNN | |
| This module provides utilities for training and evaluating CNN models: | |
| - train_epoch: Single epoch training | |
| - validate: Validation/test evaluation | |
| - train_model: Complete training loop with early stopping | |
| - evaluate_model: Comprehensive evaluation with per-class metrics | |
| Supports MLflow experiment tracking for reproducibility. | |
| Usage: | |
| from scripts.train import train_model | |
| from scripts.models import BaselineCNN | |
| model = BaselineCNN() | |
| history = train_model( | |
| model, train_loader, val_loader, | |
| num_epochs=20, learning_rate=0.001 | |
| ) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from typing import Dict, List, Tuple, Optional | |
| from pathlib import Path | |
| import json | |
| import numpy as np | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| def train_epoch( | |
| model: nn.Module, | |
| train_loader: torch.utils.data.DataLoader, | |
| criterion: nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| device: str | |
| ) -> Dict[str, float]: | |
| """ | |
| Train model for one epoch. | |
| Args: | |
| model: PyTorch model | |
| train_loader: Training data loader | |
| criterion: Loss function | |
| optimizer: Optimizer | |
| device: Device to train on ('cpu' or 'cuda') | |
| Returns: | |
| Dictionary with 'loss' and 'accuracy' metrics | |
| """ | |
| model.train() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| # Forward pass | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| # Backward pass | |
| loss.backward() | |
| optimizer.step() | |
| # Track metrics | |
| total_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| correct += predicted.eq(labels).sum().item() | |
| total += labels.size(0) | |
| return { | |
| 'loss': total_loss / len(train_loader), | |
| 'accuracy': 100.0 * correct / total | |
| } | |
| def validate( | |
| model: nn.Module, | |
| val_loader: torch.utils.data.DataLoader, | |
| criterion: nn.Module, | |
| device: str | |
| ) -> Dict[str, float]: | |
| """ | |
| Evaluate model on validation/test set. | |
| Args: | |
| model: PyTorch model | |
| val_loader: Validation data loader | |
| criterion: Loss function | |
| device: Device to evaluate on | |
| Returns: | |
| Dictionary with 'loss' and 'accuracy' metrics | |
| """ | |
| model.eval() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for images, labels in val_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| # Forward pass | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| # Track metrics | |
| total_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| correct += predicted.eq(labels).sum().item() | |
| total += labels.size(0) | |
| return { | |
| 'loss': total_loss / len(val_loader), | |
| 'accuracy': 100.0 * correct / total | |
| } | |
| def train_model( | |
| model: nn.Module, | |
| train_loader: torch.utils.data.DataLoader, | |
| val_loader: torch.utils.data.DataLoader, | |
| num_epochs: int = 20, | |
| learning_rate: float = 0.001, | |
| patience: int = 5, | |
| checkpoint_dir: str = 'models', | |
| device: Optional[str] = None, | |
| use_scheduler: bool = True, | |
| verbose: bool = True | |
| ) -> Dict[str, List[float]]: | |
| """ | |
| Train model with early stopping and checkpointing. | |
| Args: | |
| model: PyTorch model | |
| train_loader: Training data loader | |
| val_loader: Validation data loader | |
| num_epochs: Maximum number of epochs | |
| learning_rate: Initial learning rate | |
| patience: Early stopping patience (epochs without improvement) | |
| checkpoint_dir: Directory to save model checkpoints | |
| device: Device to train on (auto-detect if None) | |
| use_scheduler: Whether to use learning rate scheduler | |
| verbose: Print training progress | |
| Returns: | |
| Dictionary with training history (losses and accuracies) | |
| """ | |
| # Setup device | |
| if device is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = model.to(device) | |
| if verbose: | |
| print(f"Training on device: {device}") | |
| print(f"Model: {model.__class__.__name__}") | |
| print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| print() | |
| # Setup training components | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| # Learning rate scheduler | |
| scheduler = None | |
| if use_scheduler: | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode='min', patience=3, factor=0.5, verbose=verbose | |
| ) | |
| # Setup checkpointing | |
| checkpoint_path = Path(checkpoint_dir) | |
| checkpoint_path.mkdir(parents=True, exist_ok=True) | |
| best_model_path = checkpoint_path / 'best_model.pt' | |
| last_model_path = checkpoint_path / 'last_model.pt' | |
| # Training history | |
| history = { | |
| 'train_loss': [], | |
| 'train_accuracy': [], | |
| 'val_loss': [], | |
| 'val_accuracy': [], | |
| 'learning_rate': [] | |
| } | |
| # Early stopping setup | |
| best_val_loss = float('inf') | |
| epochs_without_improvement = 0 | |
| # Training loop | |
| for epoch in range(num_epochs): | |
| # Train | |
| train_metrics = train_epoch(model, train_loader, criterion, optimizer, device) | |
| # Validate | |
| val_metrics = validate(model, val_loader, criterion, device) | |
| # Update history | |
| history['train_loss'].append(train_metrics['loss']) | |
| history['train_accuracy'].append(train_metrics['accuracy']) | |
| history['val_loss'].append(val_metrics['loss']) | |
| history['val_accuracy'].append(val_metrics['accuracy']) | |
| history['learning_rate'].append(optimizer.param_groups[0]['lr']) | |
| # Print progress | |
| if verbose: | |
| print(f"Epoch {epoch+1}/{num_epochs}") | |
| print(f" Train Loss: {train_metrics['loss']:.4f}, " | |
| f"Train Acc: {train_metrics['accuracy']:.2f}%") | |
| print(f" Val Loss: {val_metrics['loss']:.4f}, " | |
| f"Val Acc: {val_metrics['accuracy']:.2f}%") | |
| print(f" LR: {optimizer.param_groups[0]['lr']:.6f}") | |
| print() | |
| # Learning rate scheduling | |
| if scheduler is not None: | |
| scheduler.step(val_metrics['loss']) | |
| # Save best model | |
| if val_metrics['loss'] < best_val_loss: | |
| best_val_loss = val_metrics['loss'] | |
| epochs_without_improvement = 0 | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_loss': val_metrics['loss'], | |
| 'val_accuracy': val_metrics['accuracy'] | |
| }, best_model_path) | |
| if verbose: | |
| print(f" ✓ Best model saved (val_loss: {best_val_loss:.4f})") | |
| print() | |
| else: | |
| epochs_without_improvement += 1 | |
| # Early stopping | |
| if epochs_without_improvement >= patience: | |
| if verbose: | |
| print(f"Early stopping triggered after {epoch+1} epochs") | |
| print(f"Best validation loss: {best_val_loss:.4f}") | |
| break | |
| # Save last model | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_loss': val_metrics['loss'], | |
| 'val_accuracy': val_metrics['accuracy'] | |
| }, last_model_path) | |
| if verbose: | |
| print("Training complete!") | |
| print(f"Best validation loss: {best_val_loss:.4f}") | |
| print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%") | |
| return history | |
| def evaluate_model( | |
| model: nn.Module, | |
| test_loader: torch.utils.data.DataLoader, | |
| device: Optional[str] = None, | |
| class_names: Optional[List[str]] = None | |
| ) -> Dict: | |
| """ | |
| Comprehensive model evaluation with per-class metrics. | |
| Args: | |
| model: Trained PyTorch model | |
| test_loader: Test data loader | |
| device: Device to evaluate on | |
| class_names: List of class names (default: digits 0-9) | |
| Returns: | |
| Dictionary with metrics, predictions, and confusion matrix | |
| """ | |
| if device is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = model.to(device) | |
| model.eval() | |
| if class_names is None: | |
| class_names = [str(i) for i in range(10)] | |
| all_preds = [] | |
| all_labels = [] | |
| all_probs = [] | |
| with torch.no_grad(): | |
| for images, labels in test_loader: | |
| images = images.to(device) | |
| outputs = model(images) | |
| probs = torch.softmax(outputs, dim=1) | |
| _, predicted = outputs.max(1) | |
| all_preds.extend(predicted.cpu().numpy()) | |
| all_labels.extend(labels.numpy()) | |
| all_probs.extend(probs.cpu().numpy()) | |
| all_preds = np.array(all_preds) | |
| all_labels = np.array(all_labels) | |
| all_probs = np.array(all_probs) | |
| # Overall metrics | |
| accuracy = 100.0 * (all_preds == all_labels).sum() / len(all_labels) | |
| # Classification report | |
| report = classification_report( | |
| all_labels, all_preds, | |
| target_names=class_names, | |
| output_dict=True | |
| ) | |
| # Confusion matrix | |
| conf_matrix = confusion_matrix(all_labels, all_preds) | |
| return { | |
| 'accuracy': accuracy, | |
| 'classification_report': report, | |
| 'confusion_matrix': conf_matrix, | |
| 'predictions': all_preds, | |
| 'labels': all_labels, | |
| 'probabilities': all_probs | |
| } | |
| def save_training_history(history: Dict, filepath: str) -> None: | |
| """ | |
| Save training history to JSON file. | |
| Args: | |
| history: Training history dictionary | |
| filepath: Path to save JSON file | |
| """ | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| with open(filepath, 'w') as f: | |
| json.dump(history, f, indent=2) | |
| print(f"Training history saved to {filepath}") | |
| def load_checkpoint(checkpoint_path: str, model: nn.Module) -> Tuple[nn.Module, Dict]: | |
| """ | |
| Load model from checkpoint. | |
| Args: | |
| checkpoint_path: Path to checkpoint file | |
| model: Model instance (for loading state dict) | |
| Returns: | |
| Tuple of (loaded_model, checkpoint_dict) | |
| """ | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| return model, checkpoint | |