File size: 6,039 Bytes
998bb30 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import os.path
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from tqdm import tqdm
import pandas as pd
from attacks.attack_config import SustainableAttack
from utils.plot import save_grad_cam, plot_asr_per_target
import logging
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import DataManager, get_dataloader
from attacks.UnivIntruder.att import train
from attacks.UnivIntruder.loss import UniversalPerturbation
class UnivIntruder(SustainableAttack):
def __init__(self, args, device='cuda'):
super().__init__(args, device)
self.device = device
self.args = args
self.surrogate_model = None # surrogate__model.to(device).eval()
self.target_class = args['target_class']
self.epsilon = 32
self.image_size = 32 if self.args['dataset'] == 'cifar100' else 224
self.eval_batch_szie = 128
self.adv_name = f'adv_eps{self.epsilon}_tc{self.target_class}'
self.out_path = f'{self.args["logs_eval_name"]}/{self.adv_name}'
os.makedirs(self.out_path, exist_ok=True)
self.eval = args['eval']
self.plot_gradcam = True
self.ckpt_num = None # None
def train_adv(self):
if self.eval:
pass
else:
train(self.args)
def run_test(self):
pth_name = self.get_max_step_filename(f'{self.args["logs_eval_name"]}/{self.adv_name}/ckpts')
self.ckpt = f'{self.args["logs_eval_name"]}/{self.adv_name}/ckpts/{pth_name}.pth'
self.prefix = f'{self.adv_name}_{pth_name.split("_")[-1]}'
a = torch.load(self.ckpt)
self.adv = UniversalPerturbation((3, self.image_size, self.image_size), self.epsilon / 255,
initialization=a, device=self.device)
self.adv.eval()
# Load Batch Data
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
start_class=0, end_class=10,
train=False, shuffle=False, num_workers=0)
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
if i > 0:
break
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
imgs_f = imgs[labels != self.target_class]
labels_f = labels[labels != self.target_class]
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
self.attacks(i, imgs_f, labels_f, labels_t_f)
def attacks(self, i_batch, imgs, labels, labels_t):
if self.args["model_name"] != 'finetune':
imgs_f, labels_f = self.to_all(imgs, labels)
labels_t = labels_t[:len(imgs_f)]
else:
imgs_f, labels_f = imgs, labels
clean_acc_matrix = []
asr_matrix = np.ones((self.data_manager.nb_tasks, 1))
self.model = factory.get_model(self.args["model_name"], self.args)
eval_path = os.path.join(self.args["logs_eval_name"], self.adv_name)
cnn_matrix, nme_matrix = [], []
for task in range(self.data_manager.nb_tasks):
logging.info("***** Starting attack on task [{}]. *****".format(task))
self.model.incremental_train(self.data_manager)
self.model._network.load_state_dict(
torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
self.model._network.to(self.device)
self.model._network.eval()
# Run attack on ecah target image
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
verify_input_bounds(imgs_f, current_model)
# Evaluate the model perfromance with clean data
acc = accuracy(current_model, imgs, labels)[0]
logging.info("Clean accuracy on task {}: {}%".format(task, acc * 100))
clean_acc_matrix.append(acc)
advs = self.adv(imgs_f.raw)
asr = accuracy(current_model, ep.astensor(advs), labels_t)[0]
asr_matrix[task] = asr
if self.plot_gradcam:
save_grad_cam(self.args, torch.clip(advs.detach(), 0, 1), labels_t.raw,
self.model._network, self.out_path + "/GradCam", prefix=f'task{task}',
layer_name='stage_3', save_num=100, save_raw=True)
del advs, current_model
torch.cuda.empty_cache()
self.model.after_task()
# Save all target images info: everage asr,
prefix = f'batch{i_batch}_{self.prefix}'
plot_asr_per_target(asr_matrix, eval_path, prefix, self.args, clean_acc_matrix)
df = pd.DataFrame(asr_matrix, columns=['ASR'])
df.to_excel(os.path.join(eval_path, f"{prefix}.xlsx"), index=False)
del asr_matrix, imgs, labels, labels_t, imgs_f, labels_f
torch.cuda.empty_cache()
def get_max_step_filename(self, folder_path):
files = [f for f in os.listdir(folder_path) if f.endswith('.pth')]
step_files = [(f, int(f.split('_')[-1].split('.')[0])) for f in files]
step_files.sort(key=lambda x: x[1], reverse=True)
max_step_file = step_files[0][0]
return os.path.splitext(max_step_file)[0]
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Union[Sequence[Union[float, None]], float, None],
**kwargs: Any,
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
...
def repeat(self, times: int) -> "UnivIntruder":
... |