faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Training Pipeline for MNIST CNN
This module provides utilities for training and evaluating CNN models:
- train_epoch: Single epoch training
- validate: Validation/test evaluation
- train_model: Complete training loop with early stopping
- evaluate_model: Comprehensive evaluation with per-class metrics
Supports MLflow experiment tracking for reproducibility.
Usage:
from scripts.train import train_model
from scripts.models import BaselineCNN
model = BaselineCNN()
history = train_model(
model, train_loader, val_loader,
num_epochs=20, learning_rate=0.001
)
"""
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, List, Tuple, Optional
from pathlib import Path
import json
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
def train_epoch(
model: nn.Module,
train_loader: torch.utils.data.DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
device: str
) -> Dict[str, float]:
"""
Train model for one epoch.
Args:
model: PyTorch model
train_loader: Training data loader
criterion: Loss function
optimizer: Optimizer
device: Device to train on ('cpu' or 'cuda')
Returns:
Dictionary with 'loss' and 'accuracy' metrics
"""
model.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Forward pass
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Track metrics
total_loss += loss.item()
_, predicted = outputs.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
return {
'loss': total_loss / len(train_loader),
'accuracy': 100.0 * correct / total
}
def validate(
model: nn.Module,
val_loader: torch.utils.data.DataLoader,
criterion: nn.Module,
device: str
) -> Dict[str, float]:
"""
Evaluate model on validation/test set.
Args:
model: PyTorch model
val_loader: Validation data loader
criterion: Loss function
device: Device to evaluate on
Returns:
Dictionary with 'loss' and 'accuracy' metrics
"""
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Track metrics
total_loss += loss.item()
_, predicted = outputs.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
return {
'loss': total_loss / len(val_loader),
'accuracy': 100.0 * correct / total
}
def train_model(
model: nn.Module,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
num_epochs: int = 20,
learning_rate: float = 0.001,
patience: int = 5,
checkpoint_dir: str = 'models',
device: Optional[str] = None,
use_scheduler: bool = True,
verbose: bool = True
) -> Dict[str, List[float]]:
"""
Train model with early stopping and checkpointing.
Args:
model: PyTorch model
train_loader: Training data loader
val_loader: Validation data loader
num_epochs: Maximum number of epochs
learning_rate: Initial learning rate
patience: Early stopping patience (epochs without improvement)
checkpoint_dir: Directory to save model checkpoints
device: Device to train on (auto-detect if None)
use_scheduler: Whether to use learning rate scheduler
verbose: Print training progress
Returns:
Dictionary with training history (losses and accuracies)
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
if verbose:
print(f"Training on device: {device}")
print(f"Model: {model.__class__.__name__}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
# Setup training components
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Learning rate scheduler
scheduler = None
if use_scheduler:
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=3, factor=0.5, verbose=verbose
)
# Setup checkpointing
checkpoint_path = Path(checkpoint_dir)
checkpoint_path.mkdir(parents=True, exist_ok=True)
best_model_path = checkpoint_path / 'best_model.pt'
last_model_path = checkpoint_path / 'last_model.pt'
# Training history
history = {
'train_loss': [],
'train_accuracy': [],
'val_loss': [],
'val_accuracy': [],
'learning_rate': []
}
# Early stopping setup
best_val_loss = float('inf')
epochs_without_improvement = 0
# Training loop
for epoch in range(num_epochs):
# Train
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
# Validate
val_metrics = validate(model, val_loader, criterion, device)
# Update history
history['train_loss'].append(train_metrics['loss'])
history['train_accuracy'].append(train_metrics['accuracy'])
history['val_loss'].append(val_metrics['loss'])
history['val_accuracy'].append(val_metrics['accuracy'])
history['learning_rate'].append(optimizer.param_groups[0]['lr'])
# Print progress
if verbose:
print(f"Epoch {epoch+1}/{num_epochs}")
print(f" Train Loss: {train_metrics['loss']:.4f}, "
f"Train Acc: {train_metrics['accuracy']:.2f}%")
print(f" Val Loss: {val_metrics['loss']:.4f}, "
f"Val Acc: {val_metrics['accuracy']:.2f}%")
print(f" LR: {optimizer.param_groups[0]['lr']:.6f}")
print()
# Learning rate scheduling
if scheduler is not None:
scheduler.step(val_metrics['loss'])
# Save best model
if val_metrics['loss'] < best_val_loss:
best_val_loss = val_metrics['loss']
epochs_without_improvement = 0
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_metrics['loss'],
'val_accuracy': val_metrics['accuracy']
}, best_model_path)
if verbose:
print(f" ✓ Best model saved (val_loss: {best_val_loss:.4f})")
print()
else:
epochs_without_improvement += 1
# Early stopping
if epochs_without_improvement >= patience:
if verbose:
print(f"Early stopping triggered after {epoch+1} epochs")
print(f"Best validation loss: {best_val_loss:.4f}")
break
# Save last model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_metrics['loss'],
'val_accuracy': val_metrics['accuracy']
}, last_model_path)
if verbose:
print("Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%")
return history
def evaluate_model(
model: nn.Module,
test_loader: torch.utils.data.DataLoader,
device: Optional[str] = None,
class_names: Optional[List[str]] = None
) -> Dict:
"""
Comprehensive model evaluation with per-class metrics.
Args:
model: Trained PyTorch model
test_loader: Test data loader
device: Device to evaluate on
class_names: List of class names (default: digits 0-9)
Returns:
Dictionary with metrics, predictions, and confusion matrix
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()
if class_names is None:
class_names = [str(i) for i in range(10)]
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
_, predicted = outputs.max(1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
all_probs.extend(probs.cpu().numpy())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)
# Overall metrics
accuracy = 100.0 * (all_preds == all_labels).sum() / len(all_labels)
# Classification report
report = classification_report(
all_labels, all_preds,
target_names=class_names,
output_dict=True
)
# Confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds)
return {
'accuracy': accuracy,
'classification_report': report,
'confusion_matrix': conf_matrix,
'predictions': all_preds,
'labels': all_labels,
'probabilities': all_probs
}
def save_training_history(history: Dict, filepath: str) -> None:
"""
Save training history to JSON file.
Args:
history: Training history dictionary
filepath: Path to save JSON file
"""
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
with open(filepath, 'w') as f:
json.dump(history, f, indent=2)
print(f"Training history saved to {filepath}")
def load_checkpoint(checkpoint_path: str, model: nn.Module) -> Tuple[nn.Module, Dict]:
"""
Load model from checkpoint.
Args:
checkpoint_path: Path to checkpoint file
model: Model instance (for loading state dict)
Returns:
Tuple of (loaded_model, checkpoint_dict)
"""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint