|
|
import os.path
|
|
|
from foolbox.attacks.base import *
|
|
|
from foolbox.attacks.gradient_descent_base import *
|
|
|
from tqdm import tqdm
|
|
|
import torch.optim as optim
|
|
|
from attacks.CGNC.models.generator import CrossAttenGenerator
|
|
|
from attacks.CGNC.utils_ import *
|
|
|
from attacks.CGNC.image_transformer import rotation
|
|
|
from attacks.attack_config import SustainableAttack
|
|
|
from utils.plot import plot_asr_per_target
|
|
|
import logging
|
|
|
import foolbox as fb
|
|
|
from foolbox import PyTorchModel
|
|
|
import numpy as np
|
|
|
from utils import factory
|
|
|
from utils.data_manager import get_dataloader
|
|
|
from attacks.AIM.src.gat.models.surrogate import build_surrogate
|
|
|
|
|
|
class CGNC(SustainableAttack):
|
|
|
def __init__(self, args, device='cuda'):
|
|
|
super().__init__(args, device)
|
|
|
self.device = device
|
|
|
self.args = args
|
|
|
self.surrogate_model = None
|
|
|
|
|
|
self.adv_generator = CrossAttenGenerator(nz=16, device=device)
|
|
|
self.adv_generator = self.adv_generator.to(device)
|
|
|
self.lr = 0.001
|
|
|
self.betas = (0.5, 0.999)
|
|
|
self.num_epoch = 100
|
|
|
self.optim = optim.Adam(self.adv_generator.parameters(), lr=self.lr, betas=self.betas)
|
|
|
self.criterion = nn.CrossEntropyLoss()
|
|
|
self.text_cond_dict = torch.load("attacks/CGNC/text_feature.pth")
|
|
|
self.label_set = get_classes("CL")
|
|
|
self.eps = 32 /255
|
|
|
self.eval_batch_szie = 128
|
|
|
self.surrogate_model_name = f'resnet32_cl'
|
|
|
self.prefix = f'{self.surrogate_model_name}_{len(self.label_set)}classes_eps{int(self.eps * 255)}'
|
|
|
self.save_path = os.path.join(self.args['logs_eval_name'])
|
|
|
os.makedirs(self.save_path, exist_ok=True)
|
|
|
|
|
|
def train_generator(self):
|
|
|
if 'cl' in self.surrogate_model_name:
|
|
|
s_model = factory.get_model(self.args["model_name"], self.args)
|
|
|
s_model.incremental_train(self.data_manager)
|
|
|
s_model._network.load_state_dict(
|
|
|
torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict'])
|
|
|
s_model._network.to(self.device)
|
|
|
s_model._network.eval()
|
|
|
self.surrogate_model = s_model._network
|
|
|
del s_model
|
|
|
torch.cuda.empty_cache()
|
|
|
else:
|
|
|
self.surrogate_model = build_surrogate(self.surrogate_model_name, pretrain=True).to(self.device)
|
|
|
self.surrogate_model.eval()
|
|
|
|
|
|
file_path = os.path.join(self.save_path, f'{self.prefix}.pth')
|
|
|
if os.path.exists(file_path):
|
|
|
self.adv_generator.load_state_dict(torch.load(file_path, map_location=self.device))
|
|
|
self.adv_generator.eval()
|
|
|
else:
|
|
|
self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size,
|
|
|
start_class=0, end_class=10,
|
|
|
train=True, shuffle=True, num_workers=0)
|
|
|
for epoch in range(1, self.num_epoch + 1):
|
|
|
running_loss = 0
|
|
|
laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}')
|
|
|
loss_np = 0
|
|
|
for i, (_, x, y) in enumerate(laoder_tqdm):
|
|
|
imgs = x.to(self.device)
|
|
|
imgs_rot = rotation(x)[0].to(self.device)
|
|
|
color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
|
|
|
aug = transforms.Compose([transforms.ToPILImage(),
|
|
|
transforms.RandomResizedCrop(size=imgs.size(-1)),
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
transforms.RandomApply([color_jitter], p=0.8),
|
|
|
transforms.RandomGrayscale(p=0.2),
|
|
|
transforms.ToTensor()])
|
|
|
imgs_aug = torch.stack([aug(img) for img in x]).to(self.device)
|
|
|
del x, y
|
|
|
|
|
|
label_map = {self.label_set[i]: i for i in range(len(self.label_set))}
|
|
|
np.random.shuffle(self.label_set)
|
|
|
label = np.random.choice(self.label_set, imgs.size(0))
|
|
|
cond = torch.stack([self.text_cond_dict[j] for j in label], dim=0)
|
|
|
label = torch.from_numpy(label).long().to(self.device)
|
|
|
for i in range(len(label)):
|
|
|
label[i] = label_map.get(label[i].item(), label[i].item())
|
|
|
self.adv_generator.train()
|
|
|
self.optim.zero_grad()
|
|
|
|
|
|
|
|
|
noise = self.adv_generator(input=imgs, cond=cond, eps=self.eps)
|
|
|
noise_rot = self.adv_generator(input=imgs_rot, cond=cond, eps=self.eps)
|
|
|
noise_aug = self.adv_generator(input=imgs_aug, cond=cond, eps=self.eps)
|
|
|
|
|
|
adv = noise + imgs
|
|
|
adv = torch.clamp(adv, 0.0, 1.0)
|
|
|
|
|
|
adv_rot = noise_rot + imgs_rot
|
|
|
adv_rot = torch.clamp(adv_rot, 0.0, 1.0)
|
|
|
|
|
|
adv_aug = noise_aug + imgs_aug
|
|
|
adv_aug = torch.clamp(adv_aug, 0.0, 1.0)
|
|
|
|
|
|
adv_out = self.surrogate_model(normalize(adv))
|
|
|
adv_rot_out = self.surrogate_model(normalize(adv_rot))
|
|
|
adv_aug_out = self.surrogate_model(normalize(adv_aug))
|
|
|
|
|
|
loss = self.criterion(adv_out, label) + self.criterion(adv_rot_out, label) + self.criterion(adv_aug_out, label)
|
|
|
loss.backward()
|
|
|
self.optim.step()
|
|
|
|
|
|
if i % 10 == 9:
|
|
|
running_loss = 0
|
|
|
running_loss += abs(loss.item())
|
|
|
loss_np += loss.item()
|
|
|
|
|
|
del imgs, imgs_rot, imgs_aug, adv, label, cond, noise, noise_rot, noise_aug, adv_rot, adv_aug, adv_out, adv_rot_out, adv_aug_out
|
|
|
torch.cuda.empty_cache()
|
|
|
logging.info(f'Epoch {epoch} loss: {loss_np / (len(self.loader))}')
|
|
|
torch.save(self.adv_generator.state_dict(), file_path)
|
|
|
|
|
|
def run_test(self):
|
|
|
|
|
|
self.adv_generator.eval()
|
|
|
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
|
|
|
start_class=0, end_class=10,
|
|
|
train=False, shuffle=True, num_workers=0)
|
|
|
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
|
|
|
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
|
|
|
if i > 0:
|
|
|
break
|
|
|
|
|
|
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
|
|
|
|
|
|
target_imgs = imgs[labels == self.target_class]
|
|
|
target_labels = labels[labels == self.target_class]
|
|
|
imgs_f = imgs[labels != self.target_class]
|
|
|
labels_f = labels[labels != self.target_class]
|
|
|
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
|
|
|
|
|
|
if self.args["model_name"] != 'finetune':
|
|
|
imgs_f, labels_f, labels_t_f = self.to_alls(imgs_f, labels_f,
|
|
|
labels_t_f,
|
|
|
target_imgs,
|
|
|
target_labels)[:3]
|
|
|
self.attacks(i, imgs_f, labels_f, labels_t_f)
|
|
|
|
|
|
|
|
|
def attacks(self, i_batch, imgs, labels, labels_t):
|
|
|
asr_matrix = np.ones((10, len(self.label_set)))
|
|
|
self.model = factory.get_model(self.args["model_name"], self.args)
|
|
|
for task in range(10):
|
|
|
logging.info("***** Starting attack on task [{}]. *****".format(task))
|
|
|
self.model.incremental_train(self.data_manager)
|
|
|
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
|
|
|
self.model._network.to(self.device)
|
|
|
self.model._network.eval()
|
|
|
|
|
|
|
|
|
criterion = fb.criteria.Misclassification(
|
|
|
labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
|
|
|
labels_t)
|
|
|
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
|
|
verify_input_bounds(imgs, current_model)
|
|
|
criterion = get_criterion(criterion)
|
|
|
is_adversarial = get_is_adversarial(criterion, current_model)
|
|
|
|
|
|
logging.info("Eval attack on each target images.")
|
|
|
for idx in range(len(self.label_set)):
|
|
|
cond = torch.tile(self.text_cond_dict[self.label_set[idx]], (len(imgs), 1)).to(torch.float).to(self.device)
|
|
|
noises = self.adv_generator(imgs.raw, cond, eps=self.eps)
|
|
|
advs = noises + imgs.raw
|
|
|
advs = torch.clamp(advs, 0.0, 1.0)
|
|
|
is_adv = is_adversarial(ep.astensor(advs))[0]
|
|
|
asr_matrix[task, idx] = (is_adv.bool().sum().raw.item() / len(imgs))
|
|
|
|
|
|
del advs, noises, cond, is_adv
|
|
|
torch.cuda.empty_cache()
|
|
|
del criterion, current_model, is_adversarial
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
self.model.after_task()
|
|
|
|
|
|
prefix = f'batch{i_batch}_{self.prefix}'
|
|
|
plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args)
|
|
|
for i in range(len(self.label_set)):
|
|
|
df = pd.DataFrame(asr_matrix[:, i], columns=['ASR'])
|
|
|
df.to_excel(os.path.join(self.save_path, f"{prefix}_class{i}.xlsx"), index=False)
|
|
|
|
|
|
del asr_matrix, imgs, labels, labels_t
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
model: Model,
|
|
|
inputs: T,
|
|
|
criterion: Any,
|
|
|
*,
|
|
|
epsilons: Union[Sequence[Union[float, None]], float, None],
|
|
|
**kwargs: Any,
|
|
|
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
|
|
|
...
|
|
|
|
|
|
def repeat(self, times: int) -> "CGNC":
|
|
|
... |