File size: 6,309 Bytes
210f83d 5e75650 210f83d 321eba1 13257bc 210f83d 5e75650 210f83d 5e75650 210f83d 422ae83 210f83d 422ae83 210f83d a2dfad9 210f83d 422ae83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
class PreEmphasis(torch.nn.Module):
def __init__(self, coef: float = 0.97):
super().__init__()
self.coef = coef
self.register_buffer(
'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, input: torch.tensor) -> torch.tensor:
assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!'
# reflect padding to match lengths of in/out
input = input.unsqueeze(1)
input = F.pad(input, (1, 0), 'reflect')
return F.conv1d(input, self.flipped_filter).squeeze(1)
class LCNN(nn.Module):
def __init__(self, return_emb=False, fs=16000, num_class=2):
super(LCNN, self).__init__()
self.threshold = 0.65
self.return_emb = return_emb
if fs == 16000:
n_fft = 512
win_length = 400
hop_length = 160
f_max = 7600
elif fs == 22050:
n_fft = 704
win_length = 552
hop_length = 220
f_max = 10500
elif fs == 24000:
n_fft = 768
win_length = 600
hop_length = 240
f_max = 11000
else:
raise ValueError(f"Unsupported sample rate: {fs}")
# Feature Extraction Part (First Part)
self.dropout1 = nn.Dropout(0.2)
self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(5, 5),
padding=(2, 2), stride=(1, 1))
self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm6 = nn.BatchNorm2d(32)
self.conv7 = nn.Conv2d(in_channels=32, out_channels=96, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.maxpool9 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.batchnorm10 = nn.BatchNorm2d(48)
self.conv11 = nn.Conv2d(in_channels=48, out_channels=96, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm13 = nn.BatchNorm2d(48)
self.conv14 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.maxpool16 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv17 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm19 = nn.BatchNorm2d(64)
self.conv20 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.batchnorm22 = nn.BatchNorm2d(32)
self.conv23 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, 1),
padding=(0, 0), stride=(1, 1))
self.batchnorm25 = nn.BatchNorm2d(32)
self.conv26 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3),
padding=(1, 1), stride=(1, 1))
self.maxpool28 = nn.AdaptiveMaxPool2d((16, 8))
# Classification Part (Second Part)
self.fc29 = nn.Linear(32 * 16 * 8, 128)
self.batchnorm31 = nn.BatchNorm1d(64)
self.dropout2 = nn.Dropout(0.7)
self.fc32 = nn.Linear(64, num_class)
self.torchfbank = torch.nn.Sequential(
PreEmphasis(),
torchaudio.transforms.MelSpectrogram(sample_rate=fs, n_fft=n_fft, win_length=win_length, hop_length=hop_length, \
f_min=20, f_max=f_max, window_fn=torch.hamming_window, n_mels=80),
)
self.softmax = nn.Softmax(dim=1)
def mfm2(self, x):
out1, out2 = torch.chunk(x, 2, 1)
return torch.max(out1, out2)
def mfm3(self, x):
n, c, y, z = x.shape
out1, out2, out3 = torch.chunk(x, 3, 1)
res1 = torch.max(torch.max(out1, out2), out3)
tmp1 = out1.flatten()
tmp1 = tmp1.reshape(len(tmp1), -1)
tmp2 = out2.flatten()
tmp2 = tmp2.reshape(len(tmp2), -1)
tmp3 = out3.flatten()
tmp3 = tmp3.reshape(len(tmp3), -1)
res2 = torch.cat((tmp1, tmp2, tmp3), 1)
res2 = torch.median(res2, 1)[0]
res2 = res2.reshape(n, -1, y, z)
return torch.cat((res1, res2), 1)
def forward(self, x):
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x = self.torchfbank(x)+1e-6
x = x.log()
x = x - torch.mean(x, dim=-1, keepdim=True)
# Forward pass through the Feature Extraction Part
features = self.feature_extraction(x)
# Forward pass through the Classification Part
logits, emb = self.classification(features)
output = self.softmax(logits)
if self.return_emb:
return output, emb
else:
return output
def feature_extraction(self, x):
# Part 1: Feature Extraction
x = self.conv1(x.unsqueeze(1))
x = self.mfm2(x)
x = self.maxpool3(x)
x = self.conv4(x)
x = self.mfm2(x)
x = self.batchnorm6(x)
x = self.conv7(x)
x = self.mfm2(x)
x = self.maxpool9(x)
x = self.batchnorm10(x)
x = self.conv11(x)
x = self.mfm2(x)
x = self.batchnorm13(x)
x = self.conv14(x)
x = self.mfm2(x)
x = self.maxpool16(x)
return x
def classification(self, x):
x = self.conv17(x)
x = self.mfm2(x)
x = self.batchnorm19(x)
x = self.conv20(x)
x = self.mfm2(x)
x = self.batchnorm22(x)
x = self.conv23(x)
x = self.mfm2(x)
x = self.batchnorm25(x)
x = self.conv26(x)
x = self.mfm2(x)
x = self.maxpool28(x)
# Part 2: Classification
x = x.view(-1, 32 * 16 * 8)
emb = self.fc29(x)
x = self.mfm2(emb)
x = self.batchnorm31(x)
logits = self.fc32(x)
return logits, emb
|