File size: 5,708 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os.path
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from tqdm import tqdm
import pandas as pd
from attacks.attack_config import SustainableAttack
from utils.plot import plot_asr_per_target, save_grad_cam
import logging
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import get_dataloader
from attacks.CleanSheet.utils_ import Trigger
from attacks.CleanSheet.generate_kd import train
class CleanSheet(SustainableAttack):
def __init__(self, args, device='cuda'):
super().__init__(args, device)
self.device = device
self.args = args
self.surrogate_model = None
self.eval_batch_szie = 128
self.eval = args['eval']
self.args['run_baseline'] = True
test_mode = 'Trigger'
self.trigger_name = f'{test_mode}'
self.plot_gradcam = True
def train_adv(self):
if self.eval:
pass
else:
for i in range(10):
self.args['target_class'] = i
train(self.args)
torch.cuda.empty_cache()
def run_test(self):
pth_name = self.get_max_step_filename(f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}')
self.ckpt = f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}/{pth_name}.pth'
self.prefix = f'{self.trigger_name}_{pth_name.split("_")[-1]}'
a = torch.load(self.ckpt)
self.trigger = Trigger(size=32).to(self.device)
self.trigger.load_state_dict(a)
self.trigger.eval()
# Load Batch Data
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
start_class=0, end_class=10,
train=False, shuffle=True, num_workers=0)
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
if i > 0:
break
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
target_imgs = imgs[labels == self.target_class]
target_labels = labels[labels == self.target_class]
imgs_f = imgs[labels != self.target_class]
labels_f = labels[labels != self.target_class]
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
if self.args["model_name"] != 'finetune':
imgs_f, labels_f = self.to_all(imgs_f, labels_f)
if imgs_f is None:
continue
if target_imgs is None:
continue
labels_t_f = labels_t_f[:len(imgs_f)]
self.attacks(i, imgs_f, labels_f, labels_t_f)
def attacks(self, i_batch, imgs, labels, labels_t):
asr_matrix = np.ones((10, 1))
self.model = factory.get_model(self.args["model_name"], self.args)
eval_path = os.path.join(self.args["logs_eval_name"], self.trigger_name)
for task in range(10):
logging.info("***** Starting attack on task [{}]. *****".format(task))
self.model.incremental_train(self.data_manager)
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
self.model._network.to(self.device)
self.model._network.eval()
# Run attack on ecah target image
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
verify_input_bounds(imgs, current_model)
logging.info("Eval attack on each target images.")
advs = self.trigger(imgs.raw)
advs = ep.astensor(advs).clip(-1, 1)
asr = accuracy(current_model, ep.astensor(advs), labels_t)[0]
asr_matrix[task] = asr
if self.plot_gradcam:
save_grad_cam(self.args, torch.clip(advs.raw.detach(), -1, 1), labels_t.raw,
self.model._network, eval_path + "/GradCam", prefix=f'task{task}',
layer_name='stage_3', save_num=100, save_raw=True)
del advs, current_model
torch.cuda.empty_cache()
self.model.after_task()
# Save all target images info: everage asr,
prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}'
plot_asr_per_target(asr_matrix, eval_path, prefix, self.args)
df = pd.DataFrame(asr_matrix, columns=['ASR'])
df.to_excel(os.path.join(eval_path, f"{prefix}.xlsx"), index=False)
del asr_matrix, imgs, labels, labels_t
torch.cuda.empty_cache()
def get_max_step_filename(self, folder_path):
files = [f for f in os.listdir(folder_path) if f.endswith('.pth')]
step_files = [(f, int(f.split('_')[-1].split('.')[0])) for f in files]
step_files.sort(key=lambda x: x[1], reverse=True)
max_step_file = step_files[0][0]
return os.path.splitext(max_step_file)[0]
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Union[Sequence[Union[float, None]], float, None],
**kwargs: Any,
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
...
def repeat(self, times: int) -> "CleanSheet":
... |