File size: 5,291 Bytes
6d1502a c1d2c36 6d1502a 87ce4d3 6d1502a 87ce4d3 6d1502a 49c095a 6d1502a 87ce4d3 6d1502a 87ce4d3 6d1502a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
import torchvision
from torchvision.models import resnet18
def modify_for_grayscale(model):
first_conv_layer = model.conv1
new_first_conv_layer = nn.Conv2d(
in_channels=1,
out_channels=first_conv_layer.out_channels,
kernel_size=first_conv_layer.kernel_size,
stride=first_conv_layer.stride,
padding=first_conv_layer.padding,
bias=first_conv_layer.bias is not None
)
# Copy the weights from the original convolutional layer to the new one
with torch.no_grad():
new_first_conv_layer.weight[:, 0] = first_conv_layer.weight.mean(dim=1)
if first_conv_layer.bias is not None:
new_first_conv_layer.bias = first_conv_layer.bias
# Replace the first convolutional layer in the model
model.conv1 = new_first_conv_layer # resnet18
return model
class ResNet_LogSpec(nn.Module):
def __init__(self, sample_rate=16000, return_emb=False, num_class=2):
super(ResNet_LogSpec, self).__init__()
self.threshold = 0.1
self.return_emb = return_emb
if sample_rate == 16000:
n_fft = 512
win_length = 400
hop_length = 160
elif sample_rate == 22050:
n_fft = 704
win_length = 552
hop_length = 220
elif sample_rate == 24000:
n_fft = 768
win_length = 600
hop_length = 240
else:
raise ValueError(f"Unsupported sample rate: {sample_rate}")
self.sample_rate = sample_rate
self.stft = T.Spectrogram(n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
power=2, window_fn=torch.hamming_window)
self.model = resnet18(pretrained=False)
self.model = modify_for_grayscale(self.model)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Identity()
self.dropout = nn.Dropout(p=0.5, inplace=True)
self.embedding_layer = nn.Linear(num_ftrs, 256)
self.relu = nn.ReLU()
self.classifier = nn.Linear(256, num_class)
self.softmax = nn.Softmax(dim=1)
# self.model.fc = nn.Sequential(
# nn.Dropout(p=0.5, inplace=True),
# nn.Linear(num_ftrs, 256, bias=True),
# nn.ReLU(),
# nn.Linear(256, num_class)
# )
self.to_db = T.AmplitudeToDB()
self.normalize = torchvision.transforms.Normalize(mean=0.449, std=0.226)
def forward(self, x):
x = x.unsqueeze(1)
x = self.stft(x)
x = self.to_db(x)
x = self.normalize(x)
x = self.model(x)
x = self.dropout(x)
emb = self.embedding_layer(x)
x = self.relu(emb)
logits = self.classifier(x)
out = self.softmax(logits)
if self.return_emb:
return out, emb
else:
return out
class ResNet_MelSpec(nn.Module):
def __init__(self, sample_rate=16000, return_emb=False, num_class=2):
super(ResNet_MelSpec, self).__init__()
self.threshold = 0.4
self.return_emb = return_emb
if sample_rate == 16000:
n_fft = 512
win_length = 400
hop_length = 160
f_max = 7600
elif sample_rate == 22050:
n_fft = 704
win_length = 552
hop_length = 220
f_max = 10500
elif sample_rate == 24000:
n_fft = 768
win_length = 600
hop_length = 240
f_max = 11000
else:
raise ValueError(f"Unsupported sample rate: {sample_rate}")
self.melspectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=20, f_max=f_max,
n_mels=80, window_fn=torch.hamming_window
)
self.model = resnet18(pretrained=False)
self.model = modify_for_grayscale(self.model)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Identity()
self.dropout = nn.Dropout(p=0.5, inplace=True)
self.embedding_layer = nn.Linear(num_ftrs, 256)
self.relu = nn.ReLU()
self.classifier = nn.Linear(256, num_class)
self.softmax = nn.Softmax(dim=1)
# self.model.fc = nn.Sequential(
# nn.Dropout(p=0.5, inplace=True),
# nn.Linear(num_ftrs, 256, bias=True),
# nn.ReLU(),
# nn.Linear(256, num_class)
# )
self.to_db = torchaudio.transforms.AmplitudeToDB()
self.normalize = torchvision.transforms.Normalize(mean=0.449,std=0.226)
def forward(self, x):
x = x.unsqueeze(1)
x = self.melspectrogram(x)
x = self.to_db(x)
x = self.normalize(x)
x = self.model(x)
x = self.dropout(x)
emb = self.embedding_layer(x)
x = self.relu(emb)
logits = self.classifier(x)
out = self.softmax(logits)
if self.return_emb:
return out, emb
else:
return out
|