File size: 5,875 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from torchvision.datasets import CIFAR10, CIFAR100
import torch
import torchvision
from torchvision import transforms
from attacks.UnivIntruder.dataset import TinyImageNet, ImageNet, Caltech101, ImageNet100
from attacks.UnivIntruder.utils_.text_templates import imagenet_templates


def reverse_dict(label2class):
    # Use dictionary comprehension to reverse the key-value pairs
    class2label = {value: key for key, value in label2class.items()}
    return class2label


class GetDatasetMeta():
    def __init__(self, root, dataset_name) -> None:
        self.dataset_name = dataset_name
        self.root = root

    def get_dataset_label_names(self):
        if self.dataset_name == "CIFAR10":
            tmp = CIFAR10(root=self.root)
            label_dict = reverse_dict(tmp.class_to_idx)

        elif self.dataset_name == "CIFAR100":
            tmp = CIFAR100(root=self.root)
            label_dict = reverse_dict(tmp.class_to_idx)

        elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100":
            with open(os.path.join('attacks/UnivIntruder/utils_/map_clsloc.txt'), 'r') as file:
                lines = file.readlines()

            label_dict = {}

            for line in lines:
                parts = line.strip().split(',')
                label_dict[int(parts[1])] = parts[2]

        elif self.dataset_name == "Caltech101":
            label_list = self.get_dataset().dataset.categories
            label_dict = {i: label_list[i] for i in range(101)}

        else:
            label_dict = None

        return label_dict

    def get_transformation(self):
        if self.dataset_name == "CIFAR10":
            size = 32
            normalize = [[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]]
        elif self.dataset_name == "CIFAR100":
            size = 32
            normalize = [[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                         [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]]
        elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100":
            size = 224
            # normalize = [0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]
            normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
        elif self.dataset_name == "Caltech101":
            size = 224
            normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
        else:
            return None
        preprocess = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.Normalize(mean=normalize[0], std=normalize[1]),
        ])
        return preprocess

    def get_template(self):
        return imagenet_templates

    def get_dataset(self, train=False, download=False, transform=None, **kwargs):
        if self.dataset_name == "CIFAR10":
            target_dataset = CIFAR10(root=self.root, train=train, download=download, transform=transform)
        elif self.dataset_name == "CIFAR100":
            target_dataset = CIFAR100(root=self.root, train=train, download=download, transform=transform)
        elif self.dataset_name == "TinyImageNet":
            target_dataset = TinyImageNet(root=self.root, split='train' if train else 'test', download=download,
                                          transform=transform)
        elif self.dataset_name == "ImageNet":
            target_dataset = ImageNet(root=self.root, split='train' if train else 'val', download=download,
                                      transform=transform)
        elif self.dataset_name == "ImageNet100":
            target_dataset = ImageNet100(root=self.root, split='train' if train else 'val', download=download,
                                      transform=transform)
        elif self.dataset_name == "Caltech101":
            target_dataset = Caltech101(root=self.root, transform=transform, train=train)
        return target_dataset

    def get_clean_model(self):
        if self.dataset_name == "CIFAR10":
            model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar10_resnet44', pretrained=True)
        elif self.dataset_name == "CIFAR100":
            model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet44', pretrained=True)
        elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet":
            model_visual = torchvision.models.resnet50(pretrained=True)
        return model_visual

    def n_classes(self):
        if self.dataset_name == "CIFAR10":
            n = 10
        elif self.dataset_name == "CIFAR100":
            n = 100
        elif self.dataset_name == "TinyImageNet":
            n = 500
        elif self.dataset_name == "ImageNet":
            n = 1000
        elif self.dataset_name == "ImageNet100":
            n = 100
        elif self.dataset_name == "Caltech101":
            n = 101
        return n


class InMemoryDataset(torch.utils.data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx]


class TransformedDataset(torch.utils.data.Dataset):
    def __init__(self, args_cl, original_dataset, transform=None):
        self.original_dataset = original_dataset
        self.transform = transform
        self.args_cl = args_cl

    def __getitem__(self, index):
        if self.args_cl['approach'] == 0:
            image, label = self.original_dataset[index]
        else:
            _, image, label = self.original_dataset[index]
        # _, image, label = self.original_dataset[index]
        if self.transform:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.original_dataset)