musicnn-pytorch / musicnn_torch.py
debuglevel's picture
Duplicate from oriyonay/musicnn-pytorch
0bbc70a
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import soundfile as sf
import warnings
# Suppress the PyTorch padding warning and other user warnings
warnings.filterwarnings('ignore', category=UserWarning)
# hyperparams
SR = 16000
N_MELS = 96
FFT_HOP = 256
FFT_SIZE = 512
MTT_LABELS = [
'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
'choral'
]
MSD_LABELS = [
'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
'00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
]
# -------------------------
# Building blocks
# -------------------------
class ConvReLUBN(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, padding=0):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
class TimbralBlock(nn.Module):
def __init__(self, mel_bins, out_ch):
super().__init__()
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
def forward(self, x):
x = F.pad(x, (0, 0, 3, 3))
x = self.conv_block(x)
return torch.max(x, dim=3).values
class TemporalBlock(nn.Module):
def __init__(self, kernel_size, out_ch):
super().__init__()
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
def forward(self, x):
x = self.conv_block(x)
return torch.max(x, dim=3).values
class MidEnd(nn.Module):
def __init__(self, in_ch, num_filt):
super().__init__()
self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
def forward(self, x):
x = x.transpose(1, 2).unsqueeze(3)
x_perm = x.permute(0, 2, 3, 1)
x1_pad = F.pad(x_perm, (3, 3, 0, 0))
x1 = x1_pad.permute(0, 2, 3, 1)
x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
x1_t = x1.permute(0, 2, 1, 3)
x2_perm = x1_t.permute(0, 2, 3, 1)
x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
x2 = x2_pad.permute(0, 2, 3, 1)
x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
x2_t = x2.permute(0, 2, 1, 3)
res_conv2 = x2_t + x1_t
x3_perm = res_conv2.permute(0, 2, 3, 1)
x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
x3 = x3_pad.permute(0, 2, 3, 1)
x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
x3_t = x3.permute(0, 2, 1, 3)
res_conv3 = x3_t + res_conv2
return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
class Backend(nn.Module):
def __init__(self, in_ch, num_classes, hidden):
super().__init__()
self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
self.fc1 = nn.Linear(in_ch * 2, hidden)
self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
self.fc2 = nn.Linear(hidden, num_classes)
def forward(self, x):
max_pool = torch.max(x, dim=1).values
mean_pool = torch.mean(x, dim=1)
z = torch.stack([max_pool, mean_pool], dim=2)
z = z.view(z.size(0), -1)
z = self.bn_in(z)
z = F.dropout(z, p=0.5, training=self.training)
z = self.bn_fc1(F.relu(self.fc1(z)))
z = F.dropout(z, p=0.5, training=self.training)
logits = self.fc2(z)
return logits, mean_pool, max_pool
# -------------------------
# MusicNN
# -------------------------
class MusicNN(nn.Module):
def __init__(self, num_classes, mid_filt=64, backend_units=200):
super().__init__()
self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
self.timbral_1 = TimbralBlock(int(0.4 * N_MELS), int(1.6 * 128))
self.timbral_2 = TimbralBlock(int(0.7 * N_MELS), int(1.6 * 128))
self.temp_1 = TemporalBlock(128, int(1.6 * 32))
self.temp_2 = TemporalBlock(64, int(1.6 * 32))
self.temp_3 = TemporalBlock(32, int(1.6 * 32))
self.midend = MidEnd(in_ch=561, num_filt=mid_filt)
self.backend = Backend(in_ch=mid_filt * 3 + 561, num_classes=num_classes, hidden=backend_units)
def forward(self, x):
x = x.unsqueeze(1)
x = self.bn_input(x)
f74 = self.timbral_1(x).transpose(1, 2)
f77 = self.timbral_2(x).transpose(1, 2)
s1 = self.temp_1(x).transpose(1, 2)
s2 = self.temp_2(x).transpose(1, 2)
s3 = self.temp_3(x).transpose(1, 2)
frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
mid_feats = self.midend(frontend_features.transpose(1, 2))
z = torch.cat(mid_feats, dim=2)
logits, mean_pool, max_pool = self.backend(z)
return logits, mean_pool, max_pool
# inference utils
def batch_data(audio_file, n_frames, overlap):
# Use soundfile as it handles MP3 more reliably in some local environments
audio, sr = sf.read(audio_file)
# Convert to mono if stereo
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Resample to 16000 if necessary
if sr != SR:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SR)
if len(audio) == 0:
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
audio_rep = librosa.feature.melspectrogram(
y=audio, sr=SR, hop_length=FFT_HOP, n_fft=FFT_SIZE, n_mels=N_MELS
).T
audio_rep = audio_rep.astype(np.float32)
audio_rep = np.log10(10000 * audio_rep + 1)
last_frame = audio_rep.shape[0] - n_frames + 1
batches = []
if last_frame <= 0:
patch = np.zeros((n_frames, N_MELS), dtype=np.float32)
patch[:audio_rep.shape[0], :] = audio_rep
batches.append(patch)
else:
for time_stamp in range(0, last_frame, overlap):
patch = audio_rep[time_stamp : time_stamp + n_frames, :]
batches.append(patch)
return np.stack(batches), audio_rep
def extractor(file_name, model='MTT_musicnn', input_length=3, input_overlap=False, device=None):
# Auto-detect device if not specified
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'MTT' in model:
labels = MTT_LABELS
config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200}
elif 'MSD' in model:
labels = MSD_LABELS
if 'big' in model:
config = {'num_classes': 50, 'mid_filt': 512, 'backend_units': 500}
else:
config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200}
else:
raise ValueError('Model not supported')
# Load model
net = MusicNN(**config)
weight_path = f'{model}.pt'
if not os.path.exists(weight_path):
weight_path = os.path.join('weights', f'{model}.pt')
if os.path.exists(weight_path):
net.load_state_dict(torch.load(weight_path, map_location=device))
else:
print(f'Warning: Weights not found at {weight_path}')
net.to(device)
net.eval()
# Prep data
n_frames = librosa.time_to_frames(input_length, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP) + 1
if not input_overlap:
overlap = n_frames
else:
overlap = librosa.time_to_frames(input_overlap, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP)
batch, _ = batch_data(file_name, n_frames, overlap)
batch_torch = torch.from_numpy(batch).to(device)
with torch.no_grad():
logits, _, _ = net(batch_torch)
probs = torch.sigmoid(logits).cpu().numpy()
return probs, labels
def top_tags(file_name, model='MTT_musicnn', topN=3, device=None):
# Auto-detect device if not specified
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
probs, labels = extractor(file_name, model=model, device=device)
avg_probs = np.mean(probs, axis=0)
top_indices = avg_probs.argsort()[-topN:][::-1]
return [labels[i] for i in top_indices]