|
|
import os
|
|
|
from foolbox import PyTorchModel, accuracy
|
|
|
from foolbox.attacks.base import *
|
|
|
from foolbox.attacks.gradient_descent_base import *
|
|
|
from torchvision import transforms
|
|
|
from utils.data_manager import DataManager, get_dataloader
|
|
|
import torch
|
|
|
import logging
|
|
|
import eagerpy as ep
|
|
|
from utils.data_manager import load_all_task_models
|
|
|
|
|
|
|
|
|
|
|
|
class SustainableAttack(Attack):
|
|
|
def __init__(self, args, device='cuda'):
|
|
|
super().__init__()
|
|
|
self.device = device
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
self.data_manager = DataManager(
|
|
|
args["dataset"],
|
|
|
args["shuffle"],
|
|
|
args["seed"],
|
|
|
args["init_cls"],
|
|
|
args["increment"],
|
|
|
args["attack"]
|
|
|
)
|
|
|
self.args['target_class_list'] = self.data_manager._class_order[:self.data_manager._increments[0]]
|
|
|
self.args['target_class_dict'] = dict(zip(self.args['target_class_list'], range(len(self.args['target_class_list']))))
|
|
|
self.img_s = 32 if args["dataset"] == 'cifar100' else 224
|
|
|
self.batch_size = args['batch_size']
|
|
|
self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size,
|
|
|
start_class=0, end_class=10,
|
|
|
train=True, shuffle=True, num_workers=0)
|
|
|
|
|
|
ckpts = sorted([f for f in os.listdir(args['logs_name']) if f.endswith('.pkl')])
|
|
|
self.ckpt_paths = [os.path.join(args['logs_name'], ckpt_file) for ckpt_file in ckpts]
|
|
|
self.model = None
|
|
|
self.model0 = None
|
|
|
self.attack = None
|
|
|
|
|
|
self.target_class = args['target_class']
|
|
|
if args["dataset"] == "cifar100":
|
|
|
self.norm = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
|
|
|
std=[0.2675, 0.2565, 0.2761])
|
|
|
self.preprocessing = dict(mean=[0.5071, 0.4867, 0.4408],
|
|
|
std=[0.2675, 0.2565, 0.2761], axis=-3)
|
|
|
else:
|
|
|
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225])
|
|
|
self.preprocessing = dict(mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225], axis=-3)
|
|
|
|
|
|
def run_attack(self):
|
|
|
pass
|
|
|
|
|
|
|
|
|
def to_alls(self, imgs, labels, labels_t=None,
|
|
|
target_imgs=None, target_labels=None, return_index=False):
|
|
|
correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
|
|
|
correct_index_t = ep.full_like(ep.astensors(torch.ones((len(target_imgs),), dtype=bool, device=self.device))[0], fill_value=True)
|
|
|
models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
|
|
|
batch_size=self.batch_size,
|
|
|
train=True,
|
|
|
load_type='model')[0]
|
|
|
for task in range(len(models)):
|
|
|
model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
|
|
acc_bool = accuracy(model, imgs, labels)[1]
|
|
|
if task == 0:
|
|
|
acc_bool_t, target_logits = accuracy(model, target_imgs, target_labels)[1:]
|
|
|
else:
|
|
|
acc_bool_t = accuracy(model, target_imgs, target_labels)[1]
|
|
|
correct_index = ep.logical_and(correct_index, acc_bool)
|
|
|
correct_index_t = ep.logical_and(correct_index_t, acc_bool_t)
|
|
|
del model, acc_bool, acc_bool_t
|
|
|
if correct_index.any():
|
|
|
imgs = imgs[correct_index]
|
|
|
labels = labels[correct_index]
|
|
|
if self.target_class is not None:
|
|
|
labels_t = labels_t[correct_index]
|
|
|
logging.info(
|
|
|
f"Filtering {len(labels)} Correct samples for all CL models.")
|
|
|
else:
|
|
|
print("No valid samples found for IMGS, skipping this batch.")
|
|
|
imgs, labels, labels_t = None, None, None
|
|
|
|
|
|
if correct_index_t.any():
|
|
|
target_imgs = target_imgs[correct_index_t]
|
|
|
target_labels = target_labels[correct_index_t]
|
|
|
target_logits = target_logits[correct_index_t]
|
|
|
logging.info(
|
|
|
f"Filtering {len(target_labels)} Target samples for all CL models.")
|
|
|
else:
|
|
|
logging.info("No valid samples found for TARGET IMGS, skipping this batch.")
|
|
|
target_imgs, target_labels, target_logits = None, None, None
|
|
|
if return_index:
|
|
|
return correct_index, correct_index_t
|
|
|
del models, correct_index, correct_index_t
|
|
|
return imgs, labels, labels_t, target_imgs, target_labels, target_logits
|
|
|
|
|
|
def to_all(self, imgs, labels, return_index=False):
|
|
|
|
|
|
correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
|
|
|
models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
|
|
|
batch_size=self.batch_size,
|
|
|
train=True,
|
|
|
load_type='model')[0]
|
|
|
for task in range(len(models)):
|
|
|
model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
|
|
acc_bool = accuracy(model, imgs, labels)[1]
|
|
|
correct_index = ep.logical_and(correct_index, acc_bool)
|
|
|
del model, acc_bool
|
|
|
if correct_index.any():
|
|
|
imgs = imgs[correct_index]
|
|
|
labels = labels[correct_index]
|
|
|
logging.info(
|
|
|
f"Filtering {len(labels)} Correct samples for all CL models.")
|
|
|
else:
|
|
|
logging.info("No valid samples found for IMGS, skipping this batch.")
|
|
|
imgs, labels = None, None
|
|
|
if return_index:
|
|
|
return correct_index
|
|
|
del models, correct_index
|
|
|
return imgs, labels
|
|
|
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
model: Model,
|
|
|
inputs: T,
|
|
|
criterion: Any,
|
|
|
*,
|
|
|
epsilons: Sequence[Union[float, None]],
|
|
|
**kwargs: Any,
|
|
|
) -> Tuple[List[T], List[T], T]:
|
|
|
...
|
|
|
|
|
|
def repeat(self, times: int) -> "SustainableAttack":
|
|
|
...
|
|
|
|
|
|
|