""" Test training pipeline with small subset of data. Verifies: - Training loop runs without errors - Validation metrics computed correctly - Model checkpoints saved - History tracking works """ import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) import torch from scripts.data_loader import MnistDataloader from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val from scripts.models import BaselineCNN from scripts.train import train_model, evaluate_model, save_training_history def main(): """Test training pipeline with small dataset.""" print("=" * 60) print("Testing Training Pipeline") print("=" * 60) print() # Load data (small subset for quick test) print("1. Loading MNIST data...") data_path = project_root / "data" / "raw" loader = MnistDataloader( str(data_path / "train-images.idx3-ubyte"), str(data_path / "train-labels.idx1-ubyte"), str(data_path / "t10k-images.idx3-ubyte"), str(data_path / "t10k-labels.idx1-ubyte") ) (x_train, y_train), (x_test, y_test) = loader.load_data() print(f"✓ Loaded {len(x_train):,} training samples") print() # Use small subset for quick test (1000 samples) print("2. Creating train/val split...") (x_train_split, y_train_split), (x_val, y_val) = split_train_val( x_train[:1000], y_train[:1000], val_split=0.2, random_seed=42 ) print(f"✓ Train: {len(x_train_split)} samples") print(f"✓ Val: {len(x_val)} samples") print() # Create datasets and loaders print("3. Creating datasets and loaders...") train_dataset = MnistDataset(x_train_split, y_train_split, transform=None) val_dataset = MnistDataset(x_val, y_val, transform=None) # Small test set for quick validation test_dataset = MnistDataset(x_test[:200], y_test[:200], transform=None) train_loader, val_loader = create_dataloaders( train_dataset, val_dataset, batch_size=32, num_workers=0 ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=32, shuffle=False, num_workers=0 ) print(f"✓ Train batches: {len(train_loader)}") print(f"✓ Val batches: {len(val_loader)}") print(f"✓ Test batches: {len(test_loader)}") print() # Create model print("4. Creating model...") model = BaselineCNN() device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"✓ Model: {model.__class__.__name__}") print(f"✓ Device: {device}") print() # Train model (short run for testing) print("5. Training model (3 epochs for testing)...") print("-" * 60) history = train_model( model, train_loader, val_loader, num_epochs=3, learning_rate=0.001, patience=10, # Don't trigger early stopping in test checkpoint_dir='models', device=device, use_scheduler=True, verbose=True ) print("-" * 60) print() # Check checkpoints exist print("6. Verifying checkpoints...") best_model_path = project_root / "models" / "best_model.pt" last_model_path = project_root / "models" / "last_model.pt" assert best_model_path.exists(), "Best model checkpoint not found" assert last_model_path.exists(), "Last model checkpoint not found" print("✓ Best model saved") print("✓ Last model saved") print() # Save history print("7. Saving training history...") history_path = project_root / "experiments" / "test_training_history.json" save_training_history(history, str(history_path)) print() # Evaluate on test set print("8. Evaluating on test set...") results = evaluate_model(model, test_loader, device=device) print(f"✓ Test Accuracy: {results['accuracy']:.2f}%") print() # Print per-class metrics print("Per-class metrics:") report = results['classification_report'] for digit in range(10): if str(digit) in report: metrics = report[str(digit)] print(f" Digit {digit}: " f"Precision={metrics['precision']:.3f}, " f"Recall={metrics['recall']:.3f}, " f"F1={metrics['f1-score']:.3f}") print() # Summary print("=" * 60) print("✅ ALL TESTS PASSED") print("=" * 60) print("\nTraining pipeline is working correctly!") print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%") print(f"Test accuracy: {results['accuracy']:.2f}%") print("\nNote: These are quick test results with limited data.") print("For full training, use complete dataset and more epochs.") return 0 if __name__ == "__main__": sys.exit(main())