Spaces:
Sleeping
Sleeping
| """ | |
| Test training pipeline with small subset of data. | |
| Verifies: | |
| - Training loop runs without errors | |
| - Validation metrics computed correctly | |
| - Model checkpoints saved | |
| - History tracking works | |
| """ | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| import torch | |
| from scripts.data_loader import MnistDataloader | |
| from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val | |
| from scripts.models import BaselineCNN | |
| from scripts.train import train_model, evaluate_model, save_training_history | |
| def main(): | |
| """Test training pipeline with small dataset.""" | |
| print("=" * 60) | |
| print("Testing Training Pipeline") | |
| print("=" * 60) | |
| print() | |
| # Load data (small subset for quick test) | |
| print("1. Loading MNIST data...") | |
| data_path = project_root / "data" / "raw" | |
| loader = MnistDataloader( | |
| str(data_path / "train-images.idx3-ubyte"), | |
| str(data_path / "train-labels.idx1-ubyte"), | |
| str(data_path / "t10k-images.idx3-ubyte"), | |
| str(data_path / "t10k-labels.idx1-ubyte") | |
| ) | |
| (x_train, y_train), (x_test, y_test) = loader.load_data() | |
| print(f"β Loaded {len(x_train):,} training samples") | |
| print() | |
| # Use small subset for quick test (1000 samples) | |
| print("2. Creating train/val split...") | |
| (x_train_split, y_train_split), (x_val, y_val) = split_train_val( | |
| x_train[:1000], y_train[:1000], val_split=0.2, random_seed=42 | |
| ) | |
| print(f"β Train: {len(x_train_split)} samples") | |
| print(f"β Val: {len(x_val)} samples") | |
| print() | |
| # Create datasets and loaders | |
| print("3. Creating datasets and loaders...") | |
| train_dataset = MnistDataset(x_train_split, y_train_split, transform=None) | |
| val_dataset = MnistDataset(x_val, y_val, transform=None) | |
| # Small test set for quick validation | |
| test_dataset = MnistDataset(x_test[:200], y_test[:200], transform=None) | |
| train_loader, val_loader = create_dataloaders( | |
| train_dataset, val_dataset, batch_size=32, num_workers=0 | |
| ) | |
| test_loader = torch.utils.data.DataLoader( | |
| test_dataset, batch_size=32, shuffle=False, num_workers=0 | |
| ) | |
| print(f"β Train batches: {len(train_loader)}") | |
| print(f"β Val batches: {len(val_loader)}") | |
| print(f"β Test batches: {len(test_loader)}") | |
| print() | |
| # Create model | |
| print("4. Creating model...") | |
| model = BaselineCNN() | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"β Model: {model.__class__.__name__}") | |
| print(f"β Device: {device}") | |
| print() | |
| # Train model (short run for testing) | |
| print("5. Training model (3 epochs for testing)...") | |
| print("-" * 60) | |
| history = train_model( | |
| model, | |
| train_loader, | |
| val_loader, | |
| num_epochs=3, | |
| learning_rate=0.001, | |
| patience=10, # Don't trigger early stopping in test | |
| checkpoint_dir='models', | |
| device=device, | |
| use_scheduler=True, | |
| verbose=True | |
| ) | |
| print("-" * 60) | |
| print() | |
| # Check checkpoints exist | |
| print("6. Verifying checkpoints...") | |
| best_model_path = project_root / "models" / "best_model.pt" | |
| last_model_path = project_root / "models" / "last_model.pt" | |
| assert best_model_path.exists(), "Best model checkpoint not found" | |
| assert last_model_path.exists(), "Last model checkpoint not found" | |
| print("β Best model saved") | |
| print("β Last model saved") | |
| print() | |
| # Save history | |
| print("7. Saving training history...") | |
| history_path = project_root / "experiments" / "test_training_history.json" | |
| save_training_history(history, str(history_path)) | |
| print() | |
| # Evaluate on test set | |
| print("8. Evaluating on test set...") | |
| results = evaluate_model(model, test_loader, device=device) | |
| print(f"β Test Accuracy: {results['accuracy']:.2f}%") | |
| print() | |
| # Print per-class metrics | |
| print("Per-class metrics:") | |
| report = results['classification_report'] | |
| for digit in range(10): | |
| if str(digit) in report: | |
| metrics = report[str(digit)] | |
| print(f" Digit {digit}: " | |
| f"Precision={metrics['precision']:.3f}, " | |
| f"Recall={metrics['recall']:.3f}, " | |
| f"F1={metrics['f1-score']:.3f}") | |
| print() | |
| # Summary | |
| print("=" * 60) | |
| print("β ALL TESTS PASSED") | |
| print("=" * 60) | |
| print("\nTraining pipeline is working correctly!") | |
| print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%") | |
| print(f"Test accuracy: {results['accuracy']:.2f}%") | |
| print("\nNote: These are quick test results with limited data.") | |
| print("For full training, use complete dataset and more epochs.") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |