Spaces:
Running
Running
| """ | |
| Test canonical dataset configurations in utils.py. | |
| This module tests that: | |
| 1. DATASET_CONFIGS contains correct canonical values for supported datasets | |
| 2. get_preprocess_config() returns proper PreprocessConfig objects | |
| 3. Unsupported datasets raise appropriate errors | |
| 4. Configuration values match published literature standards | |
| """ | |
| import pytest | |
| from mithridatium.utils import get_preprocess_config, DATASET_CONFIGS, PreprocessConfig | |
| class TestCanonicalConfigs: | |
| """Test canonical dataset configuration values.""" | |
| def test_cifar10_canonical_stats(self): | |
| """Test CIFAR-10 has correct canonical normalization statistics.""" | |
| # CIFAR-10 canonical values from literature | |
| expected_mean = (0.4914, 0.4822, 0.4465) | |
| expected_std = (0.2023, 0.1994, 0.2010) | |
| expected_size = (3, 32, 32) | |
| # Check DATASET_CONFIGS mapping | |
| config_data = DATASET_CONFIGS["cifar10"] | |
| assert config_data["input_size"] == expected_size | |
| assert config_data["mean"] == expected_mean | |
| assert config_data["std"] == expected_std | |
| assert config_data["normalize"] is True | |
| # Check PreprocessConfig object | |
| config = get_preprocess_config("cifar10") | |
| assert config.get_input_size() == expected_size | |
| assert config.get_mean() == expected_mean | |
| assert config.get_std() == expected_std | |
| assert config.get_normalize() is True | |
| assert config.get_dataset() == "cifar10" | |
| def test_cifar100_canonical_stats(self): | |
| """Test CIFAR-100 has correct canonical normalization statistics.""" | |
| # CIFAR-100 canonical values from literature | |
| expected_mean = (0.5071, 0.4867, 0.4408) | |
| expected_std = (0.2675, 0.2565, 0.2761) | |
| expected_size = (3, 32, 32) | |
| # Check DATASET_CONFIGS mapping | |
| config_data = DATASET_CONFIGS["cifar100"] | |
| assert config_data["input_size"] == expected_size | |
| assert config_data["mean"] == expected_mean | |
| assert config_data["std"] == expected_std | |
| assert config_data["normalize"] is True | |
| # Check PreprocessConfig object | |
| config = get_preprocess_config("cifar100") | |
| assert config.get_input_size() == expected_size | |
| assert config.get_mean() == expected_mean | |
| assert config.get_std() == expected_std | |
| assert config.get_normalize() is True | |
| assert config.get_dataset() == "cifar100" | |
| def test_imagenet_canonical_stats(self): | |
| """Test ImageNet has correct canonical normalization statistics.""" | |
| # ImageNet canonical values from torchvision/literature | |
| expected_mean = (0.485, 0.456, 0.406) | |
| expected_std = (0.229, 0.224, 0.225) | |
| expected_size = (3, 224, 224) | |
| # Check DATASET_CONFIGS mapping | |
| config_data = DATASET_CONFIGS["imagenet"] | |
| assert config_data["input_size"] == expected_size | |
| assert config_data["mean"] == expected_mean | |
| assert config_data["std"] == expected_std | |
| assert config_data["normalize"] is True | |
| # Check PreprocessConfig object | |
| config = get_preprocess_config("imagenet") | |
| assert config.get_input_size() == expected_size | |
| assert config.get_mean() == expected_mean | |
| assert config.get_std() == expected_std | |
| assert config.get_normalize() is True | |
| assert config.get_dataset() == "imagenet" | |
| def test_case_insensitive_dataset_names(self): | |
| """Test that dataset names are case-insensitive.""" | |
| # Test various case combinations | |
| for dataset_name in ["CIFAR10", "Cifar10", "cifar10", "CiFaR10"]: | |
| config = get_preprocess_config(dataset_name) | |
| assert config.get_dataset() == "cifar10" | |
| for dataset_name in ["CIFAR100", "Cifar100", "cifar100", "CiFaR100"]: | |
| config = get_preprocess_config(dataset_name) | |
| assert config.get_dataset() == "cifar100" | |
| for dataset_name in ["IMAGENET", "ImageNet", "imagenet", "ImAgEnEt"]: | |
| config = get_preprocess_config(dataset_name) | |
| assert config.get_dataset() == "imagenet" | |
| def test_whitespace_handling(self): | |
| """Test that dataset names handle whitespace correctly.""" | |
| # Test with leading/trailing whitespace | |
| config = get_preprocess_config(" cifar10 ") | |
| assert config.get_dataset() == "cifar10" | |
| config = get_preprocess_config("\tcifar100\n") | |
| assert config.get_dataset() == "cifar100" | |
| def test_unsupported_dataset_error(self): | |
| """Test that unsupported datasets raise ValueError with helpful message.""" | |
| with pytest.raises(ValueError) as exc_info: | |
| get_preprocess_config("mnist") | |
| error_msg = str(exc_info.value) | |
| assert "mnist" in error_msg | |
| assert "Unsupported dataset" in error_msg | |
| assert "cifar10" in error_msg # Should list supported datasets | |
| assert "cifar100" in error_msg | |
| assert "imagenet" in error_msg | |
| def test_preprocess_config_default_values(self): | |
| """Test that PreprocessConfig has correct default values.""" | |
| for dataset in ["cifar10", "cifar100", "imagenet"]: | |
| config = get_preprocess_config(dataset) | |
| # Common defaults across all datasets | |
| assert config.get_channels_first() is True | |
| assert config.get_value_range() == (0.0, 1.0) | |
| assert config.get_normalize() is True | |
| assert config.get_ops() == [] | |
| def test_all_supported_datasets_in_mapping(self): | |
| """Test that all datasets mentioned in error messages are in DATASET_CONFIGS.""" | |
| try: | |
| get_preprocess_config("invalid_dataset") | |
| except ValueError as e: | |
| error_msg = str(e) | |
| # Extract supported datasets from error message | |
| # Message format: "Supported datasets: cifar10, cifar100, imagenet" | |
| if "Supported datasets:" in error_msg: | |
| supported_part = error_msg.split("Supported datasets:")[1].strip() | |
| mentioned_datasets = [ds.strip() for ds in supported_part.split(",")] | |
| # Verify all mentioned datasets exist in DATASET_CONFIGS | |
| for dataset in mentioned_datasets: | |
| assert dataset in DATASET_CONFIGS, f"Dataset {dataset} mentioned in error but not in DATASET_CONFIGS" | |
| class TestDatasetConfigsCompleteness: | |
| """Test that DATASET_CONFIGS mapping is complete and well-formed.""" | |
| def test_dataset_configs_structure(self): | |
| """Test that DATASET_CONFIGS has proper structure.""" | |
| required_keys = {"input_size", "mean", "std", "normalize"} | |
| for dataset_name, config in DATASET_CONFIGS.items(): | |
| # Check all required keys present | |
| assert required_keys.issubset(config.keys()), f"Missing keys in {dataset_name} config" | |
| # Check types and shapes | |
| assert isinstance(config["input_size"], tuple) | |
| assert len(config["input_size"]) == 3 # (C, H, W) | |
| assert all(isinstance(x, int) and x > 0 for x in config["input_size"]) | |
| assert isinstance(config["mean"], tuple) | |
| assert len(config["mean"]) == 3 # (R, G, B) | |
| assert all(isinstance(x, float) and 0 <= x <= 1 for x in config["mean"]) | |
| assert isinstance(config["std"], tuple) | |
| assert len(config["std"]) == 3 # (R, G, B) | |
| assert all(isinstance(x, float) and x > 0 for x in config["std"]) | |
| assert isinstance(config["normalize"], bool) | |
| def test_cifar_datasets_have_32x32_size(self): | |
| """Test that CIFAR datasets have correct 32x32 input size.""" | |
| for dataset in ["cifar10", "cifar100"]: | |
| config = DATASET_CONFIGS[dataset] | |
| assert config["input_size"] == (3, 32, 32), f"{dataset} should be 3x32x32" | |
| def test_imagenet_has_224x224_size(self): | |
| """Test that ImageNet has correct 224x224 input size.""" | |
| config = DATASET_CONFIGS["imagenet"] | |
| assert config["input_size"] == (3, 224, 224), "ImageNet should be 3x224x224" | |
| def test_normalization_stats_reasonable_ranges(self): | |
| """Test that mean/std values are in reasonable ranges for image data.""" | |
| for dataset_name, config in DATASET_CONFIGS.items(): | |
| # Mean values should be between 0 and 1 for normalized images | |
| for channel_mean in config["mean"]: | |
| assert 0.0 <= channel_mean <= 1.0, f"{dataset_name} mean {channel_mean} out of range [0,1]" | |
| # Std values should be positive and reasonable (typically 0.1-0.5 for image data) | |
| for channel_std in config["std"]: | |
| assert 0.05 <= channel_std <= 0.5, f"{dataset_name} std {channel_std} out of reasonable range [0.05,0.5]" | |
| class TestPreprocessConfigMethods: | |
| """Test PreprocessConfig class methods and functionality.""" | |
| def test_preprocess_config_getters(self): | |
| """Test all getter methods work correctly.""" | |
| config = get_preprocess_config("cifar10") | |
| # Test all getter methods | |
| assert config.get_input_size() == (3, 32, 32) | |
| assert config.get_channels_first() is True | |
| assert config.get_value_range() == (0.0, 1.0) | |
| assert config.get_mean() == (0.4914, 0.4822, 0.4465) | |
| assert config.get_std() == (0.2023, 0.1994, 0.2010) | |
| assert config.get_normalize() is True | |
| assert config.get_ops() == [] | |
| assert config.get_dataset() == "cifar10" | |
| def test_preprocess_config_setters(self): | |
| """Test setter methods work correctly.""" | |
| config = get_preprocess_config("cifar10") | |
| # Test setters | |
| config.set_input_size((3, 64, 64)) | |
| assert config.get_input_size() == (3, 64, 64) | |
| config.set_channels_first(False) | |
| assert config.get_channels_first() is False | |
| config.set_value_range((-1.0, 1.0)) | |
| assert config.get_value_range() == (-1.0, 1.0) | |
| config.set_mean((0.5, 0.5, 0.5)) | |
| assert config.get_mean() == (0.5, 0.5, 0.5) | |
| config.set_std((0.25, 0.25, 0.25)) | |
| assert config.get_std() == (0.25, 0.25, 0.25) | |
| config.set_normalize(False) | |
| assert config.get_normalize() is False | |
| config.set_ops(["resize:64", "crop:32"]) | |
| assert config.get_ops() == ["resize:64", "crop:32"] | |
| config.set_dataset("custom") | |
| assert config.get_dataset() == "custom" |