File size: 1,727 Bytes
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch.nn as nn


class GLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        out, gate = x.chunk(2, dim=self.dim)

        return out * gate.sigmoid()


class conform_conv(nn.Module):
    def __init__(
        self, channels: int, kernel_size: int = 31, DropoutL=0.1, bias: bool = True
    ):
        super().__init__()
        self.act2 = nn.SiLU()
        self.act1 = GLU(1)

        self.pointwise_conv1 = nn.Conv1d(
            channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
        )

        # self.lorder is used to distinguish if it's a causal convolution,
        # if self.lorder > 0:
        #    it's a causal convolution, the input will be padded with
        #    `self.lorder` frames on the left in forward (causal conv impl).
        # else: it's a symmetrical convolution

        assert (kernel_size - 1) % 2 == 0
        padding = (kernel_size - 1) // 2

        self.depthwise_conv = nn.Conv1d(
            channels,
            channels,
            kernel_size,
            stride=1,
            padding=padding,
            groups=channels,
            bias=bias,
        )

        self.norm = nn.BatchNorm1d(channels)

        self.pointwise_conv2 = nn.Conv1d(
            channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
        )
        self.drop = nn.Dropout(DropoutL) if DropoutL > 0.0 else nn.Identity()

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.act1(self.pointwise_conv1(x))
        x = self.depthwise_conv(x)
        x = self.norm(x)
        x = self.act2(x)
        x = self.pointwise_conv2(x)
        return self.drop(x).transpose(1, 2)