|
|
import os
|
|
|
from pathlib import Path
|
|
|
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.utils.data as data
|
|
|
import torchvision.datasets as datasets
|
|
|
import torchvision.transforms as transforms
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
from robustbench.model_zoo.enums import BenchmarkDataset
|
|
|
from robustbench.zenodo_download import DownloadError, zenodo_download
|
|
|
from robustbench.loaders import CustomImageFolder, CustomCifarDataset
|
|
|
|
|
|
|
|
|
PREPROCESSINGS = {
|
|
|
'Res256Crop224': transforms.Compose([transforms.Resize(256),
|
|
|
transforms.CenterCrop(224),
|
|
|
transforms.ToTensor()]),
|
|
|
'Crop288': transforms.Compose([transforms.CenterCrop(288),
|
|
|
transforms.ToTensor()]),
|
|
|
'none': transforms.Compose([transforms.ToTensor()]),
|
|
|
}
|
|
|
|
|
|
|
|
|
def _load_dataset(
|
|
|
dataset: Dataset,
|
|
|
n_examples: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
batch_size = 100
|
|
|
test_loader = data.DataLoader(dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False,
|
|
|
num_workers=0)
|
|
|
|
|
|
x_test, y_test = [], []
|
|
|
for i, (x, y) in enumerate(test_loader):
|
|
|
x_test.append(x)
|
|
|
y_test.append(y)
|
|
|
if n_examples is not None and batch_size * i >= n_examples:
|
|
|
break
|
|
|
x_test_tensor = torch.cat(x_test)
|
|
|
y_test_tensor = torch.cat(y_test)
|
|
|
|
|
|
if n_examples is not None:
|
|
|
x_test_tensor = x_test_tensor[:n_examples]
|
|
|
y_test_tensor = y_test_tensor[:n_examples]
|
|
|
|
|
|
return x_test_tensor, y_test_tensor
|
|
|
|
|
|
|
|
|
def load_cifar10(
|
|
|
n_examples: Optional[int] = None,
|
|
|
data_dir: str = './data',
|
|
|
prepr: Optional[str] = 'none') -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
transforms_test = PREPROCESSINGS[prepr]
|
|
|
dataset = datasets.CIFAR10(root=data_dir,
|
|
|
train=False,
|
|
|
transform=transforms_test,
|
|
|
download=True)
|
|
|
return _load_dataset(dataset, n_examples)
|
|
|
|
|
|
|
|
|
def load_cifar100(
|
|
|
n_examples: Optional[int] = None,
|
|
|
data_dir: str = './data',
|
|
|
prepr: Optional[str] = 'none') -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
transforms_test = PREPROCESSINGS[prepr]
|
|
|
dataset = datasets.CIFAR100(root=data_dir,
|
|
|
train=False,
|
|
|
transform=transforms_test,
|
|
|
download=True)
|
|
|
return _load_dataset(dataset, n_examples)
|
|
|
|
|
|
|
|
|
def load_imagenet(
|
|
|
n_examples: Optional[int] = 5000,
|
|
|
data_dir: str = './data',
|
|
|
prepr: str = 'Res256Crop224') -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
transforms_test = PREPROCESSINGS[prepr]
|
|
|
imagenet = CustomImageFolder(data_dir + '/val', transforms_test)
|
|
|
|
|
|
test_loader = data.DataLoader(imagenet, batch_size=n_examples,
|
|
|
shuffle=False, num_workers=4)
|
|
|
|
|
|
x_test, y_test, paths = next(iter(test_loader))
|
|
|
|
|
|
return x_test, y_test
|
|
|
|
|
|
|
|
|
CleanDatasetLoader = Callable[[Optional[int], str], Tuple[torch.Tensor,
|
|
|
torch.Tensor]]
|
|
|
_clean_dataset_loaders: Dict[BenchmarkDataset, CleanDatasetLoader] = {
|
|
|
BenchmarkDataset.cifar_10: load_cifar10,
|
|
|
BenchmarkDataset.cifar_100: load_cifar100,
|
|
|
BenchmarkDataset.imagenet: load_imagenet,
|
|
|
}
|
|
|
|
|
|
|
|
|
def load_clean_dataset(dataset: BenchmarkDataset, n_examples: Optional[int],
|
|
|
data_dir: str, prepr: Optional[str] = 'none') -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
return _clean_dataset_loaders[dataset](n_examples, data_dir, prepr)
|
|
|
|
|
|
|
|
|
CORRUPTIONS = ("shot_noise", "motion_blur", "snow", "pixelate",
|
|
|
"gaussian_noise", "defocus_blur", "brightness", "fog",
|
|
|
"zoom_blur", "frost", "glass_blur", "impulse_noise", "contrast",
|
|
|
"jpeg_compression", "elastic_transform")
|
|
|
|
|
|
ZENODO_CORRUPTIONS_LINKS: Dict[BenchmarkDataset, Tuple[str, Set[str]]] = {
|
|
|
BenchmarkDataset.cifar_10: ("2535967", {"CIFAR-10-C.tar"}),
|
|
|
BenchmarkDataset.cifar_100: ("3555552", {"CIFAR-100-C.tar"})
|
|
|
}
|
|
|
|
|
|
CORRUPTIONS_DIR_NAMES: Dict[BenchmarkDataset, str] = {
|
|
|
BenchmarkDataset.cifar_10: "CIFAR-10-C",
|
|
|
BenchmarkDataset.cifar_100: "CIFAR-100-C",
|
|
|
BenchmarkDataset.imagenet: "ImageNet-C"
|
|
|
}
|
|
|
|
|
|
|
|
|
def load_cifar10c(
|
|
|
n_examples: int = 10000,
|
|
|
severity: int = 5,
|
|
|
data_dir: str = './data',
|
|
|
shuffle: bool = False,
|
|
|
corruptions: Sequence[str] = CORRUPTIONS,
|
|
|
prepr: Optional[str] = 'none'
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
return load_corruptions_cifar(BenchmarkDataset.cifar_10, n_examples,
|
|
|
severity, data_dir, corruptions, shuffle)
|
|
|
|
|
|
|
|
|
def load_cifar100c(
|
|
|
n_examples: int = 10000,
|
|
|
severity: int = 5,
|
|
|
data_dir: str = './data',
|
|
|
shuffle: bool = False,
|
|
|
corruptions: Sequence[str] = CORRUPTIONS,
|
|
|
prepr: Optional[str] = 'none'
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
return load_corruptions_cifar(BenchmarkDataset.cifar_100, n_examples,
|
|
|
severity, data_dir, corruptions, shuffle)
|
|
|
|
|
|
|
|
|
def load_imagenetc(
|
|
|
n_examples: Optional[int] = 5000,
|
|
|
severity: int = 5,
|
|
|
data_dir: str = './data',
|
|
|
shuffle: bool = False,
|
|
|
corruptions: Sequence[str] = CORRUPTIONS,
|
|
|
prepr: str = 'Res256Crop224'
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
transforms_test = PREPROCESSINGS[prepr]
|
|
|
|
|
|
assert len(corruptions) == 1, "so far only one corruption is supported (that's how this function is called in eval.py"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_folder_path = Path(data_dir) / CORRUPTIONS_DIR_NAMES[BenchmarkDataset.imagenet] / corruptions[0] / str(severity)
|
|
|
imagenet = CustomImageFolder(data_folder_path, transforms_test)
|
|
|
|
|
|
test_loader = data.DataLoader(imagenet, batch_size=n_examples,
|
|
|
shuffle=shuffle, num_workers=2)
|
|
|
|
|
|
x_test, y_test, paths = next(iter(test_loader))
|
|
|
|
|
|
return x_test, y_test
|
|
|
|
|
|
|
|
|
CorruptDatasetLoader = Callable[[int, int, str, bool, Sequence[str]],
|
|
|
Tuple[torch.Tensor, torch.Tensor]]
|
|
|
CORRUPTION_DATASET_LOADERS: Dict[BenchmarkDataset, CorruptDatasetLoader] = {
|
|
|
BenchmarkDataset.cifar_10: load_cifar10c,
|
|
|
BenchmarkDataset.cifar_100: load_cifar100c,
|
|
|
BenchmarkDataset.imagenet: load_imagenetc,
|
|
|
}
|
|
|
|
|
|
|
|
|
def load_corruptions_cifar(
|
|
|
dataset: BenchmarkDataset,
|
|
|
n_examples: int = 10000,
|
|
|
severity: int = 5,
|
|
|
data_dir: str = './data',
|
|
|
corruptions: Sequence[str] = CORRUPTIONS,
|
|
|
shuffle: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
assert 1 <= severity <= 5
|
|
|
n_total_cifar = 10000
|
|
|
|
|
|
if not os.path.exists(data_dir):
|
|
|
os.makedirs(data_dir)
|
|
|
|
|
|
data_dir = Path(data_dir)
|
|
|
data_root_dir = data_dir / CORRUPTIONS_DIR_NAMES[dataset]
|
|
|
|
|
|
if not data_root_dir.exists():
|
|
|
zenodo_download(*ZENODO_CORRUPTIONS_LINKS[dataset], save_dir=data_dir)
|
|
|
|
|
|
|
|
|
labels_path = data_root_dir / 'labels.npy'
|
|
|
if not os.path.isfile(labels_path):
|
|
|
raise DownloadError("Labels are missing, try to re-download them.")
|
|
|
labels = np.load(labels_path)
|
|
|
|
|
|
x_test_list, y_test_list = [], []
|
|
|
n_pert = len(corruptions)
|
|
|
for corruption in corruptions:
|
|
|
corruption_file_path = data_root_dir / (corruption + '.npy')
|
|
|
if not corruption_file_path.is_file():
|
|
|
raise DownloadError(
|
|
|
f"{corruption} file is missing, try to re-download it.")
|
|
|
|
|
|
images_all = np.load(corruption_file_path)
|
|
|
images = images_all[(severity - 1) * n_total_cifar:severity *
|
|
|
n_total_cifar]
|
|
|
n_img = int(np.ceil(n_examples / n_pert))
|
|
|
x_test_list.append(images[:n_img])
|
|
|
|
|
|
y_test_list.append(labels[:n_img])
|
|
|
|
|
|
x_test, y_test = np.concatenate(x_test_list), np.concatenate(y_test_list)
|
|
|
if shuffle:
|
|
|
rand_idx = np.random.permutation(np.arange(len(x_test)))
|
|
|
x_test, y_test = x_test[rand_idx], y_test[rand_idx]
|
|
|
|
|
|
|
|
|
x_test = np.transpose(x_test, (0, 3, 1, 2))
|
|
|
|
|
|
x_test = x_test.astype(np.float32) / 255
|
|
|
|
|
|
x_test = torch.tensor(x_test)[:n_examples]
|
|
|
y_test = torch.tensor(y_test)[:n_examples]
|
|
|
|
|
|
return x_test, y_test
|
|
|
|