import torch import torch.nn as nn import torchaudio import torchaudio.transforms as T import torchvision from torchvision.models import resnet18 def modify_for_grayscale(model): first_conv_layer = model.conv1 new_first_conv_layer = nn.Conv2d( in_channels=1, out_channels=first_conv_layer.out_channels, kernel_size=first_conv_layer.kernel_size, stride=first_conv_layer.stride, padding=first_conv_layer.padding, bias=first_conv_layer.bias is not None ) # Copy the weights from the original convolutional layer to the new one with torch.no_grad(): new_first_conv_layer.weight[:, 0] = first_conv_layer.weight.mean(dim=1) if first_conv_layer.bias is not None: new_first_conv_layer.bias = first_conv_layer.bias # Replace the first convolutional layer in the model model.conv1 = new_first_conv_layer # resnet18 return model class ResNet_LogSpec(nn.Module): def __init__(self, sample_rate=16000, return_emb=False, num_class=2): super(ResNet_LogSpec, self).__init__() self.threshold = 0.1 self.return_emb = return_emb if sample_rate == 16000: n_fft = 512 win_length = 400 hop_length = 160 elif sample_rate == 22050: n_fft = 704 win_length = 552 hop_length = 220 elif sample_rate == 24000: n_fft = 768 win_length = 600 hop_length = 240 else: raise ValueError(f"Unsupported sample rate: {sample_rate}") self.sample_rate = sample_rate self.stft = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=2, window_fn=torch.hamming_window) self.model = resnet18(pretrained=False) self.model = modify_for_grayscale(self.model) num_ftrs = self.model.fc.in_features self.model.fc = nn.Identity() self.dropout = nn.Dropout(p=0.5, inplace=True) self.embedding_layer = nn.Linear(num_ftrs, 256) self.relu = nn.ReLU() self.classifier = nn.Linear(256, num_class) self.softmax = nn.Softmax(dim=1) # self.model.fc = nn.Sequential( # nn.Dropout(p=0.5, inplace=True), # nn.Linear(num_ftrs, 256, bias=True), # nn.ReLU(), # nn.Linear(256, num_class) # ) self.to_db = T.AmplitudeToDB() self.normalize = torchvision.transforms.Normalize(mean=0.449, std=0.226) def forward(self, x): x = x.unsqueeze(1) x = self.stft(x) x = self.to_db(x) x = self.normalize(x) x = self.model(x) x = self.dropout(x) emb = self.embedding_layer(x) x = self.relu(emb) logits = self.classifier(x) out = self.softmax(logits) if self.return_emb: return out, emb else: return out class ResNet_MelSpec(nn.Module): def __init__(self, sample_rate=16000, return_emb=False, num_class=2): super(ResNet_MelSpec, self).__init__() self.threshold = 0.4 self.return_emb = return_emb if sample_rate == 16000: n_fft = 512 win_length = 400 hop_length = 160 f_max = 7600 elif sample_rate == 22050: n_fft = 704 win_length = 552 hop_length = 220 f_max = 10500 elif sample_rate == 24000: n_fft = 768 win_length = 600 hop_length = 240 f_max = 11000 else: raise ValueError(f"Unsupported sample rate: {sample_rate}") self.melspectrogram = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, f_min=20, f_max=f_max, n_mels=80, window_fn=torch.hamming_window ) self.model = resnet18(pretrained=False) self.model = modify_for_grayscale(self.model) num_ftrs = self.model.fc.in_features self.model.fc = nn.Identity() self.dropout = nn.Dropout(p=0.5, inplace=True) self.embedding_layer = nn.Linear(num_ftrs, 256) self.relu = nn.ReLU() self.classifier = nn.Linear(256, num_class) self.softmax = nn.Softmax(dim=1) # self.model.fc = nn.Sequential( # nn.Dropout(p=0.5, inplace=True), # nn.Linear(num_ftrs, 256, bias=True), # nn.ReLU(), # nn.Linear(256, num_class) # ) self.to_db = torchaudio.transforms.AmplitudeToDB() self.normalize = torchvision.transforms.Normalize(mean=0.449,std=0.226) def forward(self, x): x = x.unsqueeze(1) x = self.melspectrogram(x) x = self.to_db(x) x = self.normalize(x) x = self.model(x) x = self.dropout(x) emb = self.embedding_layer(x) x = self.relu(emb) logits = self.classifier(x) out = self.softmax(logits) if self.return_emb: return out, emb else: return out