| |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
|
|
| class FrozenBatchNorm2d(nn.Module): |
| """ |
| BatchNorm2d where the batch statistics and the affine parameters are fixed. |
| |
| It contains non-trainable buffers called |
| "weight" and "bias", "running_mean", "running_var", |
| initialized to perform identity transformation. |
| |
| The pre-trained backbone models from Caffe2 only contain "weight" and "bias", |
| which are computed from the original four parameters of BN. |
| The affine transform `x * weight + bias` will perform the equivalent |
| computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. |
| When loading a backbone model from Caffe2, "running_mean" and "running_var" |
| will be left unchanged as identity transformation. |
| |
| Other pre-trained backbone models may contain all 4 parameters. |
| |
| The forward is implemented by `F.batch_norm(..., training=False)`. |
| """ |
|
|
| def __init__(self, num_features, eps=1e-5): |
| super().__init__() |
| self.num_features = num_features |
| self.eps = eps |
| self.register_buffer("weight", torch.ones(num_features)) |
| self.register_buffer("bias", torch.zeros(num_features)) |
| self.register_buffer("running_mean", torch.zeros(num_features)) |
| self.register_buffer("running_var", torch.ones(num_features) - eps) |
|
|
| def forward(self, x): |
| if x.requires_grad: |
| |
| |
| scale = self.weight * (self.running_var + self.eps).rsqrt() |
| bias = self.bias - self.running_mean * scale |
| scale = scale.reshape(1, -1, 1, 1) |
| bias = bias.reshape(1, -1, 1, 1) |
| out_dtype = x.dtype |
| return x * scale.to(out_dtype) + bias.to(out_dtype) |
| else: |
| |
| |
| return F.batch_norm( |
| x, |
| self.running_mean, |
| self.running_var, |
| self.weight, |
| self.bias, |
| training=False, |
| eps=self.eps, |
| ) |
|
|
| def _load_from_state_dict( |
| self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
| ): |
| num_batches_tracked_key = prefix + 'num_batches_tracked' |
| if num_batches_tracked_key in state_dict: |
| del state_dict[num_batches_tracked_key] |
| version = local_metadata.get("version", None) |
|
|
| if version is None or version < 2: |
| |
| |
| if prefix + "running_mean" not in state_dict: |
| state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) |
| if prefix + "running_var" not in state_dict: |
| state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) |
|
|
| super()._load_from_state_dict( |
| state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
| ) |
|
|
| def __repr__(self): |
| return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) |
|
|