SAE / attacks /UnivIntruder /dataloader.py
Ttius's picture
Upload 192 files
998bb30 verified
import os
from torchvision.datasets import CIFAR10, CIFAR100
import torch
import torchvision
from torchvision import transforms
from attacks.UnivIntruder.dataset import TinyImageNet, ImageNet, Caltech101, ImageNet100
from attacks.UnivIntruder.utils_.text_templates import imagenet_templates
def reverse_dict(label2class):
# Use dictionary comprehension to reverse the key-value pairs
class2label = {value: key for key, value in label2class.items()}
return class2label
class GetDatasetMeta():
def __init__(self, root, dataset_name) -> None:
self.dataset_name = dataset_name
self.root = root
def get_dataset_label_names(self):
if self.dataset_name == "CIFAR10":
tmp = CIFAR10(root=self.root)
label_dict = reverse_dict(tmp.class_to_idx)
elif self.dataset_name == "CIFAR100":
tmp = CIFAR100(root=self.root)
label_dict = reverse_dict(tmp.class_to_idx)
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100":
with open(os.path.join('attacks/UnivIntruder/utils_/map_clsloc.txt'), 'r') as file:
lines = file.readlines()
label_dict = {}
for line in lines:
parts = line.strip().split(',')
label_dict[int(parts[1])] = parts[2]
elif self.dataset_name == "Caltech101":
label_list = self.get_dataset().dataset.categories
label_dict = {i: label_list[i] for i in range(101)}
else:
label_dict = None
return label_dict
def get_transformation(self):
if self.dataset_name == "CIFAR10":
size = 32
normalize = [[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]]
elif self.dataset_name == "CIFAR100":
size = 32
normalize = [[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]]
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet" or self.dataset_name == "ImageNet100":
size = 224
# normalize = [0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]
normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
elif self.dataset_name == "Caltech101":
size = 224
normalize = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
else:
return None
preprocess = transforms.Compose([
transforms.Resize((size, size)),
transforms.Normalize(mean=normalize[0], std=normalize[1]),
])
return preprocess
def get_template(self):
return imagenet_templates
def get_dataset(self, train=False, download=False, transform=None, **kwargs):
if self.dataset_name == "CIFAR10":
target_dataset = CIFAR10(root=self.root, train=train, download=download, transform=transform)
elif self.dataset_name == "CIFAR100":
target_dataset = CIFAR100(root=self.root, train=train, download=download, transform=transform)
elif self.dataset_name == "TinyImageNet":
target_dataset = TinyImageNet(root=self.root, split='train' if train else 'test', download=download,
transform=transform)
elif self.dataset_name == "ImageNet":
target_dataset = ImageNet(root=self.root, split='train' if train else 'val', download=download,
transform=transform)
elif self.dataset_name == "ImageNet100":
target_dataset = ImageNet100(root=self.root, split='train' if train else 'val', download=download,
transform=transform)
elif self.dataset_name == "Caltech101":
target_dataset = Caltech101(root=self.root, transform=transform, train=train)
return target_dataset
def get_clean_model(self):
if self.dataset_name == "CIFAR10":
model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar10_resnet44', pretrained=True)
elif self.dataset_name == "CIFAR100":
model_visual = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet44', pretrained=True)
elif self.dataset_name == "TinyImageNet" or self.dataset_name == "ImageNet":
model_visual = torchvision.models.resnet50(pretrained=True)
return model_visual
def n_classes(self):
if self.dataset_name == "CIFAR10":
n = 10
elif self.dataset_name == "CIFAR100":
n = 100
elif self.dataset_name == "TinyImageNet":
n = 500
elif self.dataset_name == "ImageNet":
n = 1000
elif self.dataset_name == "ImageNet100":
n = 100
elif self.dataset_name == "Caltech101":
n = 101
return n
class InMemoryDataset(torch.utils.data.Dataset):
def __init__(self, data_list):
self.data_list = data_list
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
return self.data_list[idx]
class TransformedDataset(torch.utils.data.Dataset):
def __init__(self, args_cl, original_dataset, transform=None):
self.original_dataset = original_dataset
self.transform = transform
self.args_cl = args_cl
def __getitem__(self, index):
if self.args_cl['approach'] == 0:
image, label = self.original_dataset[index]
else:
_, image, label = self.original_dataset[index]
# _, image, label = self.original_dataset[index]
if self.transform:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.original_dataset)