mnist-digit-classifier / scripts /train_with_mlflow.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
MLflow-Enabled Training Script for MNIST CNN
Full training script with comprehensive MLflow tracking:
- Hyperparameters and model architecture
- Per-epoch metrics (loss, accuracy, learning rate)
- System information and environment
- Model artifacts and checkpoints
- Training visualizations
- Confusion matrix and classification report
Usage:
python scripts/train_with_mlflow.py --epochs 20 --lr 0.001 --augment
python scripts/train_with_mlflow.py --help
"""
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import json
import sys
import numpy as np
import mlflow
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from scripts.models import BaselineCNN, count_parameters
from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val
from scripts.train import train_epoch, validate, evaluate_model, save_training_history
from scripts.data_loader import MnistDataloader
from scripts.augmentation import get_train_augmentation
from scripts.mlflow_setup import (
setup_mlflow, log_model_params, log_training_config,
log_data_info, log_system_info, log_metrics_epoch,
log_artifact_path
)
def train_with_mlflow(
model: nn.Module,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
test_loader: torch.utils.data.DataLoader,
config: dict,
run_name: str = None
) -> dict:
"""
Train model with full MLflow tracking.
Args:
model: PyTorch model to train
train_loader: Training data loader
val_loader: Validation data loader
test_loader: Test data loader
config: Training configuration dictionary
run_name: Optional name for MLflow run
Returns:
Training history dictionary
"""
device = config['device']
num_epochs = config['num_epochs']
learning_rate = config['learning_rate']
# Setup MLflow
setup_mlflow("mnist-digit-classification")
# Start MLflow run
with mlflow.start_run(run_name=run_name):
print("\n" + "="*70)
print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")
print("="*70 + "\n")
# Log all configuration
print("Logging configuration to MLflow...")
log_training_config(config)
log_model_params(model)
log_data_info(
train_size=len(train_loader.dataset),
val_size=len(val_loader.dataset),
test_size=len(test_loader.dataset),
num_classes=10,
augmentation=config.get('augmentation', False)
)
log_system_info()
# Log model architecture as text
total_params, trainable_params = count_parameters(model)
model_summary = f"""
Model: {model.__class__.__name__}
Total Parameters: {total_params:,}
Trainable Parameters: {trainable_params:,}
Device: {device}
Architecture:
{str(model)}
"""
mlflow.log_text(model_summary, "model_architecture.txt")
# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=3, factor=0.5, verbose=True
)
# Training history
history = {
'train_loss': [],
'train_accuracy': [],
'val_loss': [],
'val_accuracy': [],
'learning_rate': []
}
best_val_loss = float('inf')
patience = 5
patience_counter = 0
print(f"\nStarting training for {num_epochs} epochs...")
print(f"Device: {device}")
total_p, _ = count_parameters(model)
print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)")
print("-" * 70)
for epoch in range(num_epochs):
# Train
train_metrics = train_epoch(
model, train_loader, criterion, optimizer, device
)
# Validate
val_metrics = validate(model, val_loader, criterion, device)
# Get current learning rate
current_lr = optimizer.param_groups[0]['lr']
# Update scheduler
scheduler.step(val_metrics['loss'])
# Save history
history['train_loss'].append(train_metrics['loss'])
history['train_accuracy'].append(train_metrics['accuracy'])
history['val_loss'].append(val_metrics['loss'])
history['val_accuracy'].append(val_metrics['accuracy'])
history['learning_rate'].append(current_lr)
# Log metrics to MLflow
mlflow_metrics = {
'train_loss': train_metrics['loss'],
'train_accuracy': train_metrics['accuracy'],
'val_loss': val_metrics['loss'],
'val_accuracy': val_metrics['accuracy'],
'learning_rate': current_lr,
'epoch': epoch + 1
}
log_metrics_epoch(mlflow_metrics, step=epoch)
# Print progress
print(
f"Epoch {epoch+1}/{num_epochs} | "
f"Train Loss: {train_metrics['loss']:.4f} "
f"({train_metrics['accuracy']:.2f}%) | "
f"Val Loss: {val_metrics['loss']:.4f} "
f"({val_metrics['accuracy']:.2f}%) | "
f"LR: {current_lr:.6f}"
)
# Save best model
if val_metrics['loss'] < best_val_loss:
best_val_loss = val_metrics['loss']
best_epoch = epoch + 1
patience_counter = 0
# Save checkpoint
checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt'
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_metrics['loss'],
'val_loss': val_metrics['loss'],
'val_accuracy': val_metrics['accuracy'],
}, checkpoint_path)
print(f" → New best model! (Val Loss: {best_val_loss:.4f})")
# Log model to MLflow
mlflow.pytorch.log_model(
model,
"model",
registered_model_name="mnist-cnn-baseline"
)
else:
patience_counter += 1
# Early stopping
if patience_counter >= patience:
print(f"\nEarly stopping triggered after {epoch+1} epochs")
mlflow.log_param("early_stopped", True)
mlflow.log_param("early_stop_epoch", epoch + 1)
break
print("-" * 70)
print("\nTraining complete!")
print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})")
# Log best metrics
mlflow.log_metrics({
'best_epoch': best_epoch,
'best_val_loss': best_val_loss,
'final_train_loss': history['train_loss'][-1],
'final_val_loss': history['val_loss'][-1]
})
# Evaluate on test set
print("\nEvaluating on test set...")
test_metrics = evaluate_model(model, test_loader, device)
test_accuracy = test_metrics['accuracy']
test_report = test_metrics['classification_report']
# Extract macro average metrics
test_precision = test_report['macro avg']['precision']
test_recall = test_report['macro avg']['recall']
test_f1_score = test_report['macro avg']['f1-score']
print(f"Test Accuracy: {test_accuracy:.2f}%")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test F1-Score: {test_f1_score:.4f}")
# Log test metrics to MLflow
mlflow.log_metrics({
'test_accuracy': test_accuracy,
'test_precision': test_precision,
'test_recall': test_recall,
'test_f1_score': test_f1_score
})
# Save and log artifacts
print("\nSaving artifacts...")
# Save history
history_path = project_root / 'results' / 'mlflow_training_history.json'
history_path.parent.mkdir(exist_ok=True)
save_training_history(history, history_path)
log_artifact_path(str(history_path))
# Save test metrics
metrics_to_save = {
'test_accuracy': test_accuracy,
'test_precision': test_precision,
'test_recall': test_recall,
'test_f1_score': test_f1_score,
'classification_report': test_report,
'confusion_matrix': test_metrics['confusion_matrix'].tolist()
}
metrics_path = project_root / 'results' / 'mlflow_test_metrics.json'
with open(metrics_path, 'w') as f:
json.dump(metrics_to_save, f, indent=2)
log_artifact_path(str(metrics_path))
# Save model checkpoint
log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt'))
# Log confusion matrix as JSON
conf_matrix_dict = {
f"row_{i}": test_metrics['confusion_matrix'][i].tolist()
for i in range(len(test_metrics['confusion_matrix']))
}
mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json")
# Log classification report
mlflow.log_dict(test_report, "classification_report.json")
print("\n✓ All artifacts logged to MLflow")
print("View results: mlflow ui --backend-store-uri file:./mlruns")
return history
def main():
parser = argparse.ArgumentParser(
description='Train MNIST CNN with MLflow tracking'
)
parser.add_argument(
'--epochs', type=int, default=20,
help='Number of epochs (default: 20)'
)
parser.add_argument(
'--lr', type=float, default=0.001,
help='Learning rate (default: 0.001)'
)
parser.add_argument(
'--batch-size', type=int, default=64,
help='Batch size (default: 64)'
)
parser.add_argument(
'--augment', action='store_true',
help='Use data augmentation'
)
parser.add_argument(
'--run-name', type=str, default=None,
help='MLflow run name'
)
parser.add_argument(
'--seed', type=int, default=42,
help='Random seed (default: 42)'
)
args = parser.parse_args()
# Set random seeds
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
# Configuration
config = {
'num_epochs': args.epochs,
'learning_rate': args.lr,
'batch_size': args.batch_size,
'augmentation': args.augment,
'random_seed': args.seed,
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
'optimizer': 'Adam',
'scheduler': 'ReduceLROnPlateau',
'early_stopping_patience': 5
}
print("Training Configuration:")
print(json.dumps(config, indent=2))
# Load MNIST data
print("\nLoading MNIST data...")
data_path = project_root / 'data' / 'raw'
loader = MnistDataloader(
training_images_filepath=str(data_path / 'train-images.idx3-ubyte'),
training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'),
test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'),
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
)
(x_train, y_train), (x_test, y_test) = loader.load_data()
# Split train/val
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
x_train, y_train, val_split=0.15, random_seed=args.seed
)
# Create datasets with optional augmentation
augmentation = get_train_augmentation() if args.augment else None
train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
val_dataset = MnistDataset(x_val, y_val, transform=None)
test_dataset = MnistDataset(x_test, y_test, transform=None)
# Create data loaders
train_loader, val_loader = create_dataloaders(
train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
)
print(f"Train: {len(train_loader.dataset)} samples")
print(f"Val: {len(val_loader.dataset)} samples")
print(f"Test: {len(test_loader.dataset)} samples")
# Create model
model = BaselineCNN().to(config['device'])
# Train with MLflow
train_with_mlflow(
model, train_loader, val_loader, test_loader,
config, run_name=args.run_name
)
print("\n" + "="*70)
print("Training complete! View MLflow dashboard:")
print(" ./scripts/launch_mlflow_ui.sh")
print(" or: mlflow ui --backend-store-uri file:./mlruns")
print("="*70)
if __name__ == '__main__':
main()