Image Classification
English
TTA
File size: 11,090 Bytes
02ba886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import logging
import random
import numpy as np
import webdataset as wds

import torch
import torchvision
import torchvision.transforms as transforms

from conf import complete_data_dir_path
from datasets.corruptions_datasets import create_cifarc_dataset, create_imagenetc_dataset


logger = logging.getLogger(__name__)


def identity(x):
    return x


def get_transform(dataset_name: str, preprocess=None):
    """
    Get the transformation pipeline
    Note that the data normalization is done within the model
    Input:
        dataset_name: Name of the dataset
        adaptation: Name of the adaptation method
    Returns:
        transforms: The data pre-processing (and augmentation)
    """
    if dataset_name in ["cifar10", "cifar100"]:
        transform = transforms.Compose([transforms.ToTensor()])
    elif dataset_name in ["cifar10_c", "cifar100_c"]:
        transform = None
    elif dataset_name in ["imagenet_c", "ccc"]:
        # note that ImageNet-C and CCC are already resized and centre cropped (to size 224)
        # if use resnet50, there is a pre-normalizaion layer
        transform = transforms.Compose([transforms.ToTensor()])
    else:
        if preprocess:
            # set transform to the corresponding input transformation of the restored model
            transform = preprocess
        else:
            # use classical ImageNet transformation procedure
            transform = transforms.Compose([transforms.Resize(256),
                                            transforms.CenterCrop(224),
                                            transforms.ToTensor()])
    return transform


def get_test_loader(setting: str, dataset_name : str, data_root_dir: str, domain_name: str, domain_names_all: list, 
                    severity: int, num_examples: int, rng_seed: int, batch_size: int = 128, shuffle: bool = False, 
                    workers: int = 4, preprocess=None):
    """
    Create the test data loader
    Input:
        setting: Name of the considered setting
        dataset_name: Name of the dataset
        data_root_dir: Path of the data root directory
        domain_name: Name of the current domain
        domain_names_all: List containing all domains
        severity: Severity level in case of corrupted data
        num_examples: Number of test samples for the current domain
        rng_seed: A seed number
        batch_size: The number of samples to process in each iteration
        shuffle: Whether to shuffle the data. Will destroy pre-defined settings
        workers: Number of workers used for data loading
    Returns:
        test_loader: The test data loader
    """

    data_dir = complete_data_dir_path(data_root_dir, dataset_name)
    transform = get_transform(dataset_name, preprocess)

    # create the test dataset
    if domain_name == "none":
        test_dataset, _ = get_source_loader(dataset_name,
                                            data_root_dir, batch_size,
                                            train_split=False, workers=workers)
    else:
        if dataset_name in ["cifar10_c", "cifar100_c"]:
            test_dataset = create_cifarc_dataset(dataset_name=dataset_name,
                                                 severity=severity,
                                                 data_dir=data_dir,
                                                 corruption=domain_name,
                                                 corruptions_seq=domain_names_all,
                                                 transform=transform,
                                                 setting=setting)
            
            # randomly subsample the dataset if num_examples is specified
            if num_examples != -1:
                num_samples_orig = len(test_dataset)
                # logger.info(f"Changing the number of test samples from {num_samples_orig} to {num_examples}...")
                test_dataset.samples = random.sample(test_dataset.samples, k=min(num_examples, num_samples_orig))

        elif dataset_name == "imagenet_c":
            test_dataset = create_imagenetc_dataset(n_examples=num_examples,
                                                    severity=severity,
                                                    data_dir=data_dir,
                                                    corruption=domain_name,
                                                    corruptions_seq=domain_names_all,
                                                    transform=transform,
                                                    setting=setting)

        elif dataset_name == "ccc":
            logger.info(f"Using the following data transformation:\n{transform}")
            workers = 1
            url = os.path.join(data_root_dir, "CCC", domain_name,"serial_{00000..99999}.tar") # Uncoment this to use a local copy of CCC
            # domain_name = "baseline_20_transition+speed_1000_seed_44" # choose from: baseline_<0/20/40>_transition+speed_<1000/2000/5000>_seed_<43/44/45>
            # url = f'https://mlcloud.uni-tuebingen.de:7443/datasets/CCC/{domain_name}/serial_{{00000..99999}}.tar'

            test_dataset = (wds.WebDataset(url)
                    .decode("pil")
                    .to_tuple("input.jpg", "output.cls")
                    .map_tuple(transform, identity)
            )

        else:
            raise ValueError(f"Dataset '{dataset_name}' is not supported!")

    try:
        # shuffle the test sequence; deterministic behavior for a fixed random seed
        random.seed(rng_seed)
        np.random.seed(rng_seed)
        random.shuffle(test_dataset.samples)

        if "continual_cdc" in setting:
            new_sample_sequence = []
            remaining_samples = {domain: [x for x in test_dataset.samples if x[2] == domain] for domain in domain_names_all}
            remaining_batches = {domain : len(samples)//batch_size + 1 for domain, samples in remaining_samples.items()}

            while remaining_samples:
                selected_domains = np.random.choice(list(remaining_samples.keys()), 1)[0]
                num_selected_batches = np.random.choice(list(range(1, remaining_batches[selected_domains]+1)))
                num_selected_samples = num_selected_batches*batch_size
                new_sample_sequence += remaining_samples[selected_domains][:num_selected_samples]
                remaining_samples[selected_domains] = remaining_samples[selected_domains][num_selected_samples:]
                remaining_batches[selected_domains] -= num_selected_batches
                if len(remaining_samples[selected_domains]) == 0:
                    del remaining_samples[selected_domains]
                    del remaining_batches[selected_domains]

            test_dataset.samples = new_sample_sequence

    except AttributeError:
        logger.warning("Attribute 'samples' is missing. Continuing without shuffling, sorting or subsampling the files...")

    return torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=workers, drop_last=False)


