SAE / attacks /AIM /examples /aim_attack.py
Ttius's picture
Upload 192 files
998bb30 verified
#!/usr/bin/env python3
# Usage:
# ./examples/aim_attack.py -h
import argparse
import json
from pathlib import Path
from pprint import pformat
from typing import List, Union
import torch
from tqdm import tqdm
from attacks.GAT.src.gat.datasets import build_dataset, list_datasets
from attacks.GAT.src.gat.datasets.transforms import norm
from attacks.GAT.src.gat.models.attack import AIMAttack, ContrastiveLoss
from attacks.GAT.src.gat.models.surrogate import (build_surrogate, list_surrogates,
midlayer_dict, register_collecter)
from attacks.GAT.src.gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--expid', type=str, default=randid(4))
parser.add_argument('--workdir', type=str, default='workdirs')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--tar-classes', type=int, default=24)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--dataset',
type=str,
default='imagenet',
choices=list_datasets())
parser.add_argument('--data-root',
type=str,
default=Path(__file__).parent / '../data' / 'in_1k')
sub_parsers = parser.add_subparsers(dest='command')
train_parser = sub_parsers.add_parser('train')
train_parser.add_argument('--surrogate-id',
type=str,
default='resnet152',
choices=list_surrogates())
train_parser.add_argument('--num-epoch', type=int, default=10)
train_parser.add_argument('--lr', type=float, default=0.0002)
train_parser.add_argument('--betas',
type=float,
nargs=2,
default=(0.5, 0.999))
sub_parsers.add_parser('evaluate')
args = parser.parse_args()
args.workdir = Path(args.workdir) / args.expid
args.workdir.mkdir(parents=True, exist_ok=True)
args.device = torch.device(args.device)
args.ckpt = args.workdir / 'model.pth'
fix_random(args.seed)
with open(args.workdir / 'args.txt', 'w') as f:
f.write(pformat(vars(args)))
return args
def init_loader(dataset: str,
data_root: Union[str, Path],
tar_classes: Union[int, List[int]],
batch_size: int = 16,
command: str = 'train') -> List[torch.utils.data.DataLoader]:
train_ds = build_dataset(dataset,
data_root=data_root,
is_train=(command == 'train'))
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
)
target_ds = build_dataset(dataset,
data_root=data_root,
is_train=True,
filter_class=tar_classes)
target_loader = torch.utils.data.DataLoader(
target_ds,
batch_size=batch_size,
sampler=torch.utils.data.RandomSampler(target_ds,
replacement=True,
num_samples=len(train_ds)),
num_workers=4,
pin_memory=True,
)
return train_loader, target_loader
def train(
surrogate_id: str,
dataset: str,
data_root: Union[str, Path],
tar_classes: Union[int, List[int]] = 24,
num_epoch: int = 10,
batch_size: int = 16,
lr: float = 0.0002,
betas: Union[float, List[float]] = (0.5, 0.999),
device: Union[str, torch.device] = torch.device('cuda'),
command: str = 'train',
workdir: Union[str,
Path] = Path(__file__).parents[1] / 'workdirs') -> None:
train_loader, target_loader = init_loader(dataset, data_root, tar_classes,
batch_size, command)
normalizer = norm(dataset, _callable=True)
surrogate = build_surrogate(surrogate_id, pretrain=True).to(device)
surrogate.eval()
feat_collecter_handler, feat_collecter = register_collecter(
surrogate, midlayer_dict[surrogate_id])
attack = AIMAttack(device=device)
attack.set_mode('train')
optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
contrastive_loss = ContrastiveLoss(0.2)
sim_loss = torch.nn.functional.cosine_similarity
for epoch in range(1, num_epoch + 1):
attack.set_mode('train')
enumerator = enumerate(zip(train_loader, target_loader))
enumerator = tqdm(enumerator,
total=len(train_loader),
desc=f'Epoch {epoch}')
for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
if torch.any(y_nat == y_tar):
continue
x_nat, x_tar = x_nat.to(device), x_tar.to(device)
y_nat, y_tar = y_nat.to(device), y_tar.to(device)
x_adv = attack(x_nat, x_tar)
logits_nat = surrogate(normalizer(x_nat))
feat_nat = feat_collecter.pop()
logits_tar = surrogate(normalizer(x_tar))
feat_tar = feat_collecter.pop()
logits_adv = surrogate(normalizer(x_adv))
feat_adv = feat_collecter.pop()
loss = (contrastive_loss(logits_adv, logits_nat, logits_tar) +
sim_loss(feat_nat, feat_adv) -
sim_loss(feat_tar, feat_adv)).mean()
optim.zero_grad()
loss.backward()
optim.step()
feat_collecter_handler.remove()
attack.save_ckpt(workdir / 'model.pth')
@torch.no_grad()
def evaluate(
ckpt: Union[str, Path],
dataset: str,
data_root: Union[str, Path],
tar_classes: Union[int, List[int]] = 24,
batch_size: int = 16,
device: Union[str, torch.device] = torch.device('cuda'),
command: str = 'train',
workdir: Union[str,
Path] = Path(__file__).parents[1] / 'workdirs') -> None:
# init dataloader
eval_loader, target_loader = init_loader(dataset, data_root, tar_classes,
batch_size, command)
normalizer = norm(dataset, _callable=True)
# init attack method
attack = AIMAttack(device=device)
attack.load_ckpt(ckpt)
attack.set_mode('eval')
# init evaluate models
models = {
surrogate_id: build_surrogate(surrogate_id, pretrain=True).to(device)
for surrogate_id in list_surrogates()
}
for surrogate_id in models.keys():
models[surrogate_id].eval()
model_meters = {
surrogate_id: [AverageMeter() for _ in range(2)]
for surrogate_id in models.keys()
}
# evaluate
enumerator = enumerate(zip(eval_loader, target_loader))
enumerator = tqdm(enumerator, total=len(eval_loader), desc='Eval')
for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
x_tar, y_tar = x_tar.to(device), y_tar.to(device)
x_adv = attack(x_nat, x_tar)
for surrogate_id, model in models.items():
logits_nat = model(normalizer(x_nat))
logits_adv = model(normalizer(x_adv))
# collect metrics
acc = calc_cls_accuracy(logits_nat, y_nat)
asr = calc_cls_accuracy(logits_adv, y_tar)
model_meters[surrogate_id][0].update(acc[0].item(), x_nat.size(0))
model_meters[surrogate_id][1].update(asr[0].item(), x_nat.size(0))
# print result
results = {
surrogate_id: {
'acc': meters[0].avg,
'asr': meters[1].avg
}
for surrogate_id, meters in model_meters.items()
}
print(pformat(results))
with open(workdir / 'results.json', 'w') as f:
json.dump(results, f)
def main() -> None:
args = parse_args()
args.command = 'train'
args.surrogate_id = 'resnet152'
args.num_epoch = 10
args.lr = 0.0002
args.betas = (0.5, 0.999)
if args.command == 'train':
train(args.surrogate_id, args.dataset, args.data_root,
args.tar_classes, args.num_epoch, args.batch_size, args.lr,
args.betas, args.device, args.command, args.workdir)
elif args.command == 'evaluate':
evaluate(args.ckpt, args.dataset, args.data_root, args.tar_classes,
args.batch_size, args.device, args.command, args.workdir)
else:
raise NotImplementedError
if __name__ == '__main__':
main()