File size: 5,875 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
from torchvision.datasets import CIFAR10, CIFAR100
import torch
import torchvision
from torchvision import transforms
from attacks.UnivIntruder.dataset import TinyImageNet, ImageNet, Caltech101, ImageNet100
from attacks.UnivIntruder.utils_.text_templates import imagenet_templates
def reverse_dict(label2class):
# Use dictionary comprehension to reverse the key-value pairs
class2label = {value: key for key, value in label2class.items()}
return class2label
class GetDatasetMeta():
def __init__(self, root, dataset_name) -> None:
self.dataset_name = dataset_name
self.root = root
def get_dataset_label_names(self):
if self.dataset_name == "CIFAR10":
tmp = CIFAR10(root=self.root)
label_dict = reverse_dict(tmp.class_to_idx)
elif self.dataset_name == "CIFAR100":
tmp = CIFAR100(root=self.root)
label_dict = reverse_dict(tmp.class_to_idx)
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100":
with open(os.path.join('attacks/UnivIntruder/utils_/map_clsloc.txt'), 'r') as file:
lines = file.readlines()
label_dict = {}
for line in lines:
parts = line.strip().split(',')
label_dict[int(parts[1])] = parts[2]
elif self.dataset_name == "Caltech101":
label_list = self.get_dataset().dataset.categories
label_dict = {i: label_list[i] for i in range(101)}
else:
label_dict = None
return label_dict
def get_transformation(self):
if self.dataset_name == "CIFAR10":
size = 32
normalize = [[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]]
elif self.dataset_name == "CIFAR100":
size = 32
normalize = [[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]]
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100":
size = 224
# normalize = [0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]
normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
elif self.dataset_name == "Caltech101":
size = 224
normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
else:
return None
preprocess = transforms.Compose([
transforms.Resize((size, size)),
transforms.Normalize(mean=normalize[0], std=normalize[1]),
])
return preprocess
def get_template(self):
return imagenet_templates
def get_dataset(self, train=False, download=False, transform=None, **kwargs):
if self.dataset_name == "CIFAR10":
target_dataset = CIFAR10(root=self.root, train=train, download=download, transform=transform)
elif self.dataset_name == "CIFAR100":
target_dataset = CIFAR100(root=self.root, train=train, download=download, transform=transform)
elif self.dataset_name == "TinyImageNet":
target_dataset = TinyImageNet(root=self.root, split='train' if train else 'test', download=download,
transform=transform)
elif self.dataset_name == "ImageNet":
target_dataset = ImageNet(root=self.root, split='train' if train else 'val', download=download,
transform=transform)
elif self.dataset_name == "ImageNet100":
target_dataset = ImageNet100(root=self.root, split='train' if train else 'val', download=download,
transform=transform)
elif self.dataset_name == "Caltech101":
target_dataset = Caltech101(root=self.root, transform=transform, train=train)
return target_dataset
def get_clean_model(self):
if self.dataset_name == "CIFAR10":
model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar10_resnet44', pretrained=True)
elif self.dataset_name == "CIFAR100":
model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet44', pretrained=True)
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet":
model_visual = torchvision.models.resnet50(pretrained=True)
return model_visual
def n_classes(self):
if self.dataset_name == "CIFAR10":
n = 10
elif self.dataset_name == "CIFAR100":
n = 100
elif self.dataset_name == "TinyImageNet":
n = 500
elif self.dataset_name == "ImageNet":
n = 1000
elif self.dataset_name == "ImageNet100":
n = 100
elif self.dataset_name == "Caltech101":
n = 101
return n
class InMemoryDataset(torch.utils.data.Dataset):
def __init__(self, data_list):
self.data_list = data_list
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
return self.data_list[idx]
class TransformedDataset(torch.utils.data.Dataset):
def __init__(self, args_cl, original_dataset, transform=None):
self.original_dataset = original_dataset
self.transform = transform
self.args_cl = args_cl
def __getitem__(self, index):
if self.args_cl['approach'] == 0:
image, label = self.original_dataset[index]
else:
_, image, label = self.original_dataset[index]
# _, image, label = self.original_dataset[index]
if self.transform:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.original_dataset) |