File size: 3,100 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

import torch
import torch.nn as nn

from mmaction.registry import MODELS


@MODELS.register_module()
class SubBatchNorm3D(nn.Module):
    """Sub BatchNorm3d splits the batch dimension into N splits, and run BN on

    each of them separately (so that the stats are computed on each subset of

    examples (1/N of batch) independently). During evaluation, it aggregates

    the stats from all splits into one BN.



    Args:

        num_features (int): Dimensions of BatchNorm.

    """

    def __init__(self, num_features, **cfg):
        super(SubBatchNorm3D, self).__init__()

        self.num_features = num_features
        self.cfg_ = deepcopy(cfg)
        self.num_splits = self.cfg_.pop('num_splits', 1)
        self.num_features_split = self.num_features * self.num_splits
        # only keep one set of affine params, not in .bn or .split_bn
        self.cfg_['affine'] = False
        self.bn = nn.BatchNorm3d(num_features, **self.cfg_)
        self.split_bn = nn.BatchNorm3d(self.num_features_split, **self.cfg_)
        self.init_weights(cfg)

    def init_weights(self, cfg):
        """Initialize weights."""
        if cfg.get('affine', True):
            self.weight = torch.nn.Parameter(torch.ones(self.num_features))
            self.bias = torch.nn.Parameter(torch.zeros(self.num_features))
            self.affine = True
        else:
            self.affine = False

    def _get_aggregated_mean_std(self, means, stds, n):
        """Calculate aggregated mean and std."""
        mean = means.view(n, -1).sum(0) / n
        std = stds.view(n, -1).sum(0) / n + (
            (means.view(n, -1) - mean)**2).view(n, -1).sum(0) / n
        return mean.detach(), std.detach()

    def aggregate_stats(self):
        """Synchronize running_mean, and running_var to self.bn.



        Call this before eval, then call model.eval(); When eval, forward

        function will call self.bn instead of self.split_bn, During this time

        the running_mean, and running_var of self.bn has been obtained from

        self.split_bn.

        """
        if self.split_bn.track_running_stats:
            aggre_func = self._get_aggregated_mean_std
            self.bn.running_mean.data, self.bn.running_var.data = aggre_func(
                self.split_bn.running_mean, self.split_bn.running_var,
                self.num_splits)
        self.bn.num_batches_tracked = self.split_bn.num_batches_tracked.detach(
        )

    def forward(self, x):
        """Defines the computation performed at every call."""
        if self.training:
            n, c, t, h, w = x.shape
            assert n % self.num_splits == 0
            x = x.view(n // self.num_splits, c * self.num_splits, t, h, w)
            x = self.split_bn(x)
            x = x.view(n, c, t, h, w)
        else:
            x = self.bn(x)
        if self.affine:
            x = x * self.weight.view(-1, 1, 1, 1)
            x = x + self.bias.view(-1, 1, 1, 1)
        return x