SAE / attacks /AIM /examples /ens-gen.py
Ttius's picture
Upload 192 files
998bb30 verified
#!/usr/bin/env python3
"""
Usage
-----
./examples/ens-gen.py -d -v -s 0 \
--dataset imagenet -b 16 --eps 8 --workdir "workdirs" \
--device "cuda:0" \
train --n-ep 1 \
--surrogate-model-ids vgg19 inception_v3 resnet152 densenet169 \
--lr 0.0002 --beta 0.5 0.999 \
--use-logit-loss --use-logit-weights --use-logit-softmax-weights
"""
import argparse
import json
from pathlib import Path
from pprint import pformat
from typing import List, Union
import torch
import torchvision
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from gat.datasets import build_dataset, list_datasets
from gat.datasets.transforms import norm
from gat.models.attack import CDAAttack
from gat.models.attack.optim import (SAM, disable_running_stats,
enable_running_stats)
from gat.models.surrogate import (build_surrogate, feat_col, list_surrogates,
midlayer_dict)
from gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
class CLIParser:
@staticmethod
def init_basic_parser(p: argparse.ArgumentParser):
g_basic = p.add_argument_group('Basic Settings')
g_basic.add_argument('-v',
'--verbose',
action='store_true',
default=False)
g_basic.add_argument('-d', '--dev', action='store_true', default=False)
g_basic.add_argument('-s', '--seed', type=int, default=0)
g_basic.add_argument('--expid', type=str, default=randid(4))
g_basic.add_argument('--device', type=str, default='cuda')
g_path = p.add_argument_group('Path Settings')
g_path.add_argument('--workdir', type=str, default='workdirs')
g_path.add_argument('--data-root',
type=str,
default=Path(__file__).parent / '../data' /
'in_1k')
g_ds = p.add_argument_group('Dataset Settings')
g_ds.add_argument('--dataset',
type=str,
default='imagenet',
choices=list_datasets())
g_ds.add_argument('-b', '--batch-size', type=int, default=16)
g_at_basic = p.add_argument_group('General Attack Settings')
g_at_basic.add_argument('--eps',
'--epsilon',
dest='epsilon',
type=int,
default=8,
choices=[1, 2, 4, 8, 16])
@staticmethod
def post_basic_parser(args: argparse.Namespace):
if args.dev:
args.workdir = args.workdir.replace('workdirs', 'workdirs-dev')
args.workdir = Path(args.workdir) / args.expid
args.workdir.mkdir(parents=True, exist_ok=True)
args.device = torch.device(args.device)
args.ckpt = args.workdir / 'model.pth'
args.tf_logger = SummaryWriter(args.workdir / 'tf_log')
args.epsilon /= 255.0
if args.command == 'evaluate-pgd':
args.alpha /= 255.0
fix_random(args.seed)
if args.verbose:
print(pformat(vars(args)))
with open(args.workdir / f'args-{args.command}.txt', 'w') as f:
f.write(pformat(vars(args)))
@staticmethod
def init_train_parser(p: argparse.ArgumentParser):
g_at = p.add_argument_group('Attack Settings')
g_at.add_argument('--sur-ids',
'--surrogate-model-ids',
dest='surrogate_model_ids',
type=str,
default=['resnet152'],
nargs='+',
choices=list_surrogates())
g_at.add_argument('--n-ep',
'--num-epoch',
dest='num_epoch',
type=int,
default=10)
g_optim = p.add_argument_group('Optimization Settings')
g_optim.add_argument('--use-sam', action='store_true', default=False)
g_optim.add_argument('--lr', type=float, default=0.0002)
g_optim.add_argument('--betas',
type=float,
nargs=2,
default=(0.5, 0.999))
g_loss = p.add_argument_group('Loss Func Settings')
g_loss.add_argument('--use-logit-loss',
action='store_true',
default=False)
g_loss.add_argument('--use-logit-kl',
action='store_true',
default=False)
g_loss.add_argument('--use-logit-weights',
action='store_true',
default=False)
g_loss.add_argument('--use-logit-softmax-weights',
action='store_true',
default=False)
g_loss.add_argument('--use-feat-loss',
action='store_true',
default=False)
g_loss.add_argument('--use-feat-attn',
action='store_true',
default=False)
@staticmethod
def post_train_parser(args: argparse.Namespace):
if args.command == 'train':
assert args.use_logit_loss ^ args.use_feat_loss
if args.use_logit_kl:
assert not args.use_feat_loss
if args.use_feat_attn:
assert not args.use_logit_loss
if args.use_logit_weights:
assert args.use_logit_loss
if args.use_logit_softmax_weights:
assert args.use_logit_loss
@staticmethod
def init_evaluate_parser(p: argparse.ArgumentParser):
pass
@staticmethod
def post_evaluate_parser(args: argparse.Namespace):
pass
@staticmethod
def init_evaluate_pgd_parser(p: argparse.ArgumentParser):
g_at = p.add_argument_group('Attack Settings')
g_at.add_argument('--surrogate-model-ids',
type=str,
default=['resnet152'],
nargs='+',
choices=list_surrogates())
g_optim = p.add_argument_group('Optimization Settings')
g_optim.add_argument('--num-step', type=int, default=100)
g_optim.add_argument('--alpha',
type=int,
default=2,
choices=[1, 2, 4, 8, 16])
g_loss = p.add_argument_group('Loss Func Settings')
g_loss.add_argument('--use-loss-avg',
action='store_true',
default=False)
g_loss.add_argument('--use-logit-avg',
action='store_true',
default=False)
@staticmethod
def post_evaluate_pgd_parser(args: argparse.Namespace):
if args.command == 'evaluate-pgd':
assert args.use_loss_avg ^ args.use_logit_avg
@staticmethod
def parse_args():
p = argparse.ArgumentParser()
CLIParser.init_basic_parser(p)
sub_p = p.add_subparsers(dest='command')
CLIParser.init_train_parser(sub_p.add_parser('train'))
CLIParser.init_evaluate_parser(sub_p.add_parser('evaluate'))
CLIParser.init_evaluate_pgd_parser(sub_p.add_parser('evaluate-pgd'))
args = p.parse_args()
CLIParser.post_train_parser(args)
CLIParser.post_evaluate_parser(args)
CLIParser.post_evaluate_pgd_parser(args)
CLIParser.post_basic_parser(args)
return args
def init_loader(dataset: str,
data_root: Union[str, Path],
num_epoch: int = 1,
batch_size: int = 16,
command: str = 'train') -> List[torch.utils.data.DataLoader]:
ds = build_dataset(dataset,
data_root=data_root,
is_train=(command == 'train'))
dataloader = torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
sampler=torch.utils.data.RandomSampler(ds,
replacement=True,
num_samples=len(ds) *
num_epoch),
num_workers=4,
pin_memory=True,
)
normalizer = norm(dataset, _callable=True)
return dataloader, normalizer
def init_models(model_ids: Union[str, List[str]],
device: Union[str, torch.device] = torch.device('cuda')):
if isinstance(model_ids, str):
model_ids = [model_ids]
models = [
build_surrogate(_surrogate_id, pretrain=True).to(device)
for _surrogate_id in model_ids
]
for _ in models:
_.eval()
return models
def calc_loss(x_nat: torch.Tensor,
y_nat: torch.Tensor,
x_adv: torch.Tensor,
feat_collecter: List,
surrogate_models: List[torch.nn.Module],
normalizer: torchvision.transforms.Compose,
use_logit_loss: bool,
use_logit_kl: bool,
use_logit_weights: bool,
use_logit_softmax_weights: bool,
use_feat_loss: bool,
use_feat_attn: bool,
device: Union[str, torch.device] = torch.device('cuda')):
loss_sur = []
for surrogate_model in surrogate_models:
logit_nat = surrogate_model(normalizer(x_nat))
feat_nat = feat_collecter.pop()
logit_adv = surrogate_model(normalizer(x_adv))
feat_adv = feat_collecter.pop()
if use_logit_loss:
if use_logit_kl:
loss_sur.append(-(F.kl_div(F.log_softmax(logit_adv, dim=1),
F.softmax(logit_nat, dim=1)) +
F.kl_div(F.log_softmax(logit_nat, dim=1),
F.softmax(logit_adv, dim=1))))
else:
loss_sur.append(-(F.cross_entropy(logit_adv, y_nat).mean()))
elif use_feat_loss:
if use_feat_attn:
attn = torch.abs(torch.mean(feat_nat, dim=1, keepdim=True))
else:
attn = torch.ones_like(feat_nat)
loss_sur.append(1 + F.cosine_similarity(attn * feat_nat, attn *
feat_adv).mean())
else:
raise NotImplementedError
loss_sur = torch.stack(loss_sur)
if use_logit_weights:
if use_logit_softmax_weights:
loss_weights = torch.nn.functional.softmax(loss_sur)
else:
loss_weights = torch.nn.functional.softmin(loss_sur)
loss_all = torch.sum(loss_weights * loss_sur)
else:
loss_all = loss_sur.mean()
return loss_all
def train(surrogate_model_ids: Union[str, List[str]],
epsilon: float = 16.0 / 255.0,
num_epoch: int = 10,
dataset: str = 'imagenet',
batch_size: int = 16,
use_sam: bool = False,
lr: float = 0.0002,
betas: Union[float, List[float]] = (0.5, 0.999),
use_logit_loss: bool = False,
use_logit_kl: bool = False,
use_logit_weights: bool = False,
use_logit_softmax_weights: bool = False,
use_feat_loss: bool = False,
use_feat_attn: bool = False,
device: Union[str, torch.device] = torch.device('cuda'),
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
data_root: Union[str,
Path] = Path(__file__).parent / '../data' / 'in_1k',
tf_logger: SummaryWriter = None) -> None:
"""
Train the attack model with the given surrogate models.
"""
loader, normalizer = init_loader(dataset, data_root, num_epoch, batch_size,
'train')
surrogate_models = init_models(surrogate_model_ids, device)
attack = CDAAttack(device=device, epsilon=epsilon)
attack.set_mode('train')
if use_sam:
optim = SAM(attack.get_params(), torch.optim.Adam, lr=lr, betas=betas)
else:
optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
with feat_col(surrogate_models,
[midlayer_dict[_]
for _ in surrogate_model_ids]) as feat_collecter:
attack.set_mode('train')
enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
for step, (x_nat, y_nat) in enumerator:
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
if use_sam:
# 1
enable_running_stats(attack.get_model())
loss_v = calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
surrogate_models, normalizer,
use_logit_loss, use_logit_kl,
use_logit_weights,
use_logit_softmax_weights, use_feat_loss,
use_feat_attn, device)
loss_v.backward()
optim.first_step(zero_grad=True)
# 2
disable_running_stats(attack.get_model())
calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
surrogate_models, normalizer, use_logit_loss,
use_logit_kl, use_logit_weights,
use_logit_softmax_weights, use_feat_loss,
use_feat_attn, device).backward()
optim.second_step(zero_grad=True)
else:
x_adv = attack(x_nat)
loss_v = calc_loss(x_nat, y_nat, x_adv, feat_collecter,
surrogate_models, normalizer,
use_logit_loss, use_logit_kl,
use_logit_weights,
use_logit_softmax_weights, use_feat_loss,
use_feat_attn, device)
optim.zero_grad()
loss_v.backward()
optim.step()
if tf_logger:
tf_logger.add_scalar('loss', loss_v.item(), step)
tf_logger.add_scalar('lr', optim.param_groups[0]['lr'], step)
attack.save_ckpt(workdir / 'model.pth')
@torch.no_grad()
def evaluate(
ckpt: Union[str, Path],
epsilon: float = 16.0 / 255.0,
dataset: str = 'imagenet',
batch_size: int = 16,
device: Union[str, torch.device] = torch.device('cuda'),
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
) -> None:
"""
Evaluate the attack model with the given surrogate models
"""
loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
'evaluate')
target_models = {
k: v
for k, v in zip(list_surrogates(),
init_models(list_surrogates(), device))
}
target_acc_meters = {
target_model_id: [AverageMeter() for _ in range(2)]
for target_model_id in target_models.keys()
}
# init attack method
attack = CDAAttack(device=device, epsilon=epsilon)
attack.load_ckpt(ckpt)
attack.set_mode('eval')
# evaluate
enumerator = tqdm(enumerate(loader), total=len(loader), desc='Eval')
for step, (x_nat, y_nat) in enumerator:
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
x_adv = attack(x_nat)
for target_model_id, target_model in target_models.items():
logit_nat = target_model(normalizer(x_nat))
logit_adv = target_model(normalizer(x_adv))
# collect metrics
target_acc = calc_cls_accuracy(logit_nat, y_nat)
target_asr = calc_cls_accuracy(logit_adv, y_nat)
target_acc_meters[target_model_id][0].update(
target_acc[0].item(), x_nat.size(0))
target_acc_meters[target_model_id][1].update(
target_asr[0].item(), x_nat.size(0))
results = {
target_model_id: {
'nat_acc': target_acc_meter[0].avg,
'adv_acc': target_acc_meter[1].avg
}
for target_model_id, target_acc_meter in target_acc_meters.items()
}
print(pformat(results))
with open(workdir / 'results.json', 'w') as f:
json.dump(results, f)
def evaluate_pgd(
surrogate_model_ids: Union[str, List[str]],
epsilon: float = 16.0 / 255.0,
num_step: int = 1000,
alpha: float = 2.0 / 255.0,
dataset: str = 'imagenet',
batch_size: int = 16,
use_loss_avg: bool = False,
use_logit_avg: bool = False,
device: Union[str, torch.device] = torch.device('cuda'),
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
):
loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
'evaluate')
surrogate_models = init_models(surrogate_model_ids, device)
target_models = {
k: v
for k, v in zip(list_surrogates(),
init_models(list_surrogates(), device))
}
target_acc_meters = {
target_model_id: [AverageMeter() for _ in range(2)]
for target_model_id in target_models.keys()
}
# evaluate
enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
for step, (x_nat, y_nat) in enumerator:
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
# attack
x_nat_ori = x_nat.data
for _ in range(num_step):
x_nat.requires_grad = True
if use_loss_avg:
loss_all = 0.0
for surrogate_model in surrogate_models:
logit = surrogate_model(x_nat)
surrogate_model.zero_grad()
loss_all += F.cross_entropy(logit, y_nat)
elif use_logit_avg:
logit = torch.stack([
surrogate_model(x_nat)
for surrogate_model in surrogate_models
]).mean(dim=0)
loss_all = F.cross_entropy(logit, y_nat)
else:
raise NotADirectoryError
loss_all.backward()
x_adv_ = x_nat + alpha * x_nat.grad.sign()
eta = torch.clamp(x_adv_ - x_nat_ori, min=-epsilon, max=epsilon)
x_nat = torch.clamp(x_nat_ori + eta, min=0.0, max=1.0).detach_()
x_adv = x_nat
x_nat = x_nat_ori
# eval
with torch.no_grad():
for target_model_id, target_model in target_models.items():
logit_nat = target_model(normalizer(x_nat))
logit_adv = target_model(normalizer(x_adv))
# collect
target_acc_ = calc_cls_accuracy(logit_nat, y_nat)
target_asr_ = calc_cls_accuracy(logit_adv, y_nat)
target_acc_meters[target_model_id][0].update(
target_acc_[0].item(), x_nat.size(0))
target_acc_meters[target_model_id][1].update(
target_asr_[0].item(), x_nat.size(0))
results = {
target_model_id: {
'nat_acc': target_acc_meter[0].avg,
'adv_acc': target_acc_meter[1].avg
}
for target_model_id, target_acc_meter in target_acc_meters.items()
}
print(pformat(results))
with open(workdir / 'results-pgd.json', 'w') as f:
json.dump(results, f)
def main() -> None:
args = CLIParser.parse_args()
if args.command == 'train':
train(args.surrogate_model_ids, args.epsilon, args.num_epoch,
args.dataset, args.batch_size, args.use_sam, args.lr, args.betas,
args.use_logit_loss, args.use_logit_kl, args.use_logit_weights,
args.use_logit_softmax_weights, args.use_feat_loss,
args.use_feat_attn, args.device, args.workdir, args.data_root,
args.tf_logger)
elif args.command == 'evaluate':
evaluate(args.ckpt, args.epsilon, args.dataset, args.batch_size,
args.device, args.workdir, args.data_root)
elif args.command == 'evaluate-pgd':
evaluate_pgd(args.surrogate_model_ids, args.epsilon, args.num_step,
args.alpha, args.dataset, args.batch_size,
args.use_loss_avg, args.use_logit_avg, args.device,
args.workdir, args.data_root)
else:
raise NotImplementedError
if __name__ == '__main__':
main()