#!/usr/bin/env python3 """ Usage ----- ./examples/ens-gen.py -d -v -s 0 \ --dataset imagenet -b 16 --eps 8 --workdir "workdirs" \ --device "cuda:0" \ train --n-ep 1 \ --surrogate-model-ids vgg19 inception_v3 resnet152 densenet169 \ --lr 0.0002 --beta 0.5 0.999 \ --use-logit-loss --use-logit-weights --use-logit-softmax-weights """ import argparse import json from pathlib import Path from pprint import pformat from typing import List, Union import torch import torchvision from torch.nn import functional as F from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from gat.datasets import build_dataset, list_datasets from gat.datasets.transforms import norm from gat.models.attack import CDAAttack from gat.models.attack.optim import (SAM, disable_running_stats, enable_running_stats) from gat.models.surrogate import (build_surrogate, feat_col, list_surrogates, midlayer_dict) from gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid class CLIParser: @staticmethod def init_basic_parser(p: argparse.ArgumentParser): g_basic = p.add_argument_group('Basic Settings') g_basic.add_argument('-v', '--verbose', action='store_true', default=False) g_basic.add_argument('-d', '--dev', action='store_true', default=False) g_basic.add_argument('-s', '--seed', type=int, default=0) g_basic.add_argument('--expid', type=str, default=randid(4)) g_basic.add_argument('--device', type=str, default='cuda') g_path = p.add_argument_group('Path Settings') g_path.add_argument('--workdir', type=str, default='workdirs') g_path.add_argument('--data-root', type=str, default=Path(__file__).parent / '../data' / 'in_1k') g_ds = p.add_argument_group('Dataset Settings') g_ds.add_argument('--dataset', type=str, default='imagenet', choices=list_datasets()) g_ds.add_argument('-b', '--batch-size', type=int, default=16) g_at_basic = p.add_argument_group('General Attack Settings') g_at_basic.add_argument('--eps', '--epsilon', dest='epsilon', type=int, default=8, choices=[1, 2, 4, 8, 16]) @staticmethod def post_basic_parser(args: argparse.Namespace): if args.dev: args.workdir = args.workdir.replace('workdirs', 'workdirs-dev') args.workdir = Path(args.workdir) / args.expid args.workdir.mkdir(parents=True, exist_ok=True) args.device = torch.device(args.device) args.ckpt = args.workdir / 'model.pth' args.tf_logger = SummaryWriter(args.workdir / 'tf_log') args.epsilon /= 255.0 if args.command == 'evaluate-pgd': args.alpha /= 255.0 fix_random(args.seed) if args.verbose: print(pformat(vars(args))) with open(args.workdir / f'args-{args.command}.txt', 'w') as f: f.write(pformat(vars(args))) @staticmethod def init_train_parser(p: argparse.ArgumentParser): g_at = p.add_argument_group('Attack Settings') g_at.add_argument('--sur-ids', '--surrogate-model-ids', dest='surrogate_model_ids', type=str, default=['resnet152'], nargs='+', choices=list_surrogates()) g_at.add_argument('--n-ep', '--num-epoch', dest='num_epoch', type=int, default=10) g_optim = p.add_argument_group('Optimization Settings') g_optim.add_argument('--use-sam', action='store_true', default=False) g_optim.add_argument('--lr', type=float, default=0.0002) g_optim.add_argument('--betas', type=float, nargs=2, default=(0.5, 0.999)) g_loss = p.add_argument_group('Loss Func Settings') g_loss.add_argument('--use-logit-loss', action='store_true', default=False) g_loss.add_argument('--use-logit-kl', action='store_true', default=False) g_loss.add_argument('--use-logit-weights', action='store_true', default=False) g_loss.add_argument('--use-logit-softmax-weights', action='store_true', default=False) g_loss.add_argument('--use-feat-loss', action='store_true', default=False) g_loss.add_argument('--use-feat-attn', action='store_true', default=False) @staticmethod def post_train_parser(args: argparse.Namespace): if args.command == 'train': assert args.use_logit_loss ^ args.use_feat_loss if args.use_logit_kl: assert not args.use_feat_loss if args.use_feat_attn: assert not args.use_logit_loss if args.use_logit_weights: assert args.use_logit_loss if args.use_logit_softmax_weights: assert args.use_logit_loss @staticmethod def init_evaluate_parser(p: argparse.ArgumentParser): pass @staticmethod def post_evaluate_parser(args: argparse.Namespace): pass @staticmethod def init_evaluate_pgd_parser(p: argparse.ArgumentParser): g_at = p.add_argument_group('Attack Settings') g_at.add_argument('--surrogate-model-ids', type=str, default=['resnet152'], nargs='+', choices=list_surrogates()) g_optim = p.add_argument_group('Optimization Settings') g_optim.add_argument('--num-step', type=int, default=100) g_optim.add_argument('--alpha', type=int, default=2, choices=[1, 2, 4, 8, 16]) g_loss = p.add_argument_group('Loss Func Settings') g_loss.add_argument('--use-loss-avg', action='store_true', default=False) g_loss.add_argument('--use-logit-avg', action='store_true', default=False) @staticmethod def post_evaluate_pgd_parser(args: argparse.Namespace): if args.command == 'evaluate-pgd': assert args.use_loss_avg ^ args.use_logit_avg @staticmethod def parse_args(): p = argparse.ArgumentParser() CLIParser.init_basic_parser(p) sub_p = p.add_subparsers(dest='command') CLIParser.init_train_parser(sub_p.add_parser('train')) CLIParser.init_evaluate_parser(sub_p.add_parser('evaluate')) CLIParser.init_evaluate_pgd_parser(sub_p.add_parser('evaluate-pgd')) args = p.parse_args() CLIParser.post_train_parser(args) CLIParser.post_evaluate_parser(args) CLIParser.post_evaluate_pgd_parser(args) CLIParser.post_basic_parser(args) return args def init_loader(dataset: str, data_root: Union[str, Path], num_epoch: int = 1, batch_size: int = 16, command: str = 'train') -> List[torch.utils.data.DataLoader]: ds = build_dataset(dataset, data_root=data_root, is_train=(command == 'train')) dataloader = torch.utils.data.DataLoader( ds, batch_size=batch_size, sampler=torch.utils.data.RandomSampler(ds, replacement=True, num_samples=len(ds) * num_epoch), num_workers=4, pin_memory=True, ) normalizer = norm(dataset, _callable=True) return dataloader, normalizer def init_models(model_ids: Union[str, List[str]], device: Union[str, torch.device] = torch.device('cuda')): if isinstance(model_ids, str): model_ids = [model_ids] models = [ build_surrogate(_surrogate_id, pretrain=True).to(device) for _surrogate_id in model_ids ] for _ in models: _.eval() return models def calc_loss(x_nat: torch.Tensor, y_nat: torch.Tensor, x_adv: torch.Tensor, feat_collecter: List, surrogate_models: List[torch.nn.Module], normalizer: torchvision.transforms.Compose, use_logit_loss: bool, use_logit_kl: bool, use_logit_weights: bool, use_logit_softmax_weights: bool, use_feat_loss: bool, use_feat_attn: bool, device: Union[str, torch.device] = torch.device('cuda')): loss_sur = [] for surrogate_model in surrogate_models: logit_nat = surrogate_model(normalizer(x_nat)) feat_nat = feat_collecter.pop() logit_adv = surrogate_model(normalizer(x_adv)) feat_adv = feat_collecter.pop() if use_logit_loss: if use_logit_kl: loss_sur.append(-(F.kl_div(F.log_softmax(logit_adv, dim=1), F.softmax(logit_nat, dim=1)) + F.kl_div(F.log_softmax(logit_nat, dim=1), F.softmax(logit_adv, dim=1)))) else: loss_sur.append(-(F.cross_entropy(logit_adv, y_nat).mean())) elif use_feat_loss: if use_feat_attn: attn = torch.abs(torch.mean(feat_nat, dim=1, keepdim=True)) else: attn = torch.ones_like(feat_nat) loss_sur.append(1 + F.cosine_similarity(attn * feat_nat, attn * feat_adv).mean()) else: raise NotImplementedError loss_sur = torch.stack(loss_sur) if use_logit_weights: if use_logit_softmax_weights: loss_weights = torch.nn.functional.softmax(loss_sur) else: loss_weights = torch.nn.functional.softmin(loss_sur) loss_all = torch.sum(loss_weights * loss_sur) else: loss_all = loss_sur.mean() return loss_all def train(surrogate_model_ids: Union[str, List[str]], epsilon: float = 16.0 / 255.0, num_epoch: int = 10, dataset: str = 'imagenet', batch_size: int = 16, use_sam: bool = False, lr: float = 0.0002, betas: Union[float, List[float]] = (0.5, 0.999), use_logit_loss: bool = False, use_logit_kl: bool = False, use_logit_weights: bool = False, use_logit_softmax_weights: bool = False, use_feat_loss: bool = False, use_feat_attn: bool = False, device: Union[str, torch.device] = torch.device('cuda'), workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs', data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k', tf_logger: SummaryWriter = None) -> None: """ Train the attack model with the given surrogate models. """ loader, normalizer = init_loader(dataset, data_root, num_epoch, batch_size, 'train') surrogate_models = init_models(surrogate_model_ids, device) attack = CDAAttack(device=device, epsilon=epsilon) attack.set_mode('train') if use_sam: optim = SAM(attack.get_params(), torch.optim.Adam, lr=lr, betas=betas) else: optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas) with feat_col(surrogate_models, [midlayer_dict[_] for _ in surrogate_model_ids]) as feat_collecter: attack.set_mode('train') enumerator = tqdm(enumerate(loader), total=len(loader), desc='') for step, (x_nat, y_nat) in enumerator: x_nat, y_nat = x_nat.to(device), y_nat.to(device) if use_sam: # 1 enable_running_stats(attack.get_model()) loss_v = calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter, surrogate_models, normalizer, use_logit_loss, use_logit_kl, use_logit_weights, use_logit_softmax_weights, use_feat_loss, use_feat_attn, device) loss_v.backward() optim.first_step(zero_grad=True) # 2 disable_running_stats(attack.get_model()) calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter, surrogate_models, normalizer, use_logit_loss, use_logit_kl, use_logit_weights, use_logit_softmax_weights, use_feat_loss, use_feat_attn, device).backward() optim.second_step(zero_grad=True) else: x_adv = attack(x_nat) loss_v = calc_loss(x_nat, y_nat, x_adv, feat_collecter, surrogate_models, normalizer, use_logit_loss, use_logit_kl, use_logit_weights, use_logit_softmax_weights, use_feat_loss, use_feat_attn, device) optim.zero_grad() loss_v.backward() optim.step() if tf_logger: tf_logger.add_scalar('loss', loss_v.item(), step) tf_logger.add_scalar('lr', optim.param_groups[0]['lr'], step) attack.save_ckpt(workdir / 'model.pth') @torch.no_grad() def evaluate( ckpt: Union[str, Path], epsilon: float = 16.0 / 255.0, dataset: str = 'imagenet', batch_size: int = 16, device: Union[str, torch.device] = torch.device('cuda'), workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs', data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k', ) -> None: """ Evaluate the attack model with the given surrogate models """ loader, normalizer = init_loader(dataset, data_root, 1, batch_size, 'evaluate') target_models = { k: v for k, v in zip(list_surrogates(), init_models(list_surrogates(), device)) } target_acc_meters = { target_model_id: [AverageMeter() for _ in range(2)] for target_model_id in target_models.keys() } # init attack method attack = CDAAttack(device=device, epsilon=epsilon) attack.load_ckpt(ckpt) attack.set_mode('eval') # evaluate enumerator = tqdm(enumerate(loader), total=len(loader), desc='Eval') for step, (x_nat, y_nat) in enumerator: x_nat, y_nat = x_nat.to(device), y_nat.to(device) x_adv = attack(x_nat) for target_model_id, target_model in target_models.items(): logit_nat = target_model(normalizer(x_nat)) logit_adv = target_model(normalizer(x_adv)) # collect metrics target_acc = calc_cls_accuracy(logit_nat, y_nat) target_asr = calc_cls_accuracy(logit_adv, y_nat) target_acc_meters[target_model_id][0].update( target_acc[0].item(), x_nat.size(0)) target_acc_meters[target_model_id][1].update( target_asr[0].item(), x_nat.size(0)) results = { target_model_id: { 'nat_acc': target_acc_meter[0].avg, 'adv_acc': target_acc_meter[1].avg } for target_model_id, target_acc_meter in target_acc_meters.items() } print(pformat(results)) with open(workdir / 'results.json', 'w') as f: json.dump(results, f) def evaluate_pgd( surrogate_model_ids: Union[str, List[str]], epsilon: float = 16.0 / 255.0, num_step: int = 1000, alpha: float = 2.0 / 255.0, dataset: str = 'imagenet', batch_size: int = 16, use_loss_avg: bool = False, use_logit_avg: bool = False, device: Union[str, torch.device] = torch.device('cuda'), workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs', data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k', ): loader, normalizer = init_loader(dataset, data_root, 1, batch_size, 'evaluate') surrogate_models = init_models(surrogate_model_ids, device) target_models = { k: v for k, v in zip(list_surrogates(), init_models(list_surrogates(), device)) } target_acc_meters = { target_model_id: [AverageMeter() for _ in range(2)] for target_model_id in target_models.keys() } # evaluate enumerator = tqdm(enumerate(loader), total=len(loader), desc='') for step, (x_nat, y_nat) in enumerator: x_nat, y_nat = x_nat.to(device), y_nat.to(device) # attack x_nat_ori = x_nat.data for _ in range(num_step): x_nat.requires_grad = True if use_loss_avg: loss_all = 0.0 for surrogate_model in surrogate_models: logit = surrogate_model(x_nat) surrogate_model.zero_grad() loss_all += F.cross_entropy(logit, y_nat) elif use_logit_avg: logit = torch.stack([ surrogate_model(x_nat) for surrogate_model in surrogate_models ]).mean(dim=0) loss_all = F.cross_entropy(logit, y_nat) else: raise NotADirectoryError loss_all.backward() x_adv_ = x_nat + alpha * x_nat.grad.sign() eta = torch.clamp(x_adv_ - x_nat_ori, min=-epsilon, max=epsilon) x_nat = torch.clamp(x_nat_ori + eta, min=0.0, max=1.0).detach_() x_adv = x_nat x_nat = x_nat_ori # eval with torch.no_grad(): for target_model_id, target_model in target_models.items(): logit_nat = target_model(normalizer(x_nat)) logit_adv = target_model(normalizer(x_adv)) # collect target_acc_ = calc_cls_accuracy(logit_nat, y_nat) target_asr_ = calc_cls_accuracy(logit_adv, y_nat) target_acc_meters[target_model_id][0].update( target_acc_[0].item(), x_nat.size(0)) target_acc_meters[target_model_id][1].update( target_asr_[0].item(), x_nat.size(0)) results = { target_model_id: { 'nat_acc': target_acc_meter[0].avg, 'adv_acc': target_acc_meter[1].avg } for target_model_id, target_acc_meter in target_acc_meters.items() } print(pformat(results)) with open(workdir / 'results-pgd.json', 'w') as f: json.dump(results, f) def main() -> None: args = CLIParser.parse_args() if args.command == 'train': train(args.surrogate_model_ids, args.epsilon, args.num_epoch, args.dataset, args.batch_size, args.use_sam, args.lr, args.betas, args.use_logit_loss, args.use_logit_kl, args.use_logit_weights, args.use_logit_softmax_weights, args.use_feat_loss, args.use_feat_attn, args.device, args.workdir, args.data_root, args.tf_logger) elif args.command == 'evaluate': evaluate(args.ckpt, args.epsilon, args.dataset, args.batch_size, args.device, args.workdir, args.data_root) elif args.command == 'evaluate-pgd': evaluate_pgd(args.surrogate_model_ids, args.epsilon, args.num_step, args.alpha, args.dataset, args.batch_size, args.use_loss_avg, args.use_logit_avg, args.device, args.workdir, args.data_root) else: raise NotImplementedError if __name__ == '__main__': main()