File size: 15,584 Bytes
78c140d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""
Test dataloader normalization behavior in utils.py.

This module tests that:
1. Dataloader transforms properly normalize data to have means near 0
2. CIFAR datasets load without errors and produce expected tensor shapes
3. Normalization statistics match expected behavior
4. Transform pipelines work correctly for each dataset
"""

import pytest
import torch
import numpy as np
from mithridatium.utils import dataloader_for, get_preprocess_config


class TestDataloaderNormalization:
    """Test that dataloader normalization works correctly."""
    
    @pytest.fixture
    def small_batch_size(self):
        """Use small batch size for faster tests."""
        return 32
    
    def test_cifar10_dataloader_creation(self, small_batch_size):
        """Test that CIFAR-10 dataloader creates successfully."""
        # Test both train and test splits
        for split in ["train", "test"]:
            dataloader, config = dataloader_for("cifar10", split, batch_size=small_batch_size)
            
            # Check dataloader properties
            assert dataloader.batch_size == small_batch_size
            assert isinstance(dataloader, torch.utils.data.DataLoader)
            
            # Check config
            assert config.get_dataset() == "cifar10"
            assert config.get_input_size() == (3, 32, 32)
    
    def test_cifar100_dataloader_creation(self, small_batch_size):
        """Test that CIFAR-100 dataloader creates successfully."""
        # Test both train and test splits
        for split in ["train", "test"]:
            dataloader, config = dataloader_for("cifar100", split, batch_size=small_batch_size)
            
            # Check dataloader properties
            assert dataloader.batch_size == small_batch_size
            assert isinstance(dataloader, torch.utils.data.DataLoader)
            
            # Check config
            assert config.get_dataset() == "cifar100"
            assert config.get_input_size() == (3, 32, 32)
    
    def test_cifar10_tensor_shapes(self, small_batch_size):
        """Test that CIFAR-10 produces correct tensor shapes."""
        dataloader, _ = dataloader_for("cifar10", "test", batch_size=small_batch_size)
        
        # Get first batch
        batch_iter = iter(dataloader)
        images, labels = next(batch_iter)
        
        # Check shapes
        assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}"
        assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}"
        
        # Check data types
        assert images.dtype == torch.float32
        assert labels.dtype == torch.long  # CIFAR uses long integers for class labels
    
    def test_cifar100_tensor_shapes(self, small_batch_size):
        """Test that CIFAR-100 produces correct tensor shapes."""
        dataloader, _ = dataloader_for("cifar100", "test", batch_size=small_batch_size)
        
        # Get first batch
        batch_iter = iter(dataloader)
        images, labels = next(batch_iter)
        
        # Check shapes
        assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}"
        assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}"
        
        # Check data types
        assert images.dtype == torch.float32
        assert labels.dtype == torch.long
    
    def test_cifar10_normalization_behavior(self, small_batch_size):
        """Test that CIFAR-10 normalization produces data with means near 0."""
        dataloader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size)
        
        # Collect several batches to get good statistics
        all_images = []
        batch_count = 0
        for images, _ in dataloader:
            all_images.append(images)
            batch_count += 1
            if batch_count >= 10:  # Use 10 batches for statistics
                break
        
        # Concatenate all images
        all_images = torch.cat(all_images, dim=0)
        
        # Calculate per-channel means and stds
        # Shape: (N, C, H, W) -> calculate over N, H, W dimensions
        channel_means = torch.mean(all_images, dim=(0, 2, 3))  # Shape: (3,)
        channel_stds = torch.std(all_images, dim=(0, 2, 3))    # Shape: (3,)
        
        # Print actual values for debugging/validation
        print(f"CIFAR-10 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}")
        
        # After normalization, means should be close to 0
        # The mean centering should be very effective
        for i, mean_val in enumerate(channel_means):
            assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0"
        
        # Standard deviations should be reasonably close to 1
        # Note: Due to finite sampling and dataset characteristics, exact std=1.0 is not expected
        # We verify the normalization is working (values roughly in expected range)
        for i, std_val in enumerate(channel_stds):
            assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]"
    
    def test_cifar100_normalization_behavior(self, small_batch_size):
        """Test that CIFAR-100 normalization produces data with means near 0."""
        dataloader, config = dataloader_for("cifar100", "test", batch_size=small_batch_size)
        
        # Collect several batches to get good statistics
        all_images = []
        batch_count = 0
        for images, _ in dataloader:
            all_images.append(images)
            batch_count += 1
            if batch_count >= 10:  # Use 10 batches for statistics
                break
        
        # Concatenate all images
        all_images = torch.cat(all_images, dim=0)
        
        # Calculate per-channel means and stds
        channel_means = torch.mean(all_images, dim=(0, 2, 3))
        channel_stds = torch.std(all_images, dim=(0, 2, 3))
        
        # Print actual values for debugging/validation
        print(f"CIFAR-100 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}")
        
        # After normalization, means should be close to 0
        for i, mean_val in enumerate(channel_means):
            assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0"
        
        # Standard deviations should be reasonably close to 1
        for i, std_val in enumerate(channel_stds):
            assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]"
    
    def test_unnormalized_data_range(self, small_batch_size):
        """Test data range before and after normalization by manually checking transforms."""
        # This test verifies the transform pipeline is working correctly
        from torchvision import datasets, transforms
        
        # Create CIFAR-10 dataset without normalization
        unnormalized_transform = transforms.Compose([
            transforms.ToTensor()  # Only convert to tensor, no normalization
        ])
        
        unnormalized_ds = datasets.CIFAR10(
            root="data",
            train=False,
            download=True,
            transform=unnormalized_transform
        )
        
        unnormalized_loader = torch.utils.data.DataLoader(
            unnormalized_ds,
            batch_size=small_batch_size,
            shuffle=False
        )
        
        # Get normalized dataloader
        normalized_loader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size)
        
        # Get first batch from each
        unnorm_batch = next(iter(unnormalized_loader))[0]  # Just images
        norm_batch = next(iter(normalized_loader))[0]      # Just images
        
        # Unnormalized data should be in [0, 1] range
        assert unnorm_batch.min().item() >= 0.0, f"Unnormalized min {unnorm_batch.min().item()} < 0"
        assert unnorm_batch.max().item() <= 1.0, f"Unnormalized max {unnorm_batch.max().item()} > 1"
        
        # Normalized data should extend beyond [0, 1] range due to normalization
        # (some values will be negative after subtracting mean)
        assert norm_batch.min().item() < 0.0, f"Normalized data should have negative values, min={norm_batch.min().item()}"
        assert norm_batch.max().item() > 1.0, f"Normalized data should exceed 1, max={norm_batch.max().item()}"
    
    def test_different_batch_sizes(self):
        """Test that different batch sizes work correctly."""
        for batch_size in [1, 8, 16, 64]:
            dataloader, _ = dataloader_for("cifar10", "test", batch_size=batch_size)
            
            # Get first batch
            batch_iter = iter(dataloader)
            images, labels = next(batch_iter)
            
            # Check batch size (last batch might be smaller)
            assert images.shape[0] <= batch_size
            assert labels.shape[0] <= batch_size
            assert images.shape[0] == labels.shape[0]
    
    def test_train_vs_test_shuffle(self):
        """Test that train loader shuffles but test loader doesn't."""
        batch_size = 16
        
        # Get train and test loaders
        train_loader, _ = dataloader_for("cifar10", "train", batch_size=batch_size)
        test_loader, _ = dataloader_for("cifar10", "test", batch_size=batch_size)
        
        # For train loader, shuffle should be True (can't directly test randomness easily)
        # But we can at least verify the loaders work
        train_batch = next(iter(train_loader))
        test_batch = next(iter(test_loader))
        
        assert train_batch[0].shape == (batch_size, 3, 32, 32)
        assert test_batch[0].shape == (batch_size, 3, 32, 32)


