import torch import torch.nn as nn import torch.nn.functional as F class PyTorchAudioModel(nn.Module): def __init__(self, num_labels=6): super().__init__() self.conv1 = nn.Conv1d(13, 64, kernel_size=5, padding="same") self.bn1 = nn.BatchNorm1d(64) self.pool1 = nn.MaxPool1d(2, 2) self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding="same") self.bn2 = nn.BatchNorm1d(128) self.pool2 = nn.MaxPool1d(2, 2) self.bilstm = nn.LSTM(128, 64, bidirectional=True, batch_first=True) self.dense1 = nn.Linear(128, 128) self.dense2 = nn.Linear(128, num_labels) def forward(self, x): x = self.pool1(F.relu(self.bn1(self.conv1(x)))) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) x = x.permute(0, 2, 1) # (B,C,L)->(B,L,C) _, (h_n, _) = self.bilstm(x) x = torch.cat([h_n[-2], h_n[-1]], dim=1) x = F.relu(self.dense1(x)) return self.dense2(x)