import numpy as np import pytest from .utils import FakeDataset, build_fake_dataset CIFAR_TESTSET_NUM = 1000 def build_cifar_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset: if dataset_type.endswith('CIFAR100'): y_range = (0, 99) dataset_type = dataset_type.replace('CIFAR100', 'FakeDataset') else: y_range = (0, 9) dataset_type = dataset_type.replace('CIFAR10', 'FakeDataset') dataset_cfg = dict(type=dataset_type, x_shape=(3, 32, 32), y_range=y_range, nums=CIFAR_TESTSET_NUM, **kwargs) return build_fake_dataset(dataset_cfg) @pytest.mark.parametrize('dataset_type', ['CIFAR10', 'CIFAR100']) def test_xy(dataset_type: str) -> None: cifar = build_cifar_fake_dataset(dataset_type) xy = cifar.get_xy() x, y = xy assert len(x) == len(y) assert isinstance(y[0], int) old_x = x.copy() old_y = y.copy() cifar.set_xy(xy) assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, old_x)]) assert cifar.targets == old_y x = x[:cifar.num_classes] y = y[:cifar.num_classes] cifar.set_xy((x, y)) assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, x)]) assert cifar.targets == y assert cifar.num_classes == len(set(y)) assert len(cifar.data.shape) == 4 @pytest.mark.parametrize(['start_idx', 'end_idx', 'dataset_type'], [(0, 9, 'IndexCIFAR10'), (-10, 8, 'IndexCIFAR10'), (2, 12, 'IndexCIFAR10'), (4, 4, 'IndexCIFAR10'), (4, 3, 'IndexCIFAR10'), (0, 99, 'IndexCIFAR100'), (-10, 8, 'IndexCIFAR100'), (40, 50, 'IndexCIFAR100')]) def test_index(start_idx: int, end_idx: int, dataset_type: str) -> None: kwargs = dict(start_idx=start_idx, end_idx=end_idx) if start_idx > end_idx: with pytest.raises(ValueError): _ = build_cifar_fake_dataset(dataset_type, **kwargs) return cifar = build_cifar_fake_dataset(dataset_type, **kwargs) assert cifar.start_idx == start_idx assert cifar.end_idx == end_idx for y in cifar.targets: assert start_idx <= y <= end_idx assert cifar.num_classes == min( cifar.end_idx, cifar.raw_num_classes - 1) - max(cifar.start_idx, 0) + 1 assert len(cifar.data.shape) == 4 @pytest.mark.parametrize(['ratio', 'dataset_type'], [(-1, 'RatioCIFAR10'), (0, 'RatioCIFAR10'), (0.1, 'RatioCIFAR10'), (0.5, 'RatioCIFAR10'), (1, 'RatioCIFAR10'), (2, 'RatioCIFAR10'), (0.4, 'RatioCIFAR100')]) def test_ratio(ratio: float, dataset_type: str) -> None: kwargs = dict(ratio=ratio) if ratio <= 0 or ratio > 1: with pytest.raises(ValueError): _ = build_cifar_fake_dataset(dataset_type, **kwargs) return cifar = build_cifar_fake_dataset(dataset_type, **kwargs) assert cifar.num_classes == cifar.raw_num_classes assert len(cifar.targets) == \ int(CIFAR_TESTSET_NUM / cifar.num_classes * ratio) * cifar.num_classes assert len(cifar.data.shape) == 4 @pytest.mark.parametrize('range_ratio', [(-1, 0.2), (0, 2), (0.1, 0.1), (0.5, 0.2), (0.1, 0.5), (0, 1)]) @pytest.mark.parametrize('dataset_type', ['RangeRatioCIFAR10', 'RangeRatioCIFAR100']) def test_range_ratio(range_ratio: tuple[float, float], dataset_type: str) -> None: kwargs = dict(range_ratio=range_ratio) start_ratio = range_ratio[0] end_ratio = range_ratio[1] if not (0 <= start_ratio < end_ratio <= 1): with pytest.raises(ValueError): _ = build_cifar_fake_dataset(dataset_type, **kwargs) return cifar = build_cifar_fake_dataset(dataset_type, **kwargs) assert cifar.num_classes == cifar.raw_num_classes assert len(cifar.targets) == \ round(CIFAR_TESTSET_NUM * (end_ratio - start_ratio)) assert len(cifar.data.shape) == 4 @pytest.mark.parametrize(['range_ratio1', 'range_ratio2'], [((0, 0.5), (0.5, 1)), ((0, 0.6), (0.4, 1)), ((0, 0.7), (0.3, 1)), ((0, 0.5), (0, 1))]) @pytest.mark.parametrize('dataset_type', ['RangeRatioCIFAR10', 'RangeRatioCIFAR100']) def test_range_ratio_intersection(range_ratio1: tuple[float, float], range_ratio2: tuple[float, float], dataset_type: str) -> None: cifar1 = build_cifar_fake_dataset(dataset_type=dataset_type, range_ratio=range_ratio1, cache_xy=True) cifar2 = build_cifar_fake_dataset(dataset_type=dataset_type, range_ratio=range_ratio2, cache_xy=True) cat_x = np.concatenate([cifar1.data, cifar2.data], axis=0) unique_x = np.unique(cat_x, axis=0) intersection_number = cat_x.shape[0] - unique_x.shape[0] intersection_ratio = max(0, range_ratio1[1] - range_ratio2[0]) assert round(intersection_ratio * CIFAR_TESTSET_NUM) == intersection_number @pytest.mark.parametrize('dataset_type', ['IndexRatioCIFAR10', 'IndexRatioCIFAR100']) @pytest.mark.parametrize(['start_idx', 'end_idx', 'ratio'], [(4, 3, 0.5), (3, 4, 0), (3, 4, 2), (1, 4, 0.1)]) def test_index_ratio(start_idx: int, end_idx: int, ratio: float, dataset_type: str) -> None: kwargs = dict(start_idx=start_idx, end_idx=end_idx, ratio=ratio) if ratio <= 0 or ratio > 1 or start_idx > end_idx: with pytest.raises(ValueError): _ = build_cifar_fake_dataset(dataset_type, **kwargs) return cifar = build_cifar_fake_dataset(dataset_type, **kwargs) assert cifar.start_idx == start_idx assert cifar.end_idx == end_idx for y in cifar.targets: assert start_idx <= y <= end_idx assert len(cifar.targets) == \ cifar.num_classes / cifar.raw_num_classes * CIFAR_TESTSET_NUM * ratio assert len(cifar.data.shape) == 4 @pytest.mark.parametrize('dataset_type', ['PoisonLabelCIFAR10', 'PoisonLabelCIFAR100']) @pytest.mark.parametrize('poison_label', [-1, 5, 101]) def test_poison_label(poison_label: int, dataset_type: str) -> None: kwargs = dict(poison_label=poison_label) if poison_label < 0 or poison_label >= 100: with pytest.raises(ValueError): _ = build_cifar_fake_dataset(dataset_type, **kwargs) return cifar = build_cifar_fake_dataset(dataset_type, **kwargs) assert cifar.poison_label == poison_label assert cifar.num_classes == 1 assert all(map(lambda x: x == poison_label, cifar.targets)) assert len(cifar.data.shape) == 4 @pytest.mark.parametrize( 'dataset_type', ['RatioPoisonLabelCIFAR10', 'RatioPoisonLabelCIFAR100']) @pytest.mark.parametrize('poison_label', [-1, 5, 101]) @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2]) def test_ratio_poison_label(ratio: float, poison_label: int, dataset_type: str) -> None: kwargs = dict(ratio=ratio, poison_label=poison_label) if (poison_label < 0 or poison_label >= 100) or \ (ratio <= 0 or ratio > 1): with pytest.raises(ValueError): _ = build_cifar_fake_dataset(dataset_type, **kwargs) return cifar = build_cifar_fake_dataset(dataset_type, **kwargs) assert cifar.poison_label == poison_label assert len(cifar) == round(CIFAR_TESTSET_NUM * ratio) assert cifar.num_classes == 1 assert all(map(lambda x: x == poison_label, cifar.targets)) assert len(cifar.data.shape) == 4 @pytest.mark.parametrize( 'dataset_type', ['RangeRatioPoisonLabelCIFAR10', 'RangeRatioPoisonLabelCIFAR100']) @pytest.mark.parametrize('poison_label', [-1, 1, 101]) @pytest.mark.parametrize('range_ratio', [(0, 0.2), (0.2, 0.5), (0.5, 1), (0.5, 0.2)]) def test_range_ratio_poison_label(range_ratio: tuple[float, float], poison_label: int, dataset_type: str) -> None: kwargs = dict(range_ratio=range_ratio, poison_label=poison_label) if poison_label < 0 or poison_label >= 100 or \ not (0 <= range_ratio[0] < range_ratio[1] <= 1): with pytest.raises(ValueError): _ = build_cifar_fake_dataset(dataset_type, **kwargs) return cifar = build_cifar_fake_dataset(dataset_type, **kwargs) assert cifar.poison_label == poison_label assert len(cifar) == round(CIFAR_TESTSET_NUM * (range_ratio[1] - range_ratio[0])) assert cifar.num_classes == 1 assert all(map(lambda x: x == poison_label, cifar.targets)) assert len(cifar.data.shape) == 4