# Once for All: Train One Network and Specialize it for Efficient Deployment # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han # International Conference on Learning Representations (ICLR), 2020. from proard.utils import calc_learning_rate, build_optimizer from proard.classification.data_providers import ImagenetDataProvider from proard.classification.data_providers import Cifar10DataProvider from proard.classification.data_providers import Cifar100DataProvider from robust_loss.trades import trades_loss from robust_loss.adaad import adaad_loss from robust_loss.ard import ard_loss from robust_loss.hat import hat_loss from robust_loss.mart import mart_loss from robust_loss.sat import sat_loss from robust_loss.rslad import rslad_loss import torch __all__ = ["RunConfig", "ClassificationRunConfig", "DistributedClassificationRunConfig"] class RunConfig: def __init__( self, n_epochs, init_lr, lr_schedule_type, lr_schedule_param, dataset, train_batch_size, test_batch_size, valid_size, opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, mixup_alpha, model_init, validation_frequency, print_frequency, ): self.n_epochs = n_epochs self.init_lr = init_lr self.lr_schedule_type = lr_schedule_type self.lr_schedule_param = lr_schedule_param self.dataset = dataset self.train_batch_size = train_batch_size self.test_batch_size = test_batch_size self.valid_size = valid_size self.opt_type = opt_type self.opt_param = opt_param self.weight_decay = weight_decay self.label_smoothing = label_smoothing self.no_decay_keys = no_decay_keys self.mixup_alpha = mixup_alpha self.model_init = model_init self.validation_frequency = validation_frequency self.print_frequency = print_frequency @property def config(self): config = {} for key in self.__dict__: if not key.startswith("_"): config[key] = self.__dict__[key] return config def copy(self): return RunConfig(**self.config) """ learning rate """ def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None): """adjust learning of a given optimizer and return the new learning rate""" new_lr = calc_learning_rate( epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type ) for param_group in optimizer.param_groups: param_group["lr"] = new_lr return new_lr def warmup_adjust_learning_rate( self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0 ): T_cur = epoch * nBatch + batch + 1 new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr for param_group in optimizer.param_groups: param_group["lr"] = new_lr return new_lr """ data provider """ @property def data_provider(self): raise NotImplementedError @property def train_loader(self): return self.data_provider.train @property def valid_loader(self): return self.data_provider.valid @property def test_loader(self): return self.data_provider.test def random_sub_train_loader( self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None ): return self.data_provider.build_sub_train_loader( n_images, batch_size, num_worker, num_replicas, rank ) """ optimizer """ def build_optimizer(self, net_params): return build_optimizer( net_params, self.opt_type, self.opt_param, self.init_lr, self.weight_decay, self.no_decay_keys, ) class ClassificationRunConfig(RunConfig): def __init__( self, n_epochs=150, init_lr=0.05, lr_schedule_type="cosine", lr_schedule_param=None, dataset="imagenet", # 'cifar10' or 'cifar100' train_batch_size=256, test_batch_size=500, valid_size=None, opt_type="sgd", opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None, mixup_alpha=None, model_init="he_fout", validation_frequency=1, print_frequency=10, n_worker=32, resize_scale=0.08, distort_color="tf", image_size=224, # 32 robust_mode = False, epsilon_train = 0.031, num_steps_train = 10, step_size_train = 0.0078, clip_min_train = 0 , clip_max_train = 1, const_init_train = False, beta_train = 6.0, distance_train ="l_inf", epsilon_test = 0.031, num_steps_test = 20, step_size_test = 0.0078, clip_min_test = 0, clip_max_test = 1, const_init_test = False, beta_test = 6.0, distance_test = "l_inf", train_criterion = "trades", test_criterion = "ce", kd_criterion = 'rslad', attack_type = "linf-pgd", **kwargs ): super(ClassificationRunConfig, self).__init__( n_epochs, init_lr, lr_schedule_type, lr_schedule_param, dataset, train_batch_size, test_batch_size, valid_size, opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, mixup_alpha, model_init, validation_frequency, print_frequency, ) self.n_worker = n_worker self.resize_scale = resize_scale self.distort_color = distort_color self.image_size = image_size self.epsilon_train = epsilon_train self.num_steps_train = num_steps_train self.step_size_train = step_size_train self.clip_min_train = clip_min_train self.clip_max_train = clip_max_train self.const_init_train = const_init_train self.beta_train = beta_train self.distance_train = distance_train self.epsilon_test = epsilon_test self.num_steps_test = num_steps_test self.step_size_test = step_size_test self.clip_min_test = clip_min_test self.clip_max_test = clip_max_test self.const_init_test = const_init_test self.beta_test = beta_test self.distance_test = distance_test self.train_criterion = train_criterion self.test_criterion = test_criterion self.kd_criterion = kd_criterion self.attack_type = attack_type self.robust_mode = robust_mode @property def data_provider(self): if self.__dict__.get("_data_provider", None) is None: if self.dataset == ImagenetDataProvider.name(): DataProviderClass = ImagenetDataProvider elif self.dataset == Cifar10DataProvider.name(): DataProviderClass = Cifar10DataProvider elif self.dataset == Cifar100DataProvider.name(): DataProviderClass = Cifar100DataProvider else: raise NotImplementedError self.__dict__["_data_provider"] = DataProviderClass( train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, distort_color=self.distort_color, image_size=self.image_size, ) return self.__dict__["_data_provider"] @property def train_criterion_loss (self): if self.train_criterion == "trades" : return trades_loss elif self.train_criterion == "mart" : return mart_loss elif self.train_criterion == "sat" : return sat_loss elif self.train_criterion == "hat" : return hat_loss @property def test_criterion_loss (self) : if self.test_criterion == "ce" : return torch.nn.CrossEntropyLoss() @property def kd_criterion_loss (self) : if self.kd_criterion =="ard" : return ard_loss elif self.kd_criterion == "adaad" : return adaad_loss elif self.kd_criterion == "rslad" : return rslad_loss class DistributedClassificationRunConfig(ClassificationRunConfig): def __init__( self, n_epochs=150, init_lr=0.05, lr_schedule_type="cosine", lr_schedule_param=None, dataset="imagenet", train_batch_size=64, test_batch_size=64, valid_size=None, opt_type="sgd", opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None, mixup_alpha=None, model_init="he_fout", validation_frequency=1, print_frequency=10, n_worker=8, resize_scale=0.08, distort_color="tf", image_size=224, robust_mode = False, epsilon = 0.031, num_steps = 10, step_size = 0.0078, clip_min = 0, clip_max = 1, const_init = False, beta = 6.0, distance = "l_inf", train_criterion = "trades", test_criterion = "ce", kd_criterion = 'rslad', attack_type = "linf-pgd", **kwargs ): super(DistributedClassificationRunConfig, self).__init__( n_epochs, init_lr, lr_schedule_type, lr_schedule_param, dataset, train_batch_size, test_batch_size, valid_size, opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, mixup_alpha, model_init, validation_frequency, print_frequency, n_worker, resize_scale, distort_color, image_size, robust_mode, epsilon, num_steps, step_size, clip_min, clip_max, const_init, beta, distance, epsilon, num_steps * 2, step_size, clip_min,clip_max, const_init, beta, distance, train_criterion, test_criterion, kd_criterion, attack_type, **kwargs ) self._num_replicas = kwargs["num_replicas"] self._rank = kwargs["rank"] @property def data_provider(self): if self.__dict__.get("_data_provider", None) is None: if self.dataset == ImagenetDataProvider.name(): DataProviderClass = ImagenetDataProvider elif self.dataset == Cifar10DataProvider.name(): DataProviderClass = Cifar10DataProvider elif self.dataset == Cifar100DataProvider.name(): DataProviderClass = Cifar100DataProvider else: raise NotImplementedError if self.dataset == "imagenet": self.__dict__["_data_provider"] = DataProviderClass( train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, distort_color=self.distort_color, image_size=self.image_size, num_replicas=self._num_replicas, rank=self._rank, ) else: self.__dict__["_data_provider"] = DataProviderClass( train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=None, distort_color=None, image_size=self.image_size, num_replicas=self._num_replicas, rank=self._rank, ) return self.__dict__["_data_provider"] @property def train_criterion_loss (self): if self.train_criterion == "trades" : return trades_loss elif self.train_criterion == "mart" : return mart_loss elif self.train_criterion == "sat" : return sat_loss elif self.train_criterion == "hat" : return hat_loss @property def test_criterion_loss (self) : if self.test_criterion == "ce" : return torch.nn.CrossEntropyLoss() @property def kd_criterion_loss (self) : if self.kd_criterion =="ard" : return ard_loss elif self.kd_criterion == "adaad" : return adaad_loss elif self.kd_criterion == "rslad" : return rslad_loss