| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import soundfile as sf |
| import librosa |
| from transformers import PretrainedConfig, PreTrainedModel |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
|
|
| class MusicNNConfig(PretrainedConfig): |
| model_type = 'musicnn' |
|
|
| def __init__( |
| self, |
| num_classes=50, |
| mid_filt=64, |
| backend_units=200, |
| dataset='MTT', |
| **kwargs |
| ): |
| self.num_classes = num_classes |
| self.mid_filt = mid_filt |
| self.backend_units = backend_units |
| self.dataset = dataset |
| super().__init__(**kwargs) |
|
|
|
|
| |
| |
| |
| class ConvReLUBN(nn.Module): |
| def __init__(self, in_ch, out_ch, kernel_size, padding=0): |
| super().__init__() |
| self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding) |
| self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01) |
|
|
| def forward(self, x): |
| return self.bn(F.relu(self.conv(x))) |
|
|
|
|
| class TimbralBlock(nn.Module): |
| def __init__(self, mel_bins, out_ch): |
| super().__init__() |
| self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0) |
|
|
| def forward(self, x): |
| x = F.pad(x, (0, 0, 3, 3)) |
| x = self.conv_block(x) |
| return torch.max(x, dim=3).values |
|
|
|
|
| class TemporalBlock(nn.Module): |
| def __init__(self, kernel_size, out_ch): |
| super().__init__() |
| self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same') |
|
|
| def forward(self, x): |
| x = self.conv_block(x) |
| return torch.max(x, dim=3).values |
|
|
|
|
| class MidEnd(nn.Module): |
| def __init__(self, in_ch, num_filt): |
| super().__init__() |
| self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0) |
| self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01) |
| self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0) |
| self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01) |
| self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0) |
| self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01) |
|
|
| def forward(self, x): |
| x = x.transpose(1, 2).unsqueeze(3) |
|
|
| x_perm = x.permute(0, 2, 3, 1) |
| x1_pad = F.pad(x_perm, (3, 3, 0, 0)) |
| x1 = x1_pad.permute(0, 2, 3, 1) |
| x1 = self.c1_bn(F.relu(self.c1_conv(x1))) |
| x1_t = x1.permute(0, 2, 1, 3) |
|
|
| x2_perm = x1_t.permute(0, 2, 3, 1) |
| x2_pad = F.pad(x2_perm, (3, 3, 0, 0)) |
| x2 = x2_pad.permute(0, 2, 3, 1) |
| x2 = self.c2_bn(F.relu(self.c2_conv(x2))) |
| x2_t = x2.permute(0, 2, 1, 3) |
| res_conv2 = x2_t + x1_t |
|
|
| x3_perm = res_conv2.permute(0, 2, 3, 1) |
| x3_pad = F.pad(x3_perm, (3, 3, 0, 0)) |
| x3 = x3_pad.permute(0, 2, 3, 1) |
| x3 = self.c3_bn(F.relu(self.c3_conv(x3))) |
| x3_t = x3.permute(0, 2, 1, 3) |
| res_conv3 = x3_t + res_conv2 |
|
|
| return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)] |
|
|
|
|
| class Backend(nn.Module): |
| def __init__(self, in_ch, num_classes, hidden): |
| super().__init__() |
| self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01) |
| self.fc1 = nn.Linear(in_ch * 2, hidden) |
| self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01) |
| self.fc2 = nn.Linear(hidden, num_classes) |
|
|
| def forward(self, x): |
| max_pool = torch.max(x, dim=1).values |
| mean_pool = torch.mean(x, dim=1) |
| z = torch.stack([max_pool, mean_pool], dim=2) |
| z = z.view(z.size(0), -1) |
|
|
| z = self.bn_in(z) |
| z = F.dropout(z, p=0.5, training=self.training) |
| z = self.bn_fc1(F.relu(self.fc1(z))) |
| z = F.dropout(z, p=0.5, training=self.training) |
|
|
| logits = self.fc2(z) |
| return logits, mean_pool, max_pool |
|
|
|
|
| class MusicNN(PreTrainedModel, PyTorchModelHubMixin): |
| config_class = MusicNNConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01) |
| self.timbral_1 = TimbralBlock(int(0.4 * 96), int(1.6 * 128)) |
| self.timbral_2 = TimbralBlock(int(0.7 * 96), int(1.6 * 128)) |
| self.temp_1 = TemporalBlock(128, int(1.6 * 32)) |
| self.temp_2 = TemporalBlock(64, int(1.6 * 32)) |
| self.temp_3 = TemporalBlock(32, int(1.6 * 32)) |
| self.midend = MidEnd(in_ch=561, num_filt=config.mid_filt) |
| self.backend = Backend(in_ch=config.mid_filt * 3 + 561, num_classes=config.num_classes, hidden=config.backend_units) |
|
|
| def forward(self, x): |
| x = x.unsqueeze(1) |
| x = self.bn_input(x) |
| f74 = self.timbral_1(x).transpose(1, 2) |
| f77 = self.timbral_2(x).transpose(1, 2) |
| s1 = self.temp_1(x).transpose(1, 2) |
| s2 = self.temp_2(x).transpose(1, 2) |
| s3 = self.temp_3(x).transpose(1, 2) |
| frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2) |
| mid_feats = self.midend(frontend_features.transpose(1, 2)) |
| z = torch.cat(mid_feats, dim=2) |
| logits, mean_pool, max_pool = self.backend(z) |
| return logits, mean_pool, max_pool |
|
|
| @staticmethod |
| def preprocess_audio(audio_file, sr=16000): |
| |
| try: |
| audio, file_sr = librosa.load(audio_file, sr=None) |
| if len(audio) == 0: |
| raise ValueError("Empty audio from librosa") |
| except Exception: |
| |
| try: |
| audio, file_sr = sf.read(audio_file) |
| |
| if len(audio.shape) > 1: |
| audio = np.mean(audio, axis=1) |
| except Exception as e: |
| raise ValueError(f'Could not load audio file {audio_file}: {e}') |
|
|
| |
| if file_sr != sr: |
| audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr) |
|
|
| if len(audio) == 0: |
| raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.') |
|
|
| |
| audio_rep = librosa.feature.melspectrogram( |
| y=audio, sr=sr, hop_length=256, n_fft=512, n_mels=96 |
| ).T |
| audio_rep = audio_rep.astype(np.float32) |
| audio_rep = np.log10(10000 * audio_rep + 1) |
|
|
| return audio_rep |
|
|
| def predict_tags(self, audio_file, top_k=5): |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.to(device) |
|
|
| |
| |
|
|
| |
| audio, file_sr = sf.read(audio_file) |
|
|
| |
| if len(audio.shape) > 1: |
| audio = np.mean(audio, axis=1) |
|
|
| |
| if file_sr != 16000: |
| audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000) |
|
|
| if len(audio) == 0: |
| raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.') |
|
|
| |
| audio_rep = librosa.feature.melspectrogram( |
| y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96 |
| ).T |
| audio_rep = audio_rep.astype(np.float32) |
| audio_rep = np.log10(10000 * audio_rep + 1) |
|
|
| |
| n_frames = 187 |
| overlap = n_frames |
|
|
| last_frame = audio_rep.shape[0] - n_frames + 1 |
| batches = [] |
| if last_frame <= 0: |
| |
| patch = np.zeros((n_frames, 96), dtype=np.float32) |
| patch[:audio_rep.shape[0], :] = audio_rep |
| batches.append(patch) |
| else: |
| |
| for time_stamp in range(0, last_frame, overlap): |
| patch = audio_rep[time_stamp : time_stamp + n_frames, :] |
| batches.append(patch) |
|
|
| |
| batch_tensor = torch.from_numpy(np.stack(batches)).to(device) |
|
|
| all_probs = [] |
| with torch.no_grad(): |
| self.eval() |
| for i in range(0, len(batches), 1): |
| batch_subset = batch_tensor[i:i+1] |
| logits, _, _ = self(batch_subset) |
| probs = torch.sigmoid(logits).squeeze(0).cpu().numpy() |
| all_probs.append(probs) |
|
|
| |
| avg_probs = np.mean(all_probs, axis=0) |
|
|
| |
| if self.config.dataset == 'MTT': |
| labels = [ |
| 'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', |
| 'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', |
| 'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', |
| 'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man', |
| 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal', |
| 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice', |
| 'choral' |
| ] |
| elif self.config.dataset == 'MSD': |
| labels = [ |
| 'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance', |
| '00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists', |
| 'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s', |
| 'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic', |
| 'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country', |
| 'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal', |
| 'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy' |
| ] |
| else: |
| raise ValueError(f"Unknown dataset: {self.config.dataset}") |
|
|
| |
| top_indices = np.argsort(avg_probs)[-top_k:][::-1] |
| return [labels[i] for i in top_indices] |
|
|
| def extract_embeddings(self, audio_file, layer=None, pool='mean'): |
| """ |
| Extract embeddings from audio file. |
| Args: |
| audio_file: path to audio file |
| layer: which layer to extract from (ignored for simplicity, uses final embeddings) |
| pool: pooling method ('mean', 'max', or 'both') |
| Returns: |
| embeddings as numpy array |
| """ |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.to(device) |
|
|
| |
| audio, file_sr = sf.read(audio_file) |
|
|
| |
| if len(audio.shape) > 1: |
| audio = np.mean(audio, axis=1) |
|
|
| |
| if file_sr != 16000: |
| audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000) |
|
|
| if len(audio) == 0: |
| raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.') |
|
|
| |
| audio_rep = librosa.feature.melspectrogram( |
| y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96 |
| ).T |
| audio_rep = audio_rep.astype(np.float32) |
| audio_rep = np.log10(10000 * audio_rep + 1) |
|
|
| |
| n_frames = 187 |
| overlap = n_frames |
|
|
| last_frame = audio_rep.shape[0] - n_frames + 1 |
| batches = [] |
| if last_frame <= 0: |
| |
| patch = np.zeros((n_frames, 96), dtype=np.float32) |
| patch[:audio_rep.shape[0], :] = audio_rep |
| batches.append(patch) |
| else: |
| |
| for time_stamp in range(0, last_frame, overlap): |
| patch = audio_rep[time_stamp : time_stamp + n_frames, :] |
| batches.append(patch) |
|
|
| |
| batch_tensor = torch.from_numpy(np.stack(batches)).to(device) |
|
|
| all_embeddings = [] |
| with torch.no_grad(): |
| self.eval() |
| for i in range(0, len(batches), 1): |
| batch_subset = batch_tensor[i:i+1] |
| logits, mean_pool, max_pool = self(batch_subset) |
|
|
| if pool == 'mean': |
| embeddings = mean_pool.squeeze(0).cpu().numpy() |
| elif pool == 'max': |
| embeddings = max_pool.squeeze(0).cpu().numpy() |
| elif pool == 'both': |
| embeddings = torch.cat([mean_pool, max_pool], dim=1).squeeze(0).cpu().numpy() |
| else: |
| embeddings = mean_pool.squeeze(0).cpu().numpy() |
|
|
| all_embeddings.append(embeddings) |
|
|
| |
| avg_embeddings = np.mean(all_embeddings, axis=0) |
| return avg_embeddings |
|
|
|
|
| |
| if __name__ == '__main__': |
| import json |
| import os |
| from huggingface_hub import HfApi |
| import shutil |
|
|
| |
| config = MusicNNConfig( |
| num_classes=50, |
| mid_filt=64, |
| backend_units=200, |
| dataset='MTT' |
| ) |
|
|
| model = MusicNN(config) |
|
|
| |
| state_dict = torch.load('weights/MTT_musicnn.pt') |
| model.load_state_dict(state_dict) |
|
|
| |
| save_dir = 'musicnn-pytorch' |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| model.save_pretrained(save_dir) |
| shutil.copy('musicnn.py', save_dir) |
|
|
| |
| config_dict = config.to_dict() |
| config_dict.update({ |
| '_name_or_path': 'oriyonay/musicnn-pytorch', |
| 'architectures': ['MusicNN'], |
| 'auto_map': { |
| 'AutoConfig': 'musicnn.MusicNNConfig', |
| 'AutoModel': 'musicnn.MusicNN' |
| }, |
| 'model_type': 'musicnn' |
| }) |
|
|
| with open(os.path.join(save_dir, 'config.json'), 'w') as f: |
| json.dump(config_dict, f, indent=4) |
|
|
| |
| api = HfApi() |
| api.upload_folder( |
| folder_path=save_dir, |
| repo_id='oriyonay/musicnn-pytorch', |
| repo_type='model' |
| ) |
|
|
| print("✅ Model uploaded to Hugging Face!") |
| print("Usage: model = MusicNN.from_pretrained('oriyonay/musicnn-pytorch')") |