File size: 1,637 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Copyright (c) OpenMMLab. All rights reserved.
from mmaction.registry import MODELS
from mmaction.utils import get_str_type
from .resnet3d import ResNet3d


@MODELS.register_module()
class ResNet2Plus1d(ResNet3d):
    """ResNet (2+1)d backbone.



    This model is proposed in `A Closer Look at Spatiotemporal Convolutions for

    Action Recognition <https://arxiv.org/abs/1711.11248>`_

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.pretrained2d is False
        assert get_str_type(self.conv_cfg['type']) == 'Conv2plus1d'

    def _freeze_stages(self):
        """Prevent all the parameters from being optimized before

        ``self.frozen_stages``."""
        if self.frozen_stages >= 0:
            self.conv1.eval()
            for param in self.conv1.parameters():
                param.requires_grad = False

        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, f'layer{i}')
            m.eval()
            for param in m.parameters():
                param.requires_grad = False

    def forward(self, x):
        """Defines the computation performed at every call.



        Args:

            x (torch.Tensor): The input data.



        Returns:

            torch.Tensor: The feature of the input

            samples extracted by the backbone.

        """
        x = self.conv1(x)
        x = self.maxpool(x)
        for layer_name in self.res_layers:
            res_layer = getattr(self, layer_name)
            # no pool2 in R(2+1)d
            x = res_layer(x)

        return x