File size: 833 Bytes
b80226e
44c8400
4624e77
44c8400
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import pytest
from mithridatium.utils import get_preprocess_config

def test_get_preprocess_config():
    # Use a known dataset for the test (e.g., cifar10)
    dataset_name = "cifar10"
    
    # Load the preprocessing config for the dataset
    config = get_preprocess_config(dataset_name)
    
    # Assertions based on the expected preprocessing config for CIFAR-10
    assert config.input_size == (3, 32, 32)  # CIFAR-10 has 32x32 RGB images
    assert config.channels_first is True      # CIFAR-10 uses NCHW format
    assert config.value_range == (0.0, 1.0)  # Normalization range
    assert config.mean == (0.4914, 0.4822, 0.4465)  # CIFAR-10 dataset mean
    assert config.std == (0.2023, 0.1994, 0.2010)   # CIFAR-10 dataset standard deviation
    assert config.ops == []  # No additional operations are needed for CIFAR-10