|
|
| import torch
|
| import torch.distributed as dist
|
| from fvcore.nn.distributed import differentiable_all_reduce
|
| from torch import nn
|
| from torch.nn import functional as F
|
|
|
| from detectron2.utils import comm, env
|
|
|
| from .wrappers import BatchNorm2d
|
|
|
|
|
| class FrozenBatchNorm2d(nn.Module):
|
| """
|
| BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
|
| It contains non-trainable buffers called
|
| "weight" and "bias", "running_mean", "running_var",
|
| initialized to perform identity transformation.
|
|
|
| The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
| which are computed from the original four parameters of BN.
|
| The affine transform `x * weight + bias` will perform the equivalent
|
| computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
| When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
| will be left unchanged as identity transformation.
|
|
|
| Other pre-trained backbone models may contain all 4 parameters.
|
|
|
| The forward is implemented by `F.batch_norm(..., training=False)`.
|
| """
|
|
|
| _version = 3
|
|
|
| def __init__(self, num_features, eps=1e-5):
|
| super().__init__()
|
| self.num_features = num_features
|
| self.eps = eps
|
| self.register_buffer("weight", torch.ones(num_features))
|
| self.register_buffer("bias", torch.zeros(num_features))
|
| self.register_buffer("running_mean", torch.zeros(num_features))
|
| self.register_buffer("running_var", torch.ones(num_features) - eps)
|
| self.register_buffer("num_batches_tracked", None)
|
|
|
| def forward(self, x):
|
| if x.requires_grad:
|
|
|
|
|
| scale = self.weight * (self.running_var + self.eps).rsqrt()
|
| bias = self.bias - self.running_mean * scale
|
| scale = scale.reshape(1, -1, 1, 1)
|
| bias = bias.reshape(1, -1, 1, 1)
|
| out_dtype = x.dtype
|
| return x * scale.to(out_dtype) + bias.to(out_dtype)
|
| else:
|
|
|
|
|
| return F.batch_norm(
|
| x,
|
| self.running_mean,
|
| self.running_var,
|
| self.weight,
|
| self.bias,
|
| training=False,
|
| eps=self.eps,
|
| )
|
|
|
| def _load_from_state_dict(
|
| self,
|
| state_dict,
|
| prefix,
|
| local_metadata,
|
| strict,
|
| missing_keys,
|
| unexpected_keys,
|
| error_msgs,
|
| ):
|
| version = local_metadata.get("version", None)
|
|
|
| if version is None or version < 2:
|
|
|
|
|
| if prefix + "running_mean" not in state_dict:
|
| state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
| if prefix + "running_var" not in state_dict:
|
| state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
|
|
| super()._load_from_state_dict(
|
| state_dict,
|
| prefix,
|
| local_metadata,
|
| strict,
|
| missing_keys,
|
| unexpected_keys,
|
| error_msgs,
|
| )
|
|
|
| def __repr__(self):
|
| return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
|
|
| @classmethod
|
| def convert_frozen_batchnorm(cls, module):
|
| """
|
| Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
|
|
| Args:
|
| module (torch.nn.Module):
|
|
|
| Returns:
|
| If module is BatchNorm/SyncBatchNorm, returns a new module.
|
| Otherwise, in-place convert module and return it.
|
|
|
| Similar to convert_sync_batchnorm in
|
| https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
| """
|
| bn_module = nn.modules.batchnorm
|
| bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
| res = module
|
| if isinstance(module, bn_module):
|
| res = cls(module.num_features)
|
| if module.affine:
|
| res.weight.data = module.weight.data.clone().detach()
|
| res.bias.data = module.bias.data.clone().detach()
|
| res.running_mean.data = module.running_mean.data
|
| res.running_var.data = module.running_var.data
|
| res.eps = module.eps
|
| res.num_batches_tracked = module.num_batches_tracked
|
| else:
|
| for name, child in module.named_children():
|
| new_child = cls.convert_frozen_batchnorm(child)
|
| if new_child is not child:
|
| res.add_module(name, new_child)
|
| return res
|
|
|
| @classmethod
|
| def convert_frozenbatchnorm2d_to_batchnorm2d(cls, module: nn.Module) -> nn.Module:
|
| """
|
| Convert all FrozenBatchNorm2d to BatchNorm2d
|
|
|
| Args:
|
| module (torch.nn.Module):
|
|
|
| Returns:
|
| If module is FrozenBatchNorm2d, returns a new module.
|
| Otherwise, in-place convert module and return it.
|
|
|
| This is needed for quantization:
|
| https://fb.workplace.com/groups/1043663463248667/permalink/1296330057982005/
|
| """
|
|
|
| res = module
|
| if isinstance(module, FrozenBatchNorm2d):
|
| res = torch.nn.BatchNorm2d(module.num_features, module.eps)
|
|
|
| res.weight.data = module.weight.data.clone().detach()
|
| res.bias.data = module.bias.data.clone().detach()
|
| res.running_mean.data = module.running_mean.data.clone().detach()
|
| res.running_var.data = module.running_var.data.clone().detach()
|
| res.eps = module.eps
|
| res.num_batches_tracked = module.num_batches_tracked
|
| else:
|
| for name, child in module.named_children():
|
| new_child = cls.convert_frozenbatchnorm2d_to_batchnorm2d(child)
|
| if new_child is not child:
|
| res.add_module(name, new_child)
|
| return res
|
|
|
|
|
| def get_norm(norm, out_channels):
|
| """
|
| Args:
|
| norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
| or a callable that takes a channel number and returns
|
| the normalization layer as a nn.Module.
|
|
|
| Returns:
|
| nn.Module or None: the normalization layer
|
| """
|
| if norm is None:
|
| return None
|
| if isinstance(norm, str):
|
| if len(norm) == 0:
|
| return None
|
| norm = {
|
| "BN": BatchNorm2d,
|
|
|
| "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
|
| "FrozenBN": FrozenBatchNorm2d,
|
| "GN": lambda channels: nn.GroupNorm(32, channels),
|
|
|
| "nnSyncBN": nn.SyncBatchNorm,
|
| "naiveSyncBN": NaiveSyncBatchNorm,
|
|
|
| "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"),
|
| "LN": lambda channels: LayerNorm(channels),
|
| }[norm]
|
| return norm(out_channels)
|
|
|
|
|
| class NaiveSyncBatchNorm(BatchNorm2d):
|
| """
|
| In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
|
| when the batch size on each worker is different.
|
| (e.g., when scale augmentation is used, or when it is applied to mask head).
|
|
|
| This is a slower but correct alternative to `nn.SyncBatchNorm`.
|
|
|
| Note:
|
| There isn't a single definition of Sync BatchNorm.
|
|
|
| When ``stats_mode==""``, this module computes overall statistics by using
|
| statistics of each worker with equal weight. The result is true statistics
|
| of all samples (as if they are all on one worker) only when all workers
|
| have the same (N, H, W). This mode does not support inputs with zero batch size.
|
|
|
| When ``stats_mode=="N"``, this module computes overall statistics by weighting
|
| the statistics of each worker by their ``N``. The result is true statistics
|
| of all samples (as if they are all on one worker) only when all workers
|
| have the same (H, W). It is slower than ``stats_mode==""``.
|
|
|
| Even though the result of this module may not be the true statistics of all samples,
|
| it may still be reasonable because it might be preferrable to assign equal weights
|
| to all workers, regardless of their (H, W) dimension, instead of putting larger weight
|
| on larger images. From preliminary experiments, little difference is found between such
|
| a simplified implementation and an accurate computation of overall mean & variance.
|
| """
|
|
|
| def __init__(self, *args, stats_mode="", **kwargs):
|
| super().__init__(*args, **kwargs)
|
| assert stats_mode in ["", "N"]
|
| self._stats_mode = stats_mode
|
|
|
| def forward(self, input):
|
| if comm.get_world_size() == 1 or not self.training:
|
| return super().forward(input)
|
|
|
| B, C = input.shape[0], input.shape[1]
|
|
|
| half_input = input.dtype == torch.float16
|
| if half_input:
|
|
|
| input = input.float()
|
| mean = torch.mean(input, dim=[0, 2, 3])
|
| meansqr = torch.mean(input * input, dim=[0, 2, 3])
|
|
|
| if self._stats_mode == "":
|
| assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
|
| vec = torch.cat([mean, meansqr], dim=0)
|
| vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
|
| mean, meansqr = torch.split(vec, C)
|
| momentum = self.momentum
|
| else:
|
| if B == 0:
|
| vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
|
| vec = vec + input.sum()
|
| else:
|
| vec = torch.cat(
|
| [
|
| mean,
|
| meansqr,
|
| torch.ones([1], device=mean.device, dtype=mean.dtype),
|
| ],
|
| dim=0,
|
| )
|
| vec = differentiable_all_reduce(vec * B)
|
|
|
| total_batch = vec[-1].detach()
|
| momentum = total_batch.clamp(max=1) * self.momentum
|
| mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C)
|
|
|
| var = meansqr - mean * mean
|
| invstd = torch.rsqrt(var + self.eps)
|
| scale = self.weight * invstd
|
| bias = self.bias - mean * scale
|
| scale = scale.reshape(1, -1, 1, 1)
|
| bias = bias.reshape(1, -1, 1, 1)
|
|
|
| self.running_mean += momentum * (mean.detach() - self.running_mean)
|
| self.running_var += momentum * (var.detach() - self.running_var)
|
| ret = input * scale + bias
|
| if half_input:
|
| ret = ret.half()
|
| return ret
|
|
|
|
|
| class CycleBatchNormList(nn.ModuleList):
|
| """
|
| Implement domain-specific BatchNorm by cycling.
|
|
|
| When a BatchNorm layer is used for multiple input domains or input
|
| features, it might need to maintain a separate test-time statistics
|
| for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`.
|
|
|
| This module implements it by using N separate BN layers
|
| and it cycles through them every time a forward() is called.
|
|
|
| NOTE: The caller of this module MUST guarantee to always call
|
| this module by multiple of N times. Otherwise its test-time statistics
|
| will be incorrect.
|
| """
|
|
|
| def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs):
|
| """
|
| Args:
|
| length: number of BatchNorm layers to cycle.
|
| bn_class: the BatchNorm class to use
|
| kwargs: arguments of the BatchNorm class, such as num_features.
|
| """
|
| self._affine = kwargs.pop("affine", True)
|
| super().__init__([bn_class(**kwargs, affine=False) for k in range(length)])
|
| if self._affine:
|
|
|
| channels = self[0].num_features
|
| self.weight = nn.Parameter(torch.ones(channels))
|
| self.bias = nn.Parameter(torch.zeros(channels))
|
| self._pos = 0
|
|
|
| def forward(self, x):
|
| ret = self[self._pos](x)
|
| self._pos = (self._pos + 1) % len(self)
|
|
|
| if self._affine:
|
| w = self.weight.reshape(1, -1, 1, 1)
|
| b = self.bias.reshape(1, -1, 1, 1)
|
| return ret * w + b
|
| else:
|
| return ret
|
|
|
| def extra_repr(self):
|
| return f"affine={self._affine}"
|
|
|
|
|
| class LayerNorm(nn.Module):
|
| """
|
| A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
|
| variance normalization over the channel dimension for inputs that have shape
|
| (batch_size, channels, height, width).
|
| https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
|
| """
|
|
|
| def __init__(self, normalized_shape, eps=1e-6):
|
| super().__init__()
|
| self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| self.eps = eps
|
| self.normalized_shape = (normalized_shape,)
|
|
|
| def forward(self, x):
|
| u = x.mean(1, keepdim=True)
|
| s = (x - u).pow(2).mean(1, keepdim=True)
|
| x = (x - u) / torch.sqrt(s + self.eps)
|
| x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| return x
|
|
|