|
|
|
|
|
|
|
|
|
|
|
|
|
|
from proard.utils import calc_learning_rate, build_optimizer |
|
|
from proard.classification.data_providers import ImagenetDataProvider |
|
|
from proard.classification.data_providers import Cifar10DataProvider |
|
|
from proard.classification.data_providers import Cifar100DataProvider |
|
|
from robust_loss.trades import trades_loss |
|
|
from robust_loss.adaad import adaad_loss |
|
|
from robust_loss.ard import ard_loss |
|
|
from robust_loss.hat import hat_loss |
|
|
from robust_loss.mart import mart_loss |
|
|
from robust_loss.sat import sat_loss |
|
|
from robust_loss.rslad import rslad_loss |
|
|
import torch |
|
|
__all__ = ["RunConfig", "ClassificationRunConfig", "DistributedClassificationRunConfig"] |
|
|
|
|
|
|
|
|
class RunConfig: |
|
|
def __init__( |
|
|
self, |
|
|
n_epochs, |
|
|
init_lr, |
|
|
lr_schedule_type, |
|
|
lr_schedule_param, |
|
|
dataset, |
|
|
train_batch_size, |
|
|
test_batch_size, |
|
|
valid_size, |
|
|
opt_type, |
|
|
opt_param, |
|
|
weight_decay, |
|
|
label_smoothing, |
|
|
no_decay_keys, |
|
|
mixup_alpha, |
|
|
model_init, |
|
|
validation_frequency, |
|
|
print_frequency, |
|
|
): |
|
|
self.n_epochs = n_epochs |
|
|
self.init_lr = init_lr |
|
|
self.lr_schedule_type = lr_schedule_type |
|
|
self.lr_schedule_param = lr_schedule_param |
|
|
|
|
|
self.dataset = dataset |
|
|
self.train_batch_size = train_batch_size |
|
|
self.test_batch_size = test_batch_size |
|
|
self.valid_size = valid_size |
|
|
|
|
|
self.opt_type = opt_type |
|
|
self.opt_param = opt_param |
|
|
self.weight_decay = weight_decay |
|
|
self.label_smoothing = label_smoothing |
|
|
self.no_decay_keys = no_decay_keys |
|
|
|
|
|
self.mixup_alpha = mixup_alpha |
|
|
|
|
|
self.model_init = model_init |
|
|
self.validation_frequency = validation_frequency |
|
|
self.print_frequency = print_frequency |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
config = {} |
|
|
for key in self.__dict__: |
|
|
if not key.startswith("_"): |
|
|
config[key] = self.__dict__[key] |
|
|
return config |
|
|
|
|
|
def copy(self): |
|
|
return RunConfig(**self.config) |
|
|
|
|
|
""" learning rate """ |
|
|
|
|
|
def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None): |
|
|
"""adjust learning of a given optimizer and return the new learning rate""" |
|
|
new_lr = calc_learning_rate( |
|
|
epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type |
|
|
) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group["lr"] = new_lr |
|
|
return new_lr |
|
|
|
|
|
def warmup_adjust_learning_rate( |
|
|
self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0 |
|
|
): |
|
|
T_cur = epoch * nBatch + batch + 1 |
|
|
new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group["lr"] = new_lr |
|
|
return new_lr |
|
|
|
|
|
""" data provider """ |
|
|
|
|
|
@property |
|
|
def data_provider(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
@property |
|
|
def train_loader(self): |
|
|
return self.data_provider.train |
|
|
|
|
|
@property |
|
|
def valid_loader(self): |
|
|
return self.data_provider.valid |
|
|
|
|
|
@property |
|
|
def test_loader(self): |
|
|
return self.data_provider.test |
|
|
|
|
|
def random_sub_train_loader( |
|
|
self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None |
|
|
): |
|
|
return self.data_provider.build_sub_train_loader( |
|
|
n_images, batch_size, num_worker, num_replicas, rank |
|
|
) |
|
|
|
|
|
""" optimizer """ |
|
|
|
|
|
def build_optimizer(self, net_params): |
|
|
return build_optimizer( |
|
|
net_params, |
|
|
self.opt_type, |
|
|
self.opt_param, |
|
|
self.init_lr, |
|
|
self.weight_decay, |
|
|
self.no_decay_keys, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class ClassificationRunConfig(RunConfig): |
|
|
def __init__( |
|
|
self, |
|
|
n_epochs=150, |
|
|
init_lr=0.05, |
|
|
lr_schedule_type="cosine", |
|
|
lr_schedule_param=None, |
|
|
dataset="imagenet", |
|
|
train_batch_size=256, |
|
|
test_batch_size=500, |
|
|
valid_size=None, |
|
|
opt_type="sgd", |
|
|
opt_param=None, |
|
|
weight_decay=4e-5, |
|
|
label_smoothing=0.1, |
|
|
no_decay_keys=None, |
|
|
mixup_alpha=None, |
|
|
model_init="he_fout", |
|
|
validation_frequency=1, |
|
|
print_frequency=10, |
|
|
n_worker=32, |
|
|
resize_scale=0.08, |
|
|
distort_color="tf", |
|
|
image_size=224, |
|
|
robust_mode = False, |
|
|
epsilon_train = 0.031, |
|
|
num_steps_train = 10, |
|
|
step_size_train = 0.0078, |
|
|
clip_min_train = 0 , |
|
|
clip_max_train = 1, |
|
|
const_init_train = False, |
|
|
beta_train = 6.0, |
|
|
distance_train ="l_inf", |
|
|
epsilon_test = 0.031, |
|
|
num_steps_test = 20, |
|
|
step_size_test = 0.0078, |
|
|
clip_min_test = 0, |
|
|
clip_max_test = 1, |
|
|
const_init_test = False, |
|
|
beta_test = 6.0, |
|
|
distance_test = "l_inf", |
|
|
train_criterion = "trades", |
|
|
test_criterion = "ce", |
|
|
kd_criterion = 'rslad', |
|
|
attack_type = "linf-pgd", |
|
|
**kwargs |
|
|
): |
|
|
super(ClassificationRunConfig, self).__init__( |
|
|
n_epochs, |
|
|
init_lr, |
|
|
lr_schedule_type, |
|
|
lr_schedule_param, |
|
|
dataset, |
|
|
train_batch_size, |
|
|
test_batch_size, |
|
|
valid_size, |
|
|
opt_type, |
|
|
opt_param, |
|
|
weight_decay, |
|
|
label_smoothing, |
|
|
no_decay_keys, |
|
|
mixup_alpha, |
|
|
model_init, |
|
|
validation_frequency, |
|
|
print_frequency, |
|
|
) |
|
|
|
|
|
self.n_worker = n_worker |
|
|
self.resize_scale = resize_scale |
|
|
self.distort_color = distort_color |
|
|
self.image_size = image_size |
|
|
self.epsilon_train = epsilon_train |
|
|
self.num_steps_train = num_steps_train |
|
|
self.step_size_train = step_size_train |
|
|
self.clip_min_train = clip_min_train |
|
|
self.clip_max_train = clip_max_train |
|
|
self.const_init_train = const_init_train |
|
|
self.beta_train = beta_train |
|
|
self.distance_train = distance_train |
|
|
self.epsilon_test = epsilon_test |
|
|
self.num_steps_test = num_steps_test |
|
|
self.step_size_test = step_size_test |
|
|
self.clip_min_test = clip_min_test |
|
|
self.clip_max_test = clip_max_test |
|
|
self.const_init_test = const_init_test |
|
|
self.beta_test = beta_test |
|
|
self.distance_test = distance_test |
|
|
self.train_criterion = train_criterion |
|
|
self.test_criterion = test_criterion |
|
|
self.kd_criterion = kd_criterion |
|
|
self.attack_type = attack_type |
|
|
self.robust_mode = robust_mode |
|
|
@property |
|
|
def data_provider(self): |
|
|
if self.__dict__.get("_data_provider", None) is None: |
|
|
if self.dataset == ImagenetDataProvider.name(): |
|
|
DataProviderClass = ImagenetDataProvider |
|
|
elif self.dataset == Cifar10DataProvider.name(): |
|
|
DataProviderClass = Cifar10DataProvider |
|
|
elif self.dataset == Cifar100DataProvider.name(): |
|
|
DataProviderClass = Cifar100DataProvider |
|
|
else: |
|
|
raise NotImplementedError |
|
|
self.__dict__["_data_provider"] = DataProviderClass( |
|
|
train_batch_size=self.train_batch_size, |
|
|
test_batch_size=self.test_batch_size, |
|
|
valid_size=self.valid_size, |
|
|
n_worker=self.n_worker, |
|
|
resize_scale=self.resize_scale, |
|
|
distort_color=self.distort_color, |
|
|
image_size=self.image_size, |
|
|
) |
|
|
return self.__dict__["_data_provider"] |
|
|
@property |
|
|
def train_criterion_loss (self): |
|
|
if self.train_criterion == "trades" : |
|
|
return trades_loss |
|
|
elif self.train_criterion == "mart" : |
|
|
return mart_loss |
|
|
elif self.train_criterion == "sat" : |
|
|
return sat_loss |
|
|
elif self.train_criterion == "hat" : |
|
|
return hat_loss |
|
|
@property |
|
|
def test_criterion_loss (self) : |
|
|
if self.test_criterion == "ce" : |
|
|
return torch.nn.CrossEntropyLoss() |
|
|
@property |
|
|
def kd_criterion_loss (self) : |
|
|
if self.kd_criterion =="ard" : |
|
|
return ard_loss |
|
|
elif self.kd_criterion == "adaad" : |
|
|
return adaad_loss |
|
|
elif self.kd_criterion == "rslad" : |
|
|
return rslad_loss |
|
|
class DistributedClassificationRunConfig(ClassificationRunConfig): |
|
|
def __init__( |
|
|
self, |
|
|
n_epochs=150, |
|
|
init_lr=0.05, |
|
|
lr_schedule_type="cosine", |
|
|
lr_schedule_param=None, |
|
|
dataset="imagenet", |
|
|
train_batch_size=64, |
|
|
test_batch_size=64, |
|
|
valid_size=None, |
|
|
opt_type="sgd", |
|
|
opt_param=None, |
|
|
weight_decay=4e-5, |
|
|
label_smoothing=0.1, |
|
|
no_decay_keys=None, |
|
|
mixup_alpha=None, |
|
|
model_init="he_fout", |
|
|
validation_frequency=1, |
|
|
print_frequency=10, |
|
|
n_worker=8, |
|
|
resize_scale=0.08, |
|
|
distort_color="tf", |
|
|
image_size=224, |
|
|
robust_mode = False, |
|
|
epsilon = 0.031, |
|
|
num_steps = 10, |
|
|
step_size = 0.0078, |
|
|
clip_min = 0, |
|
|
clip_max = 1, |
|
|
const_init = False, |
|
|
beta = 6.0, |
|
|
distance = "l_inf", |
|
|
train_criterion = "trades", |
|
|
test_criterion = "ce", |
|
|
kd_criterion = 'rslad', |
|
|
attack_type = "linf-pgd", |
|
|
**kwargs |
|
|
): |
|
|
super(DistributedClassificationRunConfig, self).__init__( |
|
|
n_epochs, |
|
|
init_lr, |
|
|
lr_schedule_type, |
|
|
lr_schedule_param, |
|
|
dataset, |
|
|
train_batch_size, |
|
|
test_batch_size, |
|
|
valid_size, |
|
|
opt_type, |
|
|
opt_param, |
|
|
weight_decay, |
|
|
label_smoothing, |
|
|
no_decay_keys, |
|
|
mixup_alpha, |
|
|
model_init, |
|
|
validation_frequency, |
|
|
print_frequency, |
|
|
n_worker, |
|
|
resize_scale, |
|
|
distort_color, |
|
|
image_size, |
|
|
robust_mode, |
|
|
epsilon, |
|
|
num_steps, |
|
|
step_size, |
|
|
clip_min, |
|
|
clip_max, |
|
|
const_init, |
|
|
beta, |
|
|
distance, |
|
|
epsilon, |
|
|
num_steps * 2, |
|
|
step_size, |
|
|
clip_min,clip_max, |
|
|
const_init, |
|
|
beta, |
|
|
distance, |
|
|
train_criterion, |
|
|
test_criterion, |
|
|
kd_criterion, |
|
|
attack_type, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
self._num_replicas = kwargs["num_replicas"] |
|
|
self._rank = kwargs["rank"] |
|
|
|
|
|
@property |
|
|
def data_provider(self): |
|
|
if self.__dict__.get("_data_provider", None) is None: |
|
|
if self.dataset == ImagenetDataProvider.name(): |
|
|
DataProviderClass = ImagenetDataProvider |
|
|
elif self.dataset == Cifar10DataProvider.name(): |
|
|
DataProviderClass = Cifar10DataProvider |
|
|
elif self.dataset == Cifar100DataProvider.name(): |
|
|
DataProviderClass = Cifar100DataProvider |
|
|
else: |
|
|
raise NotImplementedError |
|
|
if self.dataset == "imagenet": |
|
|
self.__dict__["_data_provider"] = DataProviderClass( |
|
|
train_batch_size=self.train_batch_size, |
|
|
test_batch_size=self.test_batch_size, |
|
|
valid_size=self.valid_size, |
|
|
n_worker=self.n_worker, |
|
|
resize_scale=self.resize_scale, |
|
|
distort_color=self.distort_color, |
|
|
image_size=self.image_size, |
|
|
num_replicas=self._num_replicas, |
|
|
rank=self._rank, |
|
|
) |
|
|
else: |
|
|
self.__dict__["_data_provider"] = DataProviderClass( |
|
|
train_batch_size=self.train_batch_size, |
|
|
test_batch_size=self.test_batch_size, |
|
|
valid_size=self.valid_size, |
|
|
n_worker=self.n_worker, |
|
|
resize_scale=None, |
|
|
distort_color=None, |
|
|
image_size=self.image_size, |
|
|
num_replicas=self._num_replicas, |
|
|
rank=self._rank, |
|
|
) |
|
|
return self.__dict__["_data_provider"] |
|
|
@property |
|
|
def train_criterion_loss (self): |
|
|
if self.train_criterion == "trades" : |
|
|
return trades_loss |
|
|
elif self.train_criterion == "mart" : |
|
|
return mart_loss |
|
|
elif self.train_criterion == "sat" : |
|
|
return sat_loss |
|
|
elif self.train_criterion == "hat" : |
|
|
return hat_loss |
|
|
@property |
|
|
def test_criterion_loss (self) : |
|
|
if self.test_criterion == "ce" : |
|
|
return torch.nn.CrossEntropyLoss() |
|
|
@property |
|
|
def kd_criterion_loss (self) : |
|
|
if self.kd_criterion =="ard" : |
|
|
return ard_loss |
|
|
elif self.kd_criterion == "adaad" : |
|
|
return adaad_loss |
|
|
elif self.kd_criterion == "rslad" : |
|
|
return rslad_loss |
|
|
|
|
|
|