import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import librosa import soundfile as sf import warnings # Suppress the PyTorch padding warning and other user warnings warnings.filterwarnings('ignore', category=UserWarning) # hyperparams SR = 16000 N_MELS = 96 FFT_HOP = 256 FFT_SIZE = 512 MTT_LABELS = [ 'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal', 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice', 'choral' ] MSD_LABELS = [ 'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance', '00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists', 'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s', 'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic', 'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country', 'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal', 'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy' ] # ------------------------- # Building blocks # ------------------------- class ConvReLUBN(nn.Module): def __init__(self, in_ch, out_ch, kernel_size, padding=0): super().__init__() self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding) self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01) def forward(self, x): return self.bn(F.relu(self.conv(x))) class TimbralBlock(nn.Module): def __init__(self, mel_bins, out_ch): super().__init__() self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0) def forward(self, x): x = F.pad(x, (0, 0, 3, 3)) x = self.conv_block(x) return torch.max(x, dim=3).values class TemporalBlock(nn.Module): def __init__(self, kernel_size, out_ch): super().__init__() self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same') def forward(self, x): x = self.conv_block(x) return torch.max(x, dim=3).values class MidEnd(nn.Module): def __init__(self, in_ch, num_filt): super().__init__() self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0) self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01) self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0) self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01) self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0) self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01) def forward(self, x): x = x.transpose(1, 2).unsqueeze(3) x_perm = x.permute(0, 2, 3, 1) x1_pad = F.pad(x_perm, (3, 3, 0, 0)) x1 = x1_pad.permute(0, 2, 3, 1) x1 = self.c1_bn(F.relu(self.c1_conv(x1))) x1_t = x1.permute(0, 2, 1, 3) x2_perm = x1_t.permute(0, 2, 3, 1) x2_pad = F.pad(x2_perm, (3, 3, 0, 0)) x2 = x2_pad.permute(0, 2, 3, 1) x2 = self.c2_bn(F.relu(self.c2_conv(x2))) x2_t = x2.permute(0, 2, 1, 3) res_conv2 = x2_t + x1_t x3_perm = res_conv2.permute(0, 2, 3, 1) x3_pad = F.pad(x3_perm, (3, 3, 0, 0)) x3 = x3_pad.permute(0, 2, 3, 1) x3 = self.c3_bn(F.relu(self.c3_conv(x3))) x3_t = x3.permute(0, 2, 1, 3) res_conv3 = x3_t + res_conv2 return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)] class Backend(nn.Module): def __init__(self, in_ch, num_classes, hidden): super().__init__() self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01) self.fc1 = nn.Linear(in_ch * 2, hidden) self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01) self.fc2 = nn.Linear(hidden, num_classes) def forward(self, x): max_pool = torch.max(x, dim=1).values mean_pool = torch.mean(x, dim=1) z = torch.stack([max_pool, mean_pool], dim=2) z = z.view(z.size(0), -1) z = self.bn_in(z) z = F.dropout(z, p=0.5, training=self.training) z = self.bn_fc1(F.relu(self.fc1(z))) z = F.dropout(z, p=0.5, training=self.training) logits = self.fc2(z) return logits, mean_pool, max_pool # ------------------------- # MusicNN # ------------------------- class MusicNN(nn.Module): def __init__(self, num_classes, mid_filt=64, backend_units=200): super().__init__() self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01) self.timbral_1 = TimbralBlock(int(0.4 * N_MELS), int(1.6 * 128)) self.timbral_2 = TimbralBlock(int(0.7 * N_MELS), int(1.6 * 128)) self.temp_1 = TemporalBlock(128, int(1.6 * 32)) self.temp_2 = TemporalBlock(64, int(1.6 * 32)) self.temp_3 = TemporalBlock(32, int(1.6 * 32)) self.midend = MidEnd(in_ch=561, num_filt=mid_filt) self.backend = Backend(in_ch=mid_filt * 3 + 561, num_classes=num_classes, hidden=backend_units) def forward(self, x): x = x.unsqueeze(1) x = self.bn_input(x) f74 = self.timbral_1(x).transpose(1, 2) f77 = self.timbral_2(x).transpose(1, 2) s1 = self.temp_1(x).transpose(1, 2) s2 = self.temp_2(x).transpose(1, 2) s3 = self.temp_3(x).transpose(1, 2) frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2) mid_feats = self.midend(frontend_features.transpose(1, 2)) z = torch.cat(mid_feats, dim=2) logits, mean_pool, max_pool = self.backend(z) return logits, mean_pool, max_pool # inference utils def batch_data(audio_file, n_frames, overlap): # Use soundfile as it handles MP3 more reliably in some local environments audio, sr = sf.read(audio_file) # Convert to mono if stereo if len(audio.shape) > 1: audio = np.mean(audio, axis=1) # Resample to 16000 if necessary if sr != SR: audio = librosa.resample(audio, orig_sr=sr, target_sr=SR) if len(audio) == 0: raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.') audio_rep = librosa.feature.melspectrogram( y=audio, sr=SR, hop_length=FFT_HOP, n_fft=FFT_SIZE, n_mels=N_MELS ).T audio_rep = audio_rep.astype(np.float32) audio_rep = np.log10(10000 * audio_rep + 1) last_frame = audio_rep.shape[0] - n_frames + 1 batches = [] if last_frame <= 0: patch = np.zeros((n_frames, N_MELS), dtype=np.float32) patch[:audio_rep.shape[0], :] = audio_rep batches.append(patch) else: for time_stamp in range(0, last_frame, overlap): patch = audio_rep[time_stamp : time_stamp + n_frames, :] batches.append(patch) return np.stack(batches), audio_rep def extractor(file_name, model='MTT_musicnn', input_length=3, input_overlap=False, device=None): # Auto-detect device if not specified if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' if 'MTT' in model: labels = MTT_LABELS config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200} elif 'MSD' in model: labels = MSD_LABELS if 'big' in model: config = {'num_classes': 50, 'mid_filt': 512, 'backend_units': 500} else: config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200} else: raise ValueError('Model not supported') # Load model net = MusicNN(**config) weight_path = f'{model}.pt' if not os.path.exists(weight_path): weight_path = os.path.join('weights', f'{model}.pt') if os.path.exists(weight_path): net.load_state_dict(torch.load(weight_path, map_location=device)) else: print(f'Warning: Weights not found at {weight_path}') net.to(device) net.eval() # Prep data n_frames = librosa.time_to_frames(input_length, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP) + 1 if not input_overlap: overlap = n_frames else: overlap = librosa.time_to_frames(input_overlap, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP) batch, _ = batch_data(file_name, n_frames, overlap) batch_torch = torch.from_numpy(batch).to(device) with torch.no_grad(): logits, _, _ = net(batch_torch) probs = torch.sigmoid(logits).cpu().numpy() return probs, labels def top_tags(file_name, model='MTT_musicnn', topN=3, device=None): # Auto-detect device if not specified if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' probs, labels = extractor(file_name, model=model, device=device) avg_probs = np.mean(probs, axis=0) top_indices = avg_probs.argsort()[-topN:][::-1] return [labels[i] for i in top_indices]