smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import torch.nn.functional as F
import torch.nn as nn
import torch
from attacks import create_attack
from attacks.utils import ctx_noparamgrad_and_eval
from proard.utils import AverageMeter, get_net_device, DistributedTensor
from proard.classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
__all__ = ["set_running_statistics"]
def set_running_statistics(model, data_loader, distributed=False):
bn_mean = {}
bn_var = {}
forward_model = copy.deepcopy(model)
for name, m in forward_model.named_modules():
if isinstance(m, nn.BatchNorm2d):
if distributed:
bn_mean[name] = DistributedTensor(name + "#mean")
bn_var[name] = DistributedTensor(name + "#var")
else:
bn_mean[name] = AverageMeter()
bn_var[name] = AverageMeter()
def new_forward(bn, mean_est, var_est):
def lambda_forward(x):
batch_mean = (
x.mean(0, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
) # 1, C, 1, 1
batch_var = (x - batch_mean) * (x - batch_mean)
batch_var = (
batch_var.mean(0, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
batch_mean = torch.squeeze(batch_mean)
batch_var = torch.squeeze(batch_var)
mean_est.update(batch_mean.data, x.size(0))
var_est.update(batch_var.data, x.size(0))
# bn forward using calculated mean & var
_feature_dim = batch_mean.size(0)
return F.batch_norm(
x,
batch_mean,
batch_var,
bn.weight[:_feature_dim],
bn.bias[:_feature_dim],
False,
0.0,
bn.eps,
)
return lambda_forward
m.forward = new_forward(m, bn_mean[name], bn_var[name])
if len(bn_mean) == 0:
# skip if there is no batch normalization layers in the network
return
with torch.no_grad():
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
for images, labels in data_loader:
images = images.to(get_net_device(forward_model))
forward_model(images)
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False
for name, m in model.named_modules():
if name in bn_mean and bn_mean[name].count > 0:
feature_dim = bn_mean[name].avg.size(0)
assert isinstance(m, nn.BatchNorm2d)
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)