|
|
import os |
|
|
import logging |
|
|
import random |
|
|
import numpy as np |
|
|
import webdataset as wds |
|
|
|
|
|
import torch |
|
|
import torchvision |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
from conf import complete_data_dir_path |
|
|
from datasets.corruptions_datasets import create_cifarc_dataset, create_imagenetc_dataset |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def identity(x): |
|
|
return x |
|
|
|
|
|
|
|
|
def get_transform(dataset_name: str, preprocess=None): |
|
|
""" |
|
|
Get the transformation pipeline |
|
|
Note that the data normalization is done within the model |
|
|
Input: |
|
|
dataset_name: Name of the dataset |
|
|
adaptation: Name of the adaptation method |
|
|
Returns: |
|
|
transforms: The data pre-processing (and augmentation) |
|
|
""" |
|
|
if dataset_name in ["cifar10", "cifar100"]: |
|
|
transform = transforms.Compose([transforms.ToTensor()]) |
|
|
elif dataset_name in ["cifar10_c", "cifar100_c"]: |
|
|
transform = None |
|
|
elif dataset_name in ["imagenet_c", "ccc"]: |
|
|
|
|
|
|
|
|
transform = transforms.Compose([transforms.ToTensor()]) |
|
|
else: |
|
|
if preprocess: |
|
|
|
|
|
transform = preprocess |
|
|
else: |
|
|
|
|
|
transform = transforms.Compose([transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor()]) |
|
|
return transform |
|
|
|
|
|
|
|
|
def get_test_loader(setting: str, dataset_name : str, data_root_dir: str, domain_name: str, domain_names_all: list, |
|
|
severity: int, num_examples: int, rng_seed: int, batch_size: int = 128, shuffle: bool = False, |
|
|
workers: int = 4, preprocess=None): |
|
|
""" |
|
|
Create the test data loader |
|
|
Input: |
|
|
setting: Name of the considered setting |
|
|
dataset_name: Name of the dataset |
|
|
data_root_dir: Path of the data root directory |
|
|
domain_name: Name of the current domain |
|
|
domain_names_all: List containing all domains |
|
|
severity: Severity level in case of corrupted data |
|
|
num_examples: Number of test samples for the current domain |
|
|
rng_seed: A seed number |
|
|
batch_size: The number of samples to process in each iteration |
|
|
shuffle: Whether to shuffle the data. Will destroy pre-defined settings |
|
|
workers: Number of workers used for data loading |
|
|
Returns: |
|
|
test_loader: The test data loader |
|
|
""" |
|
|
|
|
|
data_dir = complete_data_dir_path(data_root_dir, dataset_name) |
|
|
transform = get_transform(dataset_name, preprocess) |
|
|
|
|
|
|
|
|
if domain_name == "none": |
|
|
test_dataset, _ = get_source_loader(dataset_name, |
|
|
data_root_dir, batch_size, |
|
|
train_split=False, workers=workers) |
|
|
else: |
|
|
if dataset_name in ["cifar10_c", "cifar100_c"]: |
|
|
test_dataset = create_cifarc_dataset(dataset_name=dataset_name, |
|
|
severity=severity, |
|
|
data_dir=data_dir, |
|
|
corruption=domain_name, |
|
|
corruptions_seq=domain_names_all, |
|
|
transform=transform, |
|
|
setting=setting) |
|
|
|
|
|
|
|
|
if num_examples != -1: |
|
|
num_samples_orig = len(test_dataset) |
|
|
|
|
|
test_dataset.samples = random.sample(test_dataset.samples, k=min(num_examples, num_samples_orig)) |
|
|
|
|
|
elif dataset_name == "imagenet_c": |
|
|
test_dataset = create_imagenetc_dataset(n_examples=num_examples, |
|
|
severity=severity, |
|
|
data_dir=data_dir, |
|
|
corruption=domain_name, |
|
|
corruptions_seq=domain_names_all, |
|
|
transform=transform, |
|
|
setting=setting) |
|
|
|
|
|
elif dataset_name == "ccc": |
|
|
logger.info(f"Using the following data transformation:\n{transform}") |
|
|
workers = 1 |
|
|
url = os.path.join(data_root_dir, "CCC", domain_name,"serial_{00000..99999}.tar") |
|
|
|
|
|
|
|
|
|
|
|
test_dataset = (wds.WebDataset(url) |
|
|
.decode("pil") |
|
|
.to_tuple("input.jpg", "output.cls") |
|
|
.map_tuple(transform, identity) |
|
|
) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Dataset '{dataset_name}' is not supported!") |
|
|
|
|
|
try: |
|
|
|
|
|
random.seed(rng_seed) |
|
|
np.random.seed(rng_seed) |
|
|
random.shuffle(test_dataset.samples) |
|
|
|
|
|
if "continual_cdc" in setting: |
|
|
new_sample_sequence = [] |
|
|
remaining_samples = {domain: [x for x in test_dataset.samples if x[2] == domain] for domain in domain_names_all} |
|
|
remaining_batches = {domain : len(samples)//batch_size + 1 for domain, samples in remaining_samples.items()} |
|
|
|
|
|
while remaining_samples: |
|
|
selected_domains = np.random.choice(list(remaining_samples.keys()), 1)[0] |
|
|
num_selected_batches = np.random.choice(list(range(1, remaining_batches[selected_domains]+1))) |
|
|
num_selected_samples = num_selected_batches*batch_size |
|
|
new_sample_sequence += remaining_samples[selected_domains][:num_selected_samples] |
|
|
remaining_samples[selected_domains] = remaining_samples[selected_domains][num_selected_samples:] |
|
|
remaining_batches[selected_domains] -= num_selected_batches |
|
|
if len(remaining_samples[selected_domains]) == 0: |
|
|
del remaining_samples[selected_domains] |
|
|
del remaining_batches[selected_domains] |
|
|
|
|
|
test_dataset.samples = new_sample_sequence |
|
|
|
|
|
except AttributeError: |
|
|
logger.warning("Attribute 'samples' is missing. Continuing without shuffling, sorting or subsampling the files...") |
|
|
|
|
|
return torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=workers, drop_last=False) |
|
|
|
|
|
|
|
|
def get_source_loader(dataset_name: str, data_root_dir: str, batch_size: int, train_split: bool = True, |
|
|
num_samples: int = -1, percentage: float = 1.0, workers: int = 4, preprocess=None): |
|
|
""" |
|
|
Create the source data loader |
|
|
Input: |
|
|
dataset_name: Name of the dataset |
|
|
data_root_dir: Path of the data root directory |
|
|
batch_size: The number of samples to process in each iteration |
|
|
train_split: Whether to use the training or validation split |
|
|
num_samples: Number of source samples used during training |
|
|
percentage: (0, 1] Percentage of source samples used during training |
|
|
workers: Number of workers used for data loading |
|
|
Returns: |
|
|
source_dataset: The source dataset |
|
|
source_loader: The source data loader |
|
|
""" |
|
|
|
|
|
|
|
|
src_dataset_name = dataset_name.split("_")[0] if dataset_name != "ccc" else "imagenet" |
|
|
|
|
|
|
|
|
data_dir = complete_data_dir_path(data_root_dir, dataset_name=src_dataset_name) |
|
|
|
|
|
|
|
|
transform = get_transform(src_dataset_name, preprocess) |
|
|
|
|
|
|
|
|
if dataset_name in ["cifar10", "cifar10_c"]: |
|
|
source_dataset = torchvision.datasets.CIFAR10(root=data_root_dir, |
|
|
train=train_split, |
|
|
download=True, |
|
|
transform=transform) |
|
|
elif dataset_name in ["cifar100", "cifar100_c"]: |
|
|
source_dataset = torchvision.datasets.CIFAR100(root=data_root_dir, |
|
|
train=train_split, |
|
|
download=True, |
|
|
transform=transform) |
|
|
elif dataset_name in ["imagenet", "imagenet_c", "imagenet_k", "ccc"]: |
|
|
split = "train" if train_split else "val" |
|
|
source_dataset = torchvision.datasets.ImageNet(root=data_dir, |
|
|
split=split, |
|
|
transform=transform) |
|
|
else: |
|
|
raise ValueError("Dataset not supported.") |
|
|
|
|
|
if percentage < 1.0 or num_samples >= 0: |
|
|
assert percentage > 0.0, "The percentage of source samples has to be in range 0.0 < percentage <= 1.0" |
|
|
assert num_samples > 0, "The number of source samples has to be at least 1" |
|
|
if src_dataset_name in ["cifar10", "cifar100"]: |
|
|
nr_src_samples = source_dataset.data.shape[0] |
|
|
nr_reduced = min(num_samples, nr_src_samples) if num_samples > 0 else int(np.ceil(nr_src_samples * percentage)) |
|
|
inds = random.sample(range(0, nr_src_samples), nr_reduced) |
|
|
source_dataset.data = source_dataset.data[inds] |
|
|
source_dataset.targets = [source_dataset.targets[k] for k in inds] |
|
|
else: |
|
|
nr_src_samples = len(source_dataset.samples) |
|
|
nr_reduced = min(num_samples, nr_src_samples) if num_samples > 0 else int(np.ceil(nr_src_samples * percentage)) |
|
|
source_dataset.samples = random.sample(source_dataset.samples, nr_reduced) |
|
|
|
|
|
logger.info(f"Number of images in source loader: {nr_reduced}/{nr_src_samples} \t Reduction factor = {nr_reduced / nr_src_samples:.4f}") |
|
|
|
|
|
|
|
|
source_loader = torch.utils.data.DataLoader(source_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=workers, |
|
|
drop_last=False) |
|
|
logger.info(f"Number of images and batches in source loader: #img = {len(source_dataset)} #batches = {len(source_loader)}") |
|
|
return source_dataset, source_loader |
|
|
|