Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_ch, out_ch, kernel_size=(3,3), pool=(2,2)): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, kernel_size, | |
| padding=(kernel_size[0]//2, kernel_size[1]//2)), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(), | |
| nn.MaxPool2d(pool), | |
| nn.Dropout2d(0.2) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class AttentionLayer(nn.Module): | |
| def __init__(self, hidden_dim): | |
| super().__init__() | |
| self.attention = nn.Linear(hidden_dim, 1) | |
| def forward(self, lstm_out): | |
| weights = torch.softmax(self.attention(lstm_out), dim=1) | |
| return torch.sum(weights * lstm_out, dim=1) | |
| class CLSTMModel(nn.Module): | |
| def __init__( | |
| self, | |
| n_mels=40, | |
| n_classes=8, | |
| conv_channels=[32, 64, 128], | |
| lstm_hidden=128, | |
| lstm_layers=2, | |
| dropout=0.4 | |
| ): | |
| super().__init__() | |
| self.conv1 = ConvBlock(1, conv_channels[0]) | |
| self.conv2 = ConvBlock(conv_channels[0], conv_channels[1]) | |
| self.conv3 = ConvBlock(conv_channels[1], conv_channels[2]) | |
| freq_after = math.ceil(n_mels / (2 ** 3)) | |
| self.lstm_input = conv_channels[2] * freq_after | |
| self.lstm = nn.LSTM( | |
| self.lstm_input, | |
| lstm_hidden, | |
| num_layers=lstm_layers, | |
| batch_first=True, | |
| bidirectional=True, | |
| dropout=dropout if lstm_layers > 1 else 0 | |
| ) | |
| self.attention = AttentionLayer(lstm_hidden * 2) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(lstm_hidden * 2, 256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, n_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = self.conv3(x) | |
| b, c, f, t = x.size() | |
| x = x.permute(0, 3, 1, 2).contiguous().view(b, t, c * f) | |
| out, _ = self.lstm(x) | |
| out = self.attention(out) | |
| return self.classifier(out) | |