import torch from foolbox.attacks.base import * from foolbox.attacks.gradient_descent_base import * from attacks.attack_config import SustainableAttack from tqdm import tqdm import logging import foolbox as fb from foolbox import PyTorchModel, accuracy import numpy as np from utils import factory from utils.data_manager import get_dataloader import os from utils.plot import plot_asr_per_target, save_grad_cam import pandas as pd class BASEAttack(SustainableAttack): def __init__(self, args, device='cuda'): super().__init__(args, device) attack_name = args['attack_method'] if attack_name == 'L2FGM': self.attack = fb.attacks.FGM() elif attack_name == 'FGSM': self.attack = fb.attacks.FGSM() elif attack_name == 'MIFGSM': self.attack = fb.attacks.MIFGSM(momentum=0.9, steps=10) elif attack_name == 'L1PGD': self.attack = fb.attacks.L1PGD(steps=10) elif attack_name == 'L2PGD': self.attack = fb.attacks.L2PGD(steps=10) elif attack_name == 'LinfPGD': self.attack = fb.attacks.LinfPGD(steps=10) elif attack_name == 'L2DeepFool': self.attack = fb.attacks.L2DeepFoolAttack(steps=10) elif attack_name == 'LinfDeepFool': self.attack = fb.attacks.LinfDeepFoolAttack(steps=10) elif attack_name == 'BoundaryAttack': self.attack = fb.attacks.BoundaryAttack(steps=10) elif attack_name == 'CarliniWagnerL2': self.attack = fb.attacks.L2CarliniWagnerAttack(steps=10) elif attack_name == 'GaussianNoise': self.attack = fb.attacks.LinfRepeatedAdditiveUniformNoiseAttack() elif attack_name == 'UniformNoise': self.attack = fb.attacks.LinfAdditiveUniformNoiseAttack() else: raise ValueError(f"Unknown attack method: {attack_name}") self.test_mode = 'BaseAttacks' self.epsilon = 16 self.eval = args['eval'] self.prefix = f'eps{self.epsilon}' self.save_path = os.path.join(self.args['logs_eval_name']) self.out_path = os.path.join(self.args['logs_eval_name'], f'gradcam_t{self.target_class}') if not os.path.exists(self.out_path): os.makedirs(self.out_path) self.plot_gradcam = True def run_test(self): # Load Batch Data eval_batch_szie = 1024 self.loader = get_dataloader(self.data_manager, batch_size=eval_batch_szie, start_class=0, end_class=10, train=False, shuffle=False, num_workers=0) for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader), desc=f'Loading Data with Batch Size of {self.batch_size}) :')): if i> 0: break imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device))) target_imgs = imgs[labels == self.target_class] target_labels = labels[labels == self.target_class] imgs_f = imgs[labels != self.target_class] labels_f = labels[labels != self.target_class] labels_t_f = ep.full_like(labels_f, fill_value=self.target_class) self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs, target_labels) def attacks(self, i_batch, imgs, labels, labels_t): clean_acc_matrix = [] asr_matrix = np.ones((10, 1)) self.model = factory.get_model(self.args["model_name"], self.args) for task in range(10): logging.info("***** Starting attack on task [{}]. *****".format(task)) self.model.incremental_train(self.data_manager) self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict']) self.model._network.to(self.device) self.model._network.eval() current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing) acc, acc_bool = accuracy(current_model, imgs, labels)[:2] clean_acc_matrix.append(acc) if task == 0: imgs = imgs[acc_bool] labels = labels[acc_bool] labels_t = labels_t[acc_bool] verify_input_bounds(imgs, current_model) criterion = fb.criteria.Misclassification( labels) if self.target_class is None else fb.criteria.TargetedMisclassification( labels_t) if task == 0: adv, adv_clip, asr_bool = self.attack(current_model, imgs, criterion=criterion, epsilons=self.epsilon/255) asr_matrix[task] = asr_bool.sum().raw.item() / len(imgs) else: asr = accuracy(current_model, adv_clip, labels_t)[0] asr_matrix[task] = asr if self.plot_gradcam: save_grad_cam(self.args, torch.clip(adv_clip.raw, 0, 1), labels_t.raw, self.model._network, self.out_path + "/GradCam", prefix=f'task{task}', layer_name='stage_3', save_num=100, save_raw=True) del criterion, current_model torch.cuda.empty_cache() self.model.after_task() # Save all target images info: everage asr, prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}' plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args, clean_acc_matrix) df = pd.DataFrame(asr_matrix, columns=['ASR']) df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False) del asr_matrix, imgs, labels, labels_t torch.cuda.empty_cache() def __call__( self, model: Model, inputs: T, criterion: Any, *, epsilons: Sequence[Union[float, None]], **kwargs: Any, ) -> Tuple[List[T], List[T], T]: ... def repeat(self, times: int) -> "BASEAttack": ...