Mithridatium / tests /test_preprocess_config.py
Pelumi Oluwategbe
updated test preprocess config and test strip scores
44c8400
raw
history blame contribute delete
833 Bytes
import pytest
from mithridatium.utils import get_preprocess_config
def test_get_preprocess_config():
# Use a known dataset for the test (e.g., cifar10)
dataset_name = "cifar10"
# Load the preprocessing config for the dataset
config = get_preprocess_config(dataset_name)
# Assertions based on the expected preprocessing config for CIFAR-10
assert config.input_size == (3, 32, 32) # CIFAR-10 has 32x32 RGB images
assert config.channels_first is True # CIFAR-10 uses NCHW format
assert config.value_range == (0.0, 1.0) # Normalization range
assert config.mean == (0.4914, 0.4822, 0.4465) # CIFAR-10 dataset mean
assert config.std == (0.2023, 0.1994, 0.2010) # CIFAR-10 dataset standard deviation
assert config.ops == [] # No additional operations are needed for CIFAR-10