File size: 2,234 Bytes
3461076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import math

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=(3,3), pool=(2,2)):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size,
                      padding=(kernel_size[0]//2, kernel_size[1]//2)),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.MaxPool2d(pool),
            nn.Dropout2d(0.2)
        )

    def forward(self, x):
        return self.net(x)


class AttentionLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, lstm_out):
        weights = torch.softmax(self.attention(lstm_out), dim=1)
        return torch.sum(weights * lstm_out, dim=1)


class CLSTMModel(nn.Module):
    def __init__(
        self,
        n_mels=40,
        n_classes=8,
        conv_channels=[32, 64, 128],
        lstm_hidden=128,
        lstm_layers=2,
        dropout=0.4
    ):
        super().__init__()

        self.conv1 = ConvBlock(1, conv_channels[0])
        self.conv2 = ConvBlock(conv_channels[0], conv_channels[1])
        self.conv3 = ConvBlock(conv_channels[1], conv_channels[2])

        freq_after = math.ceil(n_mels / (2 ** 3))
        self.lstm_input = conv_channels[2] * freq_after

        self.lstm = nn.LSTM(
            self.lstm_input,
            lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if lstm_layers > 1 else 0
        )

        self.attention = AttentionLayer(lstm_hidden * 2)

        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        b, c, f, t = x.size()
        x = x.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)

        out, _ = self.lstm(x)
        out = self.attention(out)
        return self.classifier(out)