File size: 15,084 Bytes
0bbc70a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 | import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import soundfile as sf
import librosa
from transformers import PretrainedConfig, PreTrainedModel
from huggingface_hub import PyTorchModelHubMixin
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
class MusicNNConfig(PretrainedConfig):
model_type = 'musicnn'
def __init__(
self,
num_classes=50,
mid_filt=64,
backend_units=200,
dataset='MTT',
**kwargs
):
self.num_classes = num_classes
self.mid_filt = mid_filt
self.backend_units = backend_units
self.dataset = dataset
super().__init__(**kwargs)
# -------------------------
# Building blocks
# -------------------------
class ConvReLUBN(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, padding=0):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
class TimbralBlock(nn.Module):
def __init__(self, mel_bins, out_ch):
super().__init__()
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
def forward(self, x):
x = F.pad(x, (0, 0, 3, 3))
x = self.conv_block(x)
return torch.max(x, dim=3).values
class TemporalBlock(nn.Module):
def __init__(self, kernel_size, out_ch):
super().__init__()
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
def forward(self, x):
x = self.conv_block(x)
return torch.max(x, dim=3).values
class MidEnd(nn.Module):
def __init__(self, in_ch, num_filt):
super().__init__()
self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
def forward(self, x):
x = x.transpose(1, 2).unsqueeze(3)
x_perm = x.permute(0, 2, 3, 1)
x1_pad = F.pad(x_perm, (3, 3, 0, 0))
x1 = x1_pad.permute(0, 2, 3, 1)
x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
x1_t = x1.permute(0, 2, 1, 3)
x2_perm = x1_t.permute(0, 2, 3, 1)
x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
x2 = x2_pad.permute(0, 2, 3, 1)
x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
x2_t = x2.permute(0, 2, 1, 3)
res_conv2 = x2_t + x1_t
x3_perm = res_conv2.permute(0, 2, 3, 1)
x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
x3 = x3_pad.permute(0, 2, 3, 1)
x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
x3_t = x3.permute(0, 2, 1, 3)
res_conv3 = x3_t + res_conv2
return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
class Backend(nn.Module):
def __init__(self, in_ch, num_classes, hidden):
super().__init__()
self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
self.fc1 = nn.Linear(in_ch * 2, hidden)
self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
self.fc2 = nn.Linear(hidden, num_classes)
def forward(self, x):
max_pool = torch.max(x, dim=1).values
mean_pool = torch.mean(x, dim=1)
z = torch.stack([max_pool, mean_pool], dim=2)
z = z.view(z.size(0), -1)
z = self.bn_in(z)
z = F.dropout(z, p=0.5, training=self.training)
z = self.bn_fc1(F.relu(self.fc1(z)))
z = F.dropout(z, p=0.5, training=self.training)
logits = self.fc2(z)
return logits, mean_pool, max_pool
class MusicNN(PreTrainedModel, PyTorchModelHubMixin):
config_class = MusicNNConfig
def __init__(self, config):
super().__init__(config)
self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
self.timbral_1 = TimbralBlock(int(0.4 * 96), int(1.6 * 128))
self.timbral_2 = TimbralBlock(int(0.7 * 96), int(1.6 * 128))
self.temp_1 = TemporalBlock(128, int(1.6 * 32))
self.temp_2 = TemporalBlock(64, int(1.6 * 32))
self.temp_3 = TemporalBlock(32, int(1.6 * 32))
self.midend = MidEnd(in_ch=561, num_filt=config.mid_filt)
self.backend = Backend(in_ch=config.mid_filt * 3 + 561, num_classes=config.num_classes, hidden=config.backend_units)
def forward(self, x):
x = x.unsqueeze(1)
x = self.bn_input(x)
f74 = self.timbral_1(x).transpose(1, 2)
f77 = self.timbral_2(x).transpose(1, 2)
s1 = self.temp_1(x).transpose(1, 2)
s2 = self.temp_2(x).transpose(1, 2)
s3 = self.temp_3(x).transpose(1, 2)
frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
mid_feats = self.midend(frontend_features.transpose(1, 2))
z = torch.cat(mid_feats, dim=2)
logits, mean_pool, max_pool = self.backend(z)
return logits, mean_pool, max_pool
@staticmethod
def preprocess_audio(audio_file, sr=16000):
# Try librosa first (works well for many formats)
try:
audio, file_sr = librosa.load(audio_file, sr=None)
if len(audio) == 0:
raise ValueError("Empty audio from librosa")
except Exception:
# Fallback to soundfile (better for some MP3s)
try:
audio, file_sr = sf.read(audio_file)
# Convert to mono if stereo
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
except Exception as e:
raise ValueError(f'Could not load audio file {audio_file}: {e}')
# Resample to target sample rate if necessary
if file_sr != sr:
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr)
if len(audio) == 0:
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
# Create mel spectrogram
audio_rep = librosa.feature.melspectrogram(
y=audio, sr=sr, hop_length=256, n_fft=512, n_mels=96
).T
audio_rep = audio_rep.astype(np.float32)
audio_rep = np.log10(10000 * audio_rep + 1)
return audio_rep
def predict_tags(self, audio_file, top_k=5):
# Auto-detect device and move model to it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(device)
# Use the same batching approach as the original implementation
# This matches musicnn_torch.py extractor function
# Load and preprocess audio (similar to batch_data in musicnn_torch.py)
audio, file_sr = sf.read(audio_file)
# Convert to mono if stereo
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Resample to 16000 if necessary
if file_sr != 16000:
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
if len(audio) == 0:
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
# Create mel spectrogram
audio_rep = librosa.feature.melspectrogram(
y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
).T
audio_rep = audio_rep.astype(np.float32)
audio_rep = np.log10(10000 * audio_rep + 1)
# Batch the data (same as original implementation)
n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
overlap = n_frames # No overlap for simplicity
last_frame = audio_rep.shape[0] - n_frames + 1
batches = []
if last_frame <= 0:
# Pad with zeros if audio is too short
patch = np.zeros((n_frames, 96), dtype=np.float32)
patch[:audio_rep.shape[0], :] = audio_rep
batches.append(patch)
else:
# Create overlapping windows
for time_stamp in range(0, last_frame, overlap):
patch = audio_rep[time_stamp : time_stamp + n_frames, :]
batches.append(patch)
# Convert to tensor and run inference
batch_tensor = torch.from_numpy(np.stack(batches)).to(device)
all_probs = []
with torch.no_grad():
self.eval()
for i in range(0, len(batches), 1): # Process in batches if needed
batch_subset = batch_tensor[i:i+1]
logits, _, _ = self(batch_subset)
probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
all_probs.append(probs)
# Average probabilities across all windows
avg_probs = np.mean(all_probs, axis=0)
# Get labels based on config
if self.config.dataset == 'MTT':
labels = [
'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
'choral'
]
elif self.config.dataset == 'MSD':
labels = [
'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
'00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
]
else:
raise ValueError(f"Unknown dataset: {self.config.dataset}")
# Get top k tags
top_indices = np.argsort(avg_probs)[-top_k:][::-1]
return [labels[i] for i in top_indices]
def extract_embeddings(self, audio_file, layer=None, pool='mean'):
"""
Extract embeddings from audio file.
Args:
audio_file: path to audio file
layer: which layer to extract from (ignored for simplicity, uses final embeddings)
pool: pooling method ('mean', 'max', or 'both')
Returns:
embeddings as numpy array
"""
# Auto-detect device and move model to it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(device)
# Load and preprocess audio
audio, file_sr = sf.read(audio_file)
# Convert to mono if stereo
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Resample to 16000 if necessary
if file_sr != 16000:
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
if len(audio) == 0:
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
# Create mel spectrogram
audio_rep = librosa.feature.melspectrogram(
y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
).T
audio_rep = audio_rep.astype(np.float32)
audio_rep = np.log10(10000 * audio_rep + 1)
# Batch the data
n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
overlap = n_frames # No overlap
last_frame = audio_rep.shape[0] - n_frames + 1
batches = []
if last_frame <= 0:
# Pad with zeros if audio is too short
patch = np.zeros((n_frames, 96), dtype=np.float32)
patch[:audio_rep.shape[0], :] = audio_rep
batches.append(patch)
else:
# Create windows
for time_stamp in range(0, last_frame, overlap):
patch = audio_rep[time_stamp : time_stamp + n_frames, :]
batches.append(patch)
# Convert to tensor and run inference
batch_tensor = torch.from_numpy(np.stack(batches)).to(device)
all_embeddings = []
with torch.no_grad():
self.eval()
for i in range(0, len(batches), 1):
batch_subset = batch_tensor[i:i+1]
logits, mean_pool, max_pool = self(batch_subset)
if pool == 'mean':
embeddings = mean_pool.squeeze(0).cpu().numpy()
elif pool == 'max':
embeddings = max_pool.squeeze(0).cpu().numpy()
elif pool == 'both':
embeddings = torch.cat([mean_pool, max_pool], dim=1).squeeze(0).cpu().numpy()
else:
embeddings = mean_pool.squeeze(0).cpu().numpy() # default to mean
all_embeddings.append(embeddings)
# Average embeddings across all windows
avg_embeddings = np.mean(all_embeddings, axis=0)
return avg_embeddings
# For uploading to Hugging Face Hub
if __name__ == '__main__':
import json
import os
from huggingface_hub import HfApi
import shutil
# Create the model with MTT config
config = MusicNNConfig(
num_classes=50,
mid_filt=64,
backend_units=200,
dataset='MTT'
)
model = MusicNN(config)
# Load the weights
state_dict = torch.load('weights/MTT_musicnn.pt')
model.load_state_dict(state_dict)
# Save and push to Hugging Face
save_dir = 'musicnn-pytorch'
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
shutil.copy('musicnn.py', save_dir)
# Create config.json
config_dict = config.to_dict()
config_dict.update({
'_name_or_path': 'oriyonay/musicnn-pytorch',
'architectures': ['MusicNN'],
'auto_map': {
'AutoConfig': 'musicnn.MusicNNConfig',
'AutoModel': 'musicnn.MusicNN'
},
'model_type': 'musicnn'
})
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
json.dump(config_dict, f, indent=4)
# Push to Hugging Face
api = HfApi()
api.upload_folder(
folder_path=save_dir,
repo_id='oriyonay/musicnn-pytorch',
repo_type='model'
)
print("✅ Model uploaded to Hugging Face!")
print("Usage: model = MusicNN.from_pretrained('oriyonay/musicnn-pytorch')") |