mnist-digit-classifier / scripts /test_preprocessing.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Test preprocessing pipeline with MNIST data.
Verifies:
- Dataset initialization
- Normalization to [0, 1]
- Correct tensor shapes
- DataLoader batching
- Train/val split
"""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import torch
from scripts.data_loader import MnistDataloader
from scripts.preprocessing import (
MnistDataset,
create_dataloaders,
create_test_dataloader,
split_train_val,
get_dataset_statistics
)
def test_dataset():
"""Test MnistDataset class."""
print("=" * 60)
print("TEST 1: MnistDataset Initialization and Indexing")
print("=" * 60)
# Load data
data_path = project_root / "data" / "raw"
loader = MnistDataloader(
str(data_path / "train-images.idx3-ubyte"),
str(data_path / "train-labels.idx1-ubyte"),
str(data_path / "t10k-images.idx3-ubyte"),
str(data_path / "t10k-labels.idx1-ubyte")
)
(x_train, y_train), (x_test, y_test) = loader.load_data()
# Create dataset (small subset for testing)
dataset = MnistDataset(x_train[:1000], y_train[:1000])
print(f"βœ“ Dataset created with {len(dataset)} samples")
# Test __getitem__
image, label = dataset[0]
print("βœ“ Retrieved sample 0")
print(f" Image shape: {image.shape}")
print(f" Image dtype: {image.dtype}")
print(f" Image range: [{image.min():.4f}, {image.max():.4f}]")
print(f" Label: {label.item()} (dtype: {label.dtype})")
# Verify normalization
assert image.shape == (1, 28, 28), f"Wrong shape: {image.shape}"
assert image.dtype == torch.float32, f"Wrong dtype: {image.dtype}"
assert 0 <= image.min() <= 1, f"Values not in [0, 1]: min={image.min()}"
assert 0 <= image.max() <= 1, f"Values not in [0, 1]: max={image.max()}"
assert label.dtype == torch.long, f"Label wrong dtype: {label.dtype}"
print("βœ“ All assertions passed")
print()
return dataset
def test_dataloader(dataset):
"""Test DataLoader batching."""
print("=" * 60)
print("TEST 2: DataLoader Batching")
print("=" * 60)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
print("βœ“ DataLoader created (batch_size=32)")
# Get first batch
images, labels = next(iter(loader))
print("βœ“ Retrieved first batch")
print(f" Batch images shape: {images.shape}")
print(f" Batch labels shape: {labels.shape}")
print(f" Images dtype: {images.dtype}")
print(f" Labels dtype: {labels.dtype}")
# Verify batch dimensions
assert images.shape == (32, 1, 28, 28), f"Wrong batch shape: {images.shape}"
assert labels.shape == (32,), f"Wrong labels shape: {labels.shape}"
assert images.dtype == torch.float32
assert labels.dtype == torch.long
print("βœ“ All assertions passed")
print()
def test_train_val_split():
"""Test train/validation split."""
print("=" * 60)
print("TEST 3: Train/Validation Split")
print("=" * 60)
# Load data
data_path = project_root / "data" / "raw"
loader = MnistDataloader(
str(data_path / "train-images.idx3-ubyte"),
str(data_path / "train-labels.idx1-ubyte"),
str(data_path / "t10k-images.idx3-ubyte"),
str(data_path / "t10k-labels.idx1-ubyte")
)
(x_train, y_train), _ = loader.load_data()
# Split
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
x_train, y_train, val_split=0.15, random_seed=42
)
print("βœ“ Split completed")
print(f" Original training: {len(x_train):,} samples")
train_pct = len(x_train_split) / len(x_train) * 100
print(f" New training: {len(x_train_split):,} samples ({train_pct:.1f}%)")
val_pct = len(x_val) / len(x_train) * 100
print(f" Validation: {len(x_val):,} samples ({val_pct:.1f}%)")
# Verify split ratio
expected_val_size = int(len(x_train) * 0.15)
assert abs(len(x_val) - expected_val_size) < 100, "Split ratio incorrect"
assert len(x_train_split) + len(x_val) == len(x_train), "Data loss during split"
print("βœ“ Split ratio correct")
# Check stratification (class balance)
from collections import Counter
train_counts = Counter(y_train_split)
val_counts = Counter(y_val)
print("\n Class distribution in training set:")
for digit in range(10):
print(f" Digit {digit}: {train_counts[digit]:>5,} samples")
print("\n Class distribution in validation set:")
for digit in range(10):
print(f" Digit {digit}: {val_counts[digit]:>4,} samples")
# Verify each class is present in both sets
assert all(train_counts[i] > 0 for i in range(10)), "Missing class in train"
assert all(val_counts[i] > 0 for i in range(10)), "Missing class in validation"
print("\nβœ“ All classes present in both sets")
print()
def test_full_pipeline():
"""Test complete pipeline from data loading to batching."""
print("=" * 60)
print("TEST 4: Full Pipeline")
print("=" * 60)
# Load data
data_path = project_root / "data" / "raw"
loader = MnistDataloader(
str(data_path / "train-images.idx3-ubyte"),
str(data_path / "train-labels.idx1-ubyte"),
str(data_path / "t10k-images.idx3-ubyte"),
str(data_path / "t10k-labels.idx1-ubyte")
)
(x_train, y_train), (x_test, y_test) = loader.load_data()
print("βœ“ Data loaded")
# Split train/val
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
x_train, y_train, val_split=0.15
)
print("βœ“ Train/val split completed")
# Create datasets
train_dataset = MnistDataset(x_train_split, y_train_split)
val_dataset = MnistDataset(x_val, y_val)
test_dataset = MnistDataset(x_test, y_test)
print("βœ“ Datasets created")
# Get statistics
train_stats = get_dataset_statistics(train_dataset)
print("\n Training dataset statistics:")
print(f" Samples: {train_stats['num_samples']:,}")
print(f" Image shape: {train_stats['sample_image_shape']}")
print(f" Image dtype: {train_stats['sample_image_dtype']}")
print(f" Image range: {train_stats['sample_image_range']}")
print(f" Label dtype: {train_stats['sample_label_dtype']}")
# Create dataloaders
train_loader, val_loader = create_dataloaders(
train_dataset, val_dataset, batch_size=64, num_workers=0
)
test_loader = create_test_dataloader(test_dataset, batch_size=64, num_workers=0)
print("\nβœ“ DataLoaders created")
print(f" Training batches: {len(train_loader)}")
print(f" Validation batches: {len(val_loader)}")
print(f" Test batches: {len(test_loader)}")
# Test iteration
train_batch = next(iter(train_loader))
val_batch = next(iter(val_loader))
test_batch = next(iter(test_loader))
print("\nβœ“ Successfully iterated through all loaders")
print(f" Train batch shapes: {train_batch[0].shape}, {train_batch[1].shape}")
print(f" Val batch shapes: {val_batch[0].shape}, {val_batch[1].shape}")
print(f" Test batch shapes: {test_batch[0].shape}, {test_batch[1].shape}")
print()
def main():
"""Run all tests."""
print("\nTesting MNIST Preprocessing Pipeline")
print()
try:
dataset = test_dataset()
test_dataloader(dataset)
test_train_val_split()
test_full_pipeline()
print("=" * 60)
print("βœ… ALL TESTS PASSED")
print("=" * 60)
print("\nPreprocessing pipeline is ready for model training!")
return 0
except Exception as e:
print(f"\n❌ TEST FAILED: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())