SAE / attacks /BASEAttack.py
Ttius's picture
Upload 192 files
998bb30 verified
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from attacks.attack_config import SustainableAttack
from tqdm import tqdm
import logging
import foolbox as fb
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import get_dataloader
import os
from utils.plot import plot_asr_per_target, save_grad_cam
import pandas as pd
class BASEAttack(SustainableAttack):
def __init__(self, args, device='cuda'):
super().__init__(args, device)
attack_name = args['attack_method']
if attack_name == 'L2FGM':
self.attack = fb.attacks.FGM()
elif attack_name == 'FGSM':
self.attack = fb.attacks.FGSM()
elif attack_name == 'MIFGSM':
self.attack = fb.attacks.MIFGSM(momentum=0.9, steps=10)
elif attack_name == 'L1PGD':
self.attack = fb.attacks.L1PGD(steps=10)
elif attack_name == 'L2PGD':
self.attack = fb.attacks.L2PGD(steps=10)
elif attack_name == 'LinfPGD':
self.attack = fb.attacks.LinfPGD(steps=10)
elif attack_name == 'L2DeepFool':
self.attack = fb.attacks.L2DeepFoolAttack(steps=10)
elif attack_name == 'LinfDeepFool':
self.attack = fb.attacks.LinfDeepFoolAttack(steps=10)
elif attack_name == 'BoundaryAttack':
self.attack = fb.attacks.BoundaryAttack(steps=10)
elif attack_name == 'CarliniWagnerL2':
self.attack = fb.attacks.L2CarliniWagnerAttack(steps=10)
elif attack_name == 'GaussianNoise':
self.attack = fb.attacks.LinfRepeatedAdditiveUniformNoiseAttack()
elif attack_name == 'UniformNoise':
self.attack = fb.attacks.LinfAdditiveUniformNoiseAttack()
else:
raise ValueError(f"Unknown attack method: {attack_name}")
self.test_mode = 'BaseAttacks'
self.epsilon = 16
self.eval = args['eval']
self.prefix = f'eps{self.epsilon}'
self.save_path = os.path.join(self.args['logs_eval_name'])
self.out_path = os.path.join(self.args['logs_eval_name'], f'gradcam_t{self.target_class}')
if not os.path.exists(self.out_path):
os.makedirs(self.out_path)
self.plot_gradcam = True
def run_test(self):
# Load Batch Data
eval_batch_szie = 1024
self.loader = get_dataloader(self.data_manager, batch_size=eval_batch_szie,
start_class=0, end_class=10,
train=False, shuffle=False, num_workers=0)
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
if i> 0:
break
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
target_imgs = imgs[labels == self.target_class]
target_labels = labels[labels == self.target_class]
imgs_f = imgs[labels != self.target_class]
labels_f = labels[labels != self.target_class]
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs, target_labels)
def attacks(self, i_batch, imgs, labels, labels_t):
clean_acc_matrix = []
asr_matrix = np.ones((10, 1))
self.model = factory.get_model(self.args["model_name"], self.args)
for task in range(10):
logging.info("***** Starting attack on task [{}]. *****".format(task))
self.model.incremental_train(self.data_manager)
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
self.model._network.to(self.device)
self.model._network.eval()
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
acc, acc_bool = accuracy(current_model, imgs, labels)[:2]
clean_acc_matrix.append(acc)
if task == 0:
imgs = imgs[acc_bool]
labels = labels[acc_bool]
labels_t = labels_t[acc_bool]
verify_input_bounds(imgs, current_model)
criterion = fb.criteria.Misclassification(
labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
labels_t)
if task == 0:
adv, adv_clip, asr_bool = self.attack(current_model, imgs, criterion=criterion, epsilons=self.epsilon/255)
asr_matrix[task] = asr_bool.sum().raw.item() / len(imgs)
else:
asr = accuracy(current_model, adv_clip, labels_t)[0]
asr_matrix[task] = asr
if self.plot_gradcam:
save_grad_cam(self.args, torch.clip(adv_clip.raw, 0, 1), labels_t.raw,
self.model._network, self.out_path + "/GradCam", prefix=f'task{task}',
layer_name='stage_3', save_num=100, save_raw=True)
del criterion, current_model
torch.cuda.empty_cache()
self.model.after_task()
# Save all target images info: everage asr,
prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}'
plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args, clean_acc_matrix)
df = pd.DataFrame(asr_matrix, columns=['ASR'])
df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)
del asr_matrix, imgs, labels, labels_t
torch.cuda.empty_cache()
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Sequence[Union[float, None]],
**kwargs: Any,
) -> Tuple[List[T], List[T], T]:
...
def repeat(self, times: int) -> "BASEAttack":
...