import os.path import torch from foolbox.attacks.base import * from foolbox.attacks.gradient_descent_base import * from tqdm import tqdm from attacks.AIM.src.gat.models.attack import AIMAttack, ContrastiveLoss from attacks.AIM.src.gat.models.surrogate import midlayer_dict, register_collecter, register_collecter_cl from attacks.attack_config import SustainableAttack from utils.plot import plot_asr_per_target, save_grad_cam import logging import pandas as pd import foolbox as fb from foolbox import PyTorchModel import numpy as np from utils import factory from utils.data_manager import get_dataloader class AIM(SustainableAttack): def __init__(self, args, device='cuda'): super().__init__(args, device) self.device = device self.args = args self.surrogate_model = None self.adv_generator = AIMAttack(device=device) self.adv_generator.set_mode('train') self.lr = 0.001 self.betas = (0.5, 0.999) self.num_epoch = 100 self.optim = torch.optim.Adam(self.adv_generator.get_params(), lr=self.lr, betas=self.betas) self.contrastive_loss = ContrastiveLoss(0.2) self.sim_loss = torch.nn.functional.cosine_similarity self.eval_batch_szie = 128 self.surrogate_model_name = 'resnet32_cl' self.layer = midlayer_dict[self.surrogate_model_name] self.prefix = (f'adv_generator_{self.surrogate_model_name}' f'_{self.layer}' f'_tclass{self.target_class}') self.save_path = os.path.join(self.args['logs_eval_name'], f'target{str(self.target_class)}') if not os.path.exists(self.save_path): os.makedirs(self.save_path) self.plot_gradcam = False def train_generator(self): if 'cl' in self.surrogate_model_name: s_model = factory.get_model(self.args["model_name"], self.args) s_model.incremental_train(self.data_manager) s_model._network.load_state_dict( torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict']) s_model._network.to(self.device) s_model._network.eval() self.surrogate_model = s_model._network del s_model torch.cuda.empty_cache() self.feat_collecter = [] self.feat_collecter_handler, self.feat_collecter = register_collecter_cl(self.surrogate_model, self.layer, self.feat_collecter, self.args["model_name"]) else: self.surrogate_model = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet32', pretrained=True) self.surrogate_model.to(self.device) self.surrogate_model.eval() self.feat_collecter = [] self.feat_collecter_handler, self.feat_collecter = register_collecter(self.surrogate_model, self.layer, self.feat_collecter) self.file_path = os.path.join(self.save_path, f'{self.prefix}.pth') if os.path.exists(self.file_path): self.adv_generator.load_ckpt(self.file_path) self.adv_generator.set_mode('eval') else: loaders = get_dataloader(self.data_manager, batch_size=self.batch_size, start_class=0, end_class=10, train=True, shuffle=True, num_workers=0) target_images = [] target_labels = [] for data in loaders: _, image_batch, label_batch = data mask = label_batch == self.target_class selected_images = image_batch[mask] selected_labels = label_batch[mask] target_images.append(selected_images) target_labels.append(selected_labels) del loaders target_images = torch.cat(target_images, dim=0).to(self.device) target_labels = torch.cat(target_labels, dim=0).to(self.device) target_images, target_labels = ep.astensors(*(target_images[:self.batch_size], target_labels[:self.batch_size])) total_loss = [] for epoch in range(1, self.num_epoch + 1): laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}') loss_np = 0 for i, (_, x, y) in enumerate(laoder_tqdm): x_f = x[y != self.target_class].to(self.device) del x, y if len(x_f) > len(target_images): x_f = x_f[:len(target_images)].to(self.device) else: random_indices = torch.randperm(len(target_images))[:len(x_f)].to(self.device) target_images = target_images[random_indices] x_adv = self.adv_generator(x_f, target_images.raw.to(self.device)) logits_nat = self.surrogate_model(self.norm(x_f)) feat_nat = self.feat_collecter.pop() logits_tar = self.surrogate_model(self.norm(target_images.raw)) feat_tar = self.feat_collecter.pop() logits_adv = self.surrogate_model(self.norm(x_adv)) feat_adv = self.feat_collecter.pop() loss = (self.contrastive_loss(logits_adv, logits_nat, logits_tar) + self.sim_loss(feat_nat, feat_adv) - self.sim_loss(feat_tar, feat_adv)).mean() # print(loss.item()) loss_np = loss_np + loss.item() self.optim.zero_grad() loss.backward() self.optim.step() del x_f, x_adv, logits_nat, logits_adv, logits_tar, feat_nat, feat_tar, feat_adv torch.cuda.empty_cache() total_loss.append(loss_np/(i+1)) logging.info(f'Epoch {epoch} loss: {loss_np/(len(self.loader))}') logging.info(f'Total loss: {total_loss}') self.feat_collecter_handler.remove() self.adv_generator.save_ckpt(self.file_path) def run_test(self): # Load Batch Data self.adv_generator.set_mode('eval') self.adv_generator.adv_gen.to(self.device) self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie, start_class=0, end_class=10, train=False, shuffle=False, num_workers=0) target_images = [] target_labels = [] for data in self.loader: _, image_batch, label_batch = data mask = label_batch == self.target_class selected_images = image_batch[mask] selected_labels = label_batch[mask] target_images.append(selected_images) target_labels.append(selected_labels) target_imgs = torch.cat(target_images, dim=0).to(self.device) target_labels = torch.cat(target_labels, dim=0).to(self.device) target_imgs, target_labels = ep.astensors(*(target_imgs, target_labels)) for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader), desc=f'Loading Data with Batch Size of {self.batch_size}) :')): if i > 0: break imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device))) imgs_f = imgs[labels != self.target_class] labels_f = labels[labels != self.target_class] labels_t_f = ep.full_like(labels_f, fill_value=self.target_class) self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs[:20], target_labels[:20]) def attacks(self, i_batch, imgs, labels, labels_t, target_imgs=None, target_labels=None): asr_matrix = np.ones((10, len(target_imgs))) self.model = factory.get_model(self.args["model_name"], self.args) for task in range(10): logging.info("***** Starting attack on task [{}]. *****".format(task)) self.model.incremental_train(self.data_manager) self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict']) self.model._network.to(self.device) self.model._network.eval() # Run attack on ecah target image criterion = fb.criteria.Misclassification( labels) if self.target_class is None else fb.criteria.TargetedMisclassification( labels_t) current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing) verify_input_bounds(imgs, current_model) criterion = get_criterion(criterion) is_adversarial = get_is_adversarial(criterion, current_model) logging.info("Eval attack on each target images.") for i, target_image in enumerate(target_imgs): advs = ep.astensor(self.adv_generator(imgs.raw.to(self.device), target_image.raw.repeat(len(imgs), 1, 1, 1).to(self.device))) is_adv = is_adversarial(advs)[0] asr_matrix[task, i] = (is_adv.bool().sum().raw.item() / len(imgs)) if self.plot_gradcam: save_grad_cam(self.args, torch.clip(advs.raw.detach(),0,1), labels_t.raw, self.model._network, self.save_path + "/GradCam" + f"targetimg{i}", prefix=f'task{task}', layer_name='stage_3', save_num=100, save_raw=True) del advs, is_adv, target_image torch.cuda.empty_cache() del criterion, current_model, is_adversarial torch.cuda.empty_cache() self.model.after_task() # Save all target images info: everage asr, asr_matrix = np.mean(asr_matrix, axis=1, keepdims=True) prefix = f'batch{i_batch}_{self.prefix}' plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args) df = pd.DataFrame(asr_matrix, columns=['ASR']) df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False) del asr_matrix, imgs, labels, labels_t, target_imgs torch.cuda.empty_cache() def __call__( self, model: Model, inputs: T, criterion: Any, *, epsilons: Union[Sequence[Union[float, None]], float, None], **kwargs: Any, ) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]: ... def repeat(self, times: int) -> "AIM": ...