File size: 977 Bytes
3df0cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class PyTorchAudioModel(nn.Module):
def __init__(self, num_labels=6):
super().__init__()
self.conv1 = nn.Conv1d(13, 64, kernel_size=5, padding="same")
self.bn1 = nn.BatchNorm1d(64)
self.pool1 = nn.MaxPool1d(2, 2)
self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding="same")
self.bn2 = nn.BatchNorm1d(128)
self.pool2 = nn.MaxPool1d(2, 2)
self.bilstm = nn.LSTM(128, 64, bidirectional=True, batch_first=True)
self.dense1 = nn.Linear(128, 128)
self.dense2 = nn.Linear(128, num_labels)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = x.permute(0, 2, 1) # (B,C,L)->(B,L,C)
_, (h_n, _) = self.bilstm(x)
x = torch.cat([h_n[-2], h_n[-1]], dim=1)
x = F.relu(self.dense1(x))
return self.dense2(x)
|