import os.path from foolbox.attacks.base import * from foolbox.attacks.gradient_descent_base import * from tqdm import tqdm import torch.optim as optim from attacks.CGNC.models.generator import CrossAttenGenerator from attacks.CGNC.utils_ import * from attacks.CGNC.image_transformer import rotation from attacks.attack_config import SustainableAttack from utils.plot import plot_asr_per_target import logging import foolbox as fb from foolbox import PyTorchModel import numpy as np from utils import factory from utils.data_manager import get_dataloader from attacks.AIM.src.gat.models.surrogate import build_surrogate class CGNC(SustainableAttack): def __init__(self, args, device='cuda'): super().__init__(args, device) self.device = device self.args = args self.surrogate_model = None self.adv_generator = CrossAttenGenerator(nz=16, device=device) self.adv_generator = self.adv_generator.to(device) self.lr = 0.001 # 2e-4 self.betas = (0.5, 0.999) self.num_epoch = 100 self.optim = optim.Adam(self.adv_generator.parameters(), lr=self.lr, betas=self.betas) self.criterion = nn.CrossEntropyLoss() self.text_cond_dict = torch.load("attacks/CGNC/text_feature.pth") self.label_set = get_classes("CL") self.eps = 32 /255 self.eval_batch_szie = 128 self.surrogate_model_name = f'resnet32_cl' self.prefix = f'{self.surrogate_model_name}_{len(self.label_set)}classes_eps{int(self.eps * 255)}' self.save_path = os.path.join(self.args['logs_eval_name']) os.makedirs(self.save_path, exist_ok=True) def train_generator(self): if 'cl' in self.surrogate_model_name: s_model = factory.get_model(self.args["model_name"], self.args) s_model.incremental_train(self.data_manager) s_model._network.load_state_dict( torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict']) s_model._network.to(self.device) s_model._network.eval() self.surrogate_model = s_model._network del s_model torch.cuda.empty_cache() else: self.surrogate_model = build_surrogate(self.surrogate_model_name, pretrain=True).to(self.device) self.surrogate_model.eval() file_path = os.path.join(self.save_path, f'{self.prefix}.pth') if os.path.exists(file_path): self.adv_generator.load_state_dict(torch.load(file_path, map_location=self.device)) self.adv_generator.eval() else: self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size, start_class=0, end_class=10, train=True, shuffle=True, num_workers=0) for epoch in range(1, self.num_epoch + 1): running_loss = 0 laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}') loss_np = 0 for i, (_, x, y) in enumerate(laoder_tqdm): imgs = x.to(self.device) imgs_rot = rotation(x)[0].to(self.device) color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2) aug = transforms.Compose([transforms.ToPILImage(), transforms.RandomResizedCrop(size=imgs.size(-1)), transforms.RandomHorizontalFlip(), transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.ToTensor()]) imgs_aug = torch.stack([aug(img) for img in x]).to(self.device) del x, y label_map = {self.label_set[i]: i for i in range(len(self.label_set))} np.random.shuffle(self.label_set) label = np.random.choice(self.label_set, imgs.size(0)) cond = torch.stack([self.text_cond_dict[j] for j in label], dim=0) label = torch.from_numpy(label).long().to(self.device) for i in range(len(label)): label[i] = label_map.get(label[i].item(), label[i].item()) self.adv_generator.train() self.optim.zero_grad() # generate img noise = self.adv_generator(input=imgs, cond=cond, eps=self.eps) noise_rot = self.adv_generator(input=imgs_rot, cond=cond, eps=self.eps) noise_aug = self.adv_generator(input=imgs_aug, cond=cond, eps=self.eps) adv = noise + imgs adv = torch.clamp(adv, 0.0, 1.0) adv_rot = noise_rot + imgs_rot adv_rot = torch.clamp(adv_rot, 0.0, 1.0) adv_aug = noise_aug + imgs_aug adv_aug = torch.clamp(adv_aug, 0.0, 1.0) adv_out = self.surrogate_model(normalize(adv)) adv_rot_out = self.surrogate_model(normalize(adv_rot)) adv_aug_out = self.surrogate_model(normalize(adv_aug)) loss = self.criterion(adv_out, label) + self.criterion(adv_rot_out, label) + self.criterion(adv_aug_out, label) loss.backward() self.optim.step() if i % 10 == 9: running_loss = 0 running_loss += abs(loss.item()) loss_np += loss.item() del imgs, imgs_rot, imgs_aug, adv, label, cond, noise, noise_rot, noise_aug, adv_rot, adv_aug, adv_out, adv_rot_out, adv_aug_out torch.cuda.empty_cache() logging.info(f'Epoch {epoch} loss: {loss_np / (len(self.loader))}') torch.save(self.adv_generator.state_dict(), file_path) def run_test(self): # Load Batch Data self.adv_generator.eval() self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie, start_class=0, end_class=10, train=False, shuffle=True, num_workers=0) for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader), desc=f'Loading Data with Batch Size of {self.batch_size}) :')): if i > 0: break imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device))) target_imgs = imgs[labels == self.target_class] target_labels = labels[labels == self.target_class] imgs_f = imgs[labels != self.target_class] labels_f = labels[labels != self.target_class] labels_t_f = ep.full_like(labels_f, fill_value=self.target_class) if self.args["model_name"] != 'finetune': imgs_f, labels_f, labels_t_f = self.to_alls(imgs_f, labels_f, labels_t_f, target_imgs, target_labels)[:3] self.attacks(i, imgs_f, labels_f, labels_t_f) def attacks(self, i_batch, imgs, labels, labels_t): asr_matrix = np.ones((10, len(self.label_set))) self.model = factory.get_model(self.args["model_name"], self.args) for task in range(10): logging.info("***** Starting attack on task [{}]. *****".format(task)) self.model.incremental_train(self.data_manager) self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict']) self.model._network.to(self.device) self.model._network.eval() # Run attack on ecah target image criterion = fb.criteria.Misclassification( labels) if self.target_class is None else fb.criteria.TargetedMisclassification( labels_t) current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing) verify_input_bounds(imgs, current_model) criterion = get_criterion(criterion) is_adversarial = get_is_adversarial(criterion, current_model) logging.info("Eval attack on each target images.") for idx in range(len(self.label_set)): cond = torch.tile(self.text_cond_dict[self.label_set[idx]], (len(imgs), 1)).to(torch.float).to(self.device) noises = self.adv_generator(imgs.raw, cond, eps=self.eps) advs = noises + imgs.raw advs = torch.clamp(advs, 0.0, 1.0) is_adv = is_adversarial(ep.astensor(advs))[0] asr_matrix[task, idx] = (is_adv.bool().sum().raw.item() / len(imgs)) del advs, noises, cond, is_adv torch.cuda.empty_cache() del criterion, current_model, is_adversarial torch.cuda.empty_cache() self.model.after_task() prefix = f'batch{i_batch}_{self.prefix}' plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args) for i in range(len(self.label_set)): df = pd.DataFrame(asr_matrix[:, i], columns=['ASR']) df.to_excel(os.path.join(self.save_path, f"{prefix}_class{i}.xlsx"), index=False) del asr_matrix, imgs, labels, labels_t torch.cuda.empty_cache() def __call__( self, model: Model, inputs: T, criterion: Any, *, epsilons: Union[Sequence[Union[float, None]], float, None], **kwargs: Any, ) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]: ... def repeat(self, times: int) -> "CGNC": ...