Spaces:
Sleeping
Sleeping
| """ | |
| MLflow-Enabled Training Script for MNIST CNN | |
| Full training script with comprehensive MLflow tracking: | |
| - Hyperparameters and model architecture | |
| - Per-epoch metrics (loss, accuracy, learning rate) | |
| - System information and environment | |
| - Model artifacts and checkpoints | |
| - Training visualizations | |
| - Confusion matrix and classification report | |
| Usage: | |
| python scripts/train_with_mlflow.py --epochs 20 --lr 0.001 --augment | |
| python scripts/train_with_mlflow.py --help | |
| """ | |
| import argparse | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from pathlib import Path | |
| import json | |
| import sys | |
| import numpy as np | |
| import mlflow | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from scripts.models import BaselineCNN, count_parameters | |
| from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val | |
| from scripts.train import train_epoch, validate, evaluate_model, save_training_history | |
| from scripts.data_loader import MnistDataloader | |
| from scripts.augmentation import get_train_augmentation | |
| from scripts.mlflow_setup import ( | |
| setup_mlflow, log_model_params, log_training_config, | |
| log_data_info, log_system_info, log_metrics_epoch, | |
| log_artifact_path | |
| ) | |
| def train_with_mlflow( | |
| model: nn.Module, | |
| train_loader: torch.utils.data.DataLoader, | |
| val_loader: torch.utils.data.DataLoader, | |
| test_loader: torch.utils.data.DataLoader, | |
| config: dict, | |
| run_name: str = None | |
| ) -> dict: | |
| """ | |
| Train model with full MLflow tracking. | |
| Args: | |
| model: PyTorch model to train | |
| train_loader: Training data loader | |
| val_loader: Validation data loader | |
| test_loader: Test data loader | |
| config: Training configuration dictionary | |
| run_name: Optional name for MLflow run | |
| Returns: | |
| Training history dictionary | |
| """ | |
| device = config['device'] | |
| num_epochs = config['num_epochs'] | |
| learning_rate = config['learning_rate'] | |
| # Setup MLflow | |
| setup_mlflow("mnist-digit-classification") | |
| # Start MLflow run | |
| with mlflow.start_run(run_name=run_name): | |
| print("\n" + "="*70) | |
| print(f"MLflow Run ID: {mlflow.active_run().info.run_id}") | |
| print("="*70 + "\n") | |
| # Log all configuration | |
| print("Logging configuration to MLflow...") | |
| log_training_config(config) | |
| log_model_params(model) | |
| log_data_info( | |
| train_size=len(train_loader.dataset), | |
| val_size=len(val_loader.dataset), | |
| test_size=len(test_loader.dataset), | |
| num_classes=10, | |
| augmentation=config.get('augmentation', False) | |
| ) | |
| log_system_info() | |
| # Log model architecture as text | |
| total_params, trainable_params = count_parameters(model) | |
| model_summary = f""" | |
| Model: {model.__class__.__name__} | |
| Total Parameters: {total_params:,} | |
| Trainable Parameters: {trainable_params:,} | |
| Device: {device} | |
| Architecture: | |
| {str(model)} | |
| """ | |
| mlflow.log_text(model_summary, "model_architecture.txt") | |
| # Setup training | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode='min', patience=3, factor=0.5, verbose=True | |
| ) | |
| # Training history | |
| history = { | |
| 'train_loss': [], | |
| 'train_accuracy': [], | |
| 'val_loss': [], | |
| 'val_accuracy': [], | |
| 'learning_rate': [] | |
| } | |
| best_val_loss = float('inf') | |
| patience = 5 | |
| patience_counter = 0 | |
| print(f"\nStarting training for {num_epochs} epochs...") | |
| print(f"Device: {device}") | |
| total_p, _ = count_parameters(model) | |
| print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)") | |
| print("-" * 70) | |
| for epoch in range(num_epochs): | |
| # Train | |
| train_metrics = train_epoch( | |
| model, train_loader, criterion, optimizer, device | |
| ) | |
| # Validate | |
| val_metrics = validate(model, val_loader, criterion, device) | |
| # Get current learning rate | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| # Update scheduler | |
| scheduler.step(val_metrics['loss']) | |
| # Save history | |
| history['train_loss'].append(train_metrics['loss']) | |
| history['train_accuracy'].append(train_metrics['accuracy']) | |
| history['val_loss'].append(val_metrics['loss']) | |
| history['val_accuracy'].append(val_metrics['accuracy']) | |
| history['learning_rate'].append(current_lr) | |
| # Log metrics to MLflow | |
| mlflow_metrics = { | |
| 'train_loss': train_metrics['loss'], | |
| 'train_accuracy': train_metrics['accuracy'], | |
| 'val_loss': val_metrics['loss'], | |
| 'val_accuracy': val_metrics['accuracy'], | |
| 'learning_rate': current_lr, | |
| 'epoch': epoch + 1 | |
| } | |
| log_metrics_epoch(mlflow_metrics, step=epoch) | |
| # Print progress | |
| print( | |
| f"Epoch {epoch+1}/{num_epochs} | " | |
| f"Train Loss: {train_metrics['loss']:.4f} " | |
| f"({train_metrics['accuracy']:.2f}%) | " | |
| f"Val Loss: {val_metrics['loss']:.4f} " | |
| f"({val_metrics['accuracy']:.2f}%) | " | |
| f"LR: {current_lr:.6f}" | |
| ) | |
| # Save best model | |
| if val_metrics['loss'] < best_val_loss: | |
| best_val_loss = val_metrics['loss'] | |
| best_epoch = epoch + 1 | |
| patience_counter = 0 | |
| # Save checkpoint | |
| checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt' | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'train_loss': train_metrics['loss'], | |
| 'val_loss': val_metrics['loss'], | |
| 'val_accuracy': val_metrics['accuracy'], | |
| }, checkpoint_path) | |
| print(f" → New best model! (Val Loss: {best_val_loss:.4f})") | |
| # Log model to MLflow | |
| mlflow.pytorch.log_model( | |
| model, | |
| "model", | |
| registered_model_name="mnist-cnn-baseline" | |
| ) | |
| else: | |
| patience_counter += 1 | |
| # Early stopping | |
| if patience_counter >= patience: | |
| print(f"\nEarly stopping triggered after {epoch+1} epochs") | |
| mlflow.log_param("early_stopped", True) | |
| mlflow.log_param("early_stop_epoch", epoch + 1) | |
| break | |
| print("-" * 70) | |
| print("\nTraining complete!") | |
| print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})") | |
| # Log best metrics | |
| mlflow.log_metrics({ | |
| 'best_epoch': best_epoch, | |
| 'best_val_loss': best_val_loss, | |
| 'final_train_loss': history['train_loss'][-1], | |
| 'final_val_loss': history['val_loss'][-1] | |
| }) | |
| # Evaluate on test set | |
| print("\nEvaluating on test set...") | |
| test_metrics = evaluate_model(model, test_loader, device) | |
| test_accuracy = test_metrics['accuracy'] | |
| test_report = test_metrics['classification_report'] | |
| # Extract macro average metrics | |
| test_precision = test_report['macro avg']['precision'] | |
| test_recall = test_report['macro avg']['recall'] | |
| test_f1_score = test_report['macro avg']['f1-score'] | |
| print(f"Test Accuracy: {test_accuracy:.2f}%") | |
| print(f"Test Precision: {test_precision:.4f}") | |
| print(f"Test Recall: {test_recall:.4f}") | |
| print(f"Test F1-Score: {test_f1_score:.4f}") | |
| # Log test metrics to MLflow | |
| mlflow.log_metrics({ | |
| 'test_accuracy': test_accuracy, | |
| 'test_precision': test_precision, | |
| 'test_recall': test_recall, | |
| 'test_f1_score': test_f1_score | |
| }) | |
| # Save and log artifacts | |
| print("\nSaving artifacts...") | |
| # Save history | |
| history_path = project_root / 'results' / 'mlflow_training_history.json' | |
| history_path.parent.mkdir(exist_ok=True) | |
| save_training_history(history, history_path) | |
| log_artifact_path(str(history_path)) | |
| # Save test metrics | |
| metrics_to_save = { | |
| 'test_accuracy': test_accuracy, | |
| 'test_precision': test_precision, | |
| 'test_recall': test_recall, | |
| 'test_f1_score': test_f1_score, | |
| 'classification_report': test_report, | |
| 'confusion_matrix': test_metrics['confusion_matrix'].tolist() | |
| } | |
| metrics_path = project_root / 'results' / 'mlflow_test_metrics.json' | |
| with open(metrics_path, 'w') as f: | |
| json.dump(metrics_to_save, f, indent=2) | |
| log_artifact_path(str(metrics_path)) | |
| # Save model checkpoint | |
| log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt')) | |
| # Log confusion matrix as JSON | |
| conf_matrix_dict = { | |
| f"row_{i}": test_metrics['confusion_matrix'][i].tolist() | |
| for i in range(len(test_metrics['confusion_matrix'])) | |
| } | |
| mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json") | |
| # Log classification report | |
| mlflow.log_dict(test_report, "classification_report.json") | |
| print("\n✓ All artifacts logged to MLflow") | |
| print("View results: mlflow ui --backend-store-uri file:./mlruns") | |
| return history | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='Train MNIST CNN with MLflow tracking' | |
| ) | |
| parser.add_argument( | |
| '--epochs', type=int, default=20, | |
| help='Number of epochs (default: 20)' | |
| ) | |
| parser.add_argument( | |
| '--lr', type=float, default=0.001, | |
| help='Learning rate (default: 0.001)' | |
| ) | |
| parser.add_argument( | |
| '--batch-size', type=int, default=64, | |
| help='Batch size (default: 64)' | |
| ) | |
| parser.add_argument( | |
| '--augment', action='store_true', | |
| help='Use data augmentation' | |
| ) | |
| parser.add_argument( | |
| '--run-name', type=str, default=None, | |
| help='MLflow run name' | |
| ) | |
| parser.add_argument( | |
| '--seed', type=int, default=42, | |
| help='Random seed (default: 42)' | |
| ) | |
| args = parser.parse_args() | |
| # Set random seeds | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(args.seed) | |
| # Configuration | |
| config = { | |
| 'num_epochs': args.epochs, | |
| 'learning_rate': args.lr, | |
| 'batch_size': args.batch_size, | |
| 'augmentation': args.augment, | |
| 'random_seed': args.seed, | |
| 'device': 'cuda' if torch.cuda.is_available() else 'cpu', | |
| 'optimizer': 'Adam', | |
| 'scheduler': 'ReduceLROnPlateau', | |
| 'early_stopping_patience': 5 | |
| } | |
| print("Training Configuration:") | |
| print(json.dumps(config, indent=2)) | |
| # Load MNIST data | |
| print("\nLoading MNIST data...") | |
| data_path = project_root / 'data' / 'raw' | |
| loader = MnistDataloader( | |
| training_images_filepath=str(data_path / 'train-images.idx3-ubyte'), | |
| training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'), | |
| test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'), | |
| test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte') | |
| ) | |
| (x_train, y_train), (x_test, y_test) = loader.load_data() | |
| # Split train/val | |
| (x_train_split, y_train_split), (x_val, y_val) = split_train_val( | |
| x_train, y_train, val_split=0.15, random_seed=args.seed | |
| ) | |
| # Create datasets with optional augmentation | |
| augmentation = get_train_augmentation() if args.augment else None | |
| train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation) | |
| val_dataset = MnistDataset(x_val, y_val, transform=None) | |
| test_dataset = MnistDataset(x_test, y_test, transform=None) | |
| # Create data loaders | |
| train_loader, val_loader = create_dataloaders( | |
| train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2 | |
| ) | |
| test_loader = torch.utils.data.DataLoader( | |
| test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2 | |
| ) | |
| print(f"Train: {len(train_loader.dataset)} samples") | |
| print(f"Val: {len(val_loader.dataset)} samples") | |
| print(f"Test: {len(test_loader.dataset)} samples") | |
| # Create model | |
| model = BaselineCNN().to(config['device']) | |
| # Train with MLflow | |
| train_with_mlflow( | |
| model, train_loader, val_loader, test_loader, | |
| config, run_name=args.run_name | |
| ) | |
| print("\n" + "="*70) | |
| print("Training complete! View MLflow dashboard:") | |
| print(" ./scripts/launch_mlflow_ui.sh") | |
| print(" or: mlflow ui --backend-store-uri file:./mlruns") | |
| print("="*70) | |
| if __name__ == '__main__': | |
| main() | |