SAE / attacks /attack_config.py
Ttius's picture
Upload 192 files
998bb30 verified
import os
from foolbox import PyTorchModel, accuracy
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from torchvision import transforms
from utils.data_manager import DataManager, get_dataloader
import torch
import logging
import eagerpy as ep
from utils.data_manager import load_all_task_models
class SustainableAttack(Attack):
def __init__(self, args, device='cuda'):
super().__init__()
self.device = device
self.args = args
# Only init the first 10 classes
self.data_manager = DataManager(
args["dataset"],
args["shuffle"],
args["seed"],
args["init_cls"],
args["increment"],
args["attack"]
)
self.args['target_class_list'] = self.data_manager._class_order[:self.data_manager._increments[0]]
self.args['target_class_dict'] = dict(zip(self.args['target_class_list'], range(len(self.args['target_class_list']))))
self.img_s = 32 if args["dataset"] == 'cifar100' else 224
self.batch_size = args['batch_size']
self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size,
start_class=0, end_class=10,
train=True, shuffle=True, num_workers=0)
ckpts = sorted([f for f in os.listdir(args['logs_name']) if f.endswith('.pkl')])
self.ckpt_paths = [os.path.join(args['logs_name'], ckpt_file) for ckpt_file in ckpts]
self.model = None
self.model0 = None
self.attack = None
self.target_class = args['target_class']
if args["dataset"] == "cifar100":
self.norm = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761])
self.preprocessing = dict(mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761], axis=-3)
else:
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.preprocessing = dict(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], axis=-3)
def run_attack(self):
pass
def to_alls(self, imgs, labels, labels_t=None,
target_imgs=None, target_labels=None, return_index=False):
correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
correct_index_t = ep.full_like(ep.astensors(torch.ones((len(target_imgs),), dtype=bool, device=self.device))[0], fill_value=True)
models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
batch_size=self.batch_size,
train=True,
load_type='model')[0]
for task in range(len(models)):
model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
acc_bool = accuracy(model, imgs, labels)[1]
if task == 0:
acc_bool_t, target_logits = accuracy(model, target_imgs, target_labels)[1:]
else:
acc_bool_t = accuracy(model, target_imgs, target_labels)[1]
correct_index = ep.logical_and(correct_index, acc_bool)
correct_index_t = ep.logical_and(correct_index_t, acc_bool_t)
del model, acc_bool, acc_bool_t
if correct_index.any():
imgs = imgs[correct_index]
labels = labels[correct_index]
if self.target_class is not None:
labels_t = labels_t[correct_index]
logging.info(
f"Filtering {len(labels)} Correct samples for all CL models.")
else:
print("No valid samples found for IMGS, skipping this batch.")
imgs, labels, labels_t = None, None, None
if correct_index_t.any():
target_imgs = target_imgs[correct_index_t]
target_labels = target_labels[correct_index_t]
target_logits = target_logits[correct_index_t]
logging.info(
f"Filtering {len(target_labels)} Target samples for all CL models.")
else:
logging.info("No valid samples found for TARGET IMGS, skipping this batch.")
target_imgs, target_labels, target_logits = None, None, None
if return_index:
return correct_index, correct_index_t
del models, correct_index, correct_index_t
return imgs, labels, labels_t, target_imgs, target_labels, target_logits
def to_all(self, imgs, labels, return_index=False):
# Filtering Correct Samples for All CL Models
correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
batch_size=self.batch_size,
train=True,
load_type='model')[0]
for task in range(len(models)):
model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
acc_bool = accuracy(model, imgs, labels)[1]
correct_index = ep.logical_and(correct_index, acc_bool)
del model, acc_bool
if correct_index.any():
imgs = imgs[correct_index]
labels = labels[correct_index]
logging.info(
f"Filtering {len(labels)} Correct samples for all CL models.")
else:
logging.info("No valid samples found for IMGS, skipping this batch.")
imgs, labels = None, None
if return_index:
return correct_index
del models, correct_index
return imgs, labels
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Sequence[Union[float, None]],
**kwargs: Any,
) -> Tuple[List[T], List[T], T]:
...
def repeat(self, times: int) -> "SustainableAttack":
...