ispl_safe / src /resnet_model.py
davesalvi's picture
change model
49c095a
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
import torchvision
from torchvision.models import resnet18
def modify_for_grayscale(model):
first_conv_layer = model.conv1
new_first_conv_layer = nn.Conv2d(
in_channels=1,
out_channels=first_conv_layer.out_channels,
kernel_size=first_conv_layer.kernel_size,
stride=first_conv_layer.stride,
padding=first_conv_layer.padding,
bias=first_conv_layer.bias is not None
)
# Copy the weights from the original convolutional layer to the new one
with torch.no_grad():
new_first_conv_layer.weight[:, 0] = first_conv_layer.weight.mean(dim=1)
if first_conv_layer.bias is not None:
new_first_conv_layer.bias = first_conv_layer.bias
# Replace the first convolutional layer in the model
model.conv1 = new_first_conv_layer # resnet18
return model
class ResNet_LogSpec(nn.Module):
def __init__(self, sample_rate=16000, return_emb=False, num_class=2):
super(ResNet_LogSpec, self).__init__()
self.threshold = 0.1
self.return_emb = return_emb
if sample_rate == 16000:
n_fft = 512
win_length = 400
hop_length = 160
elif sample_rate == 22050:
n_fft = 704
win_length = 552
hop_length = 220
elif sample_rate == 24000:
n_fft = 768
win_length = 600
hop_length = 240
else:
raise ValueError(f"Unsupported sample rate: {sample_rate}")
self.sample_rate = sample_rate
self.stft = T.Spectrogram(n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
power=2, window_fn=torch.hamming_window)
self.model = resnet18(pretrained=False)
self.model = modify_for_grayscale(self.model)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Identity()
self.dropout = nn.Dropout(p=0.5, inplace=True)
self.embedding_layer = nn.Linear(num_ftrs, 256)
self.relu = nn.ReLU()
self.classifier = nn.Linear(256, num_class)
self.softmax = nn.Softmax(dim=1)
# self.model.fc = nn.Sequential(
# nn.Dropout(p=0.5, inplace=True),
# nn.Linear(num_ftrs, 256, bias=True),
# nn.ReLU(),
# nn.Linear(256, num_class)
# )
self.to_db = T.AmplitudeToDB()
self.normalize = torchvision.transforms.Normalize(mean=0.449, std=0.226)
def forward(self, x):
x = x.unsqueeze(1)
x = self.stft(x)
x = self.to_db(x)
x = self.normalize(x)
x = self.model(x)
x = self.dropout(x)
emb = self.embedding_layer(x)
x = self.relu(emb)
logits = self.classifier(x)
out = self.softmax(logits)
if self.return_emb:
return out, emb
else:
return out
class ResNet_MelSpec(nn.Module):
def __init__(self, sample_rate=16000, return_emb=False, num_class=2):
super(ResNet_MelSpec, self).__init__()
self.threshold = 0.4
self.return_emb = return_emb
if sample_rate == 16000:
n_fft = 512
win_length = 400
hop_length = 160
f_max = 7600
elif sample_rate == 22050:
n_fft = 704
win_length = 552
hop_length = 220
f_max = 10500
elif sample_rate == 24000:
n_fft = 768
win_length = 600
hop_length = 240
f_max = 11000
else:
raise ValueError(f"Unsupported sample rate: {sample_rate}")
self.melspectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=20, f_max=f_max,
n_mels=80, window_fn=torch.hamming_window
)
self.model = resnet18(pretrained=False)
self.model = modify_for_grayscale(self.model)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Identity()
self.dropout = nn.Dropout(p=0.5, inplace=True)
self.embedding_layer = nn.Linear(num_ftrs, 256)
self.relu = nn.ReLU()
self.classifier = nn.Linear(256, num_class)
self.softmax = nn.Softmax(dim=1)
# self.model.fc = nn.Sequential(
# nn.Dropout(p=0.5, inplace=True),
# nn.Linear(num_ftrs, 256, bias=True),
# nn.ReLU(),
# nn.Linear(256, num_class)
# )
self.to_db = torchaudio.transforms.AmplitudeToDB()
self.normalize = torchvision.transforms.Normalize(mean=0.449,std=0.226)
def forward(self, x):
x = x.unsqueeze(1)
x = self.melspectrogram(x)
x = self.to_db(x)
x = self.normalize(x)
x = self.model(x)
x = self.dropout(x)
emb = self.embedding_layer(x)
x = self.relu(emb)
logits = self.classifier(x)
out = self.softmax(logits)
if self.return_emb:
return out, emb
else:
return out