File size: 4,121 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Optional

import torch
import torch.nn as nn
from torch.utils import checkpoint as cp

from mmaction.registry import MODELS
from ..common import TAM
from .resnet import Bottleneck, ResNet


class TABlock(nn.Module):
    """Temporal Adaptive Block (TA-Block) for TANet.



    This block is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO

    RECOGNITION <https://arxiv.org/pdf/2005.06803>`_



    The temporal adaptive module (TAM) is embedded into ResNet-Block

    after the first Conv2D, which turns the vanilla ResNet-Block

    into TA-Block.



    Args:

        block (nn.Module): Residual blocks to be substituted.

        num_segments (int): Number of frame segments.

        tam_cfg (dict): Config for temporal adaptive module (TAM).

    """

    def __init__(self, block: nn.Module, num_segments: int,

                 tam_cfg: dict) -> None:
        super().__init__()
        self.tam_cfg = deepcopy(tam_cfg)
        self.block = block
        self.num_segments = num_segments
        self.tam = TAM(
            in_channels=block.conv1.out_channels,
            num_segments=num_segments,
            **self.tam_cfg)

        if not isinstance(self.block, Bottleneck):
            raise NotImplementedError('TA-Blocks have not been fully '
                                      'implemented except the pattern based '
                                      'on Bottleneck block.')

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        assert isinstance(self.block, Bottleneck)

        def _inner_forward(x):
            """Forward wrapper for utilizing checkpoint."""
            identity = x

            out = self.block.conv1(x)
            out = self.tam(out)
            out = self.block.conv2(out)
            out = self.block.conv3(out)

            if self.block.downsample is not None:
                identity = self.block.downsample(x)

            out = out + identity

            return out

        if self.block.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.block.relu(out)

        return out


@MODELS.register_module()
class TANet(ResNet):
    """Temporal Adaptive Network (TANet) backbone.



    This backbone is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO

    RECOGNITION <https://arxiv.org/pdf/2005.06803>`_



    Embedding the temporal adaptive module (TAM) into ResNet to

    instantiate TANet.



    Args:

        depth (int): Depth of resnet, from ``{18, 34, 50, 101, 152}``.

        num_segments (int): Number of frame segments.

        tam_cfg (dict, optional): Config for temporal adaptive module (TAM).

            Defaults to None.

    """

    def __init__(self,

                 depth: int,

                 num_segments: int,

                 tam_cfg: Optional[dict] = None,

                 **kwargs) -> None:
        super().__init__(depth, **kwargs)
        assert num_segments >= 3
        self.num_segments = num_segments
        tam_cfg = dict() if tam_cfg is None else tam_cfg
        self.tam_cfg = deepcopy(tam_cfg)
        super().init_weights()
        self.make_tam_modeling()

    def init_weights(self):
        """Initialize weights."""
        pass

    def make_tam_modeling(self):
        """Replace ResNet-Block with TA-Block."""

        def make_tam_block(stage, num_segments, tam_cfg=dict()):
            blocks = list(stage.children())
            for i, block in enumerate(blocks):
                blocks[i] = TABlock(block, num_segments, deepcopy(tam_cfg))
            return nn.Sequential(*blocks)

        for i in range(self.num_stages):
            layer_name = f'layer{i + 1}'
            res_layer = getattr(self, layer_name)
            setattr(self, layer_name,
                    make_tam_block(res_layer, self.num_segments, self.tam_cfg))