SAE / attacks /AIM /AIMAttack.py
Ttius's picture
Upload 192 files
998bb30 verified
import os.path
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from tqdm import tqdm
from attacks.AIM.src.gat.models.attack import AIMAttack, ContrastiveLoss
from attacks.AIM.src.gat.models.surrogate import midlayer_dict, register_collecter, register_collecter_cl
from attacks.attack_config import SustainableAttack
from utils.plot import plot_asr_per_target, save_grad_cam
import logging
import pandas as pd
import foolbox as fb
from foolbox import PyTorchModel
import numpy as np
from utils import factory
from utils.data_manager import get_dataloader
class AIM(SustainableAttack):
def __init__(self, args, device='cuda'):
super().__init__(args, device)
self.device = device
self.args = args
self.surrogate_model = None
self.adv_generator = AIMAttack(device=device)
self.adv_generator.set_mode('train')
self.lr = 0.001
self.betas = (0.5, 0.999)
self.num_epoch = 100
self.optim = torch.optim.Adam(self.adv_generator.get_params(), lr=self.lr, betas=self.betas)
self.contrastive_loss = ContrastiveLoss(0.2)
self.sim_loss = torch.nn.functional.cosine_similarity
self.eval_batch_szie = 128
self.surrogate_model_name = 'resnet32_cl'
self.layer = midlayer_dict[self.surrogate_model_name]
self.prefix = (f'adv_generator_{self.surrogate_model_name}'
f'_{self.layer}'
f'_tclass{self.target_class}')
self.save_path = os.path.join(self.args['logs_eval_name'], f'target{str(self.target_class)}')
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
self.plot_gradcam = False
def train_generator(self):
if 'cl' in self.surrogate_model_name:
s_model = factory.get_model(self.args["model_name"], self.args)
s_model.incremental_train(self.data_manager)
s_model._network.load_state_dict(
torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict'])
s_model._network.to(self.device)
s_model._network.eval()
self.surrogate_model = s_model._network
del s_model
torch.cuda.empty_cache()
self.feat_collecter = []
self.feat_collecter_handler, self.feat_collecter = register_collecter_cl(self.surrogate_model,
self.layer,
self.feat_collecter,
self.args["model_name"])
else:
self.surrogate_model = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet32', pretrained=True)
self.surrogate_model.to(self.device)
self.surrogate_model.eval()
self.feat_collecter = []
self.feat_collecter_handler, self.feat_collecter = register_collecter(self.surrogate_model,
self.layer,
self.feat_collecter)
self.file_path = os.path.join(self.save_path, f'{self.prefix}.pth')
if os.path.exists(self.file_path):
self.adv_generator.load_ckpt(self.file_path)
self.adv_generator.set_mode('eval')
else:
loaders = get_dataloader(self.data_manager, batch_size=self.batch_size,
start_class=0, end_class=10,
train=True, shuffle=True, num_workers=0)
target_images = []
target_labels = []
for data in loaders:
_, image_batch, label_batch = data
mask = label_batch == self.target_class
selected_images = image_batch[mask]
selected_labels = label_batch[mask]
target_images.append(selected_images)
target_labels.append(selected_labels)
del loaders
target_images = torch.cat(target_images, dim=0).to(self.device)
target_labels = torch.cat(target_labels, dim=0).to(self.device)
target_images, target_labels = ep.astensors(*(target_images[:self.batch_size], target_labels[:self.batch_size]))
total_loss = []
for epoch in range(1, self.num_epoch + 1):
laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}')
loss_np = 0
for i, (_, x, y) in enumerate(laoder_tqdm):
x_f = x[y != self.target_class].to(self.device)
del x, y
if len(x_f) > len(target_images):
x_f = x_f[:len(target_images)].to(self.device)
else:
random_indices = torch.randperm(len(target_images))[:len(x_f)].to(self.device)
target_images = target_images[random_indices]
x_adv = self.adv_generator(x_f, target_images.raw.to(self.device))
logits_nat = self.surrogate_model(self.norm(x_f))
feat_nat = self.feat_collecter.pop()
logits_tar = self.surrogate_model(self.norm(target_images.raw))
feat_tar = self.feat_collecter.pop()
logits_adv = self.surrogate_model(self.norm(x_adv))
feat_adv = self.feat_collecter.pop()
loss = (self.contrastive_loss(logits_adv, logits_nat, logits_tar) +
self.sim_loss(feat_nat, feat_adv) -
self.sim_loss(feat_tar, feat_adv)).mean()
# print(loss.item())
loss_np = loss_np + loss.item()
self.optim.zero_grad()
loss.backward()
self.optim.step()
del x_f, x_adv, logits_nat, logits_adv, logits_tar, feat_nat, feat_tar, feat_adv
torch.cuda.empty_cache()
total_loss.append(loss_np/(i+1))
logging.info(f'Epoch {epoch} loss: {loss_np/(len(self.loader))}')
logging.info(f'Total loss: {total_loss}')
self.feat_collecter_handler.remove()
self.adv_generator.save_ckpt(self.file_path)
def run_test(self):
# Load Batch Data
self.adv_generator.set_mode('eval')
self.adv_generator.adv_gen.to(self.device)
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
start_class=0, end_class=10,
train=False, shuffle=False, num_workers=0)
target_images = []
target_labels = []
for data in self.loader:
_, image_batch, label_batch = data
mask = label_batch == self.target_class
selected_images = image_batch[mask]
selected_labels = label_batch[mask]
target_images.append(selected_images)
target_labels.append(selected_labels)
target_imgs = torch.cat(target_images, dim=0).to(self.device)
target_labels = torch.cat(target_labels, dim=0).to(self.device)
target_imgs, target_labels = ep.astensors(*(target_imgs, target_labels))
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
if i > 0:
break
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
imgs_f = imgs[labels != self.target_class]
labels_f = labels[labels != self.target_class]
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs[:20], target_labels[:20])
def attacks(self, i_batch, imgs, labels, labels_t, target_imgs=None, target_labels=None):
asr_matrix = np.ones((10, len(target_imgs)))
self.model = factory.get_model(self.args["model_name"], self.args)
for task in range(10):
logging.info("***** Starting attack on task [{}]. *****".format(task))
self.model.incremental_train(self.data_manager)
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
self.model._network.to(self.device)
self.model._network.eval()
# Run attack on ecah target image
criterion = fb.criteria.Misclassification(
labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
labels_t)
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
verify_input_bounds(imgs, current_model)
criterion = get_criterion(criterion)
is_adversarial = get_is_adversarial(criterion, current_model)
logging.info("Eval attack on each target images.")
for i, target_image in enumerate(target_imgs):
advs = ep.astensor(self.adv_generator(imgs.raw.to(self.device), target_image.raw.repeat(len(imgs), 1, 1, 1).to(self.device)))
is_adv = is_adversarial(advs)[0]
asr_matrix[task, i] = (is_adv.bool().sum().raw.item() / len(imgs))
if self.plot_gradcam:
save_grad_cam(self.args, torch.clip(advs.raw.detach(),0,1), labels_t.raw,
self.model._network, self.save_path + "/GradCam" + f"targetimg{i}", prefix=f'task{task}',
layer_name='stage_3', save_num=100, save_raw=True)
del advs, is_adv, target_image
torch.cuda.empty_cache()
del criterion, current_model, is_adversarial
torch.cuda.empty_cache()
self.model.after_task()
# Save all target images info: everage asr,
asr_matrix = np.mean(asr_matrix, axis=1, keepdims=True)
prefix = f'batch{i_batch}_{self.prefix}'
plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args)
df = pd.DataFrame(asr_matrix, columns=['ASR'])
df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)
del asr_matrix, imgs, labels, labels_t, target_imgs
torch.cuda.empty_cache()
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Union[Sequence[Union[float, None]], float, None],
**kwargs: Any,
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
...
def repeat(self, times: int) -> "AIM":
...