class TestDataloaderErrorHandling:
    """Test error handling in dataloader_for function."""
    
    def test_invalid_dataset_error(self):
        """Test that invalid datasets raise ValueError."""
        with pytest.raises(ValueError) as exc_info:
            dataloader_for("mnist", "test", batch_size=32)
        
        error_msg = str(exc_info.value)
        assert "Unsupported dataset" in error_msg
        assert "mnist" in error_msg
    
    def test_invalid_split_error(self):
        """Test that invalid splits raise ValueError."""
        with pytest.raises(ValueError) as exc_info:
            dataloader_for("cifar10", "validation", batch_size=32)
        
        error_msg = str(exc_info.value)
        assert "Invalid split" in error_msg
        assert "validation" in error_msg
        assert "train" in error_msg
        assert "test" in error_msg
    
    def test_case_insensitive_inputs(self):
        """Test that dataset and split names are case-insensitive."""
        # These should all work without errors
        for dataset in ["CIFAR10", "Cifar10", "cifar10"]:
            for split in ["TRAIN", "Train", "train", "TEST", "Test", "test"]:
                dataloader, config = dataloader_for(dataset, split, batch_size=8)
                assert config.get_dataset() == "cifar10"


class TestTransformPipelines:
    """Test that transform pipelines are correctly structured."""
    
    def test_cifar_transform_efficiency(self):
        """Test that CIFAR transforms don't include unnecessary resize operations."""
        # This is more of a design verification test
        # CIFAR images are already 32x32, so no resize should be needed
        
        dataloader, config = dataloader_for("cifar10", "test", batch_size=16)
        
        # Get a batch to ensure transforms work
        batch = next(iter(dataloader))
        images, labels = batch
        
        # Verify final shape is correct (transforms worked)
        assert images.shape == (16, 3, 32, 32)
        
        # Verify data is normalized (not in [0,1] range)
        assert images.min().item() < 0 or images.max().item() > 1
    
    def test_imagenet_transform_structure(self):
        """Test ImageNet transforms would include proper resize operations."""
        # Note: This test may fail if ImageNet dataset isn't available
        # In that case, we verify the error message is helpful
        
        try:
            train_loader, config = dataloader_for("imagenet", "train", batch_size=8)
            test_loader, config = dataloader_for("imagenet", "test", batch_size=8)
            
            # If ImageNet is available, verify config
            assert config.get_input_size() == (3, 224, 224)
            
        except ValueError as e:
            # Should get helpful error about manual ImageNet setup
            error_msg = str(e)
            assert "ImageNet dataset not found" in error_msg
            assert "data/imagenet" in error_msg
    
    def test_pin_memory_enabled(self):
        """Test that dataloaders have pin_memory enabled for GPU performance."""
        dataloader, _ = dataloader_for("cifar10", "test", batch_size=16)
        
        # Check that pin_memory is True (improves GPU transfer performance)
        assert dataloader.pin_memory is True
    
    def test_num_workers_set(self):
        """Test that dataloaders use multiple workers for performance."""
        dataloader, _ = dataloader_for("cifar10", "test", batch_size=16)
        
        # Check that num_workers > 0 for parallel data loading
        assert dataloader.num_workers >= 2