def get_source_loader(dataset_name: str, data_root_dir: str, batch_size: int, train_split: bool = True, 
                      num_samples: int = -1, percentage: float = 1.0, workers: int = 4, preprocess=None):
    """
    Create the source data loader
    Input:
        dataset_name: Name of the dataset
        data_root_dir: Path of the data root directory
        batch_size: The number of samples to process in each iteration
        train_split: Whether to use the training or validation split
        num_samples: Number of source samples used during training
        percentage: (0, 1] Percentage of source samples used during training
        workers: Number of workers used for data loading
    Returns:
        source_dataset: The source dataset
        source_loader: The source data loader
    """

    # create the correct source dataset name
    src_dataset_name = dataset_name.split("_")[0] if dataset_name != "ccc" else "imagenet"

    # complete the data root path to the full dataset path
    data_dir = complete_data_dir_path(data_root_dir, dataset_name=src_dataset_name)

    # get the data transformation
    transform = get_transform(src_dataset_name, preprocess)

    # create the source dataset
    if dataset_name in ["cifar10", "cifar10_c"]:
        source_dataset = torchvision.datasets.CIFAR10(root=data_root_dir,
                                                      train=train_split,
                                                      download=True,
                                                      transform=transform)
    elif dataset_name in ["cifar100", "cifar100_c"]:
        source_dataset = torchvision.datasets.CIFAR100(root=data_root_dir,
                                                       train=train_split,
                                                       download=True,
                                                       transform=transform)
    elif dataset_name in ["imagenet", "imagenet_c", "imagenet_k", "ccc"]:
        split = "train" if train_split else "val"
        source_dataset = torchvision.datasets.ImageNet(root=data_dir,
                                                       split=split,
                                                       transform=transform)
    else:
        raise ValueError("Dataset not supported.")

    if percentage < 1.0 or num_samples >= 0:    # reduce the number of source samples
        assert percentage > 0.0, "The percentage of source samples has to be in range 0.0 < percentage <= 1.0"
        assert num_samples > 0, "The number of source samples has to be at least 1"
        if src_dataset_name in ["cifar10", "cifar100"]:
            nr_src_samples = source_dataset.data.shape[0]
            nr_reduced = min(num_samples, nr_src_samples) if num_samples > 0 else int(np.ceil(nr_src_samples * percentage))
            inds = random.sample(range(0, nr_src_samples), nr_reduced)
            source_dataset.data = source_dataset.data[inds]
            source_dataset.targets = [source_dataset.targets[k] for k in inds]
        else:
            nr_src_samples = len(source_dataset.samples)
            nr_reduced = min(num_samples, nr_src_samples) if num_samples > 0 else int(np.ceil(nr_src_samples * percentage))
            source_dataset.samples = random.sample(source_dataset.samples, nr_reduced)

        logger.info(f"Number of images in source loader: {nr_reduced}/{nr_src_samples} \t Reduction factor = {nr_reduced / nr_src_samples:.4f}")

    # create the source data loader
    source_loader = torch.utils.data.DataLoader(source_dataset,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=workers,
                                                drop_last=False)
    logger.info(f"Number of images and batches in source loader: #img = {len(source_dataset)} #batches = {len(source_loader)}")
    return source_dataset, source_loader