""" Full Training Script for MNIST CNN Trains the baseline CNN on the complete MNIST dataset with: - Full train/val/test split (51k/9k/10k) - Optional data augmentation - Early stopping and checkpointing - Comprehensive evaluation and metrics - Training history visualization Usage: python scripts/train_baseline.py [--augment] [--epochs 20] [--lr 0.001] """ import sys from pathlib import Path import argparse import json # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) import torch import matplotlib.pyplot as plt import numpy as np from scripts.data_loader import MnistDataloader from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val from scripts.augmentation import get_train_augmentation from scripts.models import BaselineCNN, get_model_summary from scripts.train import train_model, evaluate_model, save_training_history def plot_training_history(history: dict, save_path: str): """Plot and save training history curves.""" fig, axes = plt.subplots(2, 2, figsize=(12, 10)) epochs = range(1, len(history['train_loss']) + 1) # Loss curves axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss') axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss') axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].set_title('Training and Validation Loss') axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # Accuracy curves axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Acc') axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Acc') axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy (%)') axes[0, 1].set_title('Training and Validation Accuracy') axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # Learning rate axes[1, 0].plot(epochs, history['learning_rate'], 'g-') axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('Learning Rate') axes[1, 0].set_title('Learning Rate Schedule') axes[1, 0].set_yscale('log') axes[1, 0].grid(True, alpha=0.3) # Loss difference (overfitting indicator) loss_diff = np.array(history['val_loss']) - np.array(history['train_loss']) axes[1, 1].plot(epochs, loss_diff, 'm-') axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3) axes[1, 1].set_xlabel('Epoch') axes[1, 1].set_ylabel('Val Loss - Train Loss') axes[1, 1].set_title('Overfitting Indicator (positive = overfitting)') axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Training curves saved to {save_path}") plt.close() def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str): """Plot and save confusion matrix.""" fig, ax = plt.subplots(figsize=(10, 8)) im = ax.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues) ax.figure.colorbar(im, ax=ax) # Labels classes = list(range(10)) ax.set(xticks=np.arange(conf_matrix.shape[1]), yticks=np.arange(conf_matrix.shape[0]), xticklabels=classes, yticklabels=classes, title='Confusion Matrix - MNIST Digit Classification', ylabel='True Label', xlabel='Predicted Label') # Rotate the tick labels plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Add text annotations thresh = conf_matrix.max() / 2. for i in range(conf_matrix.shape[0]): for j in range(conf_matrix.shape[1]): ax.text(j, i, format(conf_matrix[i, j], 'd'), ha="center", va="center", color="white" if conf_matrix[i, j] > thresh else "black") plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Confusion matrix saved to {save_path}") plt.close() def main(): parser = argparse.ArgumentParser( description='Train baseline CNN on MNIST' ) parser.add_argument( '--augment', action='store_true', help='Use data augmentation' ) parser.add_argument( '--epochs', type=int, default=20, help='Number of epochs (default: 20)' ) parser.add_argument( '--lr', type=float, default=0.001, help='Learning rate (default: 0.001)' ) parser.add_argument( '--batch-size', type=int, default=64, help='Batch size (default: 64)' ) parser.add_argument( '--patience', type=int, default=5, help='Early stopping patience (default: 5)' ) args = parser.parse_args() print("=" * 60) print("MNIST CNN Training - Baseline Model") print("=" * 60) print("Configuration:") print(f" Epochs: {args.epochs}") print(f" Learning Rate: {args.lr}") print(f" Batch Size: {args.batch_size}") print(f" Augmentation: {'Yes' if args.augment else 'No'}") print(f" Early Stopping Patience: {args.patience}") print() # 1. Load data print("1. Loading MNIST dataset...") data_path = project_root / "data" / "raw" loader = MnistDataloader( str(data_path / "train-images.idx3-ubyte"), str(data_path / "train-labels.idx1-ubyte"), str(data_path / "t10k-images.idx3-ubyte"), str(data_path / "t10k-labels.idx1-ubyte") ) (x_train, y_train), (x_test, y_test) = loader.load_data() print(f"✓ Loaded {len(x_train):,} training samples") print(f"✓ Loaded {len(x_test):,} test samples") print() # 2. Train/val split print("2. Creating train/validation split...") (x_train_split, y_train_split), (x_val, y_val) = split_train_val( x_train, y_train, val_split=0.15, random_seed=42 ) print(f"✓ Train: {len(x_train_split):,} samples") print(f"✓ Validation: {len(x_val):,} samples") print(f"✓ Test: {len(x_test):,} samples") print() # 3. Create datasets with optional augmentation print("3. Creating datasets...") augmentation = get_train_augmentation() if args.augment else None train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation) val_dataset = MnistDataset(x_val, y_val, transform=None) test_dataset = MnistDataset(x_test, y_test, transform=None) train_loader, val_loader = create_dataloaders( train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2 ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2 ) print(f"✓ Train batches: {len(train_loader)}") print(f"✓ Val batches: {len(val_loader)}") print(f"✓ Test batches: {len(test_loader)}") print() # 4. Create model print("4. Creating model...") model = BaselineCNN() print(get_model_summary(model)) print() # 5. Train model print("5. Training model...") print("-" * 60) history = train_model( model, train_loader, val_loader, num_epochs=args.epochs, learning_rate=args.lr, patience=args.patience, checkpoint_dir='models', device=None, # Auto-detect use_scheduler=True, verbose=True ) print("-" * 60) print() # 6. Load best model and evaluate print("6. Evaluating best model on test set...") checkpoint = torch.load('models/best_model.pt', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) device = 'cuda' if torch.cuda.is_available() else 'cpu' results = evaluate_model(model, test_loader, device=device) print(f"✓ Test Accuracy: {results['accuracy']:.2f}%") print() # 7. Print detailed metrics print("7. Per-class metrics:") print("-" * 60) report = results['classification_report'] print( f"{'Digit':<8} {'Precision':<12} {'Recall':<12} " f"{'F1-Score':<12} {'Support':<10}" ) print("-" * 60) for digit in range(10): if str(digit) in report: metrics = report[str(digit)] print( f"{digit:<8} {metrics['precision']:<12.3f} " f"{metrics['recall']:<12.3f} " f"{metrics['f1-score']:<12.3f} {metrics['support']:<10}" ) print("-" * 60) acc_line = ( f"{'Accuracy':<8} {' ':<12} {' ':<12} " f"{report['accuracy']:<12.3f} " f"{report['macro avg']['support']:<10}" ) print(acc_line) macro_line = ( f"{'Macro Avg':<8} {report['macro avg']['precision']:<12.3f} " f"{report['macro avg']['recall']:<12.3f} " f"{report['macro avg']['f1-score']:<12.3f} " f"{report['macro avg']['support']:<10}" ) print(macro_line) print() # 8. Save results print("8. Saving results...") results_dir = project_root / "results" results_dir.mkdir(exist_ok=True) # Save history history_path = results_dir / "baseline_training_history.json" save_training_history(history, str(history_path)) # Plot training curves curves_path = results_dir / "baseline_training_curves.png" plot_training_history(history, str(curves_path)) # Plot confusion matrix conf_matrix_path = results_dir / "baseline_confusion_matrix.png" plot_confusion_matrix(results['confusion_matrix'], str(conf_matrix_path)) # Save evaluation metrics metrics_path = results_dir / "baseline_metrics.json" # Convert numpy arrays to lists for JSON serialization metrics_data = { 'test_accuracy': float(results['accuracy']), 'classification_report': report, 'confusion_matrix': results['confusion_matrix'].tolist(), 'best_epoch': int(checkpoint['epoch']), 'best_val_loss': float(checkpoint['val_loss']), 'best_val_accuracy': float(checkpoint['val_accuracy']), 'final_train_accuracy': float(history['train_accuracy'][-1]), 'final_val_accuracy': float(history['val_accuracy'][-1]), 'config': { 'epochs': args.epochs, 'learning_rate': args.lr, 'batch_size': args.batch_size, 'augmentation': args.augment, 'patience': args.patience } } with open(metrics_path, 'w') as f: json.dump(metrics_data, f, indent=2) print(f"Evaluation metrics saved to {metrics_path}") print() # 9. Summary print("=" * 60) print("✅ TRAINING COMPLETE") print("=" * 60) print(f"Best Epoch: {checkpoint['epoch'] + 1}") print(f"Best Val Loss: {checkpoint['val_loss']:.4f}") print(f"Best Val Accuracy: {checkpoint['val_accuracy']:.2f}%") print(f"Test Accuracy: {results['accuracy']:.2f}%") print() print("Saved artifacts:") print(" - Best model: models/best_model.pt") print(f" - Training history: {history_path}") print(f" - Training curves: {curves_path}") print(f" - Confusion matrix: {conf_matrix_path}") print(f" - Metrics: {metrics_path}") return 0 if __name__ == "__main__": sys.exit(main())