# Once for All: Train One Network and Specialize it for Efficient Deployment # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han # International Conference on Learning Representations (ICLR), 2020. import copy import torch.nn.functional as F import torch.nn as nn import torch from attacks import create_attack from attacks.utils import ctx_noparamgrad_and_eval from proard.utils import AverageMeter, get_net_device, DistributedTensor from proard.classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d __all__ = ["set_running_statistics"] def set_running_statistics(model, data_loader, distributed=False): bn_mean = {} bn_var = {} forward_model = copy.deepcopy(model) for name, m in forward_model.named_modules(): if isinstance(m, nn.BatchNorm2d): if distributed: bn_mean[name] = DistributedTensor(name + "#mean") bn_var[name] = DistributedTensor(name + "#var") else: bn_mean[name] = AverageMeter() bn_var[name] = AverageMeter() def new_forward(bn, mean_est, var_est): def lambda_forward(x): batch_mean = ( x.mean(0, keepdim=True) .mean(2, keepdim=True) .mean(3, keepdim=True) ) # 1, C, 1, 1 batch_var = (x - batch_mean) * (x - batch_mean) batch_var = ( batch_var.mean(0, keepdim=True) .mean(2, keepdim=True) .mean(3, keepdim=True) ) batch_mean = torch.squeeze(batch_mean) batch_var = torch.squeeze(batch_var) mean_est.update(batch_mean.data, x.size(0)) var_est.update(batch_var.data, x.size(0)) # bn forward using calculated mean & var _feature_dim = batch_mean.size(0) return F.batch_norm( x, batch_mean, batch_var, bn.weight[:_feature_dim], bn.bias[:_feature_dim], False, 0.0, bn.eps, ) return lambda_forward m.forward = new_forward(m, bn_mean[name], bn_var[name]) if len(bn_mean) == 0: # skip if there is no batch normalization layers in the network return with torch.no_grad(): DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True for images, labels in data_loader: images = images.to(get_net_device(forward_model)) forward_model(images) DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False for name, m in model.named_modules(): if name in bn_mean and bn_mean[name].count > 0: feature_dim = bn_mean[name].avg.size(0) assert isinstance(m, nn.BatchNorm2d) m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) m.running_var.data[:feature_dim].copy_(bn_var[name].avg)