|
|
import os.path
|
|
|
import torch
|
|
|
from foolbox.attacks.base import *
|
|
|
from foolbox.attacks.gradient_descent_base import *
|
|
|
from tqdm import tqdm
|
|
|
from attacks.AIM.src.gat.models.attack import AIMAttack, ContrastiveLoss
|
|
|
from attacks.AIM.src.gat.models.surrogate import midlayer_dict, register_collecter, register_collecter_cl
|
|
|
from attacks.attack_config import SustainableAttack
|
|
|
from utils.plot import plot_asr_per_target, save_grad_cam
|
|
|
import logging
|
|
|
import pandas as pd
|
|
|
import foolbox as fb
|
|
|
from foolbox import PyTorchModel
|
|
|
import numpy as np
|
|
|
from utils import factory
|
|
|
from utils.data_manager import get_dataloader
|
|
|
|
|
|
|
|
|
class AIM(SustainableAttack):
|
|
|
def __init__(self, args, device='cuda'):
|
|
|
super().__init__(args, device)
|
|
|
self.device = device
|
|
|
self.args = args
|
|
|
self.surrogate_model = None
|
|
|
|
|
|
self.adv_generator = AIMAttack(device=device)
|
|
|
self.adv_generator.set_mode('train')
|
|
|
self.lr = 0.001
|
|
|
self.betas = (0.5, 0.999)
|
|
|
self.num_epoch = 100
|
|
|
self.optim = torch.optim.Adam(self.adv_generator.get_params(), lr=self.lr, betas=self.betas)
|
|
|
self.contrastive_loss = ContrastiveLoss(0.2)
|
|
|
self.sim_loss = torch.nn.functional.cosine_similarity
|
|
|
self.eval_batch_szie = 128
|
|
|
|
|
|
self.surrogate_model_name = 'resnet32_cl'
|
|
|
self.layer = midlayer_dict[self.surrogate_model_name]
|
|
|
self.prefix = (f'adv_generator_{self.surrogate_model_name}'
|
|
|
f'_{self.layer}'
|
|
|
f'_tclass{self.target_class}')
|
|
|
self.save_path = os.path.join(self.args['logs_eval_name'], f'target{str(self.target_class)}')
|
|
|
if not os.path.exists(self.save_path):
|
|
|
os.makedirs(self.save_path)
|
|
|
|
|
|
self.plot_gradcam = False
|
|
|
|
|
|
def train_generator(self):
|
|
|
if 'cl' in self.surrogate_model_name:
|
|
|
s_model = factory.get_model(self.args["model_name"], self.args)
|
|
|
s_model.incremental_train(self.data_manager)
|
|
|
s_model._network.load_state_dict(
|
|
|
torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict'])
|
|
|
s_model._network.to(self.device)
|
|
|
s_model._network.eval()
|
|
|
self.surrogate_model = s_model._network
|
|
|
del s_model
|
|
|
torch.cuda.empty_cache()
|
|
|
self.feat_collecter = []
|
|
|
self.feat_collecter_handler, self.feat_collecter = register_collecter_cl(self.surrogate_model,
|
|
|
self.layer,
|
|
|
self.feat_collecter,
|
|
|
self.args["model_name"])
|
|
|
else:
|
|
|
self.surrogate_model = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet32', pretrained=True)
|
|
|
self.surrogate_model.to(self.device)
|
|
|
self.surrogate_model.eval()
|
|
|
self.feat_collecter = []
|
|
|
self.feat_collecter_handler, self.feat_collecter = register_collecter(self.surrogate_model,
|
|
|
self.layer,
|
|
|
self.feat_collecter)
|
|
|
self.file_path = os.path.join(self.save_path, f'{self.prefix}.pth')
|
|
|
if os.path.exists(self.file_path):
|
|
|
self.adv_generator.load_ckpt(self.file_path)
|
|
|
self.adv_generator.set_mode('eval')
|
|
|
else:
|
|
|
loaders = get_dataloader(self.data_manager, batch_size=self.batch_size,
|
|
|
start_class=0, end_class=10,
|
|
|
train=True, shuffle=True, num_workers=0)
|
|
|
|
|
|
target_images = []
|
|
|
target_labels = []
|
|
|
for data in loaders:
|
|
|
_, image_batch, label_batch = data
|
|
|
mask = label_batch == self.target_class
|
|
|
selected_images = image_batch[mask]
|
|
|
selected_labels = label_batch[mask]
|
|
|
target_images.append(selected_images)
|
|
|
target_labels.append(selected_labels)
|
|
|
del loaders
|
|
|
target_images = torch.cat(target_images, dim=0).to(self.device)
|
|
|
target_labels = torch.cat(target_labels, dim=0).to(self.device)
|
|
|
target_images, target_labels = ep.astensors(*(target_images[:self.batch_size], target_labels[:self.batch_size]))
|
|
|
|
|
|
total_loss = []
|
|
|
for epoch in range(1, self.num_epoch + 1):
|
|
|
laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}')
|
|
|
loss_np = 0
|
|
|
for i, (_, x, y) in enumerate(laoder_tqdm):
|
|
|
x_f = x[y != self.target_class].to(self.device)
|
|
|
del x, y
|
|
|
if len(x_f) > len(target_images):
|
|
|
x_f = x_f[:len(target_images)].to(self.device)
|
|
|
else:
|
|
|
random_indices = torch.randperm(len(target_images))[:len(x_f)].to(self.device)
|
|
|
target_images = target_images[random_indices]
|
|
|
|
|
|
x_adv = self.adv_generator(x_f, target_images.raw.to(self.device))
|
|
|
|
|
|
logits_nat = self.surrogate_model(self.norm(x_f))
|
|
|
feat_nat = self.feat_collecter.pop()
|
|
|
logits_tar = self.surrogate_model(self.norm(target_images.raw))
|
|
|
feat_tar = self.feat_collecter.pop()
|
|
|
logits_adv = self.surrogate_model(self.norm(x_adv))
|
|
|
feat_adv = self.feat_collecter.pop()
|
|
|
|
|
|
loss = (self.contrastive_loss(logits_adv, logits_nat, logits_tar) +
|
|
|
self.sim_loss(feat_nat, feat_adv) -
|
|
|
self.sim_loss(feat_tar, feat_adv)).mean()
|
|
|
|
|
|
loss_np = loss_np + loss.item()
|
|
|
|
|
|
self.optim.zero_grad()
|
|
|
loss.backward()
|
|
|
self.optim.step()
|
|
|
del x_f, x_adv, logits_nat, logits_adv, logits_tar, feat_nat, feat_tar, feat_adv
|
|
|
torch.cuda.empty_cache()
|
|
|
total_loss.append(loss_np/(i+1))
|
|
|
logging.info(f'Epoch {epoch} loss: {loss_np/(len(self.loader))}')
|
|
|
logging.info(f'Total loss: {total_loss}')
|
|
|
self.feat_collecter_handler.remove()
|
|
|
self.adv_generator.save_ckpt(self.file_path)
|
|
|
|
|
|
|
|
|
def run_test(self):
|
|
|
|
|
|
self.adv_generator.set_mode('eval')
|
|
|
self.adv_generator.adv_gen.to(self.device)
|
|
|
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
|
|
|
start_class=0, end_class=10,
|
|
|
train=False, shuffle=False, num_workers=0)
|
|
|
target_images = []
|
|
|
target_labels = []
|
|
|
for data in self.loader:
|
|
|
_, image_batch, label_batch = data
|
|
|
mask = label_batch == self.target_class
|
|
|
selected_images = image_batch[mask]
|
|
|
selected_labels = label_batch[mask]
|
|
|
target_images.append(selected_images)
|
|
|
target_labels.append(selected_labels)
|
|
|
target_imgs = torch.cat(target_images, dim=0).to(self.device)
|
|
|
target_labels = torch.cat(target_labels, dim=0).to(self.device)
|
|
|
target_imgs, target_labels = ep.astensors(*(target_imgs, target_labels))
|
|
|
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
|
|
|
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
|
|
|
if i > 0:
|
|
|
break
|
|
|
|
|
|
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
|
|
|
|
|
|
imgs_f = imgs[labels != self.target_class]
|
|
|
labels_f = labels[labels != self.target_class]
|
|
|
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
|
|
|
|
|
|
self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs[:20], target_labels[:20])
|
|
|
|
|
|
|
|
|
def attacks(self, i_batch, imgs, labels, labels_t, target_imgs=None, target_labels=None):
|
|
|
asr_matrix = np.ones((10, len(target_imgs)))
|
|
|
self.model = factory.get_model(self.args["model_name"], self.args)
|
|
|
for task in range(10):
|
|
|
logging.info("***** Starting attack on task [{}]. *****".format(task))
|
|
|
self.model.incremental_train(self.data_manager)
|
|
|
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
|
|
|
self.model._network.to(self.device)
|
|
|
self.model._network.eval()
|
|
|
|
|
|
|
|
|
criterion = fb.criteria.Misclassification(
|
|
|
labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
|
|
|
labels_t)
|
|
|
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
|
|
verify_input_bounds(imgs, current_model)
|
|
|
criterion = get_criterion(criterion)
|
|
|
is_adversarial = get_is_adversarial(criterion, current_model)
|
|
|
|
|
|
logging.info("Eval attack on each target images.")
|
|
|
for i, target_image in enumerate(target_imgs):
|
|
|
advs = ep.astensor(self.adv_generator(imgs.raw.to(self.device), target_image.raw.repeat(len(imgs), 1, 1, 1).to(self.device)))
|
|
|
is_adv = is_adversarial(advs)[0]
|
|
|
asr_matrix[task, i] = (is_adv.bool().sum().raw.item() / len(imgs))
|
|
|
if self.plot_gradcam:
|
|
|
save_grad_cam(self.args, torch.clip(advs.raw.detach(),0,1), labels_t.raw,
|
|
|
self.model._network, self.save_path + "/GradCam" + f"targetimg{i}", prefix=f'task{task}',
|
|
|
layer_name='stage_3', save_num=100, save_raw=True)
|
|
|
|
|
|
del advs, is_adv, target_image
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
del criterion, current_model, is_adversarial
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
self.model.after_task()
|
|
|
|
|
|
|
|
|
asr_matrix = np.mean(asr_matrix, axis=1, keepdims=True)
|
|
|
prefix = f'batch{i_batch}_{self.prefix}'
|
|
|
plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args)
|
|
|
df = pd.DataFrame(asr_matrix, columns=['ASR'])
|
|
|
df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)
|
|
|
|
|
|
del asr_matrix, imgs, labels, labels_t, target_imgs
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
model: Model,
|
|
|
inputs: T,
|
|
|
criterion: Any,
|
|
|
*,
|
|
|
epsilons: Union[Sequence[Union[float, None]], float, None],
|
|
|
**kwargs: Any,
|
|
|
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
|
|
|
...
|
|
|
|
|
|
def repeat(self, times: int) -> "AIM":
|
|
|
... |