| |
| |
|
|
| """BatchNorm (BN) utility functions and custom batch-size BN implementations""" |
|
|
| from functools import partial |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from pytorchvideo.layers.batch_norm import NaiveSyncBatchNorm3d |
|
|
|
|
| def get_norm(cfg): |
| """ |
| Args: |
| cfg (CfgNode): model building configs, details are in the comments of |
| the config file. |
| Returns: |
| nn.Module: the normalization layer. |
| """ |
| if cfg.BN.NORM_TYPE in {"batchnorm", "sync_batchnorm_apex"}: |
| return nn.BatchNorm3d |
| elif cfg.BN.NORM_TYPE == "sub_batchnorm": |
| return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) |
| elif cfg.BN.NORM_TYPE == "sync_batchnorm": |
| return partial( |
| NaiveSyncBatchNorm3d, |
| num_sync_devices=cfg.BN.NUM_SYNC_DEVICES, |
| global_sync=cfg.BN.GLOBAL_SYNC, |
| ) |
| else: |
| raise NotImplementedError( |
| "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) |
| ) |
|
|
|
|
| class SubBatchNorm3d(nn.Module): |
| """ |
| The standard BN layer computes stats across all examples in a GPU. In some |
| cases it is desirable to compute stats across only a subset of examples |
| (e.g., in multigrid training https://arxiv.org/abs/1912.00998). |
| SubBatchNorm3d splits the batch dimension into N splits, and run BN on |
| each of them separately (so that the stats are computed on each subset of |
| examples (1/N of batch) independently. During evaluation, it aggregates |
| the stats from all splits into one BN. |
| """ |
|
|
| def __init__(self, num_splits, **args): |
| """ |
| Args: |
| num_splits (int): number of splits. |
| args (list): other arguments. |
| """ |
| super(SubBatchNorm3d, self).__init__() |
| self.num_splits = num_splits |
| num_features = args["num_features"] |
| |
| if args.get("affine", True): |
| self.affine = True |
| args["affine"] = False |
| self.weight = torch.nn.Parameter(torch.ones(num_features)) |
| self.bias = torch.nn.Parameter(torch.zeros(num_features)) |
| else: |
| self.affine = False |
| self.bn = nn.BatchNorm3d(**args) |
| args["num_features"] = num_features * num_splits |
| self.split_bn = nn.BatchNorm3d(**args) |
|
|
| def _get_aggregated_mean_std(self, means, stds, n): |
| """ |
| Calculate the aggregated mean and stds. |
| Args: |
| means (tensor): mean values. |
| stds (tensor): standard deviations. |
| n (int): number of sets of means and stds. |
| """ |
| mean = means.view(n, -1).sum(0) / n |
| std = ( |
| stds.view(n, -1).sum(0) / n |
| + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n |
| ) |
| return mean.detach(), std.detach() |
|
|
| def aggregate_stats(self): |
| """ |
| Synchronize running_mean, and running_var. Call this before eval. |
| """ |
| if self.split_bn.track_running_stats: |
| ( |
| self.bn.running_mean.data, |
| self.bn.running_var.data, |
| ) = self._get_aggregated_mean_std( |
| self.split_bn.running_mean, |
| self.split_bn.running_var, |
| self.num_splits, |
| ) |
|
|
| def forward(self, x): |
| if self.training: |
| n, c, t, h, w = x.shape |
| x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) |
| x = self.split_bn(x) |
| x = x.view(n, c, t, h, w) |
| else: |
| x = self.bn(x) |
| if self.affine: |
| x = x * self.weight.view((-1, 1, 1, 1)) |
| x = x + self.bias.view((-1, 1, 1, 1)) |
| return x |
|
|