import os from foolbox import PyTorchModel, accuracy from foolbox.attacks.base import * from foolbox.attacks.gradient_descent_base import * from torchvision import transforms from utils.data_manager import DataManager, get_dataloader import torch import logging import eagerpy as ep from utils.data_manager import load_all_task_models class SustainableAttack(Attack): def __init__(self, args, device='cuda'): super().__init__() self.device = device self.args = args # Only init the first 10 classes self.data_manager = DataManager( args["dataset"], args["shuffle"], args["seed"], args["init_cls"], args["increment"], args["attack"] ) self.args['target_class_list'] = self.data_manager._class_order[:self.data_manager._increments[0]] self.args['target_class_dict'] = dict(zip(self.args['target_class_list'], range(len(self.args['target_class_list'])))) self.img_s = 32 if args["dataset"] == 'cifar100' else 224 self.batch_size = args['batch_size'] self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size, start_class=0, end_class=10, train=True, shuffle=True, num_workers=0) ckpts = sorted([f for f in os.listdir(args['logs_name']) if f.endswith('.pkl')]) self.ckpt_paths = [os.path.join(args['logs_name'], ckpt_file) for ckpt_file in ckpts] self.model = None self.model0 = None self.attack = None self.target_class = args['target_class'] if args["dataset"] == "cifar100": self.norm = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]) self.preprocessing = dict(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761], axis=-3) else: self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3) def run_attack(self): pass def to_alls(self, imgs, labels, labels_t=None, target_imgs=None, target_labels=None, return_index=False): correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True) correct_index_t = ep.full_like(ep.astensors(torch.ones((len(target_imgs),), dtype=bool, device=self.device))[0], fill_value=True) models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager, batch_size=self.batch_size, train=True, load_type='model')[0] for task in range(len(models)): model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing) acc_bool = accuracy(model, imgs, labels)[1] if task == 0: acc_bool_t, target_logits = accuracy(model, target_imgs, target_labels)[1:] else: acc_bool_t = accuracy(model, target_imgs, target_labels)[1] correct_index = ep.logical_and(correct_index, acc_bool) correct_index_t = ep.logical_and(correct_index_t, acc_bool_t) del model, acc_bool, acc_bool_t if correct_index.any(): imgs = imgs[correct_index] labels = labels[correct_index] if self.target_class is not None: labels_t = labels_t[correct_index] logging.info( f"Filtering {len(labels)} Correct samples for all CL models.") else: print("No valid samples found for IMGS, skipping this batch.") imgs, labels, labels_t = None, None, None if correct_index_t.any(): target_imgs = target_imgs[correct_index_t] target_labels = target_labels[correct_index_t] target_logits = target_logits[correct_index_t] logging.info( f"Filtering {len(target_labels)} Target samples for all CL models.") else: logging.info("No valid samples found for TARGET IMGS, skipping this batch.") target_imgs, target_labels, target_logits = None, None, None if return_index: return correct_index, correct_index_t del models, correct_index, correct_index_t return imgs, labels, labels_t, target_imgs, target_labels, target_logits def to_all(self, imgs, labels, return_index=False): # Filtering Correct Samples for All CL Models correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True) models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager, batch_size=self.batch_size, train=True, load_type='model')[0] for task in range(len(models)): model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing) acc_bool = accuracy(model, imgs, labels)[1] correct_index = ep.logical_and(correct_index, acc_bool) del model, acc_bool if correct_index.any(): imgs = imgs[correct_index] labels = labels[correct_index] logging.info( f"Filtering {len(labels)} Correct samples for all CL models.") else: logging.info("No valid samples found for IMGS, skipping this batch.") imgs, labels = None, None if return_index: return correct_index del models, correct_index return imgs, labels def __call__( self, model: Model, inputs: T, criterion: Any, *, epsilons: Sequence[Union[float, None]], **kwargs: Any, ) -> Tuple[List[T], List[T], T]: ... def repeat(self, times: int) -> "SustainableAttack": ...