""" Training Pipeline for MNIST CNN This module provides utilities for training and evaluating CNN models: - train_epoch: Single epoch training - validate: Validation/test evaluation - train_model: Complete training loop with early stopping - evaluate_model: Comprehensive evaluation with per-class metrics Supports MLflow experiment tracking for reproducibility. Usage: from scripts.train import train_model from scripts.models import BaselineCNN model = BaselineCNN() history = train_model( model, train_loader, val_loader, num_epochs=20, learning_rate=0.001 ) """ import torch import torch.nn as nn import torch.optim as optim from typing import Dict, List, Tuple, Optional from pathlib import Path import json import numpy as np from sklearn.metrics import classification_report, confusion_matrix def train_epoch( model: nn.Module, train_loader: torch.utils.data.DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, device: str ) -> Dict[str, float]: """ Train model for one epoch. Args: model: PyTorch model train_loader: Training data loader criterion: Loss function optimizer: Optimizer device: Device to train on ('cpu' or 'cuda') Returns: Dictionary with 'loss' and 'accuracy' metrics """ model.train() total_loss = 0.0 correct = 0 total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) # Forward pass optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) # Backward pass loss.backward() optimizer.step() # Track metrics total_loss += loss.item() _, predicted = outputs.max(1) correct += predicted.eq(labels).sum().item() total += labels.size(0) return { 'loss': total_loss / len(train_loader), 'accuracy': 100.0 * correct / total } def validate( model: nn.Module, val_loader: torch.utils.data.DataLoader, criterion: nn.Module, device: str ) -> Dict[str, float]: """ Evaluate model on validation/test set. Args: model: PyTorch model val_loader: Validation data loader criterion: Loss function device: Device to evaluate on Returns: Dictionary with 'loss' and 'accuracy' metrics """ model.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) # Forward pass outputs = model(images) loss = criterion(outputs, labels) # Track metrics total_loss += loss.item() _, predicted = outputs.max(1) correct += predicted.eq(labels).sum().item() total += labels.size(0) return { 'loss': total_loss / len(val_loader), 'accuracy': 100.0 * correct / total } def train_model( model: nn.Module, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, num_epochs: int = 20, learning_rate: float = 0.001, patience: int = 5, checkpoint_dir: str = 'models', device: Optional[str] = None, use_scheduler: bool = True, verbose: bool = True ) -> Dict[str, List[float]]: """ Train model with early stopping and checkpointing. Args: model: PyTorch model train_loader: Training data loader val_loader: Validation data loader num_epochs: Maximum number of epochs learning_rate: Initial learning rate patience: Early stopping patience (epochs without improvement) checkpoint_dir: Directory to save model checkpoints device: Device to train on (auto-detect if None) use_scheduler: Whether to use learning rate scheduler verbose: Print training progress Returns: Dictionary with training history (losses and accuracies) """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) if verbose: print(f"Training on device: {device}") print(f"Model: {model.__class__.__name__}") print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") print() # Setup training components criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Learning rate scheduler scheduler = None if use_scheduler: scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=3, factor=0.5, verbose=verbose ) # Setup checkpointing checkpoint_path = Path(checkpoint_dir) checkpoint_path.mkdir(parents=True, exist_ok=True) best_model_path = checkpoint_path / 'best_model.pt' last_model_path = checkpoint_path / 'last_model.pt' # Training history history = { 'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': [], 'learning_rate': [] } # Early stopping setup best_val_loss = float('inf') epochs_without_improvement = 0 # Training loop for epoch in range(num_epochs): # Train train_metrics = train_epoch(model, train_loader, criterion, optimizer, device) # Validate val_metrics = validate(model, val_loader, criterion, device) # Update history history['train_loss'].append(train_metrics['loss']) history['train_accuracy'].append(train_metrics['accuracy']) history['val_loss'].append(val_metrics['loss']) history['val_accuracy'].append(val_metrics['accuracy']) history['learning_rate'].append(optimizer.param_groups[0]['lr']) # Print progress if verbose: print(f"Epoch {epoch+1}/{num_epochs}") print(f" Train Loss: {train_metrics['loss']:.4f}, " f"Train Acc: {train_metrics['accuracy']:.2f}%") print(f" Val Loss: {val_metrics['loss']:.4f}, " f"Val Acc: {val_metrics['accuracy']:.2f}%") print(f" LR: {optimizer.param_groups[0]['lr']:.6f}") print() # Learning rate scheduling if scheduler is not None: scheduler.step(val_metrics['loss']) # Save best model if val_metrics['loss'] < best_val_loss: best_val_loss = val_metrics['loss'] epochs_without_improvement = 0 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_metrics['loss'], 'val_accuracy': val_metrics['accuracy'] }, best_model_path) if verbose: print(f" ✓ Best model saved (val_loss: {best_val_loss:.4f})") print() else: epochs_without_improvement += 1 # Early stopping if epochs_without_improvement >= patience: if verbose: print(f"Early stopping triggered after {epoch+1} epochs") print(f"Best validation loss: {best_val_loss:.4f}") break # Save last model torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_metrics['loss'], 'val_accuracy': val_metrics['accuracy'] }, last_model_path) if verbose: print("Training complete!") print(f"Best validation loss: {best_val_loss:.4f}") print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%") return history def evaluate_model( model: nn.Module, test_loader: torch.utils.data.DataLoader, device: Optional[str] = None, class_names: Optional[List[str]] = None ) -> Dict: """ Comprehensive model evaluation with per-class metrics. Args: model: Trained PyTorch model test_loader: Test data loader device: Device to evaluate on class_names: List of class names (default: digits 0-9) Returns: Dictionary with metrics, predictions, and confusion matrix """ if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) model.eval() if class_names is None: class_names = [str(i) for i in range(10)] all_preds = [] all_labels = [] all_probs = [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) outputs = model(images) probs = torch.softmax(outputs, dim=1) _, predicted = outputs.max(1) all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.numpy()) all_probs.extend(probs.cpu().numpy()) all_preds = np.array(all_preds) all_labels = np.array(all_labels) all_probs = np.array(all_probs) # Overall metrics accuracy = 100.0 * (all_preds == all_labels).sum() / len(all_labels) # Classification report report = classification_report( all_labels, all_preds, target_names=class_names, output_dict=True ) # Confusion matrix conf_matrix = confusion_matrix(all_labels, all_preds) return { 'accuracy': accuracy, 'classification_report': report, 'confusion_matrix': conf_matrix, 'predictions': all_preds, 'labels': all_labels, 'probabilities': all_probs } def save_training_history(history: Dict, filepath: str) -> None: """ Save training history to JSON file. Args: history: Training history dictionary filepath: Path to save JSON file """ Path(filepath).parent.mkdir(parents=True, exist_ok=True) with open(filepath, 'w') as f: json.dump(history, f, indent=2) print(f"Training history saved to {filepath}") def load_checkpoint(checkpoint_path: str, model: nn.Module) -> Tuple[nn.Module, Dict]: """ Load model from checkpoint. Args: checkpoint_path: Path to checkpoint file model: Model instance (for loading state dict) Returns: Tuple of (loaded_model, checkpoint_dict) """ checkpoint = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) return model, checkpoint