|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from pprint import pformat
|
|
|
from typing import List, Union
|
|
|
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from attacks.GAT.src.gat.datasets import build_dataset, list_datasets
|
|
|
from attacks.GAT.src.gat.datasets.transforms import norm
|
|
|
from attacks.GAT.src.gat.models.attack import AIMAttack, ContrastiveLoss
|
|
|
from attacks.GAT.src.gat.models.surrogate import (build_surrogate, list_surrogates,
|
|
|
midlayer_dict, register_collecter)
|
|
|
from attacks.GAT.src.gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('-v', '--verbose', action='store_true')
|
|
|
parser.add_argument('--seed', type=int, default=0)
|
|
|
parser.add_argument('--expid', type=str, default=randid(4))
|
|
|
parser.add_argument('--workdir', type=str, default='workdirs')
|
|
|
parser.add_argument('--device', type=str, default='cuda')
|
|
|
parser.add_argument('--tar-classes', type=int, default=24)
|
|
|
parser.add_argument('--batch-size', type=int, default=16)
|
|
|
parser.add_argument('--dataset',
|
|
|
type=str,
|
|
|
default='imagenet',
|
|
|
choices=list_datasets())
|
|
|
parser.add_argument('--data-root',
|
|
|
type=str,
|
|
|
default=Path(__file__).parent / '../data' / 'in_1k')
|
|
|
sub_parsers = parser.add_subparsers(dest='command')
|
|
|
train_parser = sub_parsers.add_parser('train')
|
|
|
train_parser.add_argument('--surrogate-id',
|
|
|
type=str,
|
|
|
default='resnet152',
|
|
|
choices=list_surrogates())
|
|
|
train_parser.add_argument('--num-epoch', type=int, default=10)
|
|
|
train_parser.add_argument('--lr', type=float, default=0.0002)
|
|
|
train_parser.add_argument('--betas',
|
|
|
type=float,
|
|
|
nargs=2,
|
|
|
default=(0.5, 0.999))
|
|
|
sub_parsers.add_parser('evaluate')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
args.workdir = Path(args.workdir) / args.expid
|
|
|
args.workdir.mkdir(parents=True, exist_ok=True)
|
|
|
args.device = torch.device(args.device)
|
|
|
|
|
|
args.ckpt = args.workdir / 'model.pth'
|
|
|
|
|
|
fix_random(args.seed)
|
|
|
|
|
|
with open(args.workdir / 'args.txt', 'w') as f:
|
|
|
f.write(pformat(vars(args)))
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
def init_loader(dataset: str,
|
|
|
data_root: Union[str, Path],
|
|
|
tar_classes: Union[int, List[int]],
|
|
|
batch_size: int = 16,
|
|
|
command: str = 'train') -> List[torch.utils.data.DataLoader]:
|
|
|
train_ds = build_dataset(dataset,
|
|
|
data_root=data_root,
|
|
|
is_train=(command == 'train'))
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
|
train_ds,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=4,
|
|
|
pin_memory=True,
|
|
|
)
|
|
|
target_ds = build_dataset(dataset,
|
|
|
data_root=data_root,
|
|
|
is_train=True,
|
|
|
filter_class=tar_classes)
|
|
|
target_loader = torch.utils.data.DataLoader(
|
|
|
target_ds,
|
|
|
batch_size=batch_size,
|
|
|
sampler=torch.utils.data.RandomSampler(target_ds,
|
|
|
replacement=True,
|
|
|
num_samples=len(train_ds)),
|
|
|
num_workers=4,
|
|
|
pin_memory=True,
|
|
|
)
|
|
|
return train_loader, target_loader
|
|
|
|
|
|
|
|
|
def train(
|
|
|
surrogate_id: str,
|
|
|
dataset: str,
|
|
|
data_root: Union[str, Path],
|
|
|
tar_classes: Union[int, List[int]] = 24,
|
|
|
num_epoch: int = 10,
|
|
|
batch_size: int = 16,
|
|
|
lr: float = 0.0002,
|
|
|
betas: Union[float, List[float]] = (0.5, 0.999),
|
|
|
device: Union[str, torch.device] = torch.device('cuda'),
|
|
|
command: str = 'train',
|
|
|
workdir: Union[str,
|
|
|
Path] = Path(__file__).parents[1] / 'workdirs') -> None:
|
|
|
|
|
|
train_loader, target_loader = init_loader(dataset, data_root, tar_classes,
|
|
|
batch_size, command)
|
|
|
normalizer = norm(dataset, _callable=True)
|
|
|
|
|
|
surrogate = build_surrogate(surrogate_id, pretrain=True).to(device)
|
|
|
surrogate.eval()
|
|
|
feat_collecter_handler, feat_collecter = register_collecter(
|
|
|
surrogate, midlayer_dict[surrogate_id])
|
|
|
|
|
|
attack = AIMAttack(device=device)
|
|
|
attack.set_mode('train')
|
|
|
optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
|
|
|
|
|
|
contrastive_loss = ContrastiveLoss(0.2)
|
|
|
sim_loss = torch.nn.functional.cosine_similarity
|
|
|
|
|
|
for epoch in range(1, num_epoch + 1):
|
|
|
attack.set_mode('train')
|
|
|
enumerator = enumerate(zip(train_loader, target_loader))
|
|
|
enumerator = tqdm(enumerator,
|
|
|
total=len(train_loader),
|
|
|
desc=f'Epoch {epoch}')
|
|
|
for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
|
|
|
if torch.any(y_nat == y_tar):
|
|
|
continue
|
|
|
x_nat, x_tar = x_nat.to(device), x_tar.to(device)
|
|
|
y_nat, y_tar = y_nat.to(device), y_tar.to(device)
|
|
|
x_adv = attack(x_nat, x_tar)
|
|
|
|
|
|
logits_nat = surrogate(normalizer(x_nat))
|
|
|
feat_nat = feat_collecter.pop()
|
|
|
logits_tar = surrogate(normalizer(x_tar))
|
|
|
feat_tar = feat_collecter.pop()
|
|
|
logits_adv = surrogate(normalizer(x_adv))
|
|
|
feat_adv = feat_collecter.pop()
|
|
|
|
|
|
loss = (contrastive_loss(logits_adv, logits_nat, logits_tar) +
|
|
|
sim_loss(feat_nat, feat_adv) -
|
|
|
sim_loss(feat_tar, feat_adv)).mean()
|
|
|
|
|
|
optim.zero_grad()
|
|
|
loss.backward()
|
|
|
optim.step()
|
|
|
|
|
|
feat_collecter_handler.remove()
|
|
|
|
|
|
attack.save_ckpt(workdir / 'model.pth')
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def evaluate(
|
|
|
ckpt: Union[str, Path],
|
|
|
dataset: str,
|
|
|
data_root: Union[str, Path],
|
|
|
tar_classes: Union[int, List[int]] = 24,
|
|
|
batch_size: int = 16,
|
|
|
device: Union[str, torch.device] = torch.device('cuda'),
|
|
|
command: str = 'train',
|
|
|
workdir: Union[str,
|
|
|
Path] = Path(__file__).parents[1] / 'workdirs') -> None:
|
|
|
|
|
|
eval_loader, target_loader = init_loader(dataset, data_root, tar_classes,
|
|
|
batch_size, command)
|
|
|
normalizer = norm(dataset, _callable=True)
|
|
|
|
|
|
attack = AIMAttack(device=device)
|
|
|
attack.load_ckpt(ckpt)
|
|
|
attack.set_mode('eval')
|
|
|
|
|
|
models = {
|
|
|
surrogate_id: build_surrogate(surrogate_id, pretrain=True).to(device)
|
|
|
for surrogate_id in list_surrogates()
|
|
|
}
|
|
|
for surrogate_id in models.keys():
|
|
|
models[surrogate_id].eval()
|
|
|
model_meters = {
|
|
|
surrogate_id: [AverageMeter() for _ in range(2)]
|
|
|
for surrogate_id in models.keys()
|
|
|
}
|
|
|
|
|
|
enumerator = enumerate(zip(eval_loader, target_loader))
|
|
|
enumerator = tqdm(enumerator, total=len(eval_loader), desc='Eval')
|
|
|
for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
|
|
|
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
|
|
x_tar, y_tar = x_tar.to(device), y_tar.to(device)
|
|
|
x_adv = attack(x_nat, x_tar)
|
|
|
for surrogate_id, model in models.items():
|
|
|
logits_nat = model(normalizer(x_nat))
|
|
|
logits_adv = model(normalizer(x_adv))
|
|
|
|
|
|
acc = calc_cls_accuracy(logits_nat, y_nat)
|
|
|
asr = calc_cls_accuracy(logits_adv, y_tar)
|
|
|
model_meters[surrogate_id][0].update(acc[0].item(), x_nat.size(0))
|
|
|
model_meters[surrogate_id][1].update(asr[0].item(), x_nat.size(0))
|
|
|
|
|
|
results = {
|
|
|
surrogate_id: {
|
|
|
'acc': meters[0].avg,
|
|
|
'asr': meters[1].avg
|
|
|
}
|
|
|
for surrogate_id, meters in model_meters.items()
|
|
|
}
|
|
|
print(pformat(results))
|
|
|
with open(workdir / 'results.json', 'w') as f:
|
|
|
json.dump(results, f)
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
args = parse_args()
|
|
|
args.command = 'train'
|
|
|
args.surrogate_id = 'resnet152'
|
|
|
args.num_epoch = 10
|
|
|
args.lr = 0.0002
|
|
|
args.betas = (0.5, 0.999)
|
|
|
if args.command == 'train':
|
|
|
train(args.surrogate_id, args.dataset, args.data_root,
|
|
|
args.tar_classes, args.num_epoch, args.batch_size, args.lr,
|
|
|
args.betas, args.device, args.command, args.workdir)
|
|
|
elif args.command == 'evaluate':
|
|
|
evaluate(args.ckpt, args.dataset, args.data_root, args.tar_classes,
|
|
|
args.batch_size, args.device, args.command, args.workdir)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|