|
|
import os |
|
|
from torchvision.datasets import CIFAR10, CIFAR100 |
|
|
import torch |
|
|
import torchvision |
|
|
from torchvision import transforms |
|
|
from attacks.UnivIntruder.dataset import TinyImageNet, ImageNet, Caltech101, ImageNet100 |
|
|
from attacks.UnivIntruder.utils_.text_templates import imagenet_templates |
|
|
|
|
|
|
|
|
def reverse_dict(label2class): |
|
|
|
|
|
class2label = {value: key for key, value in label2class.items()} |
|
|
return class2label |
|
|
|
|
|
|
|
|
class GetDatasetMeta(): |
|
|
def __init__(self, root, dataset_name) -> None: |
|
|
self.dataset_name = dataset_name |
|
|
self.root = root |
|
|
|
|
|
def get_dataset_label_names(self): |
|
|
if self.dataset_name == "CIFAR10": |
|
|
tmp = CIFAR10(root=self.root) |
|
|
label_dict = reverse_dict(tmp.class_to_idx) |
|
|
|
|
|
elif self.dataset_name == "CIFAR100": |
|
|
tmp = CIFAR100(root=self.root) |
|
|
label_dict = reverse_dict(tmp.class_to_idx) |
|
|
|
|
|
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100": |
|
|
with open(os.path.join('attacks/UnivIntruder/utils_/map_clsloc.txt'), 'r') as file: |
|
|
lines = file.readlines() |
|
|
|
|
|
label_dict = {} |
|
|
|
|
|
for line in lines: |
|
|
parts = line.strip().split(',') |
|
|
label_dict[int(parts[1])] = parts[2] |
|
|
|
|
|
elif self.dataset_name == "Caltech101": |
|
|
label_list = self.get_dataset().dataset.categories |
|
|
label_dict = {i: label_list[i] for i in range(101)} |
|
|
|
|
|
else: |
|
|
label_dict = None |
|
|
|
|
|
return label_dict |
|
|
|
|
|
def get_transformation(self): |
|
|
if self.dataset_name == "CIFAR10": |
|
|
size = 32 |
|
|
normalize = [[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]] |
|
|
elif self.dataset_name == "CIFAR100": |
|
|
size = 32 |
|
|
normalize = [[0.5070751592371323, 0.48654887331495095, 0.4409178433670343], |
|
|
[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]] |
|
|
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100": |
|
|
size = 224 |
|
|
|
|
|
normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] |
|
|
elif self.dataset_name == "Caltech101": |
|
|
size = 224 |
|
|
normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] |
|
|
else: |
|
|
return None |
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((size, size)), |
|
|
transforms.Normalize(mean=normalize[0], std=normalize[1]), |
|
|
]) |
|
|
return preprocess |
|
|
|
|
|
def get_template(self): |
|
|
return imagenet_templates |
|
|
|
|
|
def get_dataset(self, train=False, download=False, transform=None, **kwargs): |
|
|
if self.dataset_name == "CIFAR10": |
|
|
target_dataset = CIFAR10(root=self.root, train=train, download=download, transform=transform) |
|
|
elif self.dataset_name == "CIFAR100": |
|
|
target_dataset = CIFAR100(root=self.root, train=train, download=download, transform=transform) |
|
|
elif self.dataset_name == "TinyImageNet": |
|
|
target_dataset = TinyImageNet(root=self.root, split='train' if train else 'test', download=download, |
|
|
transform=transform) |
|
|
elif self.dataset_name == "ImageNet": |
|
|
target_dataset = ImageNet(root=self.root, split='train' if train else 'val', download=download, |
|
|
transform=transform) |
|
|
elif self.dataset_name == "ImageNet100": |
|
|
target_dataset = ImageNet100(root=self.root, split='train' if train else 'val', download=download, |
|
|
transform=transform) |
|
|
elif self.dataset_name == "Caltech101": |
|
|
target_dataset = Caltech101(root=self.root, transform=transform, train=train) |
|
|
return target_dataset |
|
|
|
|
|
def get_clean_model(self): |
|
|
if self.dataset_name == "CIFAR10": |
|
|
model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar10_resnet44', pretrained=True) |
|
|
elif self.dataset_name == "CIFAR100": |
|
|
model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet44', pretrained=True) |
|
|
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet": |
|
|
model_visual = torchvision.models.resnet50(pretrained=True) |
|
|
return model_visual |
|
|
|
|
|
def n_classes(self): |
|
|
if self.dataset_name == "CIFAR10": |
|
|
n = 10 |
|
|
elif self.dataset_name == "CIFAR100": |
|
|
n = 100 |
|
|
elif self.dataset_name == "TinyImageNet": |
|
|
n = 500 |
|
|
elif self.dataset_name == "ImageNet": |
|
|
n = 1000 |
|
|
elif self.dataset_name == "ImageNet100": |
|
|
n = 100 |
|
|
elif self.dataset_name == "Caltech101": |
|
|
n = 101 |
|
|
return n |
|
|
|
|
|
|
|
|
class InMemoryDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, data_list): |
|
|
self.data_list = data_list |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data_list) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.data_list[idx] |
|
|
|
|
|
|
|
|
class TransformedDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, args_cl, original_dataset, transform=None): |
|
|
self.original_dataset = original_dataset |
|
|
self.transform = transform |
|
|
self.args_cl = args_cl |
|
|
|
|
|
def __getitem__(self, index): |
|
|
if self.args_cl['approach'] == 0: |
|
|
image, label = self.original_dataset[index] |
|
|
else: |
|
|
_, image, label = self.original_dataset[index] |
|
|
|
|
|
if self.transform: |
|
|
image = self.transform(image) |
|
|
return image, label |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.original_dataset) |