""" Test dataloader normalization behavior in utils.py. This module tests that: 1. Dataloader transforms properly normalize data to have means near 0 2. CIFAR datasets load without errors and produce expected tensor shapes 3. Normalization statistics match expected behavior 4. Transform pipelines work correctly for each dataset """ import pytest import torch import numpy as np from mithridatium.utils import dataloader_for, get_preprocess_config class TestDataloaderNormalization: """Test that dataloader normalization works correctly.""" @pytest.fixture def small_batch_size(self): """Use small batch size for faster tests.""" return 32 def test_cifar10_dataloader_creation(self, small_batch_size): """Test that CIFAR-10 dataloader creates successfully.""" # Test both train and test splits for split in ["train", "test"]: dataloader, config = dataloader_for("cifar10", split, batch_size=small_batch_size) # Check dataloader properties assert dataloader.batch_size == small_batch_size assert isinstance(dataloader, torch.utils.data.DataLoader) # Check config assert config.get_dataset() == "cifar10" assert config.get_input_size() == (3, 32, 32) def test_cifar100_dataloader_creation(self, small_batch_size): """Test that CIFAR-100 dataloader creates successfully.""" # Test both train and test splits for split in ["train", "test"]: dataloader, config = dataloader_for("cifar100", split, batch_size=small_batch_size) # Check dataloader properties assert dataloader.batch_size == small_batch_size assert isinstance(dataloader, torch.utils.data.DataLoader) # Check config assert config.get_dataset() == "cifar100" assert config.get_input_size() == (3, 32, 32) def test_cifar10_tensor_shapes(self, small_batch_size): """Test that CIFAR-10 produces correct tensor shapes.""" dataloader, _ = dataloader_for("cifar10", "test", batch_size=small_batch_size) # Get first batch batch_iter = iter(dataloader) images, labels = next(batch_iter) # Check shapes assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}" assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}" # Check data types assert images.dtype == torch.float32 assert labels.dtype == torch.long # CIFAR uses long integers for class labels def test_cifar100_tensor_shapes(self, small_batch_size): """Test that CIFAR-100 produces correct tensor shapes.""" dataloader, _ = dataloader_for("cifar100", "test", batch_size=small_batch_size) # Get first batch batch_iter = iter(dataloader) images, labels = next(batch_iter) # Check shapes assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}" assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}" # Check data types assert images.dtype == torch.float32 assert labels.dtype == torch.long def test_cifar10_normalization_behavior(self, small_batch_size): """Test that CIFAR-10 normalization produces data with means near 0.""" dataloader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size) # Collect several batches to get good statistics all_images = [] batch_count = 0 for images, _ in dataloader: all_images.append(images) batch_count += 1 if batch_count >= 10: # Use 10 batches for statistics break # Concatenate all images all_images = torch.cat(all_images, dim=0) # Calculate per-channel means and stds # Shape: (N, C, H, W) -> calculate over N, H, W dimensions channel_means = torch.mean(all_images, dim=(0, 2, 3)) # Shape: (3,) channel_stds = torch.std(all_images, dim=(0, 2, 3)) # Shape: (3,) # Print actual values for debugging/validation print(f"CIFAR-10 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}") # After normalization, means should be close to 0 # The mean centering should be very effective for i, mean_val in enumerate(channel_means): assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0" # Standard deviations should be reasonably close to 1 # Note: Due to finite sampling and dataset characteristics, exact std=1.0 is not expected # We verify the normalization is working (values roughly in expected range) for i, std_val in enumerate(channel_stds): assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]" def test_cifar100_normalization_behavior(self, small_batch_size): """Test that CIFAR-100 normalization produces data with means near 0.""" dataloader, config = dataloader_for("cifar100", "test", batch_size=small_batch_size) # Collect several batches to get good statistics all_images = [] batch_count = 0 for images, _ in dataloader: all_images.append(images) batch_count += 1 if batch_count >= 10: # Use 10 batches for statistics break # Concatenate all images all_images = torch.cat(all_images, dim=0) # Calculate per-channel means and stds channel_means = torch.mean(all_images, dim=(0, 2, 3)) channel_stds = torch.std(all_images, dim=(0, 2, 3)) # Print actual values for debugging/validation print(f"CIFAR-100 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}") # After normalization, means should be close to 0 for i, mean_val in enumerate(channel_means): assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0" # Standard deviations should be reasonably close to 1 for i, std_val in enumerate(channel_stds): assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]" def test_unnormalized_data_range(self, small_batch_size): """Test data range before and after normalization by manually checking transforms.""" # This test verifies the transform pipeline is working correctly from torchvision import datasets, transforms # Create CIFAR-10 dataset without normalization unnormalized_transform = transforms.Compose([ transforms.ToTensor() # Only convert to tensor, no normalization ]) unnormalized_ds = datasets.CIFAR10( root="data", train=False, download=True, transform=unnormalized_transform ) unnormalized_loader = torch.utils.data.DataLoader( unnormalized_ds, batch_size=small_batch_size, shuffle=False ) # Get normalized dataloader normalized_loader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size) # Get first batch from each unnorm_batch = next(iter(unnormalized_loader))[0] # Just images norm_batch = next(iter(normalized_loader))[0] # Just images # Unnormalized data should be in [0, 1] range assert unnorm_batch.min().item() >= 0.0, f"Unnormalized min {unnorm_batch.min().item()} < 0" assert unnorm_batch.max().item() <= 1.0, f"Unnormalized max {unnorm_batch.max().item()} > 1" # Normalized data should extend beyond [0, 1] range due to normalization # (some values will be negative after subtracting mean) assert norm_batch.min().item() < 0.0, f"Normalized data should have negative values, min={norm_batch.min().item()}" assert norm_batch.max().item() > 1.0, f"Normalized data should exceed 1, max={norm_batch.max().item()}" def test_different_batch_sizes(self): """Test that different batch sizes work correctly.""" for batch_size in [1, 8, 16, 64]: dataloader, _ = dataloader_for("cifar10", "test", batch_size=batch_size) # Get first batch batch_iter = iter(dataloader) images, labels = next(batch_iter) # Check batch size (last batch might be smaller) assert images.shape[0] <= batch_size assert labels.shape[0] <= batch_size assert images.shape[0] == labels.shape[0] def test_train_vs_test_shuffle(self): """Test that train loader shuffles but test loader doesn't.""" batch_size = 16 # Get train and test loaders train_loader, _ = dataloader_for("cifar10", "train", batch_size=batch_size) test_loader, _ = dataloader_for("cifar10", "test", batch_size=batch_size) # For train loader, shuffle should be True (can't directly test randomness easily) # But we can at least verify the loaders work train_batch = next(iter(train_loader)) test_batch = next(iter(test_loader)) assert train_batch[0].shape == (batch_size, 3, 32, 32) assert test_batch[0].shape == (batch_size, 3, 32, 32) class TestDataloaderErrorHandling: """Test error handling in dataloader_for function.""" def test_invalid_dataset_error(self): """Test that invalid datasets raise ValueError.""" with pytest.raises(ValueError) as exc_info: dataloader_for("mnist", "test", batch_size=32) error_msg = str(exc_info.value) assert "Unsupported dataset" in error_msg assert "mnist" in error_msg def test_invalid_split_error(self): """Test that invalid splits raise ValueError.""" with pytest.raises(ValueError) as exc_info: dataloader_for("cifar10", "validation", batch_size=32) error_msg = str(exc_info.value) assert "Invalid split" in error_msg assert "validation" in error_msg assert "train" in error_msg assert "test" in error_msg def test_case_insensitive_inputs(self): """Test that dataset and split names are case-insensitive.""" # These should all work without errors for dataset in ["CIFAR10", "Cifar10", "cifar10"]: for split in ["TRAIN", "Train", "train", "TEST", "Test", "test"]: dataloader, config = dataloader_for(dataset, split, batch_size=8) assert config.get_dataset() == "cifar10" class TestTransformPipelines: """Test that transform pipelines are correctly structured.""" def test_cifar_transform_efficiency(self): """Test that CIFAR transforms don't include unnecessary resize operations.""" # This is more of a design verification test # CIFAR images are already 32x32, so no resize should be needed dataloader, config = dataloader_for("cifar10", "test", batch_size=16) # Get a batch to ensure transforms work batch = next(iter(dataloader)) images, labels = batch # Verify final shape is correct (transforms worked) assert images.shape == (16, 3, 32, 32) # Verify data is normalized (not in [0,1] range) assert images.min().item() < 0 or images.max().item() > 1 def test_imagenet_transform_structure(self): """Test ImageNet transforms would include proper resize operations.""" # Note: This test may fail if ImageNet dataset isn't available # In that case, we verify the error message is helpful try: train_loader, config = dataloader_for("imagenet", "train", batch_size=8) test_loader, config = dataloader_for("imagenet", "test", batch_size=8) # If ImageNet is available, verify config assert config.get_input_size() == (3, 224, 224) except ValueError as e: # Should get helpful error about manual ImageNet setup error_msg = str(e) assert "ImageNet dataset not found" in error_msg assert "data/imagenet" in error_msg def test_pin_memory_enabled(self): """Test that dataloaders have pin_memory enabled for GPU performance.""" dataloader, _ = dataloader_for("cifar10", "test", batch_size=16) # Check that pin_memory is True (improves GPU transfer performance) assert dataloader.pin_memory is True def test_num_workers_set(self): """Test that dataloaders use multiple workers for performance.""" dataloader, _ = dataloader_for("cifar10", "test", batch_size=16) # Check that num_workers > 0 for parallel data loading assert dataloader.num_workers >= 2 class TestNormalizationMath: """Test the mathematical correctness of normalization.""" def test_normalization_formula_correctness(self): """Test that normalization follows the correct formula: (x - mean) / std.""" # Create simple test data test_tensor = torch.tensor([[[ [0.4914, 0.6000], # First channel values [0.3000, 0.8000] ]]], dtype=torch.float32) # Shape: (1, 1, 2, 2) # CIFAR-10 stats for red channel mean = 0.4914 std = 0.2023 # Apply normalization manually normalized_manual = (test_tensor - mean) / std # Apply normalization using torchvision transform from torchvision import transforms normalize_transform = transforms.Normalize(mean=(mean,), std=(std,)) normalized_torch = normalize_transform(test_tensor) # Results should be identical (within floating point precision) torch.testing.assert_close(normalized_manual, normalized_torch, rtol=1e-6, atol=1e-6) def test_inverse_normalization_possible(self): """Test that normalization can be inverted to recover original values.""" dataloader, config = dataloader_for("cifar10", "test", batch_size=4) # Get normalized batch normalized_batch = next(iter(dataloader))[0] # Apply inverse normalization: x_orig = (x_norm * std) + mean mean = torch.tensor(config.get_mean()).view(1, 3, 1, 1) # Shape: (1, 3, 1, 1) std = torch.tensor(config.get_std()).view(1, 3, 1, 1) # Shape: (1, 3, 1, 1) denormalized_batch = (normalized_batch * std) + mean # Denormalized values should be approximately in [0, 1] range # (not exactly due to discretization and floating point precision) assert denormalized_batch.min().item() >= -0.1, f"Denormalized min {denormalized_batch.min().item()} too low" assert denormalized_batch.max().item() <= 1.1, f"Denormalized max {denormalized_batch.max().item()} too high"