|
|
|
|
|
"""
|
|
|
Usage
|
|
|
-----
|
|
|
./examples/ens-gen.py -d -v -s 0 \
|
|
|
--dataset imagenet -b 16 --eps 8 --workdir "workdirs" \
|
|
|
--device "cuda:0" \
|
|
|
train --n-ep 1 \
|
|
|
--surrogate-model-ids vgg19 inception_v3 resnet152 densenet169 \
|
|
|
--lr 0.0002 --beta 0.5 0.999 \
|
|
|
--use-logit-loss --use-logit-weights --use-logit-softmax-weights
|
|
|
"""
|
|
|
import argparse
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from pprint import pformat
|
|
|
from typing import List, Union
|
|
|
|
|
|
import torch
|
|
|
import torchvision
|
|
|
from torch.nn import functional as F
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from gat.datasets import build_dataset, list_datasets
|
|
|
from gat.datasets.transforms import norm
|
|
|
from gat.models.attack import CDAAttack
|
|
|
from gat.models.attack.optim import (SAM, disable_running_stats,
|
|
|
enable_running_stats)
|
|
|
from gat.models.surrogate import (build_surrogate, feat_col, list_surrogates,
|
|
|
midlayer_dict)
|
|
|
from gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
|
|
|
|
|
|
|
|
|
class CLIParser:
|
|
|
|
|
|
@staticmethod
|
|
|
def init_basic_parser(p: argparse.ArgumentParser):
|
|
|
g_basic = p.add_argument_group('Basic Settings')
|
|
|
g_basic.add_argument('-v',
|
|
|
'--verbose',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
g_basic.add_argument('-d', '--dev', action='store_true', default=False)
|
|
|
g_basic.add_argument('-s', '--seed', type=int, default=0)
|
|
|
g_basic.add_argument('--expid', type=str, default=randid(4))
|
|
|
g_basic.add_argument('--device', type=str, default='cuda')
|
|
|
|
|
|
g_path = p.add_argument_group('Path Settings')
|
|
|
g_path.add_argument('--workdir', type=str, default='workdirs')
|
|
|
g_path.add_argument('--data-root',
|
|
|
type=str,
|
|
|
default=Path(__file__).parent / '../data' /
|
|
|
'in_1k')
|
|
|
|
|
|
g_ds = p.add_argument_group('Dataset Settings')
|
|
|
g_ds.add_argument('--dataset',
|
|
|
type=str,
|
|
|
default='imagenet',
|
|
|
choices=list_datasets())
|
|
|
g_ds.add_argument('-b', '--batch-size', type=int, default=16)
|
|
|
|
|
|
g_at_basic = p.add_argument_group('General Attack Settings')
|
|
|
g_at_basic.add_argument('--eps',
|
|
|
'--epsilon',
|
|
|
dest='epsilon',
|
|
|
type=int,
|
|
|
default=8,
|
|
|
choices=[1, 2, 4, 8, 16])
|
|
|
|
|
|
@staticmethod
|
|
|
def post_basic_parser(args: argparse.Namespace):
|
|
|
if args.dev:
|
|
|
args.workdir = args.workdir.replace('workdirs', 'workdirs-dev')
|
|
|
args.workdir = Path(args.workdir) / args.expid
|
|
|
args.workdir.mkdir(parents=True, exist_ok=True)
|
|
|
args.device = torch.device(args.device)
|
|
|
|
|
|
args.ckpt = args.workdir / 'model.pth'
|
|
|
args.tf_logger = SummaryWriter(args.workdir / 'tf_log')
|
|
|
|
|
|
args.epsilon /= 255.0
|
|
|
if args.command == 'evaluate-pgd':
|
|
|
args.alpha /= 255.0
|
|
|
|
|
|
fix_random(args.seed)
|
|
|
|
|
|
if args.verbose:
|
|
|
print(pformat(vars(args)))
|
|
|
with open(args.workdir / f'args-{args.command}.txt', 'w') as f:
|
|
|
f.write(pformat(vars(args)))
|
|
|
|
|
|
@staticmethod
|
|
|
def init_train_parser(p: argparse.ArgumentParser):
|
|
|
g_at = p.add_argument_group('Attack Settings')
|
|
|
g_at.add_argument('--sur-ids',
|
|
|
'--surrogate-model-ids',
|
|
|
dest='surrogate_model_ids',
|
|
|
type=str,
|
|
|
default=['resnet152'],
|
|
|
nargs='+',
|
|
|
choices=list_surrogates())
|
|
|
g_at.add_argument('--n-ep',
|
|
|
'--num-epoch',
|
|
|
dest='num_epoch',
|
|
|
type=int,
|
|
|
default=10)
|
|
|
|
|
|
g_optim = p.add_argument_group('Optimization Settings')
|
|
|
g_optim.add_argument('--use-sam', action='store_true', default=False)
|
|
|
g_optim.add_argument('--lr', type=float, default=0.0002)
|
|
|
g_optim.add_argument('--betas',
|
|
|
type=float,
|
|
|
nargs=2,
|
|
|
default=(0.5, 0.999))
|
|
|
|
|
|
g_loss = p.add_argument_group('Loss Func Settings')
|
|
|
g_loss.add_argument('--use-logit-loss',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
g_loss.add_argument('--use-logit-kl',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
g_loss.add_argument('--use-logit-weights',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
g_loss.add_argument('--use-logit-softmax-weights',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
g_loss.add_argument('--use-feat-loss',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
g_loss.add_argument('--use-feat-attn',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
|
|
|
@staticmethod
|
|
|
def post_train_parser(args: argparse.Namespace):
|
|
|
if args.command == 'train':
|
|
|
assert args.use_logit_loss ^ args.use_feat_loss
|
|
|
if args.use_logit_kl:
|
|
|
assert not args.use_feat_loss
|
|
|
if args.use_feat_attn:
|
|
|
assert not args.use_logit_loss
|
|
|
if args.use_logit_weights:
|
|
|
assert args.use_logit_loss
|
|
|
if args.use_logit_softmax_weights:
|
|
|
assert args.use_logit_loss
|
|
|
|
|
|
@staticmethod
|
|
|
def init_evaluate_parser(p: argparse.ArgumentParser):
|
|
|
pass
|
|
|
|
|
|
@staticmethod
|
|
|
def post_evaluate_parser(args: argparse.Namespace):
|
|
|
pass
|
|
|
|
|
|
@staticmethod
|
|
|
def init_evaluate_pgd_parser(p: argparse.ArgumentParser):
|
|
|
g_at = p.add_argument_group('Attack Settings')
|
|
|
g_at.add_argument('--surrogate-model-ids',
|
|
|
type=str,
|
|
|
default=['resnet152'],
|
|
|
nargs='+',
|
|
|
choices=list_surrogates())
|
|
|
|
|
|
g_optim = p.add_argument_group('Optimization Settings')
|
|
|
g_optim.add_argument('--num-step', type=int, default=100)
|
|
|
g_optim.add_argument('--alpha',
|
|
|
type=int,
|
|
|
default=2,
|
|
|
choices=[1, 2, 4, 8, 16])
|
|
|
|
|
|
g_loss = p.add_argument_group('Loss Func Settings')
|
|
|
g_loss.add_argument('--use-loss-avg',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
g_loss.add_argument('--use-logit-avg',
|
|
|
action='store_true',
|
|
|
default=False)
|
|
|
|
|
|
@staticmethod
|
|
|
def post_evaluate_pgd_parser(args: argparse.Namespace):
|
|
|
if args.command == 'evaluate-pgd':
|
|
|
assert args.use_loss_avg ^ args.use_logit_avg
|
|
|
|
|
|
@staticmethod
|
|
|
def parse_args():
|
|
|
p = argparse.ArgumentParser()
|
|
|
CLIParser.init_basic_parser(p)
|
|
|
sub_p = p.add_subparsers(dest='command')
|
|
|
|
|
|
CLIParser.init_train_parser(sub_p.add_parser('train'))
|
|
|
CLIParser.init_evaluate_parser(sub_p.add_parser('evaluate'))
|
|
|
CLIParser.init_evaluate_pgd_parser(sub_p.add_parser('evaluate-pgd'))
|
|
|
args = p.parse_args()
|
|
|
CLIParser.post_train_parser(args)
|
|
|
CLIParser.post_evaluate_parser(args)
|
|
|
CLIParser.post_evaluate_pgd_parser(args)
|
|
|
|
|
|
CLIParser.post_basic_parser(args)
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
def init_loader(dataset: str,
|
|
|
data_root: Union[str, Path],
|
|
|
num_epoch: int = 1,
|
|
|
batch_size: int = 16,
|
|
|
command: str = 'train') -> List[torch.utils.data.DataLoader]:
|
|
|
ds = build_dataset(dataset,
|
|
|
data_root=data_root,
|
|
|
is_train=(command == 'train'))
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
|
ds,
|
|
|
batch_size=batch_size,
|
|
|
sampler=torch.utils.data.RandomSampler(ds,
|
|
|
replacement=True,
|
|
|
num_samples=len(ds) *
|
|
|
num_epoch),
|
|
|
num_workers=4,
|
|
|
pin_memory=True,
|
|
|
)
|
|
|
normalizer = norm(dataset, _callable=True)
|
|
|
return dataloader, normalizer
|
|
|
|
|
|
|
|
|
def init_models(model_ids: Union[str, List[str]],
|
|
|
device: Union[str, torch.device] = torch.device('cuda')):
|
|
|
if isinstance(model_ids, str):
|
|
|
model_ids = [model_ids]
|
|
|
models = [
|
|
|
build_surrogate(_surrogate_id, pretrain=True).to(device)
|
|
|
for _surrogate_id in model_ids
|
|
|
]
|
|
|
for _ in models:
|
|
|
_.eval()
|
|
|
return models
|
|
|
|
|
|
|
|
|
def calc_loss(x_nat: torch.Tensor,
|
|
|
y_nat: torch.Tensor,
|
|
|
x_adv: torch.Tensor,
|
|
|
feat_collecter: List,
|
|
|
surrogate_models: List[torch.nn.Module],
|
|
|
normalizer: torchvision.transforms.Compose,
|
|
|
use_logit_loss: bool,
|
|
|
use_logit_kl: bool,
|
|
|
use_logit_weights: bool,
|
|
|
use_logit_softmax_weights: bool,
|
|
|
use_feat_loss: bool,
|
|
|
use_feat_attn: bool,
|
|
|
device: Union[str, torch.device] = torch.device('cuda')):
|
|
|
loss_sur = []
|
|
|
for surrogate_model in surrogate_models:
|
|
|
logit_nat = surrogate_model(normalizer(x_nat))
|
|
|
feat_nat = feat_collecter.pop()
|
|
|
logit_adv = surrogate_model(normalizer(x_adv))
|
|
|
feat_adv = feat_collecter.pop()
|
|
|
if use_logit_loss:
|
|
|
if use_logit_kl:
|
|
|
loss_sur.append(-(F.kl_div(F.log_softmax(logit_adv, dim=1),
|
|
|
F.softmax(logit_nat, dim=1)) +
|
|
|
F.kl_div(F.log_softmax(logit_nat, dim=1),
|
|
|
F.softmax(logit_adv, dim=1))))
|
|
|
else:
|
|
|
loss_sur.append(-(F.cross_entropy(logit_adv, y_nat).mean()))
|
|
|
elif use_feat_loss:
|
|
|
if use_feat_attn:
|
|
|
attn = torch.abs(torch.mean(feat_nat, dim=1, keepdim=True))
|
|
|
else:
|
|
|
attn = torch.ones_like(feat_nat)
|
|
|
loss_sur.append(1 + F.cosine_similarity(attn * feat_nat, attn *
|
|
|
feat_adv).mean())
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
loss_sur = torch.stack(loss_sur)
|
|
|
if use_logit_weights:
|
|
|
if use_logit_softmax_weights:
|
|
|
loss_weights = torch.nn.functional.softmax(loss_sur)
|
|
|
else:
|
|
|
loss_weights = torch.nn.functional.softmin(loss_sur)
|
|
|
loss_all = torch.sum(loss_weights * loss_sur)
|
|
|
else:
|
|
|
loss_all = loss_sur.mean()
|
|
|
|
|
|
return loss_all
|
|
|
|
|
|
|
|
|
def train(surrogate_model_ids: Union[str, List[str]],
|
|
|
epsilon: float = 16.0 / 255.0,
|
|
|
num_epoch: int = 10,
|
|
|
dataset: str = 'imagenet',
|
|
|
batch_size: int = 16,
|
|
|
use_sam: bool = False,
|
|
|
lr: float = 0.0002,
|
|
|
betas: Union[float, List[float]] = (0.5, 0.999),
|
|
|
use_logit_loss: bool = False,
|
|
|
use_logit_kl: bool = False,
|
|
|
use_logit_weights: bool = False,
|
|
|
use_logit_softmax_weights: bool = False,
|
|
|
use_feat_loss: bool = False,
|
|
|
use_feat_attn: bool = False,
|
|
|
device: Union[str, torch.device] = torch.device('cuda'),
|
|
|
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
|
|
|
data_root: Union[str,
|
|
|
Path] = Path(__file__).parent / '../data' / 'in_1k',
|
|
|
tf_logger: SummaryWriter = None) -> None:
|
|
|
"""
|
|
|
Train the attack model with the given surrogate models.
|
|
|
"""
|
|
|
loader, normalizer = init_loader(dataset, data_root, num_epoch, batch_size,
|
|
|
'train')
|
|
|
surrogate_models = init_models(surrogate_model_ids, device)
|
|
|
|
|
|
attack = CDAAttack(device=device, epsilon=epsilon)
|
|
|
attack.set_mode('train')
|
|
|
if use_sam:
|
|
|
optim = SAM(attack.get_params(), torch.optim.Adam, lr=lr, betas=betas)
|
|
|
else:
|
|
|
optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
|
|
|
|
|
|
with feat_col(surrogate_models,
|
|
|
[midlayer_dict[_]
|
|
|
for _ in surrogate_model_ids]) as feat_collecter:
|
|
|
attack.set_mode('train')
|
|
|
enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
|
|
|
for step, (x_nat, y_nat) in enumerator:
|
|
|
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
|
|
|
|
|
if use_sam:
|
|
|
|
|
|
enable_running_stats(attack.get_model())
|
|
|
loss_v = calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
|
|
|
surrogate_models, normalizer,
|
|
|
use_logit_loss, use_logit_kl,
|
|
|
use_logit_weights,
|
|
|
use_logit_softmax_weights, use_feat_loss,
|
|
|
use_feat_attn, device)
|
|
|
loss_v.backward()
|
|
|
optim.first_step(zero_grad=True)
|
|
|
|
|
|
disable_running_stats(attack.get_model())
|
|
|
calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
|
|
|
surrogate_models, normalizer, use_logit_loss,
|
|
|
use_logit_kl, use_logit_weights,
|
|
|
use_logit_softmax_weights, use_feat_loss,
|
|
|
use_feat_attn, device).backward()
|
|
|
optim.second_step(zero_grad=True)
|
|
|
else:
|
|
|
x_adv = attack(x_nat)
|
|
|
loss_v = calc_loss(x_nat, y_nat, x_adv, feat_collecter,
|
|
|
surrogate_models, normalizer,
|
|
|
use_logit_loss, use_logit_kl,
|
|
|
use_logit_weights,
|
|
|
use_logit_softmax_weights, use_feat_loss,
|
|
|
use_feat_attn, device)
|
|
|
optim.zero_grad()
|
|
|
loss_v.backward()
|
|
|
optim.step()
|
|
|
|
|
|
if tf_logger:
|
|
|
tf_logger.add_scalar('loss', loss_v.item(), step)
|
|
|
tf_logger.add_scalar('lr', optim.param_groups[0]['lr'], step)
|
|
|
|
|
|
attack.save_ckpt(workdir / 'model.pth')
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def evaluate(
|
|
|
ckpt: Union[str, Path],
|
|
|
epsilon: float = 16.0 / 255.0,
|
|
|
dataset: str = 'imagenet',
|
|
|
batch_size: int = 16,
|
|
|
device: Union[str, torch.device] = torch.device('cuda'),
|
|
|
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
|
|
|
data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
|
|
|
) -> None:
|
|
|
"""
|
|
|
Evaluate the attack model with the given surrogate models
|
|
|
"""
|
|
|
loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
|
|
|
'evaluate')
|
|
|
target_models = {
|
|
|
k: v
|
|
|
for k, v in zip(list_surrogates(),
|
|
|
init_models(list_surrogates(), device))
|
|
|
}
|
|
|
target_acc_meters = {
|
|
|
target_model_id: [AverageMeter() for _ in range(2)]
|
|
|
for target_model_id in target_models.keys()
|
|
|
}
|
|
|
|
|
|
attack = CDAAttack(device=device, epsilon=epsilon)
|
|
|
attack.load_ckpt(ckpt)
|
|
|
attack.set_mode('eval')
|
|
|
|
|
|
enumerator = tqdm(enumerate(loader), total=len(loader), desc='Eval')
|
|
|
for step, (x_nat, y_nat) in enumerator:
|
|
|
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
|
|
x_adv = attack(x_nat)
|
|
|
for target_model_id, target_model in target_models.items():
|
|
|
logit_nat = target_model(normalizer(x_nat))
|
|
|
logit_adv = target_model(normalizer(x_adv))
|
|
|
|
|
|
target_acc = calc_cls_accuracy(logit_nat, y_nat)
|
|
|
target_asr = calc_cls_accuracy(logit_adv, y_nat)
|
|
|
target_acc_meters[target_model_id][0].update(
|
|
|
target_acc[0].item(), x_nat.size(0))
|
|
|
target_acc_meters[target_model_id][1].update(
|
|
|
target_asr[0].item(), x_nat.size(0))
|
|
|
results = {
|
|
|
target_model_id: {
|
|
|
'nat_acc': target_acc_meter[0].avg,
|
|
|
'adv_acc': target_acc_meter[1].avg
|
|
|
}
|
|
|
for target_model_id, target_acc_meter in target_acc_meters.items()
|
|
|
}
|
|
|
print(pformat(results))
|
|
|
with open(workdir / 'results.json', 'w') as f:
|
|
|
json.dump(results, f)
|
|
|
|
|
|
|
|
|
def evaluate_pgd(
|
|
|
surrogate_model_ids: Union[str, List[str]],
|
|
|
epsilon: float = 16.0 / 255.0,
|
|
|
num_step: int = 1000,
|
|
|
alpha: float = 2.0 / 255.0,
|
|
|
dataset: str = 'imagenet',
|
|
|
batch_size: int = 16,
|
|
|
use_loss_avg: bool = False,
|
|
|
use_logit_avg: bool = False,
|
|
|
device: Union[str, torch.device] = torch.device('cuda'),
|
|
|
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
|
|
|
data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
|
|
|
):
|
|
|
loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
|
|
|
'evaluate')
|
|
|
surrogate_models = init_models(surrogate_model_ids, device)
|
|
|
target_models = {
|
|
|
k: v
|
|
|
for k, v in zip(list_surrogates(),
|
|
|
init_models(list_surrogates(), device))
|
|
|
}
|
|
|
target_acc_meters = {
|
|
|
target_model_id: [AverageMeter() for _ in range(2)]
|
|
|
for target_model_id in target_models.keys()
|
|
|
}
|
|
|
|
|
|
enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
|
|
|
for step, (x_nat, y_nat) in enumerator:
|
|
|
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
|
|
|
|
|
x_nat_ori = x_nat.data
|
|
|
for _ in range(num_step):
|
|
|
x_nat.requires_grad = True
|
|
|
if use_loss_avg:
|
|
|
loss_all = 0.0
|
|
|
for surrogate_model in surrogate_models:
|
|
|
logit = surrogate_model(x_nat)
|
|
|
surrogate_model.zero_grad()
|
|
|
loss_all += F.cross_entropy(logit, y_nat)
|
|
|
elif use_logit_avg:
|
|
|
logit = torch.stack([
|
|
|
surrogate_model(x_nat)
|
|
|
for surrogate_model in surrogate_models
|
|
|
]).mean(dim=0)
|
|
|
loss_all = F.cross_entropy(logit, y_nat)
|
|
|
else:
|
|
|
raise NotADirectoryError
|
|
|
loss_all.backward()
|
|
|
x_adv_ = x_nat + alpha * x_nat.grad.sign()
|
|
|
eta = torch.clamp(x_adv_ - x_nat_ori, min=-epsilon, max=epsilon)
|
|
|
x_nat = torch.clamp(x_nat_ori + eta, min=0.0, max=1.0).detach_()
|
|
|
x_adv = x_nat
|
|
|
x_nat = x_nat_ori
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for target_model_id, target_model in target_models.items():
|
|
|
logit_nat = target_model(normalizer(x_nat))
|
|
|
logit_adv = target_model(normalizer(x_adv))
|
|
|
|
|
|
target_acc_ = calc_cls_accuracy(logit_nat, y_nat)
|
|
|
target_asr_ = calc_cls_accuracy(logit_adv, y_nat)
|
|
|
target_acc_meters[target_model_id][0].update(
|
|
|
target_acc_[0].item(), x_nat.size(0))
|
|
|
target_acc_meters[target_model_id][1].update(
|
|
|
target_asr_[0].item(), x_nat.size(0))
|
|
|
results = {
|
|
|
target_model_id: {
|
|
|
'nat_acc': target_acc_meter[0].avg,
|
|
|
'adv_acc': target_acc_meter[1].avg
|
|
|
}
|
|
|
for target_model_id, target_acc_meter in target_acc_meters.items()
|
|
|
}
|
|
|
print(pformat(results))
|
|
|
with open(workdir / 'results-pgd.json', 'w') as f:
|
|
|
json.dump(results, f)
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
args = CLIParser.parse_args()
|
|
|
if args.command == 'train':
|
|
|
train(args.surrogate_model_ids, args.epsilon, args.num_epoch,
|
|
|
args.dataset, args.batch_size, args.use_sam, args.lr, args.betas,
|
|
|
args.use_logit_loss, args.use_logit_kl, args.use_logit_weights,
|
|
|
args.use_logit_softmax_weights, args.use_feat_loss,
|
|
|
args.use_feat_attn, args.device, args.workdir, args.data_root,
|
|
|
args.tf_logger)
|
|
|
elif args.command == 'evaluate':
|
|
|
evaluate(args.ckpt, args.epsilon, args.dataset, args.batch_size,
|
|
|
args.device, args.workdir, args.data_root)
|
|
|
elif args.command == 'evaluate-pgd':
|
|
|
evaluate_pgd(args.surrogate_model_ids, args.epsilon, args.num_step,
|
|
|
args.alpha, args.dataset, args.batch_size,
|
|
|
args.use_loss_avg, args.use_logit_avg, args.device,
|
|
|
args.workdir, args.data_root)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|