SAE / utils /data.py
Ttius's picture
Upload 192 files
998bb30 verified
import numpy as np
from torchvision import datasets, transforms
from utils.toolkit import split_images_labels
import os
import shutil
import torch
from PIL import Image
import logging
import json
class iData(object):
train_trsf = []
test_trsf = []
common_trsf = []
class_order = None
class iCIFAR10(iData):
use_path = False
train_trsf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=63 / 255),
transforms.ToTensor(),
]
test_trsf = [transforms.ToTensor()]
common_trsf = [
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)
),
]
class_order = np.arange(10).tolist()
def download_data(self):
train_dataset = datasets.cifar.CIFAR10("./datasets", train=True, download=True)
test_dataset = datasets.cifar.CIFAR10("./datasets", train=False, download=True)
self.train_data, self.train_targets = train_dataset.data, np.array(
train_dataset.targets
)
self.test_data, self.test_targets = test_dataset.data, np.array(
test_dataset.targets
)
class iCIFAR100(iData):
use_path = False
train_trsf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=63 / 255),
transforms.ToTensor()
]
test_trsf = [transforms.ToTensor()]
common_trsf = [
transforms.Normalize(
mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)
),
]
class_order = np.arange(100).tolist()
def download_data(self):
train_dataset = datasets.cifar.CIFAR100("./datasets", train=True, download=True)
test_dataset = datasets.cifar.CIFAR100("./datasets", train=False, download=True)
self.train_data, self.train_targets = train_dataset.data, np.array(
train_dataset.targets
)
self.test_data, self.test_targets = test_dataset.data, np.array(
test_dataset.targets
)
class iImageNet1000(iData):
use_path = True
train_trsf = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=63 / 255),
transforms.ToTensor(),
]
test_trsf = [
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
common_trsf = [
# transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
class_order = np.arange(1000).tolist()
def download_data(self):
# assert 0, "You should specify the folder of your dataset"
train_dir = "you_path/Imagenet/train"
test_dir = "you_path/Imagenet/val"
train_dset = datasets.ImageFolder(train_dir)
test_dset = datasets.ImageFolder(test_dir)
self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
class iImageNet100(iData):
use_path = True
train_trsf = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
test_trsf = [
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
common_trsf = [
# transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
class_order = np.arange(1000).tolist()
def download_data(self):
# assert 0, "You should specify the folder of your dataset"
train_dir = "you_path/Imagenet/train"
test_dir = "you_path/Imagenet/val"
train_dset = datasets.ImageFolder(train_dir)
test_dset = datasets.ImageFolder(test_dir)
self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
def save_target_images_above_threshold(asr_matrix, target_imgs, threshold, save_path, target_class, alpha,
mode="aobve_threshold"):
if mode == "above_threshold":
avg_asr = asr_matrix.mean(axis=0)
selected_indices = np.where(avg_asr > threshold)[0]
mode = f"above_threshold_alpha{alpha}"
elif mode == 'top1':
avg_asr = asr_matrix.mean(axis=0)
selected_indices = [np.argmax(avg_asr)]
elif mode == 'top1_above_threshold':
avg_asr = asr_matrix.mean(axis=0)
selected_indices = np.where(avg_asr > threshold)[0]
if len(selected_indices) > 0:
selected_indices = [selected_indices[np.argmax(avg_asr[selected_indices])]]
mode = f"top1_above_threshold_alpha{alpha}"
elif mode == 'top1_for_task0':
asr_task0 = asr_matrix[0, :]
selected_indices = np.where(asr_task0 > threshold)[0]
if len(selected_indices) > 0:
selected_indices = [selected_indices[np.argmax(asr_task0[selected_indices])]]
if len(selected_indices) == 0:
logging.info("No target images with average ASR above the threshold.")
return
target_folder = os.path.join(save_path, f'target_dataset_{mode}', f"{target_class}_{threshold}")
os.makedirs(target_folder, exist_ok=True)
existing_files = sorted(os.listdir(target_folder))
next_index = len(existing_files)
for idx in selected_indices:
target_image = target_imgs[idx]
target_image_pil = Image.fromarray(np.moveaxis((target_image * 255).astype(np.uint8), 0, -1))
target_image_name = f"{next_index}.png"
target_image_pil.save(os.path.join(target_folder, target_image_name))
logging.info(f"Saved target image {next_index} to {os.path.join(target_folder, target_image_name)}")
next_index += 1
logging.info(f"Target images saved to {target_folder}")
def load_target_imgs(logs_name, target_class, alpha, threshold):
target_folder = os.path.join(logs_name, f'target_dataset_alpha{alpha}', f"{target_class}_{threshold}")
if not os.path.exists(target_folder):
logging.error(f"Target folder {target_folder} does not exist.")
return None, None
target_imgs = []
target_labels = []
for file_name in sorted(os.listdir(target_folder)):
if file_name.endswith(".png"):
target_image_path = os.path.join(target_folder, file_name)
target_image_pil = Image.open(target_image_path)
target_image_pil = target_image_pil.resize((32, 32))
target_image = np.array(target_image_pil)
target_image = np.moveaxis(target_image, -1, 0)
target_image_tensor = torch.from_numpy(target_image.astype(np.float32) / 255.0)
target_imgs.append(target_image_tensor)
target_labels.append(target_class)
# 将列表中的 target_imgs 转换为一个 eagerpy Tensor
target_imgs_tensor = torch.stack(target_imgs)
target_labels_tensor = torch.tensor(target_labels)
logging.info(f"Loaded {len(target_imgs)} target images from {target_folder}")
return target_imgs_tensor, target_labels_tensor
def load_json(settings_path):
with open(settings_path) as data_file:
param = json.load(data_file)
return param