|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import torch.nn.functional as F |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
from attacks import create_attack |
|
|
from attacks.utils import ctx_noparamgrad_and_eval |
|
|
from proard.utils import AverageMeter, get_net_device, DistributedTensor |
|
|
from proard.classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d |
|
|
|
|
|
__all__ = ["set_running_statistics"] |
|
|
|
|
|
def set_running_statistics(model, data_loader, distributed=False): |
|
|
bn_mean = {} |
|
|
bn_var = {} |
|
|
|
|
|
forward_model = copy.deepcopy(model) |
|
|
for name, m in forward_model.named_modules(): |
|
|
if isinstance(m, nn.BatchNorm2d): |
|
|
if distributed: |
|
|
bn_mean[name] = DistributedTensor(name + "#mean") |
|
|
bn_var[name] = DistributedTensor(name + "#var") |
|
|
else: |
|
|
bn_mean[name] = AverageMeter() |
|
|
bn_var[name] = AverageMeter() |
|
|
|
|
|
def new_forward(bn, mean_est, var_est): |
|
|
def lambda_forward(x): |
|
|
batch_mean = ( |
|
|
x.mean(0, keepdim=True) |
|
|
.mean(2, keepdim=True) |
|
|
.mean(3, keepdim=True) |
|
|
) |
|
|
batch_var = (x - batch_mean) * (x - batch_mean) |
|
|
batch_var = ( |
|
|
batch_var.mean(0, keepdim=True) |
|
|
.mean(2, keepdim=True) |
|
|
.mean(3, keepdim=True) |
|
|
) |
|
|
|
|
|
batch_mean = torch.squeeze(batch_mean) |
|
|
batch_var = torch.squeeze(batch_var) |
|
|
|
|
|
mean_est.update(batch_mean.data, x.size(0)) |
|
|
var_est.update(batch_var.data, x.size(0)) |
|
|
|
|
|
|
|
|
_feature_dim = batch_mean.size(0) |
|
|
return F.batch_norm( |
|
|
x, |
|
|
batch_mean, |
|
|
batch_var, |
|
|
bn.weight[:_feature_dim], |
|
|
bn.bias[:_feature_dim], |
|
|
False, |
|
|
0.0, |
|
|
bn.eps, |
|
|
) |
|
|
|
|
|
return lambda_forward |
|
|
|
|
|
m.forward = new_forward(m, bn_mean[name], bn_var[name]) |
|
|
|
|
|
if len(bn_mean) == 0: |
|
|
|
|
|
return |
|
|
|
|
|
with torch.no_grad(): |
|
|
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True |
|
|
for images, labels in data_loader: |
|
|
images = images.to(get_net_device(forward_model)) |
|
|
forward_model(images) |
|
|
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False |
|
|
|
|
|
for name, m in model.named_modules(): |
|
|
if name in bn_mean and bn_mean[name].count > 0: |
|
|
feature_dim = bn_mean[name].avg.size(0) |
|
|
assert isinstance(m, nn.BatchNorm2d) |
|
|
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) |
|
|
m.running_var.data[:feature_dim].copy_(bn_var[name].avg) |
|
|
|