File size: 4,890 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F


class TAM(nn.Module):
    """Temporal Adaptive Module(TAM) for TANet.



    This module is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO

    RECOGNITION <https://arxiv.org/pdf/2005.06803>`_



    Args:

        in_channels (int): Channel num of input features.

        num_segments (int): Number of frame segments.

        alpha (int): ``alpha`` in the paper and is the ratio of the

            intermediate channel number to the initial channel number in the

            global branch. Defaults to 2.

        adaptive_kernel_size (int): ``K`` in the paper and is the size of the

            adaptive kernel size in the global branch. Defaults to 3.

        beta (int): ``beta`` in the paper and is set to control the model

            complexity in the local branch. Defaults to 4.

        conv1d_kernel_size (int): Size of the convolution kernel of Conv1d in

            the local branch. Defaults to 3.

        adaptive_convolution_stride (int): The first dimension of strides in

            the adaptive convolution of ``Temporal Adaptive Aggregation``.

            Defaults to 1.

        adaptive_convolution_padding (int): The first dimension of paddings in

            the adaptive convolution of ``Temporal Adaptive Aggregation``.

            Defaults to 1.

        init_std (float): Std value for initiation of `nn.Linear`. Defaults to

            0.001.

    """

    def __init__(self,

                 in_channels: int,

                 num_segments: int,

                 alpha: int = 2,

                 adaptive_kernel_size: int = 3,

                 beta: int = 4,

                 conv1d_kernel_size: int = 3,

                 adaptive_convolution_stride: int = 1,

                 adaptive_convolution_padding: int = 1,

                 init_std: float = 0.001) -> None:
        super().__init__()

        assert beta > 0 and alpha > 0
        self.in_channels = in_channels
        self.num_segments = num_segments
        self.alpha = alpha
        self.adaptive_kernel_size = adaptive_kernel_size
        self.beta = beta
        self.conv1d_kernel_size = conv1d_kernel_size
        self.adaptive_convolution_stride = adaptive_convolution_stride
        self.adaptive_convolution_padding = adaptive_convolution_padding
        self.init_std = init_std

        self.G = nn.Sequential(
            nn.Linear(num_segments, num_segments * alpha, bias=False),
            nn.BatchNorm1d(num_segments * alpha), nn.ReLU(inplace=True),
            nn.Linear(num_segments * alpha, adaptive_kernel_size, bias=False),
            nn.Softmax(-1))

        self.L = nn.Sequential(
            nn.Conv1d(
                in_channels,
                in_channels // beta,
                conv1d_kernel_size,
                stride=1,
                padding=conv1d_kernel_size // 2,
                bias=False), nn.BatchNorm1d(in_channels // beta),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels // beta, in_channels, 1, bias=False),
            nn.Sigmoid())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call.



        Args:

            x (torch.Tensor): The input data.



        Returns:

            torch.Tensor: The output of the module.

        """
        # [n, c, h, w]
        n, c, h, w = x.size()
        num_segments = self.num_segments
        num_batches = n // num_segments
        assert c == self.in_channels

        # [num_batches, c, num_segments, h, w]
        x = x.view(num_batches, num_segments, c, h, w)
        x = x.permute(0, 2, 1, 3, 4).contiguous()

        # [num_batches * c, num_segments, 1, 1]
        theta_out = F.adaptive_avg_pool2d(
            x.view(-1, num_segments, h, w), (1, 1))

        # [num_batches * c, 1, adaptive_kernel_size, 1]
        conv_kernel = self.G(theta_out.view(-1, num_segments)).view(
            num_batches * c, 1, -1, 1)

        # [num_batches, c, num_segments, 1, 1]
        local_activation = self.L(theta_out.view(-1, c, num_segments)).view(
            num_batches, c, num_segments, 1, 1)

        # [num_batches, c, num_segments, h, w]
        new_x = x * local_activation

        # [1, num_batches * c, num_segments, h * w]
        y = F.conv2d(
            new_x.view(1, num_batches * c, num_segments, h * w),
            conv_kernel,
            bias=None,
            stride=(self.adaptive_convolution_stride, 1),
            padding=(self.adaptive_convolution_padding, 0),
            groups=num_batches * c)

        # [n, c, h, w]
        y = y.view(num_batches, c, num_segments, h, w)
        y = y.permute(0, 2, 1, 3, 4).contiguous().view(n, c, h, w)

        return y