Mithridatium / tests /test_utils_configs.py
Gustavo Lucca
Implemented ImageNet and normalized data transforms, plus testing scripts
78c140d
"""
Test canonical dataset configurations in utils.py.
This module tests that:
1. DATASET_CONFIGS contains correct canonical values for supported datasets
2. get_preprocess_config() returns proper PreprocessConfig objects
3. Unsupported datasets raise appropriate errors
4. Configuration values match published literature standards
"""
import pytest
from mithridatium.utils import get_preprocess_config, DATASET_CONFIGS, PreprocessConfig
class TestCanonicalConfigs:
"""Test canonical dataset configuration values."""
def test_cifar10_canonical_stats(self):
"""Test CIFAR-10 has correct canonical normalization statistics."""
# CIFAR-10 canonical values from literature
expected_mean = (0.4914, 0.4822, 0.4465)
expected_std = (0.2023, 0.1994, 0.2010)
expected_size = (3, 32, 32)
# Check DATASET_CONFIGS mapping
config_data = DATASET_CONFIGS["cifar10"]
assert config_data["input_size"] == expected_size
assert config_data["mean"] == expected_mean
assert config_data["std"] == expected_std
assert config_data["normalize"] is True
# Check PreprocessConfig object
config = get_preprocess_config("cifar10")
assert config.get_input_size() == expected_size
assert config.get_mean() == expected_mean
assert config.get_std() == expected_std
assert config.get_normalize() is True
assert config.get_dataset() == "cifar10"
def test_cifar100_canonical_stats(self):
"""Test CIFAR-100 has correct canonical normalization statistics."""
# CIFAR-100 canonical values from literature
expected_mean = (0.5071, 0.4867, 0.4408)
expected_std = (0.2675, 0.2565, 0.2761)
expected_size = (3, 32, 32)
# Check DATASET_CONFIGS mapping
config_data = DATASET_CONFIGS["cifar100"]
assert config_data["input_size"] == expected_size
assert config_data["mean"] == expected_mean
assert config_data["std"] == expected_std
assert config_data["normalize"] is True
# Check PreprocessConfig object
config = get_preprocess_config("cifar100")
assert config.get_input_size() == expected_size
assert config.get_mean() == expected_mean
assert config.get_std() == expected_std
assert config.get_normalize() is True
assert config.get_dataset() == "cifar100"
def test_imagenet_canonical_stats(self):
"""Test ImageNet has correct canonical normalization statistics."""
# ImageNet canonical values from torchvision/literature
expected_mean = (0.485, 0.456, 0.406)
expected_std = (0.229, 0.224, 0.225)
expected_size = (3, 224, 224)
# Check DATASET_CONFIGS mapping
config_data = DATASET_CONFIGS["imagenet"]
assert config_data["input_size"] == expected_size
assert config_data["mean"] == expected_mean
assert config_data["std"] == expected_std
assert config_data["normalize"] is True
# Check PreprocessConfig object
config = get_preprocess_config("imagenet")
assert config.get_input_size() == expected_size
assert config.get_mean() == expected_mean
assert config.get_std() == expected_std
assert config.get_normalize() is True
assert config.get_dataset() == "imagenet"
def test_case_insensitive_dataset_names(self):
"""Test that dataset names are case-insensitive."""
# Test various case combinations
for dataset_name in ["CIFAR10", "Cifar10", "cifar10", "CiFaR10"]:
config = get_preprocess_config(dataset_name)
assert config.get_dataset() == "cifar10"
for dataset_name in ["CIFAR100", "Cifar100", "cifar100", "CiFaR100"]:
config = get_preprocess_config(dataset_name)
assert config.get_dataset() == "cifar100"
for dataset_name in ["IMAGENET", "ImageNet", "imagenet", "ImAgEnEt"]:
config = get_preprocess_config(dataset_name)
assert config.get_dataset() == "imagenet"
def test_whitespace_handling(self):
"""Test that dataset names handle whitespace correctly."""
# Test with leading/trailing whitespace
config = get_preprocess_config(" cifar10 ")
assert config.get_dataset() == "cifar10"
config = get_preprocess_config("\tcifar100\n")
assert config.get_dataset() == "cifar100"
def test_unsupported_dataset_error(self):
"""Test that unsupported datasets raise ValueError with helpful message."""
with pytest.raises(ValueError) as exc_info:
get_preprocess_config("mnist")
error_msg = str(exc_info.value)
assert "mnist" in error_msg
assert "Unsupported dataset" in error_msg
assert "cifar10" in error_msg # Should list supported datasets
assert "cifar100" in error_msg
assert "imagenet" in error_msg
def test_preprocess_config_default_values(self):
"""Test that PreprocessConfig has correct default values."""
for dataset in ["cifar10", "cifar100", "imagenet"]:
config = get_preprocess_config(dataset)
# Common defaults across all datasets
assert config.get_channels_first() is True
assert config.get_value_range() == (0.0, 1.0)
assert config.get_normalize() is True
assert config.get_ops() == []
def test_all_supported_datasets_in_mapping(self):
"""Test that all datasets mentioned in error messages are in DATASET_CONFIGS."""
try:
get_preprocess_config("invalid_dataset")
except ValueError as e:
error_msg = str(e)
# Extract supported datasets from error message
# Message format: "Supported datasets: cifar10, cifar100, imagenet"
if "Supported datasets:" in error_msg:
supported_part = error_msg.split("Supported datasets:")[1].strip()
mentioned_datasets = [ds.strip() for ds in supported_part.split(",")]
# Verify all mentioned datasets exist in DATASET_CONFIGS
for dataset in mentioned_datasets:
assert dataset in DATASET_CONFIGS, f"Dataset {dataset} mentioned in error but not in DATASET_CONFIGS"
class TestDatasetConfigsCompleteness:
"""Test that DATASET_CONFIGS mapping is complete and well-formed."""
def test_dataset_configs_structure(self):
"""Test that DATASET_CONFIGS has proper structure."""
required_keys = {"input_size", "mean", "std", "normalize"}
for dataset_name, config in DATASET_CONFIGS.items():
# Check all required keys present
assert required_keys.issubset(config.keys()), f"Missing keys in {dataset_name} config"
# Check types and shapes
assert isinstance(config["input_size"], tuple)
assert len(config["input_size"]) == 3 # (C, H, W)
assert all(isinstance(x, int) and x > 0 for x in config["input_size"])
assert isinstance(config["mean"], tuple)
assert len(config["mean"]) == 3 # (R, G, B)
assert all(isinstance(x, float) and 0 <= x <= 1 for x in config["mean"])
assert isinstance(config["std"], tuple)
assert len(config["std"]) == 3 # (R, G, B)
assert all(isinstance(x, float) and x > 0 for x in config["std"])
assert isinstance(config["normalize"], bool)
def test_cifar_datasets_have_32x32_size(self):
"""Test that CIFAR datasets have correct 32x32 input size."""
for dataset in ["cifar10", "cifar100"]:
config = DATASET_CONFIGS[dataset]
assert config["input_size"] == (3, 32, 32), f"{dataset} should be 3x32x32"
def test_imagenet_has_224x224_size(self):
"""Test that ImageNet has correct 224x224 input size."""
config = DATASET_CONFIGS["imagenet"]
assert config["input_size"] == (3, 224, 224), "ImageNet should be 3x224x224"
def test_normalization_stats_reasonable_ranges(self):
"""Test that mean/std values are in reasonable ranges for image data."""
for dataset_name, config in DATASET_CONFIGS.items():
# Mean values should be between 0 and 1 for normalized images
for channel_mean in config["mean"]:
assert 0.0 <= channel_mean <= 1.0, f"{dataset_name} mean {channel_mean} out of range [0,1]"
# Std values should be positive and reasonable (typically 0.1-0.5 for image data)
for channel_std in config["std"]:
assert 0.05 <= channel_std <= 0.5, f"{dataset_name} std {channel_std} out of reasonable range [0.05,0.5]"
class TestPreprocessConfigMethods:
"""Test PreprocessConfig class methods and functionality."""
def test_preprocess_config_getters(self):
"""Test all getter methods work correctly."""
config = get_preprocess_config("cifar10")
# Test all getter methods
assert config.get_input_size() == (3, 32, 32)
assert config.get_channels_first() is True
assert config.get_value_range() == (0.0, 1.0)
assert config.get_mean() == (0.4914, 0.4822, 0.4465)
assert config.get_std() == (0.2023, 0.1994, 0.2010)
assert config.get_normalize() is True
assert config.get_ops() == []
assert config.get_dataset() == "cifar10"
def test_preprocess_config_setters(self):
"""Test setter methods work correctly."""
config = get_preprocess_config("cifar10")
# Test setters
config.set_input_size((3, 64, 64))
assert config.get_input_size() == (3, 64, 64)
config.set_channels_first(False)
assert config.get_channels_first() is False
config.set_value_range((-1.0, 1.0))
assert config.get_value_range() == (-1.0, 1.0)
config.set_mean((0.5, 0.5, 0.5))
assert config.get_mean() == (0.5, 0.5, 0.5)
config.set_std((0.25, 0.25, 0.25))
assert config.get_std() == (0.25, 0.25, 0.25)
config.set_normalize(False)
assert config.get_normalize() is False
config.set_ops(["resize:64", "crop:32"])
assert config.get_ops() == ["resize:64", "crop:32"]
config.set_dataset("custom")
assert config.get_dataset() == "custom"