import numpy as np from torchvision import datasets, transforms from utils.toolkit import split_images_labels import os import shutil import torch from PIL import Image import logging import json class iData(object): train_trsf = [] test_trsf = [] common_trsf = [] class_order = None class iCIFAR10(iData): use_path = False train_trsf = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=63 / 255), transforms.ToTensor(), ] test_trsf = [transforms.ToTensor()] common_trsf = [ transforms.Normalize( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) ), ] class_order = np.arange(10).tolist() def download_data(self): train_dataset = datasets.cifar.CIFAR10("./datasets", train=True, download=True) test_dataset = datasets.cifar.CIFAR10("./datasets", train=False, download=True) self.train_data, self.train_targets = train_dataset.data, np.array( train_dataset.targets ) self.test_data, self.test_targets = test_dataset.data, np.array( test_dataset.targets ) class iCIFAR100(iData): use_path = False train_trsf = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=63 / 255), transforms.ToTensor() ] test_trsf = [transforms.ToTensor()] common_trsf = [ transforms.Normalize( mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) ), ] class_order = np.arange(100).tolist() def download_data(self): train_dataset = datasets.cifar.CIFAR100("./datasets", train=True, download=True) test_dataset = datasets.cifar.CIFAR100("./datasets", train=False, download=True) self.train_data, self.train_targets = train_dataset.data, np.array( train_dataset.targets ) self.test_data, self.test_targets = test_dataset.data, np.array( test_dataset.targets ) class iImageNet1000(iData): use_path = True train_trsf = [ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=63 / 255), transforms.ToTensor(), ] test_trsf = [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), ] common_trsf = [ # transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] class_order = np.arange(1000).tolist() def download_data(self): # assert 0, "You should specify the folder of your dataset" train_dir = "you_path/Imagenet/train" test_dir = "you_path/Imagenet/val" train_dset = datasets.ImageFolder(train_dir) test_dset = datasets.ImageFolder(test_dir) self.train_data, self.train_targets = split_images_labels(train_dset.imgs) self.test_data, self.test_targets = split_images_labels(test_dset.imgs) class iImageNet100(iData): use_path = True train_trsf = [ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] test_trsf = [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), ] common_trsf = [ # transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] class_order = np.arange(1000).tolist() def download_data(self): # assert 0, "You should specify the folder of your dataset" train_dir = "you_path/Imagenet/train" test_dir = "you_path/Imagenet/val" train_dset = datasets.ImageFolder(train_dir) test_dset = datasets.ImageFolder(test_dir) self.train_data, self.train_targets = split_images_labels(train_dset.imgs) self.test_data, self.test_targets = split_images_labels(test_dset.imgs) def save_target_images_above_threshold(asr_matrix, target_imgs, threshold, save_path, target_class, alpha, mode="aobve_threshold"): if mode == "above_threshold": avg_asr = asr_matrix.mean(axis=0) selected_indices = np.where(avg_asr > threshold)[0] mode = f"above_threshold_alpha{alpha}" elif mode == 'top1': avg_asr = asr_matrix.mean(axis=0) selected_indices = [np.argmax(avg_asr)] elif mode == 'top1_above_threshold': avg_asr = asr_matrix.mean(axis=0) selected_indices = np.where(avg_asr > threshold)[0] if len(selected_indices) > 0: selected_indices = [selected_indices[np.argmax(avg_asr[selected_indices])]] mode = f"top1_above_threshold_alpha{alpha}" elif mode == 'top1_for_task0': asr_task0 = asr_matrix[0, :] selected_indices = np.where(asr_task0 > threshold)[0] if len(selected_indices) > 0: selected_indices = [selected_indices[np.argmax(asr_task0[selected_indices])]] if len(selected_indices) == 0: logging.info("No target images with average ASR above the threshold.") return target_folder = os.path.join(save_path, f'target_dataset_{mode}', f"{target_class}_{threshold}") os.makedirs(target_folder, exist_ok=True) existing_files = sorted(os.listdir(target_folder)) next_index = len(existing_files) for idx in selected_indices: target_image = target_imgs[idx] target_image_pil = Image.fromarray(np.moveaxis((target_image * 255).astype(np.uint8), 0, -1)) target_image_name = f"{next_index}.png" target_image_pil.save(os.path.join(target_folder, target_image_name)) logging.info(f"Saved target image {next_index} to {os.path.join(target_folder, target_image_name)}") next_index += 1 logging.info(f"Target images saved to {target_folder}") def load_target_imgs(logs_name, target_class, alpha, threshold): target_folder = os.path.join(logs_name, f'target_dataset_alpha{alpha}', f"{target_class}_{threshold}") if not os.path.exists(target_folder): logging.error(f"Target folder {target_folder} does not exist.") return None, None target_imgs = [] target_labels = [] for file_name in sorted(os.listdir(target_folder)): if file_name.endswith(".png"): target_image_path = os.path.join(target_folder, file_name) target_image_pil = Image.open(target_image_path) target_image_pil = target_image_pil.resize((32, 32)) target_image = np.array(target_image_pil) target_image = np.moveaxis(target_image, -1, 0) target_image_tensor = torch.from_numpy(target_image.astype(np.float32) / 255.0) target_imgs.append(target_image_tensor) target_labels.append(target_class) # 将列表中的 target_imgs 转换为一个 eagerpy Tensor target_imgs_tensor = torch.stack(target_imgs) target_labels_tensor = torch.tensor(target_labels) logging.info(f"Loaded {len(target_imgs)} target images from {target_folder}") return target_imgs_tensor, target_labels_tensor def load_json(settings_path): with open(settings_path) as data_file: param = json.load(data_file) return param