File size: 6,200 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from attacks.attack_config import SustainableAttack
from tqdm import tqdm
import logging
import foolbox as fb
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import get_dataloader
import os
from utils.plot import plot_asr_per_target, save_grad_cam
import pandas as pd
class BASEAttack(SustainableAttack):
def __init__(self, args, device='cuda'):
super().__init__(args, device)
attack_name = args['attack_method']
if attack_name == 'L2FGM':
self.attack = fb.attacks.FGM()
elif attack_name == 'FGSM':
self.attack = fb.attacks.FGSM()
elif attack_name == 'MIFGSM':
self.attack = fb.attacks.MIFGSM(momentum=0.9, steps=10)
elif attack_name == 'L1PGD':
self.attack = fb.attacks.L1PGD(steps=10)
elif attack_name == 'L2PGD':
self.attack = fb.attacks.L2PGD(steps=10)
elif attack_name == 'LinfPGD':
self.attack = fb.attacks.LinfPGD(steps=10)
elif attack_name == 'L2DeepFool':
self.attack = fb.attacks.L2DeepFoolAttack(steps=10)
elif attack_name == 'LinfDeepFool':
self.attack = fb.attacks.LinfDeepFoolAttack(steps=10)
elif attack_name == 'BoundaryAttack':
self.attack = fb.attacks.BoundaryAttack(steps=10)
elif attack_name == 'CarliniWagnerL2':
self.attack = fb.attacks.L2CarliniWagnerAttack(steps=10)
elif attack_name == 'GaussianNoise':
self.attack = fb.attacks.LinfRepeatedAdditiveUniformNoiseAttack()
elif attack_name == 'UniformNoise':
self.attack = fb.attacks.LinfAdditiveUniformNoiseAttack()
else:
raise ValueError(f"Unknown attack method: {attack_name}")
self.test_mode = 'BaseAttacks'
self.epsilon = 16
self.eval = args['eval']
self.prefix = f'eps{self.epsilon}'
self.save_path = os.path.join(self.args['logs_eval_name'])
self.out_path = os.path.join(self.args['logs_eval_name'], f'gradcam_t{self.target_class}')
if not os.path.exists(self.out_path):
os.makedirs(self.out_path)
self.plot_gradcam = True
def run_test(self):
# Load Batch Data
eval_batch_szie = 1024
self.loader = get_dataloader(self.data_manager, batch_size=eval_batch_szie,
start_class=0, end_class=10,
train=False, shuffle=False, num_workers=0)
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
if i> 0:
break
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
target_imgs = imgs[labels == self.target_class]
target_labels = labels[labels == self.target_class]
imgs_f = imgs[labels != self.target_class]
labels_f = labels[labels != self.target_class]
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs, target_labels)
def attacks(self, i_batch, imgs, labels, labels_t):
clean_acc_matrix = []
asr_matrix = np.ones((10, 1))
self.model = factory.get_model(self.args["model_name"], self.args)
for task in range(10):
logging.info("***** Starting attack on task [{}]. *****".format(task))
self.model.incremental_train(self.data_manager)
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
self.model._network.to(self.device)
self.model._network.eval()
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
acc, acc_bool = accuracy(current_model, imgs, labels)[:2]
clean_acc_matrix.append(acc)
if task == 0:
imgs = imgs[acc_bool]
labels = labels[acc_bool]
labels_t = labels_t[acc_bool]
verify_input_bounds(imgs, current_model)
criterion = fb.criteria.Misclassification(
labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
labels_t)
if task == 0:
adv, adv_clip, asr_bool = self.attack(current_model, imgs, criterion=criterion, epsilons=self.epsilon/255)
asr_matrix[task] = asr_bool.sum().raw.item() / len(imgs)
else:
asr = accuracy(current_model, adv_clip, labels_t)[0]
asr_matrix[task] = asr
if self.plot_gradcam:
save_grad_cam(self.args, torch.clip(adv_clip.raw, 0, 1), labels_t.raw,
self.model._network, self.out_path + "/GradCam", prefix=f'task{task}',
layer_name='stage_3', save_num=100, save_raw=True)
del criterion, current_model
torch.cuda.empty_cache()
self.model.after_task()
# Save all target images info: everage asr,
prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}'
plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args, clean_acc_matrix)
df = pd.DataFrame(asr_matrix, columns=['ASR'])
df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)
del asr_matrix, imgs, labels, labels_t
torch.cuda.empty_cache()
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Sequence[Union[float, None]],
**kwargs: Any,
) -> Tuple[List[T], List[T], T]:
...
def repeat(self, times: int) -> "BASEAttack":
... |