class TestNormalizationMath:
    """Test the mathematical correctness of normalization."""
    
    def test_normalization_formula_correctness(self):
        """Test that normalization follows the correct formula: (x - mean) / std."""
        # Create simple test data
        test_tensor = torch.tensor([[[
            [0.4914, 0.6000],  # First channel values
            [0.3000, 0.8000]
        ]]], dtype=torch.float32)  # Shape: (1, 1, 2, 2)
        
        # CIFAR-10 stats for red channel
        mean = 0.4914
        std = 0.2023
        
        # Apply normalization manually
        normalized_manual = (test_tensor - mean) / std
        
        # Apply normalization using torchvision transform
        from torchvision import transforms
        normalize_transform = transforms.Normalize(mean=(mean,), std=(std,))
        normalized_torch = normalize_transform(test_tensor)
        
        # Results should be identical (within floating point precision)
        torch.testing.assert_close(normalized_manual, normalized_torch, rtol=1e-6, atol=1e-6)
    
    def test_inverse_normalization_possible(self):
        """Test that normalization can be inverted to recover original values."""
        dataloader, config = dataloader_for("cifar10", "test", batch_size=4)
        
        # Get normalized batch
        normalized_batch = next(iter(dataloader))[0]
        
        # Apply inverse normalization: x_orig = (x_norm * std) + mean
        mean = torch.tensor(config.get_mean()).view(1, 3, 1, 1)  # Shape: (1, 3, 1, 1)
        std = torch.tensor(config.get_std()).view(1, 3, 1, 1)    # Shape: (1, 3, 1, 1)
        
        denormalized_batch = (normalized_batch * std) + mean
        
        # Denormalized values should be approximately in [0, 1] range
        # (not exactly due to discretization and floating point precision)
        assert denormalized_batch.min().item() >= -0.1, f"Denormalized min {denormalized_batch.min().item()} too low"
        assert denormalized_batch.max().item() <= 1.1, f"Denormalized max {denormalized_batch.max().item()} too high"