Ttius's picture
Upload 192 files
998bb30 verified
import numpy as np
import pytest
from .utils import FakeDataset, build_fake_dataset
CIFAR_TESTSET_NUM = 1000
def build_cifar_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset:
if dataset_type.endswith('CIFAR100'):
y_range = (0, 99)
dataset_type = dataset_type.replace('CIFAR100', 'FakeDataset')
else:
y_range = (0, 9)
dataset_type = dataset_type.replace('CIFAR10', 'FakeDataset')
dataset_cfg = dict(type=dataset_type,
x_shape=(3, 32, 32),
y_range=y_range,
nums=CIFAR_TESTSET_NUM,
**kwargs)
return build_fake_dataset(dataset_cfg)
@pytest.mark.parametrize('dataset_type', ['CIFAR10', 'CIFAR100'])
def test_xy(dataset_type: str) -> None:
cifar = build_cifar_fake_dataset(dataset_type)
xy = cifar.get_xy()
x, y = xy
assert len(x) == len(y)
assert isinstance(y[0], int)
old_x = x.copy()
old_y = y.copy()
cifar.set_xy(xy)
assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, old_x)])
assert cifar.targets == old_y
x = x[:cifar.num_classes]
y = y[:cifar.num_classes]
cifar.set_xy((x, y))
assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, x)])
assert cifar.targets == y
assert cifar.num_classes == len(set(y))
assert len(cifar.data.shape) == 4
@pytest.mark.parametrize(['start_idx', 'end_idx', 'dataset_type'],
[(0, 9, 'IndexCIFAR10'), (-10, 8, 'IndexCIFAR10'),
(2, 12, 'IndexCIFAR10'), (4, 4, 'IndexCIFAR10'),
(4, 3, 'IndexCIFAR10'), (0, 99, 'IndexCIFAR100'),
(-10, 8, 'IndexCIFAR100'),
(40, 50, 'IndexCIFAR100')])
def test_index(start_idx: int, end_idx: int, dataset_type: str) -> None:
kwargs = dict(start_idx=start_idx, end_idx=end_idx)
if start_idx > end_idx:
with pytest.raises(ValueError):
_ = build_cifar_fake_dataset(dataset_type, **kwargs)
return
cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
assert cifar.start_idx == start_idx
assert cifar.end_idx == end_idx
for y in cifar.targets:
assert start_idx <= y <= end_idx
assert cifar.num_classes == min(
cifar.end_idx, cifar.raw_num_classes - 1) - max(cifar.start_idx, 0) + 1
assert len(cifar.data.shape) == 4
@pytest.mark.parametrize(['ratio', 'dataset_type'], [(-1, 'RatioCIFAR10'),
(0, 'RatioCIFAR10'),
(0.1, 'RatioCIFAR10'),
(0.5, 'RatioCIFAR10'),
(1, 'RatioCIFAR10'),
(2, 'RatioCIFAR10'),
(0.4, 'RatioCIFAR100')])
def test_ratio(ratio: float, dataset_type: str) -> None:
kwargs = dict(ratio=ratio)
if ratio <= 0 or ratio > 1:
with pytest.raises(ValueError):
_ = build_cifar_fake_dataset(dataset_type, **kwargs)
return
cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
assert cifar.num_classes == cifar.raw_num_classes
assert len(cifar.targets) == \
int(CIFAR_TESTSET_NUM / cifar.num_classes * ratio) * cifar.num_classes
assert len(cifar.data.shape) == 4
@pytest.mark.parametrize('range_ratio', [(-1, 0.2), (0, 2), (0.1, 0.1),
(0.5, 0.2), (0.1, 0.5), (0, 1)])
@pytest.mark.parametrize('dataset_type',
['RangeRatioCIFAR10', 'RangeRatioCIFAR100'])
def test_range_ratio(range_ratio: tuple[float, float],
dataset_type: str) -> None:
kwargs = dict(range_ratio=range_ratio)
start_ratio = range_ratio[0]
end_ratio = range_ratio[1]
if not (0 <= start_ratio < end_ratio <= 1):
with pytest.raises(ValueError):
_ = build_cifar_fake_dataset(dataset_type, **kwargs)
return
cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
assert cifar.num_classes == cifar.raw_num_classes
assert len(cifar.targets) == \
round(CIFAR_TESTSET_NUM * (end_ratio - start_ratio))
assert len(cifar.data.shape) == 4
@pytest.mark.parametrize(['range_ratio1', 'range_ratio2'],
[((0, 0.5), (0.5, 1)), ((0, 0.6), (0.4, 1)),
((0, 0.7), (0.3, 1)), ((0, 0.5), (0, 1))])
@pytest.mark.parametrize('dataset_type',
['RangeRatioCIFAR10', 'RangeRatioCIFAR100'])
def test_range_ratio_intersection(range_ratio1: tuple[float, float],
range_ratio2: tuple[float, float],
dataset_type: str) -> None:
cifar1 = build_cifar_fake_dataset(dataset_type=dataset_type,
range_ratio=range_ratio1,
cache_xy=True)
cifar2 = build_cifar_fake_dataset(dataset_type=dataset_type,
range_ratio=range_ratio2,
cache_xy=True)
cat_x = np.concatenate([cifar1.data, cifar2.data], axis=0)
unique_x = np.unique(cat_x, axis=0)
intersection_number = cat_x.shape[0] - unique_x.shape[0]
intersection_ratio = max(0, range_ratio1[1] - range_ratio2[0])
assert round(intersection_ratio * CIFAR_TESTSET_NUM) == intersection_number
@pytest.mark.parametrize('dataset_type',
['IndexRatioCIFAR10', 'IndexRatioCIFAR100'])
@pytest.mark.parametrize(['start_idx', 'end_idx', 'ratio'], [(4, 3, 0.5),
(3, 4, 0),
(3, 4, 2),
(1, 4, 0.1)])
def test_index_ratio(start_idx: int, end_idx: int, ratio: float,
dataset_type: str) -> None:
kwargs = dict(start_idx=start_idx, end_idx=end_idx, ratio=ratio)
if ratio <= 0 or ratio > 1 or start_idx > end_idx:
with pytest.raises(ValueError):
_ = build_cifar_fake_dataset(dataset_type, **kwargs)
return
cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
assert cifar.start_idx == start_idx
assert cifar.end_idx == end_idx
for y in cifar.targets:
assert start_idx <= y <= end_idx
assert len(cifar.targets) == \
cifar.num_classes / cifar.raw_num_classes * CIFAR_TESTSET_NUM * ratio
assert len(cifar.data.shape) == 4
@pytest.mark.parametrize('dataset_type',
['PoisonLabelCIFAR10', 'PoisonLabelCIFAR100'])
@pytest.mark.parametrize('poison_label', [-1, 5, 101])
def test_poison_label(poison_label: int, dataset_type: str) -> None:
kwargs = dict(poison_label=poison_label)
if poison_label < 0 or poison_label >= 100:
with pytest.raises(ValueError):
_ = build_cifar_fake_dataset(dataset_type, **kwargs)
return
cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
assert cifar.poison_label == poison_label
assert cifar.num_classes == 1
assert all(map(lambda x: x == poison_label, cifar.targets))
assert len(cifar.data.shape) == 4
@pytest.mark.parametrize(
'dataset_type', ['RatioPoisonLabelCIFAR10', 'RatioPoisonLabelCIFAR100'])
@pytest.mark.parametrize('poison_label', [-1, 5, 101])
@pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2])
def test_ratio_poison_label(ratio: float, poison_label: int,
dataset_type: str) -> None:
kwargs = dict(ratio=ratio, poison_label=poison_label)
if (poison_label < 0 or poison_label >= 100) or \
(ratio <= 0 or ratio > 1):
with pytest.raises(ValueError):
_ = build_cifar_fake_dataset(dataset_type, **kwargs)
return
cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
assert cifar.poison_label == poison_label
assert len(cifar) == round(CIFAR_TESTSET_NUM * ratio)
assert cifar.num_classes == 1
assert all(map(lambda x: x == poison_label, cifar.targets))
assert len(cifar.data.shape) == 4
@pytest.mark.parametrize(
'dataset_type',
['RangeRatioPoisonLabelCIFAR10', 'RangeRatioPoisonLabelCIFAR100'])
@pytest.mark.parametrize('poison_label', [-1, 1, 101])
@pytest.mark.parametrize('range_ratio', [(0, 0.2), (0.2, 0.5), (0.5, 1),
(0.5, 0.2)])
def test_range_ratio_poison_label(range_ratio: tuple[float,
float], poison_label: int,
dataset_type: str) -> None:
kwargs = dict(range_ratio=range_ratio, poison_label=poison_label)
if poison_label < 0 or poison_label >= 100 or \
not (0 <= range_ratio[0] < range_ratio[1] <= 1):
with pytest.raises(ValueError):
_ = build_cifar_fake_dataset(dataset_type, **kwargs)
return
cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
assert cifar.poison_label == poison_label
assert len(cifar) == round(CIFAR_TESTSET_NUM *
(range_ratio[1] - range_ratio[0]))
assert cifar.num_classes == 1
assert all(map(lambda x: x == poison_label, cifar.targets))
assert len(cifar.data.shape) == 4