File size: 977 Bytes
3df0cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
import torch.nn.functional as F


class PyTorchAudioModel(nn.Module):
    def __init__(self, num_labels=6):
        super().__init__()
        self.conv1 = nn.Conv1d(13, 64, kernel_size=5, padding="same")
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2, 2)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding="same")
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2, 2)
        self.bilstm = nn.LSTM(128, 64, bidirectional=True, batch_first=True)
        self.dense1 = nn.Linear(128, 128)
        self.dense2 = nn.Linear(128, num_labels)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = x.permute(0, 2, 1)  # (B,C,L)->(B,L,C)
        _, (h_n, _) = self.bilstm(x)
        x = torch.cat([h_n[-2], h_n[-1]], dim=1)
        x = F.relu(self.dense1(x))
        return self.dense2(x)