|
|
import numpy as np |
|
|
from torchvision import datasets, transforms |
|
|
from utils.toolkit import split_images_labels |
|
|
import os |
|
|
import shutil |
|
|
import torch |
|
|
from PIL import Image |
|
|
import logging |
|
|
import json |
|
|
|
|
|
|
|
|
class iData(object): |
|
|
train_trsf = [] |
|
|
test_trsf = [] |
|
|
common_trsf = [] |
|
|
class_order = None |
|
|
|
|
|
|
|
|
class iCIFAR10(iData): |
|
|
use_path = False |
|
|
train_trsf = [ |
|
|
transforms.RandomCrop(32, padding=4), |
|
|
transforms.RandomHorizontalFlip(p=0.5), |
|
|
transforms.ColorJitter(brightness=63 / 255), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
test_trsf = [transforms.ToTensor()] |
|
|
common_trsf = [ |
|
|
transforms.Normalize( |
|
|
mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) |
|
|
), |
|
|
] |
|
|
|
|
|
class_order = np.arange(10).tolist() |
|
|
|
|
|
def download_data(self): |
|
|
train_dataset = datasets.cifar.CIFAR10("./datasets", train=True, download=True) |
|
|
test_dataset = datasets.cifar.CIFAR10("./datasets", train=False, download=True) |
|
|
self.train_data, self.train_targets = train_dataset.data, np.array( |
|
|
train_dataset.targets |
|
|
) |
|
|
self.test_data, self.test_targets = test_dataset.data, np.array( |
|
|
test_dataset.targets |
|
|
) |
|
|
|
|
|
|
|
|
class iCIFAR100(iData): |
|
|
use_path = False |
|
|
train_trsf = [ |
|
|
transforms.RandomCrop(32, padding=4), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.ColorJitter(brightness=63 / 255), |
|
|
transforms.ToTensor() |
|
|
] |
|
|
test_trsf = [transforms.ToTensor()] |
|
|
common_trsf = [ |
|
|
transforms.Normalize( |
|
|
mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) |
|
|
), |
|
|
] |
|
|
|
|
|
class_order = np.arange(100).tolist() |
|
|
|
|
|
def download_data(self): |
|
|
train_dataset = datasets.cifar.CIFAR100("./datasets", train=True, download=True) |
|
|
test_dataset = datasets.cifar.CIFAR100("./datasets", train=False, download=True) |
|
|
self.train_data, self.train_targets = train_dataset.data, np.array( |
|
|
train_dataset.targets |
|
|
) |
|
|
self.test_data, self.test_targets = test_dataset.data, np.array( |
|
|
test_dataset.targets |
|
|
) |
|
|
|
|
|
|
|
|
class iImageNet1000(iData): |
|
|
use_path = True |
|
|
train_trsf = [ |
|
|
transforms.RandomResizedCrop(224), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.ColorJitter(brightness=63 / 255), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
test_trsf = [ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
common_trsf = [ |
|
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
] |
|
|
|
|
|
class_order = np.arange(1000).tolist() |
|
|
|
|
|
def download_data(self): |
|
|
|
|
|
train_dir = "you_path/Imagenet/train" |
|
|
test_dir = "you_path/Imagenet/val" |
|
|
|
|
|
train_dset = datasets.ImageFolder(train_dir) |
|
|
test_dset = datasets.ImageFolder(test_dir) |
|
|
|
|
|
self.train_data, self.train_targets = split_images_labels(train_dset.imgs) |
|
|
self.test_data, self.test_targets = split_images_labels(test_dset.imgs) |
|
|
|
|
|
|
|
|
class iImageNet100(iData): |
|
|
use_path = True |
|
|
train_trsf = [ |
|
|
transforms.RandomResizedCrop(224), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
test_trsf = [ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
common_trsf = [ |
|
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
] |
|
|
|
|
|
class_order = np.arange(1000).tolist() |
|
|
|
|
|
def download_data(self): |
|
|
|
|
|
train_dir = "you_path/Imagenet/train" |
|
|
test_dir = "you_path/Imagenet/val" |
|
|
|
|
|
train_dset = datasets.ImageFolder(train_dir) |
|
|
test_dset = datasets.ImageFolder(test_dir) |
|
|
|
|
|
self.train_data, self.train_targets = split_images_labels(train_dset.imgs) |
|
|
self.test_data, self.test_targets = split_images_labels(test_dset.imgs) |
|
|
|
|
|
|
|
|
def save_target_images_above_threshold(asr_matrix, target_imgs, threshold, save_path, target_class, alpha, |
|
|
mode="aobve_threshold"): |
|
|
if mode == "above_threshold": |
|
|
avg_asr = asr_matrix.mean(axis=0) |
|
|
selected_indices = np.where(avg_asr > threshold)[0] |
|
|
mode = f"above_threshold_alpha{alpha}" |
|
|
elif mode == 'top1': |
|
|
avg_asr = asr_matrix.mean(axis=0) |
|
|
selected_indices = [np.argmax(avg_asr)] |
|
|
elif mode == 'top1_above_threshold': |
|
|
avg_asr = asr_matrix.mean(axis=0) |
|
|
selected_indices = np.where(avg_asr > threshold)[0] |
|
|
if len(selected_indices) > 0: |
|
|
selected_indices = [selected_indices[np.argmax(avg_asr[selected_indices])]] |
|
|
mode = f"top1_above_threshold_alpha{alpha}" |
|
|
elif mode == 'top1_for_task0': |
|
|
asr_task0 = asr_matrix[0, :] |
|
|
selected_indices = np.where(asr_task0 > threshold)[0] |
|
|
if len(selected_indices) > 0: |
|
|
selected_indices = [selected_indices[np.argmax(asr_task0[selected_indices])]] |
|
|
|
|
|
if len(selected_indices) == 0: |
|
|
logging.info("No target images with average ASR above the threshold.") |
|
|
return |
|
|
|
|
|
target_folder = os.path.join(save_path, f'target_dataset_{mode}', f"{target_class}_{threshold}") |
|
|
os.makedirs(target_folder, exist_ok=True) |
|
|
|
|
|
existing_files = sorted(os.listdir(target_folder)) |
|
|
next_index = len(existing_files) |
|
|
|
|
|
for idx in selected_indices: |
|
|
target_image = target_imgs[idx] |
|
|
|
|
|
target_image_pil = Image.fromarray(np.moveaxis((target_image * 255).astype(np.uint8), 0, -1)) |
|
|
target_image_name = f"{next_index}.png" |
|
|
target_image_pil.save(os.path.join(target_folder, target_image_name)) |
|
|
|
|
|
logging.info(f"Saved target image {next_index} to {os.path.join(target_folder, target_image_name)}") |
|
|
next_index += 1 |
|
|
|
|
|
logging.info(f"Target images saved to {target_folder}") |
|
|
|
|
|
|
|
|
def load_target_imgs(logs_name, target_class, alpha, threshold): |
|
|
target_folder = os.path.join(logs_name, f'target_dataset_alpha{alpha}', f"{target_class}_{threshold}") |
|
|
|
|
|
if not os.path.exists(target_folder): |
|
|
logging.error(f"Target folder {target_folder} does not exist.") |
|
|
return None, None |
|
|
|
|
|
target_imgs = [] |
|
|
target_labels = [] |
|
|
|
|
|
for file_name in sorted(os.listdir(target_folder)): |
|
|
if file_name.endswith(".png"): |
|
|
target_image_path = os.path.join(target_folder, file_name) |
|
|
target_image_pil = Image.open(target_image_path) |
|
|
target_image_pil = target_image_pil.resize((32, 32)) |
|
|
target_image = np.array(target_image_pil) |
|
|
|
|
|
target_image = np.moveaxis(target_image, -1, 0) |
|
|
target_image_tensor = torch.from_numpy(target_image.astype(np.float32) / 255.0) |
|
|
|
|
|
target_imgs.append(target_image_tensor) |
|
|
target_labels.append(target_class) |
|
|
|
|
|
|
|
|
target_imgs_tensor = torch.stack(target_imgs) |
|
|
target_labels_tensor = torch.tensor(target_labels) |
|
|
|
|
|
logging.info(f"Loaded {len(target_imgs)} target images from {target_folder}") |
|
|
return target_imgs_tensor, target_labels_tensor |
|
|
|
|
|
|
|
|
def load_json(settings_path): |
|
|
with open(settings_path) as data_file: |
|
|
param = json.load(data_file) |
|
|
|
|
|
return param |