#!/usr/bin/env python3 # Usage: # ./examples/aim_attack.py -h import argparse import json from pathlib import Path from pprint import pformat from typing import List, Union import torch from tqdm import tqdm from attacks.GAT.src.gat.datasets import build_dataset, list_datasets from attacks.GAT.src.gat.datasets.transforms import norm from attacks.GAT.src.gat.models.attack import AIMAttack, ContrastiveLoss from attacks.GAT.src.gat.models.surrogate import (build_surrogate, list_surrogates, midlayer_dict, register_collecter) from attacks.GAT.src.gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--expid', type=str, default=randid(4)) parser.add_argument('--workdir', type=str, default='workdirs') parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--tar-classes', type=int, default=24) parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--dataset', type=str, default='imagenet', choices=list_datasets()) parser.add_argument('--data-root', type=str, default=Path(__file__).parent / '../data' / 'in_1k') sub_parsers = parser.add_subparsers(dest='command') train_parser = sub_parsers.add_parser('train') train_parser.add_argument('--surrogate-id', type=str, default='resnet152', choices=list_surrogates()) train_parser.add_argument('--num-epoch', type=int, default=10) train_parser.add_argument('--lr', type=float, default=0.0002) train_parser.add_argument('--betas', type=float, nargs=2, default=(0.5, 0.999)) sub_parsers.add_parser('evaluate') args = parser.parse_args() args.workdir = Path(args.workdir) / args.expid args.workdir.mkdir(parents=True, exist_ok=True) args.device = torch.device(args.device) args.ckpt = args.workdir / 'model.pth' fix_random(args.seed) with open(args.workdir / 'args.txt', 'w') as f: f.write(pformat(vars(args))) return args def init_loader(dataset: str, data_root: Union[str, Path], tar_classes: Union[int, List[int]], batch_size: int = 16, command: str = 'train') -> List[torch.utils.data.DataLoader]: train_ds = build_dataset(dataset, data_root=data_root, is_train=(command == 'train')) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, ) target_ds = build_dataset(dataset, data_root=data_root, is_train=True, filter_class=tar_classes) target_loader = torch.utils.data.DataLoader( target_ds, batch_size=batch_size, sampler=torch.utils.data.RandomSampler(target_ds, replacement=True, num_samples=len(train_ds)), num_workers=4, pin_memory=True, ) return train_loader, target_loader def train( surrogate_id: str, dataset: str, data_root: Union[str, Path], tar_classes: Union[int, List[int]] = 24, num_epoch: int = 10, batch_size: int = 16, lr: float = 0.0002, betas: Union[float, List[float]] = (0.5, 0.999), device: Union[str, torch.device] = torch.device('cuda'), command: str = 'train', workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs') -> None: train_loader, target_loader = init_loader(dataset, data_root, tar_classes, batch_size, command) normalizer = norm(dataset, _callable=True) surrogate = build_surrogate(surrogate_id, pretrain=True).to(device) surrogate.eval() feat_collecter_handler, feat_collecter = register_collecter( surrogate, midlayer_dict[surrogate_id]) attack = AIMAttack(device=device) attack.set_mode('train') optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas) contrastive_loss = ContrastiveLoss(0.2) sim_loss = torch.nn.functional.cosine_similarity for epoch in range(1, num_epoch + 1): attack.set_mode('train') enumerator = enumerate(zip(train_loader, target_loader)) enumerator = tqdm(enumerator, total=len(train_loader), desc=f'Epoch {epoch}') for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator: if torch.any(y_nat == y_tar): continue x_nat, x_tar = x_nat.to(device), x_tar.to(device) y_nat, y_tar = y_nat.to(device), y_tar.to(device) x_adv = attack(x_nat, x_tar) logits_nat = surrogate(normalizer(x_nat)) feat_nat = feat_collecter.pop() logits_tar = surrogate(normalizer(x_tar)) feat_tar = feat_collecter.pop() logits_adv = surrogate(normalizer(x_adv)) feat_adv = feat_collecter.pop() loss = (contrastive_loss(logits_adv, logits_nat, logits_tar) + sim_loss(feat_nat, feat_adv) - sim_loss(feat_tar, feat_adv)).mean() optim.zero_grad() loss.backward() optim.step() feat_collecter_handler.remove() attack.save_ckpt(workdir / 'model.pth') @torch.no_grad() def evaluate( ckpt: Union[str, Path], dataset: str, data_root: Union[str, Path], tar_classes: Union[int, List[int]] = 24, batch_size: int = 16, device: Union[str, torch.device] = torch.device('cuda'), command: str = 'train', workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs') -> None: # init dataloader eval_loader, target_loader = init_loader(dataset, data_root, tar_classes, batch_size, command) normalizer = norm(dataset, _callable=True) # init attack method attack = AIMAttack(device=device) attack.load_ckpt(ckpt) attack.set_mode('eval') # init evaluate models models = { surrogate_id: build_surrogate(surrogate_id, pretrain=True).to(device) for surrogate_id in list_surrogates() } for surrogate_id in models.keys(): models[surrogate_id].eval() model_meters = { surrogate_id: [AverageMeter() for _ in range(2)] for surrogate_id in models.keys() } # evaluate enumerator = enumerate(zip(eval_loader, target_loader)) enumerator = tqdm(enumerator, total=len(eval_loader), desc='Eval') for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator: x_nat, y_nat = x_nat.to(device), y_nat.to(device) x_tar, y_tar = x_tar.to(device), y_tar.to(device) x_adv = attack(x_nat, x_tar) for surrogate_id, model in models.items(): logits_nat = model(normalizer(x_nat)) logits_adv = model(normalizer(x_adv)) # collect metrics acc = calc_cls_accuracy(logits_nat, y_nat) asr = calc_cls_accuracy(logits_adv, y_tar) model_meters[surrogate_id][0].update(acc[0].item(), x_nat.size(0)) model_meters[surrogate_id][1].update(asr[0].item(), x_nat.size(0)) # print result results = { surrogate_id: { 'acc': meters[0].avg, 'asr': meters[1].avg } for surrogate_id, meters in model_meters.items() } print(pformat(results)) with open(workdir / 'results.json', 'w') as f: json.dump(results, f) def main() -> None: args = parse_args() args.command = 'train' args.surrogate_id = 'resnet152' args.num_epoch = 10 args.lr = 0.0002 args.betas = (0.5, 0.999) if args.command == 'train': train(args.surrogate_id, args.dataset, args.data_root, args.tar_classes, args.num_epoch, args.batch_size, args.lr, args.betas, args.device, args.command, args.workdir) elif args.command == 'evaluate': evaluate(args.ckpt, args.dataset, args.data_root, args.tar_classes, args.batch_size, args.device, args.command, args.workdir) else: raise NotImplementedError if __name__ == '__main__': main()