mnist-digit-classifier / scripts /train_baseline.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Full Training Script for MNIST CNN
Trains the baseline CNN on the complete MNIST dataset with:
- Full train/val/test split (51k/9k/10k)
- Optional data augmentation
- Early stopping and checkpointing
- Comprehensive evaluation and metrics
- Training history visualization
Usage:
python scripts/train_baseline.py [--augment] [--epochs 20] [--lr 0.001]
"""
import sys
from pathlib import Path
import argparse
import json
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import torch
import matplotlib.pyplot as plt
import numpy as np
from scripts.data_loader import MnistDataloader
from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val
from scripts.augmentation import get_train_augmentation
from scripts.models import BaselineCNN, get_model_summary
from scripts.train import train_model, evaluate_model, save_training_history
def plot_training_history(history: dict, save_path: str):
"""Plot and save training history curves."""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
epochs = range(1, len(history['train_loss']) + 1)
# Loss curves
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss')
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Accuracy curves
axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Acc')
axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Acc')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Training and Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Learning rate
axes[1, 0].plot(epochs, history['learning_rate'], 'g-')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)
# Loss difference (overfitting indicator)
loss_diff = np.array(history['val_loss']) - np.array(history['train_loss'])
axes[1, 1].plot(epochs, loss_diff, 'm-')
axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Val Loss - Train Loss')
axes[1, 1].set_title('Overfitting Indicator (positive = overfitting)')
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Training curves saved to {save_path}")
plt.close()
def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
"""Plot and save confusion matrix."""
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# Labels
classes = list(range(10))
ax.set(xticks=np.arange(conf_matrix.shape[1]),
yticks=np.arange(conf_matrix.shape[0]),
xticklabels=classes, yticklabels=classes,
title='Confusion Matrix - MNIST Digit Classification',
ylabel='True Label',
xlabel='Predicted Label')
# Rotate the tick labels
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Add text annotations
thresh = conf_matrix.max() / 2.
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
ax.text(j, i, format(conf_matrix[i, j], 'd'),
ha="center", va="center",
color="white" if conf_matrix[i, j] > thresh else "black")
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to {save_path}")
plt.close()
def main():
parser = argparse.ArgumentParser(
description='Train baseline CNN on MNIST'
)
parser.add_argument(
'--augment', action='store_true', help='Use data augmentation'
)
parser.add_argument(
'--epochs', type=int, default=20,
help='Number of epochs (default: 20)'
)
parser.add_argument(
'--lr', type=float, default=0.001,
help='Learning rate (default: 0.001)'
)
parser.add_argument(
'--batch-size', type=int, default=64,
help='Batch size (default: 64)'
)
parser.add_argument(
'--patience', type=int, default=5,
help='Early stopping patience (default: 5)'
)
args = parser.parse_args()
print("=" * 60)
print("MNIST CNN Training - Baseline Model")
print("=" * 60)
print("Configuration:")
print(f" Epochs: {args.epochs}")
print(f" Learning Rate: {args.lr}")
print(f" Batch Size: {args.batch_size}")
print(f" Augmentation: {'Yes' if args.augment else 'No'}")
print(f" Early Stopping Patience: {args.patience}")
print()
# 1. Load data
print("1. Loading MNIST dataset...")
data_path = project_root / "data" / "raw"
loader = MnistDataloader(
str(data_path / "train-images.idx3-ubyte"),
str(data_path / "train-labels.idx1-ubyte"),
str(data_path / "t10k-images.idx3-ubyte"),
str(data_path / "t10k-labels.idx1-ubyte")
)
(x_train, y_train), (x_test, y_test) = loader.load_data()
print(f"✓ Loaded {len(x_train):,} training samples")
print(f"✓ Loaded {len(x_test):,} test samples")
print()
# 2. Train/val split
print("2. Creating train/validation split...")
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
x_train, y_train, val_split=0.15, random_seed=42
)
print(f"✓ Train: {len(x_train_split):,} samples")
print(f"✓ Validation: {len(x_val):,} samples")
print(f"✓ Test: {len(x_test):,} samples")
print()
# 3. Create datasets with optional augmentation
print("3. Creating datasets...")
augmentation = get_train_augmentation() if args.augment else None
train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
val_dataset = MnistDataset(x_val, y_val, transform=None)
test_dataset = MnistDataset(x_test, y_test, transform=None)
train_loader, val_loader = create_dataloaders(
train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
)
print(f"✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches: {len(val_loader)}")
print(f"✓ Test batches: {len(test_loader)}")
print()
# 4. Create model
print("4. Creating model...")
model = BaselineCNN()
print(get_model_summary(model))
print()
# 5. Train model
print("5. Training model...")
print("-" * 60)
history = train_model(
model,
train_loader,
val_loader,
num_epochs=args.epochs,
learning_rate=args.lr,
patience=args.patience,
checkpoint_dir='models',
device=None, # Auto-detect
use_scheduler=True,
verbose=True
)
print("-" * 60)
print()
# 6. Load best model and evaluate
print("6. Evaluating best model on test set...")
checkpoint = torch.load('models/best_model.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
device = 'cuda' if torch.cuda.is_available() else 'cpu'
results = evaluate_model(model, test_loader, device=device)
print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
print()
# 7. Print detailed metrics
print("7. Per-class metrics:")
print("-" * 60)
report = results['classification_report']
print(
f"{'Digit':<8} {'Precision':<12} {'Recall':<12} "
f"{'F1-Score':<12} {'Support':<10}"
)
print("-" * 60)
for digit in range(10):
if str(digit) in report:
metrics = report[str(digit)]
print(
f"{digit:<8} {metrics['precision']:<12.3f} "
f"{metrics['recall']:<12.3f} "
f"{metrics['f1-score']:<12.3f} {metrics['support']:<10}"
)
print("-" * 60)
acc_line = (
f"{'Accuracy':<8} {' ':<12} {' ':<12} "
f"{report['accuracy']:<12.3f} "
f"{report['macro avg']['support']:<10}"
)
print(acc_line)
macro_line = (
f"{'Macro Avg':<8} {report['macro avg']['precision']:<12.3f} "
f"{report['macro avg']['recall']:<12.3f} "
f"{report['macro avg']['f1-score']:<12.3f} "
f"{report['macro avg']['support']:<10}"
)
print(macro_line)
print()
# 8. Save results
print("8. Saving results...")
results_dir = project_root / "results"
results_dir.mkdir(exist_ok=True)
# Save history
history_path = results_dir / "baseline_training_history.json"
save_training_history(history, str(history_path))
# Plot training curves
curves_path = results_dir / "baseline_training_curves.png"
plot_training_history(history, str(curves_path))
# Plot confusion matrix
conf_matrix_path = results_dir / "baseline_confusion_matrix.png"
plot_confusion_matrix(results['confusion_matrix'], str(conf_matrix_path))
# Save evaluation metrics
metrics_path = results_dir / "baseline_metrics.json"
# Convert numpy arrays to lists for JSON serialization
metrics_data = {
'test_accuracy': float(results['accuracy']),
'classification_report': report,
'confusion_matrix': results['confusion_matrix'].tolist(),
'best_epoch': int(checkpoint['epoch']),
'best_val_loss': float(checkpoint['val_loss']),
'best_val_accuracy': float(checkpoint['val_accuracy']),
'final_train_accuracy': float(history['train_accuracy'][-1]),
'final_val_accuracy': float(history['val_accuracy'][-1]),
'config': {
'epochs': args.epochs,
'learning_rate': args.lr,
'batch_size': args.batch_size,
'augmentation': args.augment,
'patience': args.patience
}
}
with open(metrics_path, 'w') as f:
json.dump(metrics_data, f, indent=2)
print(f"Evaluation metrics saved to {metrics_path}")
print()
# 9. Summary
print("=" * 60)
print("✅ TRAINING COMPLETE")
print("=" * 60)
print(f"Best Epoch: {checkpoint['epoch'] + 1}")
print(f"Best Val Loss: {checkpoint['val_loss']:.4f}")
print(f"Best Val Accuracy: {checkpoint['val_accuracy']:.2f}%")
print(f"Test Accuracy: {results['accuracy']:.2f}%")
print()
print("Saved artifacts:")
print(" - Best model: models/best_model.pt")
print(f" - Training history: {history_path}")
print(f" - Training curves: {curves_path}")
print(f" - Confusion matrix: {conf_matrix_path}")
print(f" - Metrics: {metrics_path}")
return 0
if __name__ == "__main__":
sys.exit(main())