|
|
import os.path
|
|
|
import torchvision
|
|
|
import torch
|
|
|
from foolbox.attacks.base import *
|
|
|
from foolbox.attacks.gradient_descent_base import *
|
|
|
from tqdm import tqdm
|
|
|
import pandas as pd
|
|
|
from attacks.Gaker.Generator.Generator import Generator
|
|
|
from attacks.Gaker.Generator.train import CustomResnet50, CustomDenseNet121
|
|
|
from attacks.AIM.src.gat.models.surrogate import midlayer_dict, register_collecter_cl
|
|
|
from attacks.attack_config import SustainableAttack
|
|
|
from utils.plot import plot_asr_per_target, save_grad_cam
|
|
|
import logging
|
|
|
import foolbox as fb
|
|
|
from foolbox import PyTorchModel
|
|
|
import numpy as np
|
|
|
from utils import factory
|
|
|
from utils.data_manager import get_dataloader
|
|
|
from attacks.Gaker.utils_.gaussian_smoothing import get_gaussian_kernel
|
|
|
|
|
|
class Gaker(SustainableAttack):
|
|
|
def __init__(self, args, device='cuda'):
|
|
|
super().__init__(args, device)
|
|
|
self.device = device
|
|
|
self.args = args
|
|
|
self.surrogate_model = None
|
|
|
self.surrogate_model_name = 'resnet50'
|
|
|
if 'cl' in self.surrogate_model_name:
|
|
|
s_model = factory.get_model(self.args["model_name"], self.args)
|
|
|
s_model.incremental_train(self.data_manager)
|
|
|
s_model._network.load_state_dict(
|
|
|
torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict'])
|
|
|
s_model._network.to(self.device)
|
|
|
s_model._network.eval()
|
|
|
self.surrogate_model = s_model._network
|
|
|
del s_model
|
|
|
torch.cuda.empty_cache()
|
|
|
self.feat_collecter = []
|
|
|
self.layer = midlayer_dict[self.surrogate_model_name]
|
|
|
self.feat_collecter_handler, self.feat_collecter = register_collecter_cl(self.surrogate_model,
|
|
|
self.layer,
|
|
|
self.feat_collecter,
|
|
|
self.args["model_name"])
|
|
|
self.feature_channel = 2048
|
|
|
else:
|
|
|
if self.surrogate_model_name == "resnet50":
|
|
|
original_model = torchvision.models.resnet50(pretrained=True)
|
|
|
self.feature_extraction = CustomResnet50(original_model)
|
|
|
self.feature_extraction = self.feature_extraction.eval().to(self.device)
|
|
|
self.feature_channel = 2048
|
|
|
elif self.surrogate_model_name == "densenet121":
|
|
|
original_model = torchvision.models.densenet121(pretrained=True)
|
|
|
self.feature_extraction = CustomDenseNet121(original_model)
|
|
|
self.feature_extraction = self.feature_extraction.eval().to(self.device)
|
|
|
self.feature_channel = 1024
|
|
|
elif self.surrogate_model_name == "vgg19bn":
|
|
|
vgg19bn = torchvision.models.vgg19(pretrained=True).eval().to(self.device)
|
|
|
self.feature_channel = 4096
|
|
|
global hook_output
|
|
|
hook_output = None
|
|
|
def hook(module, input, output):
|
|
|
global hook_output
|
|
|
hook_output = output
|
|
|
handle = vgg19bn.classifier[5].register_forward_hook(hook)
|
|
|
|
|
|
self.adv_generator = Generator(num_target=10, ch=32, ch_mult=[1, 2, 3, 4],num_res_blocks=1,feature_channel_num=self.feature_channel).to(device)
|
|
|
self.lr = 0.001
|
|
|
self.betas = (0.5, 0.999)
|
|
|
self.num_epoch = 100
|
|
|
self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.adv_generator.parameters()), lr=self.lr, weight_decay=5e-5)
|
|
|
self.kernel = get_gaussian_kernel(kernel_size=3, pad=2, sigma=1).to(device)
|
|
|
|
|
|
self.eps = 32 / 255
|
|
|
self.ran_best = 'random'
|
|
|
self.prefix = f'{int(self.eps*255)}'
|
|
|
self.save_path = os.path.join(self.args['logs_eval_name'], f'target{str(self.target_class)}')
|
|
|
if not os.path.exists(self.save_path):
|
|
|
os.makedirs(self.save_path)
|
|
|
self.eval_batch_size = 128
|
|
|
self.plot_gradcam = False
|
|
|
|
|
|
def train_generator(self):
|
|
|
self.file_path = os.path.join(self.save_path, f'{self.prefix}.pth')
|
|
|
if os.path.exists(self.file_path):
|
|
|
self.adv_generator.load_state_dict(torch.load(self.file_path, map_location=self.device))
|
|
|
self.adv_generator.eval()
|
|
|
else:
|
|
|
loaders = get_dataloader(self.data_manager, batch_size=self.batch_size,
|
|
|
start_class=0, end_class=10,
|
|
|
train=True, shuffle=True, num_workers=0)
|
|
|
target_images = []
|
|
|
target_labels = []
|
|
|
for data in loaders:
|
|
|
_, image_batch, label_batch = data
|
|
|
mask = label_batch == self.target_class
|
|
|
selected_images = image_batch[mask]
|
|
|
selected_labels = label_batch[mask]
|
|
|
target_images.append(selected_images)
|
|
|
target_labels.append(selected_labels)
|
|
|
del loaders
|
|
|
target_images = torch.cat(target_images, dim=0).to(self.device)
|
|
|
target_labels = torch.cat(target_labels, dim=0).to(self.device)
|
|
|
|
|
|
for epoch in range(1, self.num_epoch + 1):
|
|
|
iteration = 0
|
|
|
laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}')
|
|
|
for i, (_, x, y) in enumerate(laoder_tqdm):
|
|
|
x_f = x[y != self.target_class].to(self.device)
|
|
|
y_f = y[y != self.target_class].to(self.device)
|
|
|
del x, y
|
|
|
|
|
|
if len(x_f) > len(target_images):
|
|
|
x_f = x_f[:len(target_images)]
|
|
|
y_f = y_f[:len(target_images)]
|
|
|
else:
|
|
|
target_images = target_images[:len(x_f)]
|
|
|
target_labels = target_labels[:len(x_f)]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
if self.surrogate_model_name == "resnet50" or self.surrogate_model_name == "densenet121":
|
|
|
target_fea = self.feature_extraction(self.norm(target_images)).squeeze()
|
|
|
|
|
|
output_to_mix = target_fea
|
|
|
target_feature = []
|
|
|
for i in range(target_images.shape[0]):
|
|
|
target_feature.append(target_fea[i])
|
|
|
|
|
|
target_feature = torch.tensor(
|
|
|
np.array([item.cpu().detach().numpy() for item in target_feature])).to(self.device)
|
|
|
mask = torch.ne(y_f, target_labels).long().to(self.device)
|
|
|
perturbated_imgs = self.kernel(self.adv_generator(x_f, mix=output_to_mix))
|
|
|
|
|
|
adv = torch.min(torch.max(perturbated_imgs, x_f - self.eps), x_f + self.eps)
|
|
|
adv = torch.clamp(adv, 0.0, 1.0)
|
|
|
|
|
|
if self.surrogate_model_name == "resnet50" or self.surrogate_model_name == "densenet121":
|
|
|
adv_feature = self.feature_extraction(self.norm(adv))
|
|
|
|
|
|
if self.surrogate_model_name == "resnet50" or self.surrogate_model_name == "densenet121" or \
|
|
|
self.surrogate_model_name == "vgg19bn":
|
|
|
adv_feature = adv_feature.squeeze()
|
|
|
loss = 1 - torch.cosine_similarity(adv_feature, target_feature, dim=1)
|
|
|
|
|
|
loss = mask * loss
|
|
|
|
|
|
noise = adv - x_f
|
|
|
if self.surrogate_model_name == "resnet50" or self.surrogate_model_name == "densenet121":
|
|
|
noise_feature = self.feature_extraction(self.norm(noise)).squeeze()
|
|
|
loss_noise = 1 - torch.cosine_similarity(noise_feature, target_feature, dim=1)
|
|
|
|
|
|
loss_noise = mask * loss_noise * 0.5
|
|
|
loss = loss + loss_noise
|
|
|
|
|
|
loss = (loss.sum()) / x_f.shape[0]
|
|
|
|
|
|
loss.backward()
|
|
|
self.optimizer.step()
|
|
|
|
|
|
iteration += 1
|
|
|
|
|
|
del x_f, y_f, noise, adv, adv_feature, loss, loss_noise, perturbated_imgs, mask
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.save(self.adv_generator.state_dict(), self.file_path)
|
|
|
|
|
|
def run_test(self):
|
|
|
|
|
|
self.adv_generator.eval()
|
|
|
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_size,
|
|
|
start_class=0, end_class=10,
|
|
|
train=False, shuffle=True, num_workers=0)
|
|
|
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
|
|
|
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
|
|
|
if i > 0:
|
|
|
break
|
|
|
|
|
|
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
|
|
|
|
|
|
target_imgs = imgs[labels == self.target_class]
|
|
|
target_labels = labels[labels == self.target_class]
|
|
|
imgs_f = imgs[labels != self.target_class]
|
|
|
labels_f = labels[labels != self.target_class]
|
|
|
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
|
|
|
|
|
|
self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs[:20], target_labels[:20])
|
|
|
|
|
|
def attacks(self, i_batch, imgs, labels, labels_t, target_imgs=None, target_labels=None):
|
|
|
asr_matrix = np.ones((10, len(target_imgs)))
|
|
|
self.model = factory.get_model(self.args["model_name"], self.args)
|
|
|
for task in range(10):
|
|
|
logging.info("***** Starting attack on task [{}]. *****".format(task))
|
|
|
self.model.incremental_train(self.data_manager)
|
|
|
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
|
|
|
self.model._network.to(self.device)
|
|
|
self.model._network.eval()
|
|
|
|
|
|
|
|
|
criterion = fb.criteria.Misclassification(
|
|
|
labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
|
|
|
labels_t)
|
|
|
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
|
|
verify_input_bounds(imgs, current_model)
|
|
|
criterion = get_criterion(criterion)
|
|
|
is_adversarial = get_is_adversarial(criterion, current_model)
|
|
|
|
|
|
logging.info("Eval attack on each target images.")
|
|
|
for i, target_image in enumerate(target_imgs):
|
|
|
advs = self.adv_gen(imgs.raw, target_image.raw.repeat(len(imgs), 1, 1, 1))
|
|
|
is_adv = is_adversarial(ep.astensor(advs))[0]
|
|
|
asr_matrix[task, i] = (is_adv.bool().sum().raw.item() / len(imgs))
|
|
|
|
|
|
if self.plot_gradcam:
|
|
|
save_grad_cam(self.args, torch.clip(advs.detach(), 0, 1), labels_t.raw,
|
|
|
self.model._network, self.save_path + "/GradCam" + f"targetimg{i}",
|
|
|
prefix=f'task{task}',
|
|
|
layer_name='stage_3', save_num=100, save_raw=True)
|
|
|
|
|
|
del advs, is_adv, target_image
|
|
|
torch.cuda.empty_cache()
|
|
|
del criterion, current_model, is_adversarial
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
self.model.after_task()
|
|
|
|
|
|
asr_matrix = np.mean(asr_matrix, axis=1, keepdims=True)
|
|
|
prefix = f'batch{i_batch}_{self.prefix}'
|
|
|
plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args)
|
|
|
df = pd.DataFrame(asr_matrix, columns=['ASR'])
|
|
|
df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)
|
|
|
|
|
|
del asr_matrix, imgs, labels, labels_t, target_imgs
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def adv_gen(self, images, target_images):
|
|
|
if self.ran_best == 'random':
|
|
|
target_feature = self.feature_extraction(self.norm(target_images))
|
|
|
output_to_mix = target_feature.squeeze()
|
|
|
elif self.ran_best == 'best':
|
|
|
print('not used')
|
|
|
else:
|
|
|
print('please choose random or best')
|
|
|
|
|
|
perturbated_imgs = self.adv_generator(images, mix=output_to_mix)
|
|
|
|
|
|
perturbated_imgs = self.kernel(perturbated_imgs)
|
|
|
|
|
|
adv = torch.min(torch.max(perturbated_imgs, images - self.eps), images + self.eps)
|
|
|
adv = torch.clamp(adv, 0, 1.0)
|
|
|
|
|
|
return adv
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
model: Model,
|
|
|
inputs: T,
|
|
|
criterion: Any,
|
|
|
*,
|
|
|
epsilons: Union[Sequence[Union[float, None]], float, None],
|
|
|
**kwargs: Any,
|
|
|
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
|
|
|
...
|
|
|
|
|
|
def repeat(self, times: int) -> "Gaker":
|
|
|
... |