File size: 8,132 Bytes
102ac67 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | # Copyright (c) Facebook, Inc. and its affiliates.
import torch
import torch.distributed as dist
from fvcore.nn.distributed import differentiable_all_reduce
from torch import nn
from torch.nn import functional as F
from .wrappers import BatchNorm2d
class FrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
It contains non-trainable buffers called
"weight" and "bias", "running_mean", "running_var",
initialized to perform identity transformation.
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
which are computed from the original four parameters of BN.
The affine transform `x * weight + bias` will perform the equivalent
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
When loading a backbone model from Caffe2, "running_mean" and "running_var"
will be left unchanged as identity transformation.
Other pre-trained backbone models may contain all 4 parameters.
The forward is implemented by `F.batch_norm(..., training=False)`.
"""
_version = 3
def __init__(self, num_features, eps=1e-5):
super().__init__()
self.num_features = num_features
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features) - eps)
def forward(self, x):
if x.requires_grad:
# When gradients are needed, F.batch_norm will use extra memory
# because its backward op computes gradients for weight/bias as well.
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
out_dtype = x.dtype # may be half
return x * scale.to(out_dtype) + bias.to(out_dtype)
else:
# When gradients are not needed, F.batch_norm is a single fused op
# and provide more optimization opportunities.
return F.batch_norm(
x,
self.running_mean,
self.running_var,
self.weight,
self.bias,
training=False,
eps=self.eps,
)
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# No running_mean/var in early versions
# This will silent the warnings
if prefix + "running_mean" not in state_dict:
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
if prefix + "running_var" not in state_dict:
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def __repr__(self):
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
@classmethod
def convert_frozen_batchnorm(cls, module):
"""
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
Args:
module (torch.nn.Module):
Returns:
If module is BatchNorm/SyncBatchNorm, returns a new module.
Otherwise, in-place convert module and return it.
Similar to convert_sync_batchnorm in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
"""
bn_module = nn.modules.batchnorm
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
res = module
if isinstance(module, bn_module):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def get_norm(norm, out_channels):
"""
Args:
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
if norm is None:
return None
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": BatchNorm2d,
# Fixed in https://github.com/pytorch/pytorch/pull/36382
"SyncBN": nn.SyncBatchNorm,
"FrozenBN": FrozenBatchNorm2d,
"GN": lambda channels: nn.GroupNorm(32, channels),
# for debugging:
"nnSyncBN": nn.SyncBatchNorm,
"LN": lambda channels: LayerNorm(channels)
}[norm]
return norm(out_channels)
class CycleBatchNormList(nn.ModuleList):
"""
Implement domain-specific BatchNorm by cycling.
When a BatchNorm layer is used for multiple input domains or input
features, it might need to maintain a separate test-time statistics
for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`.
This module implements it by using N separate BN layers
and it cycles through them every time a forward() is called.
NOTE: The caller of this module MUST guarantee to always call
this module by multiple of N times. Otherwise its test-time statistics
will be incorrect.
"""
def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs):
"""
Args:
length: number of BatchNorm layers to cycle.
bn_class: the BatchNorm class to use
kwargs: arguments of the BatchNorm class, such as num_features.
"""
self._affine = kwargs.pop("affine", True)
super().__init__([bn_class(**kwargs, affine=False) for k in range(length)])
if self._affine:
# shared affine, domain-specific BN
channels = self[0].num_features
self.weight = nn.Parameter(torch.ones(channels))
self.bias = nn.Parameter(torch.zeros(channels))
self._pos = 0
def forward(self, x):
ret = self[self._pos](x)
self._pos = (self._pos + 1) % len(self)
if self._affine:
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
return ret * w + b
else:
return ret
def extra_repr(self):
return f"affine={self._affine}"
class LayerNorm(nn.Module):
"""
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
variance normalization over the channel dimension for inputs that have shape
(batch_size, channels, height, width).
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape,)
def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x |