|
|
import os.path
|
|
|
import torch
|
|
|
from foolbox.attacks.base import *
|
|
|
from foolbox.attacks.gradient_descent_base import *
|
|
|
from tqdm import tqdm
|
|
|
import pandas as pd
|
|
|
from attacks.attack_config import SustainableAttack
|
|
|
from utils.plot import plot_asr_per_target, save_grad_cam
|
|
|
import logging
|
|
|
from foolbox import PyTorchModel, accuracy
|
|
|
import numpy as np
|
|
|
from utils import factory
|
|
|
from utils.data_manager import get_dataloader
|
|
|
from attacks.CleanSheet.utils_ import Trigger
|
|
|
from attacks.CleanSheet.generate_kd import train
|
|
|
|
|
|
|
|
|
class CleanSheet(SustainableAttack):
|
|
|
def __init__(self, args, device='cuda'):
|
|
|
super().__init__(args, device)
|
|
|
self.device = device
|
|
|
self.args = args
|
|
|
self.surrogate_model = None
|
|
|
self.eval_batch_szie = 128
|
|
|
|
|
|
self.eval = args['eval']
|
|
|
self.args['run_baseline'] = True
|
|
|
|
|
|
test_mode = 'Trigger'
|
|
|
self.trigger_name = f'{test_mode}'
|
|
|
self.plot_gradcam = True
|
|
|
|
|
|
def train_adv(self):
|
|
|
if self.eval:
|
|
|
pass
|
|
|
else:
|
|
|
for i in range(10):
|
|
|
self.args['target_class'] = i
|
|
|
train(self.args)
|
|
|
torch.cuda.empty_cache()
|
|
|
def run_test(self):
|
|
|
pth_name = self.get_max_step_filename(f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}')
|
|
|
self.ckpt = f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}/{pth_name}.pth'
|
|
|
self.prefix = f'{self.trigger_name}_{pth_name.split("_")[-1]}'
|
|
|
a = torch.load(self.ckpt)
|
|
|
self.trigger = Trigger(size=32).to(self.device)
|
|
|
self.trigger.load_state_dict(a)
|
|
|
self.trigger.eval()
|
|
|
|
|
|
|
|
|
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
|
|
|
start_class=0, end_class=10,
|
|
|
train=False, shuffle=True, num_workers=0)
|
|
|
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
|
|
|
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
|
|
|
if i > 0:
|
|
|
break
|
|
|
|
|
|
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
|
|
|
|
|
|
target_imgs = imgs[labels == self.target_class]
|
|
|
target_labels = labels[labels == self.target_class]
|
|
|
imgs_f = imgs[labels != self.target_class]
|
|
|
labels_f = labels[labels != self.target_class]
|
|
|
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
|
|
|
if self.args["model_name"] != 'finetune':
|
|
|
imgs_f, labels_f = self.to_all(imgs_f, labels_f)
|
|
|
if imgs_f is None:
|
|
|
continue
|
|
|
if target_imgs is None:
|
|
|
continue
|
|
|
labels_t_f = labels_t_f[:len(imgs_f)]
|
|
|
|
|
|
self.attacks(i, imgs_f, labels_f, labels_t_f)
|
|
|
|
|
|
def attacks(self, i_batch, imgs, labels, labels_t):
|
|
|
asr_matrix = np.ones((10, 1))
|
|
|
self.model = factory.get_model(self.args["model_name"], self.args)
|
|
|
eval_path = os.path.join(self.args["logs_eval_name"], self.trigger_name)
|
|
|
for task in range(10):
|
|
|
logging.info("***** Starting attack on task [{}]. *****".format(task))
|
|
|
self.model.incremental_train(self.data_manager)
|
|
|
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
|
|
|
self.model._network.to(self.device)
|
|
|
self.model._network.eval()
|
|
|
|
|
|
|
|
|
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
|
|
verify_input_bounds(imgs, current_model)
|
|
|
logging.info("Eval attack on each target images.")
|
|
|
advs = self.trigger(imgs.raw)
|
|
|
advs = ep.astensor(advs).clip(-1, 1)
|
|
|
asr = accuracy(current_model, ep.astensor(advs), labels_t)[0]
|
|
|
asr_matrix[task] = asr
|
|
|
if self.plot_gradcam:
|
|
|
save_grad_cam(self.args, torch.clip(advs.raw.detach(), -1, 1), labels_t.raw,
|
|
|
self.model._network, eval_path + "/GradCam", prefix=f'task{task}',
|
|
|
layer_name='stage_3', save_num=100, save_raw=True)
|
|
|
del advs, current_model
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
self.model.after_task()
|
|
|
|
|
|
|
|
|
prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}'
|
|
|
plot_asr_per_target(asr_matrix, eval_path, prefix, self.args)
|
|
|
df = pd.DataFrame(asr_matrix, columns=['ASR'])
|
|
|
df.to_excel(os.path.join(eval_path, f"{prefix}.xlsx"), index=False)
|
|
|
|
|
|
del asr_matrix, imgs, labels, labels_t
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def get_max_step_filename(self, folder_path):
|
|
|
files = [f for f in os.listdir(folder_path) if f.endswith('.pth')]
|
|
|
step_files = [(f, int(f.split('_')[-1].split('.')[0])) for f in files]
|
|
|
step_files.sort(key=lambda x: x[1], reverse=True)
|
|
|
max_step_file = step_files[0][0]
|
|
|
return os.path.splitext(max_step_file)[0]
|
|
|
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
model: Model,
|
|
|
inputs: T,
|
|
|
criterion: Any,
|
|
|
*,
|
|
|
epsilons: Union[Sequence[Union[float, None]], float, None],
|
|
|
**kwargs: Any,
|
|
|
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
|
|
|
...
|
|
|
|
|
|
def repeat(self, times: int) -> "CleanSheet":
|
|
|
... |