Spaces:
Sleeping
Sleeping
| """ | |
| Test preprocessing pipeline with MNIST data. | |
| Verifies: | |
| - Dataset initialization | |
| - Normalization to [0, 1] | |
| - Correct tensor shapes | |
| - DataLoader batching | |
| - Train/val split | |
| """ | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| import torch | |
| from scripts.data_loader import MnistDataloader | |
| from scripts.preprocessing import ( | |
| MnistDataset, | |
| create_dataloaders, | |
| create_test_dataloader, | |
| split_train_val, | |
| get_dataset_statistics | |
| ) | |
| def test_dataset(): | |
| """Test MnistDataset class.""" | |
| print("=" * 60) | |
| print("TEST 1: MnistDataset Initialization and Indexing") | |
| print("=" * 60) | |
| # Load data | |
| data_path = project_root / "data" / "raw" | |
| loader = MnistDataloader( | |
| str(data_path / "train-images.idx3-ubyte"), | |
| str(data_path / "train-labels.idx1-ubyte"), | |
| str(data_path / "t10k-images.idx3-ubyte"), | |
| str(data_path / "t10k-labels.idx1-ubyte") | |
| ) | |
| (x_train, y_train), (x_test, y_test) = loader.load_data() | |
| # Create dataset (small subset for testing) | |
| dataset = MnistDataset(x_train[:1000], y_train[:1000]) | |
| print(f"β Dataset created with {len(dataset)} samples") | |
| # Test __getitem__ | |
| image, label = dataset[0] | |
| print("β Retrieved sample 0") | |
| print(f" Image shape: {image.shape}") | |
| print(f" Image dtype: {image.dtype}") | |
| print(f" Image range: [{image.min():.4f}, {image.max():.4f}]") | |
| print(f" Label: {label.item()} (dtype: {label.dtype})") | |
| # Verify normalization | |
| assert image.shape == (1, 28, 28), f"Wrong shape: {image.shape}" | |
| assert image.dtype == torch.float32, f"Wrong dtype: {image.dtype}" | |
| assert 0 <= image.min() <= 1, f"Values not in [0, 1]: min={image.min()}" | |
| assert 0 <= image.max() <= 1, f"Values not in [0, 1]: max={image.max()}" | |
| assert label.dtype == torch.long, f"Label wrong dtype: {label.dtype}" | |
| print("β All assertions passed") | |
| print() | |
| return dataset | |
| def test_dataloader(dataset): | |
| """Test DataLoader batching.""" | |
| print("=" * 60) | |
| print("TEST 2: DataLoader Batching") | |
| print("=" * 60) | |
| loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) | |
| print("β DataLoader created (batch_size=32)") | |
| # Get first batch | |
| images, labels = next(iter(loader)) | |
| print("β Retrieved first batch") | |
| print(f" Batch images shape: {images.shape}") | |
| print(f" Batch labels shape: {labels.shape}") | |
| print(f" Images dtype: {images.dtype}") | |
| print(f" Labels dtype: {labels.dtype}") | |
| # Verify batch dimensions | |
| assert images.shape == (32, 1, 28, 28), f"Wrong batch shape: {images.shape}" | |
| assert labels.shape == (32,), f"Wrong labels shape: {labels.shape}" | |
| assert images.dtype == torch.float32 | |
| assert labels.dtype == torch.long | |
| print("β All assertions passed") | |
| print() | |
| def test_train_val_split(): | |
| """Test train/validation split.""" | |
| print("=" * 60) | |
| print("TEST 3: Train/Validation Split") | |
| print("=" * 60) | |
| # Load data | |
| data_path = project_root / "data" / "raw" | |
| loader = MnistDataloader( | |
| str(data_path / "train-images.idx3-ubyte"), | |
| str(data_path / "train-labels.idx1-ubyte"), | |
| str(data_path / "t10k-images.idx3-ubyte"), | |
| str(data_path / "t10k-labels.idx1-ubyte") | |
| ) | |
| (x_train, y_train), _ = loader.load_data() | |
| # Split | |
| (x_train_split, y_train_split), (x_val, y_val) = split_train_val( | |
| x_train, y_train, val_split=0.15, random_seed=42 | |
| ) | |
| print("β Split completed") | |
| print(f" Original training: {len(x_train):,} samples") | |
| train_pct = len(x_train_split) / len(x_train) * 100 | |
| print(f" New training: {len(x_train_split):,} samples ({train_pct:.1f}%)") | |
| val_pct = len(x_val) / len(x_train) * 100 | |
| print(f" Validation: {len(x_val):,} samples ({val_pct:.1f}%)") | |
| # Verify split ratio | |
| expected_val_size = int(len(x_train) * 0.15) | |
| assert abs(len(x_val) - expected_val_size) < 100, "Split ratio incorrect" | |
| assert len(x_train_split) + len(x_val) == len(x_train), "Data loss during split" | |
| print("β Split ratio correct") | |
| # Check stratification (class balance) | |
| from collections import Counter | |
| train_counts = Counter(y_train_split) | |
| val_counts = Counter(y_val) | |
| print("\n Class distribution in training set:") | |
| for digit in range(10): | |
| print(f" Digit {digit}: {train_counts[digit]:>5,} samples") | |
| print("\n Class distribution in validation set:") | |
| for digit in range(10): | |
| print(f" Digit {digit}: {val_counts[digit]:>4,} samples") | |
| # Verify each class is present in both sets | |
| assert all(train_counts[i] > 0 for i in range(10)), "Missing class in train" | |
| assert all(val_counts[i] > 0 for i in range(10)), "Missing class in validation" | |
| print("\nβ All classes present in both sets") | |
| print() | |
| def test_full_pipeline(): | |
| """Test complete pipeline from data loading to batching.""" | |
| print("=" * 60) | |
| print("TEST 4: Full Pipeline") | |
| print("=" * 60) | |
| # Load data | |
| data_path = project_root / "data" / "raw" | |
| loader = MnistDataloader( | |
| str(data_path / "train-images.idx3-ubyte"), | |
| str(data_path / "train-labels.idx1-ubyte"), | |
| str(data_path / "t10k-images.idx3-ubyte"), | |
| str(data_path / "t10k-labels.idx1-ubyte") | |
| ) | |
| (x_train, y_train), (x_test, y_test) = loader.load_data() | |
| print("β Data loaded") | |
| # Split train/val | |
| (x_train_split, y_train_split), (x_val, y_val) = split_train_val( | |
| x_train, y_train, val_split=0.15 | |
| ) | |
| print("β Train/val split completed") | |
| # Create datasets | |
| train_dataset = MnistDataset(x_train_split, y_train_split) | |
| val_dataset = MnistDataset(x_val, y_val) | |
| test_dataset = MnistDataset(x_test, y_test) | |
| print("β Datasets created") | |
| # Get statistics | |
| train_stats = get_dataset_statistics(train_dataset) | |
| print("\n Training dataset statistics:") | |
| print(f" Samples: {train_stats['num_samples']:,}") | |
| print(f" Image shape: {train_stats['sample_image_shape']}") | |
| print(f" Image dtype: {train_stats['sample_image_dtype']}") | |
| print(f" Image range: {train_stats['sample_image_range']}") | |
| print(f" Label dtype: {train_stats['sample_label_dtype']}") | |
| # Create dataloaders | |
| train_loader, val_loader = create_dataloaders( | |
| train_dataset, val_dataset, batch_size=64, num_workers=0 | |
| ) | |
| test_loader = create_test_dataloader(test_dataset, batch_size=64, num_workers=0) | |
| print("\nβ DataLoaders created") | |
| print(f" Training batches: {len(train_loader)}") | |
| print(f" Validation batches: {len(val_loader)}") | |
| print(f" Test batches: {len(test_loader)}") | |
| # Test iteration | |
| train_batch = next(iter(train_loader)) | |
| val_batch = next(iter(val_loader)) | |
| test_batch = next(iter(test_loader)) | |
| print("\nβ Successfully iterated through all loaders") | |
| print(f" Train batch shapes: {train_batch[0].shape}, {train_batch[1].shape}") | |
| print(f" Val batch shapes: {val_batch[0].shape}, {val_batch[1].shape}") | |
| print(f" Test batch shapes: {test_batch[0].shape}, {test_batch[1].shape}") | |
| print() | |
| def main(): | |
| """Run all tests.""" | |
| print("\nTesting MNIST Preprocessing Pipeline") | |
| print() | |
| try: | |
| dataset = test_dataset() | |
| test_dataloader(dataset) | |
| test_train_val_split() | |
| test_full_pipeline() | |
| print("=" * 60) | |
| print("β ALL TESTS PASSED") | |
| print("=" * 60) | |
| print("\nPreprocessing pipeline is ready for model training!") | |
| return 0 | |
| except Exception as e: | |
| print(f"\nβ TEST FAILED: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |