|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class PyTorchAudioModel(nn.Module): |
|
|
def __init__(self, num_labels=6): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv1d(13, 64, kernel_size=5, padding="same") |
|
|
self.bn1 = nn.BatchNorm1d(64) |
|
|
self.pool1 = nn.MaxPool1d(2, 2) |
|
|
self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding="same") |
|
|
self.bn2 = nn.BatchNorm1d(128) |
|
|
self.pool2 = nn.MaxPool1d(2, 2) |
|
|
self.bilstm = nn.LSTM(128, 64, bidirectional=True, batch_first=True) |
|
|
self.dense1 = nn.Linear(128, 128) |
|
|
self.dense2 = nn.Linear(128, num_labels) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.pool1(F.relu(self.bn1(self.conv1(x)))) |
|
|
x = self.pool2(F.relu(self.bn2(self.conv2(x)))) |
|
|
x = x.permute(0, 2, 1) |
|
|
_, (h_n, _) = self.bilstm(x) |
|
|
x = torch.cat([h_n[-2], h_n[-1]], dim=1) |
|
|
x = F.relu(self.dense1(x)) |
|
|
return self.dense2(x) |
|
|
|