Spaces:
Sleeping
Sleeping
| """ | |
| Full Training Script for MNIST CNN | |
| Trains the baseline CNN on the complete MNIST dataset with: | |
| - Full train/val/test split (51k/9k/10k) | |
| - Optional data augmentation | |
| - Early stopping and checkpointing | |
| - Comprehensive evaluation and metrics | |
| - Training history visualization | |
| Usage: | |
| python scripts/train_baseline.py [--augment] [--epochs 20] [--lr 0.001] | |
| """ | |
| import sys | |
| from pathlib import Path | |
| import argparse | |
| import json | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from scripts.data_loader import MnistDataloader | |
| from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val | |
| from scripts.augmentation import get_train_augmentation | |
| from scripts.models import BaselineCNN, get_model_summary | |
| from scripts.train import train_model, evaluate_model, save_training_history | |
| def plot_training_history(history: dict, save_path: str): | |
| """Plot and save training history curves.""" | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 10)) | |
| epochs = range(1, len(history['train_loss']) + 1) | |
| # Loss curves | |
| axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss') | |
| axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss') | |
| axes[0, 0].set_xlabel('Epoch') | |
| axes[0, 0].set_ylabel('Loss') | |
| axes[0, 0].set_title('Training and Validation Loss') | |
| axes[0, 0].legend() | |
| axes[0, 0].grid(True, alpha=0.3) | |
| # Accuracy curves | |
| axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Acc') | |
| axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Acc') | |
| axes[0, 1].set_xlabel('Epoch') | |
| axes[0, 1].set_ylabel('Accuracy (%)') | |
| axes[0, 1].set_title('Training and Validation Accuracy') | |
| axes[0, 1].legend() | |
| axes[0, 1].grid(True, alpha=0.3) | |
| # Learning rate | |
| axes[1, 0].plot(epochs, history['learning_rate'], 'g-') | |
| axes[1, 0].set_xlabel('Epoch') | |
| axes[1, 0].set_ylabel('Learning Rate') | |
| axes[1, 0].set_title('Learning Rate Schedule') | |
| axes[1, 0].set_yscale('log') | |
| axes[1, 0].grid(True, alpha=0.3) | |
| # Loss difference (overfitting indicator) | |
| loss_diff = np.array(history['val_loss']) - np.array(history['train_loss']) | |
| axes[1, 1].plot(epochs, loss_diff, 'm-') | |
| axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3) | |
| axes[1, 1].set_xlabel('Epoch') | |
| axes[1, 1].set_ylabel('Val Loss - Train Loss') | |
| axes[1, 1].set_title('Overfitting Indicator (positive = overfitting)') | |
| axes[1, 1].grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"Training curves saved to {save_path}") | |
| plt.close() | |
| def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str): | |
| """Plot and save confusion matrix.""" | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| im = ax.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues) | |
| ax.figure.colorbar(im, ax=ax) | |
| # Labels | |
| classes = list(range(10)) | |
| ax.set(xticks=np.arange(conf_matrix.shape[1]), | |
| yticks=np.arange(conf_matrix.shape[0]), | |
| xticklabels=classes, yticklabels=classes, | |
| title='Confusion Matrix - MNIST Digit Classification', | |
| ylabel='True Label', | |
| xlabel='Predicted Label') | |
| # Rotate the tick labels | |
| plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") | |
| # Add text annotations | |
| thresh = conf_matrix.max() / 2. | |
| for i in range(conf_matrix.shape[0]): | |
| for j in range(conf_matrix.shape[1]): | |
| ax.text(j, i, format(conf_matrix[i, j], 'd'), | |
| ha="center", va="center", | |
| color="white" if conf_matrix[i, j] > thresh else "black") | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"Confusion matrix saved to {save_path}") | |
| plt.close() | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='Train baseline CNN on MNIST' | |
| ) | |
| parser.add_argument( | |
| '--augment', action='store_true', help='Use data augmentation' | |
| ) | |
| parser.add_argument( | |
| '--epochs', type=int, default=20, | |
| help='Number of epochs (default: 20)' | |
| ) | |
| parser.add_argument( | |
| '--lr', type=float, default=0.001, | |
| help='Learning rate (default: 0.001)' | |
| ) | |
| parser.add_argument( | |
| '--batch-size', type=int, default=64, | |
| help='Batch size (default: 64)' | |
| ) | |
| parser.add_argument( | |
| '--patience', type=int, default=5, | |
| help='Early stopping patience (default: 5)' | |
| ) | |
| args = parser.parse_args() | |
| print("=" * 60) | |
| print("MNIST CNN Training - Baseline Model") | |
| print("=" * 60) | |
| print("Configuration:") | |
| print(f" Epochs: {args.epochs}") | |
| print(f" Learning Rate: {args.lr}") | |
| print(f" Batch Size: {args.batch_size}") | |
| print(f" Augmentation: {'Yes' if args.augment else 'No'}") | |
| print(f" Early Stopping Patience: {args.patience}") | |
| print() | |
| # 1. Load data | |
| print("1. Loading MNIST dataset...") | |
| data_path = project_root / "data" / "raw" | |
| loader = MnistDataloader( | |
| str(data_path / "train-images.idx3-ubyte"), | |
| str(data_path / "train-labels.idx1-ubyte"), | |
| str(data_path / "t10k-images.idx3-ubyte"), | |
| str(data_path / "t10k-labels.idx1-ubyte") | |
| ) | |
| (x_train, y_train), (x_test, y_test) = loader.load_data() | |
| print(f"✓ Loaded {len(x_train):,} training samples") | |
| print(f"✓ Loaded {len(x_test):,} test samples") | |
| print() | |
| # 2. Train/val split | |
| print("2. Creating train/validation split...") | |
| (x_train_split, y_train_split), (x_val, y_val) = split_train_val( | |
| x_train, y_train, val_split=0.15, random_seed=42 | |
| ) | |
| print(f"✓ Train: {len(x_train_split):,} samples") | |
| print(f"✓ Validation: {len(x_val):,} samples") | |
| print(f"✓ Test: {len(x_test):,} samples") | |
| print() | |
| # 3. Create datasets with optional augmentation | |
| print("3. Creating datasets...") | |
| augmentation = get_train_augmentation() if args.augment else None | |
| train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation) | |
| val_dataset = MnistDataset(x_val, y_val, transform=None) | |
| test_dataset = MnistDataset(x_test, y_test, transform=None) | |
| train_loader, val_loader = create_dataloaders( | |
| train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2 | |
| ) | |
| test_loader = torch.utils.data.DataLoader( | |
| test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2 | |
| ) | |
| print(f"✓ Train batches: {len(train_loader)}") | |
| print(f"✓ Val batches: {len(val_loader)}") | |
| print(f"✓ Test batches: {len(test_loader)}") | |
| print() | |
| # 4. Create model | |
| print("4. Creating model...") | |
| model = BaselineCNN() | |
| print(get_model_summary(model)) | |
| print() | |
| # 5. Train model | |
| print("5. Training model...") | |
| print("-" * 60) | |
| history = train_model( | |
| model, | |
| train_loader, | |
| val_loader, | |
| num_epochs=args.epochs, | |
| learning_rate=args.lr, | |
| patience=args.patience, | |
| checkpoint_dir='models', | |
| device=None, # Auto-detect | |
| use_scheduler=True, | |
| verbose=True | |
| ) | |
| print("-" * 60) | |
| print() | |
| # 6. Load best model and evaluate | |
| print("6. Evaluating best model on test set...") | |
| checkpoint = torch.load('models/best_model.pt', map_location='cpu') | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| results = evaluate_model(model, test_loader, device=device) | |
| print(f"✓ Test Accuracy: {results['accuracy']:.2f}%") | |
| print() | |
| # 7. Print detailed metrics | |
| print("7. Per-class metrics:") | |
| print("-" * 60) | |
| report = results['classification_report'] | |
| print( | |
| f"{'Digit':<8} {'Precision':<12} {'Recall':<12} " | |
| f"{'F1-Score':<12} {'Support':<10}" | |
| ) | |
| print("-" * 60) | |
| for digit in range(10): | |
| if str(digit) in report: | |
| metrics = report[str(digit)] | |
| print( | |
| f"{digit:<8} {metrics['precision']:<12.3f} " | |
| f"{metrics['recall']:<12.3f} " | |
| f"{metrics['f1-score']:<12.3f} {metrics['support']:<10}" | |
| ) | |
| print("-" * 60) | |
| acc_line = ( | |
| f"{'Accuracy':<8} {' ':<12} {' ':<12} " | |
| f"{report['accuracy']:<12.3f} " | |
| f"{report['macro avg']['support']:<10}" | |
| ) | |
| print(acc_line) | |
| macro_line = ( | |
| f"{'Macro Avg':<8} {report['macro avg']['precision']:<12.3f} " | |
| f"{report['macro avg']['recall']:<12.3f} " | |
| f"{report['macro avg']['f1-score']:<12.3f} " | |
| f"{report['macro avg']['support']:<10}" | |
| ) | |
| print(macro_line) | |
| print() | |
| # 8. Save results | |
| print("8. Saving results...") | |
| results_dir = project_root / "results" | |
| results_dir.mkdir(exist_ok=True) | |
| # Save history | |
| history_path = results_dir / "baseline_training_history.json" | |
| save_training_history(history, str(history_path)) | |
| # Plot training curves | |
| curves_path = results_dir / "baseline_training_curves.png" | |
| plot_training_history(history, str(curves_path)) | |
| # Plot confusion matrix | |
| conf_matrix_path = results_dir / "baseline_confusion_matrix.png" | |
| plot_confusion_matrix(results['confusion_matrix'], str(conf_matrix_path)) | |
| # Save evaluation metrics | |
| metrics_path = results_dir / "baseline_metrics.json" | |
| # Convert numpy arrays to lists for JSON serialization | |
| metrics_data = { | |
| 'test_accuracy': float(results['accuracy']), | |
| 'classification_report': report, | |
| 'confusion_matrix': results['confusion_matrix'].tolist(), | |
| 'best_epoch': int(checkpoint['epoch']), | |
| 'best_val_loss': float(checkpoint['val_loss']), | |
| 'best_val_accuracy': float(checkpoint['val_accuracy']), | |
| 'final_train_accuracy': float(history['train_accuracy'][-1]), | |
| 'final_val_accuracy': float(history['val_accuracy'][-1]), | |
| 'config': { | |
| 'epochs': args.epochs, | |
| 'learning_rate': args.lr, | |
| 'batch_size': args.batch_size, | |
| 'augmentation': args.augment, | |
| 'patience': args.patience | |
| } | |
| } | |
| with open(metrics_path, 'w') as f: | |
| json.dump(metrics_data, f, indent=2) | |
| print(f"Evaluation metrics saved to {metrics_path}") | |
| print() | |
| # 9. Summary | |
| print("=" * 60) | |
| print("✅ TRAINING COMPLETE") | |
| print("=" * 60) | |
| print(f"Best Epoch: {checkpoint['epoch'] + 1}") | |
| print(f"Best Val Loss: {checkpoint['val_loss']:.4f}") | |
| print(f"Best Val Accuracy: {checkpoint['val_accuracy']:.2f}%") | |
| print(f"Test Accuracy: {results['accuracy']:.2f}%") | |
| print() | |
| print("Saved artifacts:") | |
| print(" - Best model: models/best_model.pt") | |
| print(f" - Training history: {history_path}") | |
| print(f" - Training curves: {curves_path}") | |
| print(f" - Confusion matrix: {conf_matrix_path}") | |
| print(f" - Metrics: {metrics_path}") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |