SAE / attacks /UnivIntruder /UnivIntruderAttack.py
Ttius's picture
Upload 192 files
998bb30 verified
import os.path
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from tqdm import tqdm
import pandas as pd
from attacks.attack_config import SustainableAttack
from utils.plot import save_grad_cam, plot_asr_per_target
import logging
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import DataManager, get_dataloader
from attacks.UnivIntruder.att import train
from attacks.UnivIntruder.loss import UniversalPerturbation
class UnivIntruder(SustainableAttack):
def __init__(self, args, device='cuda'):
super().__init__(args, device)
self.device = device
self.args = args
self.surrogate_model = None # surrogate__model.to(device).eval()
self.target_class = args['target_class']
self.epsilon = 32
self.image_size = 32 if self.args['dataset'] == 'cifar100' else 224
self.eval_batch_szie = 128
self.adv_name = f'adv_eps{self.epsilon}_tc{self.target_class}'
self.out_path = f'{self.args["logs_eval_name"]}/{self.adv_name}'
os.makedirs(self.out_path, exist_ok=True)
self.eval = args['eval']
self.plot_gradcam = True
self.ckpt_num = None # None
def train_adv(self):
if self.eval:
pass
else:
train(self.args)
def run_test(self):
pth_name = self.get_max_step_filename(f'{self.args["logs_eval_name"]}/{self.adv_name}/ckpts')
self.ckpt = f'{self.args["logs_eval_name"]}/{self.adv_name}/ckpts/{pth_name}.pth'
self.prefix = f'{self.adv_name}_{pth_name.split("_")[-1]}'
a = torch.load(self.ckpt)
self.adv = UniversalPerturbation((3, self.image_size, self.image_size), self.epsilon / 255,
initialization=a, device=self.device)
self.adv.eval()
# Load Batch Data
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
start_class=0, end_class=10,
train=False, shuffle=False, num_workers=0)
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
if i > 0:
break
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
imgs_f = imgs[labels != self.target_class]
labels_f = labels[labels != self.target_class]
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
self.attacks(i, imgs_f, labels_f, labels_t_f)
def attacks(self, i_batch, imgs, labels, labels_t):
if self.args["model_name"] != 'finetune':
imgs_f, labels_f = self.to_all(imgs, labels)
labels_t = labels_t[:len(imgs_f)]
else:
imgs_f, labels_f = imgs, labels
clean_acc_matrix = []
asr_matrix = np.ones((self.data_manager.nb_tasks, 1))
self.model = factory.get_model(self.args["model_name"], self.args)
eval_path = os.path.join(self.args["logs_eval_name"], self.adv_name)
cnn_matrix, nme_matrix = [], []
for task in range(self.data_manager.nb_tasks):
logging.info("***** Starting attack on task [{}]. *****".format(task))
self.model.incremental_train(self.data_manager)
self.model._network.load_state_dict(
torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
self.model._network.to(self.device)
self.model._network.eval()
# Run attack on ecah target image
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
verify_input_bounds(imgs_f, current_model)
# Evaluate the model perfromance with clean data
acc = accuracy(current_model, imgs, labels)[0]
logging.info("Clean accuracy on task {}: {}%".format(task, acc * 100))
clean_acc_matrix.append(acc)
advs = self.adv(imgs_f.raw)
asr = accuracy(current_model, ep.astensor(advs), labels_t)[0]
asr_matrix[task] = asr
if self.plot_gradcam:
save_grad_cam(self.args, torch.clip(advs.detach(), 0, 1), labels_t.raw,
self.model._network, self.out_path + "/GradCam", prefix=f'task{task}',
layer_name='stage_3', save_num=100, save_raw=True)
del advs, current_model
torch.cuda.empty_cache()
self.model.after_task()
# Save all target images info: everage asr,
prefix = f'batch{i_batch}_{self.prefix}'
plot_asr_per_target(asr_matrix, eval_path, prefix, self.args, clean_acc_matrix)
df = pd.DataFrame(asr_matrix, columns=['ASR'])
df.to_excel(os.path.join(eval_path, f"{prefix}.xlsx"), index=False)
del asr_matrix, imgs, labels, labels_t, imgs_f, labels_f
torch.cuda.empty_cache()
def get_max_step_filename(self, folder_path):
files = [f for f in os.listdir(folder_path) if f.endswith('.pth')]
step_files = [(f, int(f.split('_')[-1].split('.')[0])) for f in files]
step_files.sort(key=lambda x: x[1], reverse=True)
max_step_file = step_files[0][0]
return os.path.splitext(max_step_file)[0]
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Union[Sequence[Union[float, None]], float, None],
**kwargs: Any,
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
...
def repeat(self, times: int) -> "UnivIntruder":
...