|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class SubBatchNorm3D(nn.Module):
|
|
|
"""Sub BatchNorm3d splits the batch dimension into N splits, and run BN on
|
|
|
each of them separately (so that the stats are computed on each subset of
|
|
|
examples (1/N of batch) independently). During evaluation, it aggregates
|
|
|
the stats from all splits into one BN.
|
|
|
|
|
|
Args:
|
|
|
num_features (int): Dimensions of BatchNorm.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_features, **cfg):
|
|
|
super(SubBatchNorm3D, self).__init__()
|
|
|
|
|
|
self.num_features = num_features
|
|
|
self.cfg_ = deepcopy(cfg)
|
|
|
self.num_splits = self.cfg_.pop('num_splits', 1)
|
|
|
self.num_features_split = self.num_features * self.num_splits
|
|
|
|
|
|
self.cfg_['affine'] = False
|
|
|
self.bn = nn.BatchNorm3d(num_features, **self.cfg_)
|
|
|
self.split_bn = nn.BatchNorm3d(self.num_features_split, **self.cfg_)
|
|
|
self.init_weights(cfg)
|
|
|
|
|
|
def init_weights(self, cfg):
|
|
|
"""Initialize weights."""
|
|
|
if cfg.get('affine', True):
|
|
|
self.weight = torch.nn.Parameter(torch.ones(self.num_features))
|
|
|
self.bias = torch.nn.Parameter(torch.zeros(self.num_features))
|
|
|
self.affine = True
|
|
|
else:
|
|
|
self.affine = False
|
|
|
|
|
|
def _get_aggregated_mean_std(self, means, stds, n):
|
|
|
"""Calculate aggregated mean and std."""
|
|
|
mean = means.view(n, -1).sum(0) / n
|
|
|
std = stds.view(n, -1).sum(0) / n + (
|
|
|
(means.view(n, -1) - mean)**2).view(n, -1).sum(0) / n
|
|
|
return mean.detach(), std.detach()
|
|
|
|
|
|
def aggregate_stats(self):
|
|
|
"""Synchronize running_mean, and running_var to self.bn.
|
|
|
|
|
|
Call this before eval, then call model.eval(); When eval, forward
|
|
|
function will call self.bn instead of self.split_bn, During this time
|
|
|
the running_mean, and running_var of self.bn has been obtained from
|
|
|
self.split_bn.
|
|
|
"""
|
|
|
if self.split_bn.track_running_stats:
|
|
|
aggre_func = self._get_aggregated_mean_std
|
|
|
self.bn.running_mean.data, self.bn.running_var.data = aggre_func(
|
|
|
self.split_bn.running_mean, self.split_bn.running_var,
|
|
|
self.num_splits)
|
|
|
self.bn.num_batches_tracked = self.split_bn.num_batches_tracked.detach(
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Defines the computation performed at every call."""
|
|
|
if self.training:
|
|
|
n, c, t, h, w = x.shape
|
|
|
assert n % self.num_splits == 0
|
|
|
x = x.view(n // self.num_splits, c * self.num_splits, t, h, w)
|
|
|
x = self.split_bn(x)
|
|
|
x = x.view(n, c, t, h, w)
|
|
|
else:
|
|
|
x = self.bn(x)
|
|
|
if self.affine:
|
|
|
x = x * self.weight.view(-1, 1, 1, 1)
|
|
|
x = x + self.bias.view(-1, 1, 1, 1)
|
|
|
return x
|
|
|
|