import os.path import torch from foolbox.attacks.base import * from foolbox.attacks.gradient_descent_base import * from tqdm import tqdm import pandas as pd from attacks.attack_config import SustainableAttack from utils.plot import plot_asr_per_target, save_grad_cam import logging from foolbox import PyTorchModel, accuracy import numpy as np from utils import factory from utils.data_manager import get_dataloader from attacks.CleanSheet.utils_ import Trigger from attacks.CleanSheet.generate_kd import train class CleanSheet(SustainableAttack): def __init__(self, args, device='cuda'): super().__init__(args, device) self.device = device self.args = args self.surrogate_model = None self.eval_batch_szie = 128 self.eval = args['eval'] self.args['run_baseline'] = True test_mode = 'Trigger' self.trigger_name = f'{test_mode}' self.plot_gradcam = True def train_adv(self): if self.eval: pass else: for i in range(10): self.args['target_class'] = i train(self.args) torch.cuda.empty_cache() def run_test(self): pth_name = self.get_max_step_filename(f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}') self.ckpt = f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}/{pth_name}.pth' self.prefix = f'{self.trigger_name}_{pth_name.split("_")[-1]}' a = torch.load(self.ckpt) self.trigger = Trigger(size=32).to(self.device) self.trigger.load_state_dict(a) self.trigger.eval() # Load Batch Data self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie, start_class=0, end_class=10, train=False, shuffle=True, num_workers=0) for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader), desc=f'Loading Data with Batch Size of {self.batch_size}) :')): if i > 0: break imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device))) target_imgs = imgs[labels == self.target_class] target_labels = labels[labels == self.target_class] imgs_f = imgs[labels != self.target_class] labels_f = labels[labels != self.target_class] labels_t_f = ep.full_like(labels_f, fill_value=self.target_class) if self.args["model_name"] != 'finetune': imgs_f, labels_f = self.to_all(imgs_f, labels_f) if imgs_f is None: continue if target_imgs is None: continue labels_t_f = labels_t_f[:len(imgs_f)] self.attacks(i, imgs_f, labels_f, labels_t_f) def attacks(self, i_batch, imgs, labels, labels_t): asr_matrix = np.ones((10, 1)) self.model = factory.get_model(self.args["model_name"], self.args) eval_path = os.path.join(self.args["logs_eval_name"], self.trigger_name) for task in range(10): logging.info("***** Starting attack on task [{}]. *****".format(task)) self.model.incremental_train(self.data_manager) self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict']) self.model._network.to(self.device) self.model._network.eval() # Run attack on ecah target image current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing) verify_input_bounds(imgs, current_model) logging.info("Eval attack on each target images.") advs = self.trigger(imgs.raw) advs = ep.astensor(advs).clip(-1, 1) asr = accuracy(current_model, ep.astensor(advs), labels_t)[0] asr_matrix[task] = asr if self.plot_gradcam: save_grad_cam(self.args, torch.clip(advs.raw.detach(), -1, 1), labels_t.raw, self.model._network, eval_path + "/GradCam", prefix=f'task{task}', layer_name='stage_3', save_num=100, save_raw=True) del advs, current_model torch.cuda.empty_cache() self.model.after_task() # Save all target images info: everage asr, prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}' plot_asr_per_target(asr_matrix, eval_path, prefix, self.args) df = pd.DataFrame(asr_matrix, columns=['ASR']) df.to_excel(os.path.join(eval_path, f"{prefix}.xlsx"), index=False) del asr_matrix, imgs, labels, labels_t torch.cuda.empty_cache() def get_max_step_filename(self, folder_path): files = [f for f in os.listdir(folder_path) if f.endswith('.pth')] step_files = [(f, int(f.split('_')[-1].split('.')[0])) for f in files] step_files.sort(key=lambda x: x[1], reverse=True) max_step_file = step_files[0][0] return os.path.splitext(max_step_file)[0] def __call__( self, model: Model, inputs: T, criterion: Any, *, epsilons: Union[Sequence[Union[float, None]], float, None], **kwargs: Any, ) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]: ... def repeat(self, times: int) -> "CleanSheet": ...