ispl_safe / src /lcnn_model.py
davesalvi's picture
LCNN thres
321eba1
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
class PreEmphasis(torch.nn.Module):
def __init__(self, coef: float = 0.97):
super().__init__()
self.coef = coef
self.register_buffer(
'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, input: torch.tensor) -> torch.tensor:
assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!'
# reflect padding to match lengths of in/out
input = input.unsqueeze(1)
input = F.pad(input, (1, 0), 'reflect')
return F.conv1d(input, self.flipped_filter).squeeze(1)
class LCNN(nn.Module):
def __init__(self, return_emb=False, fs=16000, num_class=2):
super(LCNN, self).__init__()
self.threshold = 0.65
self.return_emb = return_emb
if fs == 16000:
n_fft = 512
win_length = 400
hop_length = 160
f_max = 7600
elif fs == 22050:
n_fft = 704
win_length = 552
hop_length = 220
f_max = 10500
elif fs == 24000:
n_fft = 768
win_length = 600
hop_length = 240
f_max = 11000
else:
raise ValueError(f"Unsupported sample rate: {fs}")
# Feature Extraction Part (First Part)
self.dropout1 = nn.Dropout(0.2)
self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(5, 5),
padding=(2, 2), stride=(1, 1))
self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm6 = nn.BatchNorm2d(32)
self.conv7 = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.maxpool9 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.batchnorm10 = nn.BatchNorm2d(48)
self.conv11 = nn.Conv2d(in_channels=48, out_channels=96, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm13 = nn.BatchNorm2d(48)
self.conv14 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.maxpool16 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv17 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm19 = nn.BatchNorm2d(64)
self.conv20 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.batchnorm22 = nn.BatchNorm2d(32)
self.conv23 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm25 = nn.BatchNorm2d(32)
self.conv26 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.maxpool28 = nn.AdaptiveMaxPool2d((16, 8))
# Classification Part (Second Part)
self.fc29 = nn.Linear(32 * 16 * 8, 128)
self.batchnorm31 = nn.BatchNorm1d(64)
self.dropout2 = nn.Dropout(0.7)
self.fc32 = nn.Linear(64, num_class)
self.torchfbank = torch.nn.Sequential(
PreEmphasis(),
torchaudio.transforms.MelSpectrogram(sample_rate=fs, n_fft=n_fft, win_length=win_length, hop_length=hop_length, \
f_min=20, f_max=f_max, window_fn=torch.hamming_window, n_mels=80),
)
self.softmax = nn.Softmax(dim=1)
def mfm2(self, x):
out1, out2 = torch.chunk(x, 2, 1)
return torch.max(out1, out2)
def mfm3(self, x):
n, c, y, z = x.shape
out1, out2, out3 = torch.chunk(x, 3, 1)
res1 = torch.max(torch.max(out1, out2), out3)
tmp1 = out1.flatten()
tmp1 = tmp1.reshape(len(tmp1), -1)
tmp2 = out2.flatten()
tmp2 = tmp2.reshape(len(tmp2), -1)
tmp3 = out3.flatten()
tmp3 = tmp3.reshape(len(tmp3), -1)
res2 = torch.cat((tmp1, tmp2, tmp3), 1)
res2 = torch.median(res2, 1)[0]
res2 = res2.reshape(n, -1, y, z)
return torch.cat((res1, res2), 1)
def forward(self, x):
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x = self.torchfbank(x)+1e-6
x = x.log()
x = x - torch.mean(x, dim=-1, keepdim=True)
# Forward pass through the Feature Extraction Part
features = self.feature_extraction(x)
# Forward pass through the Classification Part
logits, emb = self.classification(features)
output = self.softmax(logits)
if self.return_emb:
return output, emb
else:
return output
def feature_extraction(self, x):
# Part 1: Feature Extraction
x = self.conv1(x.unsqueeze(1))
x = self.mfm2(x)
x = self.maxpool3(x)
x = self.conv4(x)
x = self.mfm2(x)
x = self.batchnorm6(x)
x = self.conv7(x)
x = self.mfm2(x)
x = self.maxpool9(x)
x = self.batchnorm10(x)
x = self.conv11(x)
x = self.mfm2(x)
x = self.batchnorm13(x)
x = self.conv14(x)
x = self.mfm2(x)
x = self.maxpool16(x)
return x
def classification(self, x):
x = self.conv17(x)
x = self.mfm2(x)
x = self.batchnorm19(x)
x = self.conv20(x)
x = self.mfm2(x)
x = self.batchnorm22(x)
x = self.conv23(x)
x = self.mfm2(x)
x = self.batchnorm25(x)
x = self.conv26(x)
x = self.mfm2(x)
x = self.maxpool28(x)
# Part 2: Classification
x = x.view(-1, 32 * 16 * 8)
emb = self.fc29(x)
x = self.mfm2(emb)
x = self.batchnorm31(x)
logits = self.fc32(x)
return logits, emb