""" Test preprocessing pipeline with MNIST data. Verifies: - Dataset initialization - Normalization to [0, 1] - Correct tensor shapes - DataLoader batching - Train/val split """ import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) import torch from scripts.data_loader import MnistDataloader from scripts.preprocessing import ( MnistDataset, create_dataloaders, create_test_dataloader, split_train_val, get_dataset_statistics ) def test_dataset(): """Test MnistDataset class.""" print("=" * 60) print("TEST 1: MnistDataset Initialization and Indexing") print("=" * 60) # Load data data_path = project_root / "data" / "raw" loader = MnistDataloader( str(data_path / "train-images.idx3-ubyte"), str(data_path / "train-labels.idx1-ubyte"), str(data_path / "t10k-images.idx3-ubyte"), str(data_path / "t10k-labels.idx1-ubyte") ) (x_train, y_train), (x_test, y_test) = loader.load_data() # Create dataset (small subset for testing) dataset = MnistDataset(x_train[:1000], y_train[:1000]) print(f"✓ Dataset created with {len(dataset)} samples") # Test __getitem__ image, label = dataset[0] print("✓ Retrieved sample 0") print(f" Image shape: {image.shape}") print(f" Image dtype: {image.dtype}") print(f" Image range: [{image.min():.4f}, {image.max():.4f}]") print(f" Label: {label.item()} (dtype: {label.dtype})") # Verify normalization assert image.shape == (1, 28, 28), f"Wrong shape: {image.shape}" assert image.dtype == torch.float32, f"Wrong dtype: {image.dtype}" assert 0 <= image.min() <= 1, f"Values not in [0, 1]: min={image.min()}" assert 0 <= image.max() <= 1, f"Values not in [0, 1]: max={image.max()}" assert label.dtype == torch.long, f"Label wrong dtype: {label.dtype}" print("✓ All assertions passed") print() return dataset def test_dataloader(dataset): """Test DataLoader batching.""" print("=" * 60) print("TEST 2: DataLoader Batching") print("=" * 60) loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) print("✓ DataLoader created (batch_size=32)") # Get first batch images, labels = next(iter(loader)) print("✓ Retrieved first batch") print(f" Batch images shape: {images.shape}") print(f" Batch labels shape: {labels.shape}") print(f" Images dtype: {images.dtype}") print(f" Labels dtype: {labels.dtype}") # Verify batch dimensions assert images.shape == (32, 1, 28, 28), f"Wrong batch shape: {images.shape}" assert labels.shape == (32,), f"Wrong labels shape: {labels.shape}" assert images.dtype == torch.float32 assert labels.dtype == torch.long print("✓ All assertions passed") print() def test_train_val_split(): """Test train/validation split.""" print("=" * 60) print("TEST 3: Train/Validation Split") print("=" * 60) # Load data data_path = project_root / "data" / "raw" loader = MnistDataloader( str(data_path / "train-images.idx3-ubyte"), str(data_path / "train-labels.idx1-ubyte"), str(data_path / "t10k-images.idx3-ubyte"), str(data_path / "t10k-labels.idx1-ubyte") ) (x_train, y_train), _ = loader.load_data() # Split (x_train_split, y_train_split), (x_val, y_val) = split_train_val( x_train, y_train, val_split=0.15, random_seed=42 ) print("✓ Split completed") print(f" Original training: {len(x_train):,} samples") train_pct = len(x_train_split) / len(x_train) * 100 print(f" New training: {len(x_train_split):,} samples ({train_pct:.1f}%)") val_pct = len(x_val) / len(x_train) * 100 print(f" Validation: {len(x_val):,} samples ({val_pct:.1f}%)") # Verify split ratio expected_val_size = int(len(x_train) * 0.15) assert abs(len(x_val) - expected_val_size) < 100, "Split ratio incorrect" assert len(x_train_split) + len(x_val) == len(x_train), "Data loss during split" print("✓ Split ratio correct") # Check stratification (class balance) from collections import Counter train_counts = Counter(y_train_split) val_counts = Counter(y_val) print("\n Class distribution in training set:") for digit in range(10): print(f" Digit {digit}: {train_counts[digit]:>5,} samples") print("\n Class distribution in validation set:") for digit in range(10): print(f" Digit {digit}: {val_counts[digit]:>4,} samples") # Verify each class is present in both sets assert all(train_counts[i] > 0 for i in range(10)), "Missing class in train" assert all(val_counts[i] > 0 for i in range(10)), "Missing class in validation" print("\n✓ All classes present in both sets") print() def test_full_pipeline(): """Test complete pipeline from data loading to batching.""" print("=" * 60) print("TEST 4: Full Pipeline") print("=" * 60) # Load data data_path = project_root / "data" / "raw" loader = MnistDataloader( str(data_path / "train-images.idx3-ubyte"), str(data_path / "train-labels.idx1-ubyte"), str(data_path / "t10k-images.idx3-ubyte"), str(data_path / "t10k-labels.idx1-ubyte") ) (x_train, y_train), (x_test, y_test) = loader.load_data() print("✓ Data loaded") # Split train/val (x_train_split, y_train_split), (x_val, y_val) = split_train_val( x_train, y_train, val_split=0.15 ) print("✓ Train/val split completed") # Create datasets train_dataset = MnistDataset(x_train_split, y_train_split) val_dataset = MnistDataset(x_val, y_val) test_dataset = MnistDataset(x_test, y_test) print("✓ Datasets created") # Get statistics train_stats = get_dataset_statistics(train_dataset) print("\n Training dataset statistics:") print(f" Samples: {train_stats['num_samples']:,}") print(f" Image shape: {train_stats['sample_image_shape']}") print(f" Image dtype: {train_stats['sample_image_dtype']}") print(f" Image range: {train_stats['sample_image_range']}") print(f" Label dtype: {train_stats['sample_label_dtype']}") # Create dataloaders train_loader, val_loader = create_dataloaders( train_dataset, val_dataset, batch_size=64, num_workers=0 ) test_loader = create_test_dataloader(test_dataset, batch_size=64, num_workers=0) print("\n✓ DataLoaders created") print(f" Training batches: {len(train_loader)}") print(f" Validation batches: {len(val_loader)}") print(f" Test batches: {len(test_loader)}") # Test iteration train_batch = next(iter(train_loader)) val_batch = next(iter(val_loader)) test_batch = next(iter(test_loader)) print("\n✓ Successfully iterated through all loaders") print(f" Train batch shapes: {train_batch[0].shape}, {train_batch[1].shape}") print(f" Val batch shapes: {val_batch[0].shape}, {val_batch[1].shape}") print(f" Test batch shapes: {test_batch[0].shape}, {test_batch[1].shape}") print() def main(): """Run all tests.""" print("\nTesting MNIST Preprocessing Pipeline") print() try: dataset = test_dataset() test_dataloader(dataset) test_train_val_split() test_full_pipeline() print("=" * 60) print("✅ ALL TESTS PASSED") print("=" * 60) print("\nPreprocessing pipeline is ready for model training!") return 0 except Exception as e: print(f"\n❌ TEST FAILED: {e}") import traceback traceback.print_exc() return 1 if __name__ == "__main__": sys.exit(main())