File size: 3,832 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model.weight_init import constant_init, kaiming_init
from torch.nn.modules.utils import _pair

from mmaction.registry import MODELS


@MODELS.register_module()
class ConvAudio(nn.Module):
    """Conv2d module for AudioResNet backbone.



        <https://arxiv.org/abs/2001.08740>`_.



    Args:

        in_channels (int): Same as ``nn.Conv2d``.

        out_channels (int): Same as ``nn.Conv2d``.

        kernel_size (Union[int, Tuple[int]]): Same as ``nn.Conv2d``.

        op (str): Operation to merge the output of freq

            and time feature map. Choices are ``sum`` and ``concat``.

            Defaults to ``concat``.

        stride (Union[int, Tuple[int]]): Same as ``nn.Conv2d``. Defaults to 1.

        padding (Union[int, Tuple[int]]): Same as ``nn.Conv2d``. Defaults to 0.

        dilation (Union[int, Tuple[int]]): Same as ``nn.Conv2d``.

            Defaults to 1.

        groups (int): Same as ``nn.Conv2d``. Defaults to 1.

        bias (Union[bool, str]): If specified as ``auto``, it will be decided

            by the ``norm_cfg``. Bias will be set as True if ``norm_cfg``

            is None, otherwise False. Defaults to False.

    """

    def __init__(self,

                 in_channels: int,

                 out_channels: int,

                 kernel_size: Union[int, Tuple[int]],

                 op: str = 'concat',

                 stride: Union[int, Tuple[int]] = 1,

                 padding: Union[int, Tuple[int]] = 0,

                 dilation: Union[int, Tuple[int]] = 1,

                 groups: int = 1,

                 bias: Union[bool, str] = False) -> None:
        super().__init__()

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        assert op in ['concat', 'sum']
        self.op = op
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.output_padding = (0, 0)
        self.transposed = False

        self.conv_1 = ConvModule(
            in_channels,
            out_channels,
            kernel_size=(kernel_size[0], 1),
            stride=stride,
            padding=(kernel_size[0] // 2, 0),
            bias=bias,
            conv_cfg=dict(type='Conv'),
            norm_cfg=dict(type='BN'),
            act_cfg=dict(type='ReLU'))

        self.conv_2 = ConvModule(
            in_channels,
            out_channels,
            kernel_size=(1, kernel_size[1]),
            stride=stride,
            padding=(0, kernel_size[1] // 2),
            bias=bias,
            conv_cfg=dict(type='Conv'),
            norm_cfg=dict(type='BN'),
            act_cfg=dict(type='ReLU'))

        self.init_weights()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call.



        Args:

            x (torch.Tensor): The input data.



        Returns:

            torch.Tensor: The output of the module.

        """
        x_1 = self.conv_1(x)
        x_2 = self.conv_2(x)
        if self.op == 'concat':
            out = torch.cat([x_1, x_2], 1)
        else:
            out = x_1 + x_2
        return out

    def init_weights(self) -> None:
        """Initiate the parameters from scratch."""
        kaiming_init(self.conv_1.conv)
        kaiming_init(self.conv_2.conv)
        constant_init(self.conv_1.bn, 1, bias=0)
        constant_init(self.conv_2.bn, 1, bias=0)