debuglevel oriyonay commited on
Commit
0bbc70a
·
0 Parent(s):

Duplicate from oriyonay/musicnn-pytorch

Browse files

Co-authored-by: ori yonay <oriyonay@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - audio
5
+ - music
6
+ - music-tagging
7
+ - pytorch
8
+ ---
9
+
10
+ # MusicNN-PyTorch
11
+
12
+ This is a PyTorch reimplementation of the [MusicNN](https://github.com/jordipons/musicnn) library for music audio tagging.
13
+
14
+ It contains the model architecture and converted weights from the original TensorFlow 1.x checkpoints.
15
+
16
+ ## Supported Models
17
+
18
+ - `MTT_musicnn`: Trained on MagnaTagATune (50 tags) - **Default model**
19
+ - `MSD_musicnn`: Trained on Million Song Dataset (50 tags)
20
+ - `MSD_musicnn_big`: Larger version trained on MSD (512 filters)
21
+
22
+ ## Super Simple Usage (Hugging Face Transformers)
23
+
24
+ ```python
25
+ from transformers import AutoModel
26
+
27
+ # Load the model (downloads automatically)
28
+ model = AutoModel.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
29
+
30
+ # Use the model
31
+ tags = model.predict_tags("your_audio.mp3", top_k=5)
32
+ print(f"Top 5 tags: {tags}")
33
+ ```
34
+
35
+ ## Embeddings (Optional)
36
+
37
+ ```python
38
+ from transformers import AutoModel
39
+
40
+ model = AutoModel.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
41
+
42
+ # Extract embeddings from any layer
43
+ emb = model.extract_embeddings("your_audio.mp3", layer="penultimate", pool="mean")
44
+ print(emb.shape)
45
+ ```
46
+
47
+ ## Colab Example
48
+
49
+ ```python
50
+ # Install dependencies
51
+ !pip install transformers torch librosa soundfile
52
+
53
+ # Load with AutoModel
54
+ from transformers import AutoModel
55
+ model = AutoModel.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
56
+
57
+ # Use the model
58
+ tags = model.predict_tags("your_audio.mp3", top_k=5)
59
+ print(tags)
60
+ ```
61
+
62
+ ## Traditional Usage
63
+
64
+ If you prefer to download the code manually:
65
+
66
+ ```python
67
+ from musicnn_torch import top_tags
68
+
69
+ # Get top 5 tags for an audio file
70
+ tags = top_tags('path/to/audio.mp3', model='MTT_musicnn', topN=5)
71
+ print(tags)
72
+ ```
73
+
74
+ ## Installation
75
+
76
+ ```bash
77
+ pip install transformers torch librosa soundfile
78
+ ```
79
+
80
+ ## Credits
81
+
82
+ Original implementation by [Jordi Pons](https://github.com/jordipons).
83
+ PyTorch port by Gemini.
config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_classes": 50,
3
+ "mid_filt": 64,
4
+ "backend_units": 200,
5
+ "dataset": "MTT",
6
+ "return_dict": true,
7
+ "output_hidden_states": false,
8
+ "output_attentions": false,
9
+ "torchscript": false,
10
+ "torch_dtype": "float32",
11
+ "use_bfloat16": false,
12
+ "tf_legacy_loss": false,
13
+ "pruned_heads": {},
14
+ "tie_word_embeddings": true,
15
+ "chunk_size_feed_forward": 0,
16
+ "is_encoder_decoder": false,
17
+ "is_decoder": false,
18
+ "cross_attention_hidden_size": null,
19
+ "add_cross_attention": false,
20
+ "tie_encoder_decoder": false,
21
+ "max_length": 20,
22
+ "min_length": 0,
23
+ "do_sample": false,
24
+ "early_stopping": false,
25
+ "num_beams": 1,
26
+ "num_beam_groups": 1,
27
+ "diversity_penalty": 0.0,
28
+ "temperature": 1.0,
29
+ "top_k": 50,
30
+ "top_p": 1.0,
31
+ "typical_p": 1.0,
32
+ "repetition_penalty": 1.0,
33
+ "length_penalty": 1.0,
34
+ "no_repeat_ngram_size": 0,
35
+ "encoder_no_repeat_ngram_size": 0,
36
+ "bad_words_ids": null,
37
+ "num_return_sequences": 1,
38
+ "output_scores": false,
39
+ "return_dict_in_generate": false,
40
+ "forced_bos_token_id": null,
41
+ "forced_eos_token_id": null,
42
+ "remove_invalid_values": false,
43
+ "exponential_decay_length_penalty": null,
44
+ "suppress_tokens": null,
45
+ "begin_suppress_tokens": null,
46
+ "architectures": [
47
+ "MusicNN"
48
+ ],
49
+ "finetuning_task": null,
50
+ "id2label": {
51
+ "0": "LABEL_0",
52
+ "1": "LABEL_1"
53
+ },
54
+ "label2id": {
55
+ "LABEL_0": 0,
56
+ "LABEL_1": 1
57
+ },
58
+ "tokenizer_class": null,
59
+ "prefix": null,
60
+ "bos_token_id": null,
61
+ "pad_token_id": null,
62
+ "eos_token_id": null,
63
+ "sep_token_id": null,
64
+ "decoder_start_token_id": null,
65
+ "task_specific_params": null,
66
+ "problem_type": null,
67
+ "_name_or_path": "oriyonay/musicnn-pytorch",
68
+ "_attn_implementation_autoset": false,
69
+ "transformers_version": "4.48.0",
70
+ "model_type": "musicnn",
71
+ "auto_map": {
72
+ "AutoConfig": "musicnn.MusicNNConfig",
73
+ "AutoModel": "musicnn.MusicNN"
74
+ }
75
+ }
configuration_musicnn.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MusicNNConfig(PretrainedConfig):
4
+ model_type = 'musicnn'
5
+
6
+ def __init__(
7
+ self,
8
+ num_classes=50,
9
+ mid_filt=64,
10
+ backend_units=200,
11
+ dataset='MTT',
12
+ **kwargs
13
+ ):
14
+ self.num_classes = num_classes
15
+ self.mid_filt = mid_filt
16
+ self.backend_units = backend_units
17
+ self.dataset = dataset
18
+ super().__init__(**kwargs)
inference.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from musicnn_torch import top_tags
2
+ import os
3
+
4
+ # Use the absolute paths you provided
5
+ files = [
6
+ '/Users/oriyonay/Desktop/CRAZY BEAT.mp3',
7
+ '/Users/oriyonay/Desktop/burn the stage/bounces/02 the type of girl.mp3',
8
+ '/Users/oriyonay/Desktop/burn the stage/extras/jazzy red roses.mp3'
9
+ ]
10
+
11
+ for f in files:
12
+ if os.path.exists(f):
13
+ print(f"\n--- Predicting top tags for {os.path.basename(f)} ---")
14
+ try:
15
+ tags = top_tags(f, model='MTT_musicnn', topN=5)
16
+ print(f"Top 5 tags: {tags}")
17
+ except Exception as e:
18
+ print(f"Error processing {f}: {e}")
19
+ else:
20
+ print(f"\nWarning: File not found at {f}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc0b9400fcaed6e9ce7fbcfa97ec91e4fcb5f2ab34ca3a0cd6bef4af74753e1a
3
+ size 3175212
modeling_musicnn.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import librosa
7
+ from transformers import PreTrainedModel, PretrainedConfig
8
+
9
+ class MusicNNConfig(PretrainedConfig):
10
+ model_type = 'musicnn'
11
+
12
+ def __init__(
13
+ self,
14
+ num_classes=50,
15
+ mid_filt=64,
16
+ backend_units=200,
17
+ dataset='MTT',
18
+ **kwargs
19
+ ):
20
+ self.num_classes = num_classes
21
+ self.mid_filt = mid_filt
22
+ self.backend_units = backend_units
23
+ self.dataset = dataset
24
+ super().__init__(**kwargs)
25
+
26
+ # -------------------------
27
+ # Building blocks
28
+ # -------------------------
29
+ class ConvReLUBN(nn.Module):
30
+ def __init__(self, in_ch, out_ch, kernel_size, padding=0):
31
+ super().__init__()
32
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
33
+ self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
34
+
35
+ def forward(self, x):
36
+ return self.bn(F.relu(self.conv(x)))
37
+
38
+
39
+ class TimbralBlock(nn.Module):
40
+ def __init__(self, mel_bins, out_ch):
41
+ super().__init__()
42
+ self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
43
+
44
+ def forward(self, x):
45
+ x = F.pad(x, (0, 0, 3, 3))
46
+ x = self.conv_block(x)
47
+ return torch.max(x, dim=3).values
48
+
49
+
50
+ class TemporalBlock(nn.Module):
51
+ def __init__(self, kernel_size, out_ch):
52
+ super().__init__()
53
+ self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
54
+
55
+ def forward(self, x):
56
+ x = self.conv_block(x)
57
+ return torch.max(x, dim=3).values
58
+
59
+
60
+ class MidEnd(nn.Module):
61
+ def __init__(self, in_ch, num_filt):
62
+ super().__init__()
63
+ self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
64
+ self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
65
+ self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
66
+ self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
67
+ self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
68
+ self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
69
+
70
+ def forward(self, x):
71
+ x = x.transpose(1, 2).unsqueeze(3)
72
+
73
+ x_perm = x.permute(0, 2, 3, 1)
74
+ x1_pad = F.pad(x_perm, (3, 3, 0, 0))
75
+ x1 = x1_pad.permute(0, 2, 3, 1)
76
+ x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
77
+ x1_t = x1.permute(0, 2, 1, 3)
78
+
79
+ x2_perm = x1_t.permute(0, 2, 3, 1)
80
+ x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
81
+ x2 = x2_pad.permute(0, 2, 3, 1)
82
+ x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
83
+ x2_t = x2.permute(0, 2, 1, 3)
84
+ res_conv2 = x2_t + x1_t
85
+
86
+ x3_perm = res_conv2.permute(0, 2, 3, 1)
87
+ x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
88
+ x3 = x3_pad.permute(0, 2, 3, 1)
89
+ x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
90
+ x3_t = x3.permute(0, 2, 1, 3)
91
+ res_conv3 = x3_t + res_conv2
92
+
93
+ return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
94
+
95
+
96
+ class Backend(nn.Module):
97
+ def __init__(self, in_ch, num_classes, hidden):
98
+ super().__init__()
99
+ self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
100
+ self.fc1 = nn.Linear(in_ch * 2, hidden)
101
+ self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
102
+ self.fc2 = nn.Linear(hidden, num_classes)
103
+
104
+ def forward(self, x):
105
+ max_pool = torch.max(x, dim=1).values
106
+ mean_pool = torch.mean(x, dim=1)
107
+ z = torch.stack([max_pool, mean_pool], dim=2)
108
+ z = z.view(z.size(0), -1)
109
+
110
+ z = self.bn_in(z)
111
+ z = F.dropout(z, p=0.5, training=self.training)
112
+ z = self.bn_fc1(F.relu(self.fc1(z)))
113
+ z = F.dropout(z, p=0.5, training=self.training)
114
+
115
+ logits = self.fc2(z)
116
+ return logits, mean_pool, max_pool
117
+
118
+
119
+ class MusicNNModel(PreTrainedModel):
120
+ config_class = MusicNNConfig
121
+
122
+ def __init__(self, config):
123
+ super().__init__(config)
124
+ self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
125
+ self.timbral_1 = TimbralBlock(int(0.4 * 96), int(1.6 * 128))
126
+ self.timbral_2 = TimbralBlock(int(0.7 * 96), int(1.6 * 128))
127
+ self.temp_1 = TemporalBlock(128, int(1.6 * 32))
128
+ self.temp_2 = TemporalBlock(64, int(1.6 * 32))
129
+ self.temp_3 = TemporalBlock(32, int(1.6 * 32))
130
+ self.midend = MidEnd(in_ch=561, num_filt=config.mid_filt)
131
+ self.backend = Backend(in_ch=config.mid_filt * 3 + 561, num_classes=config.num_classes, hidden=config.backend_units)
132
+
133
+ def forward(self, x):
134
+ # x is [B, T, M]
135
+ x = x.unsqueeze(1)
136
+ x = self.bn_input(x)
137
+ f74 = self.timbral_1(x).transpose(1, 2)
138
+ f77 = self.timbral_2(x).transpose(1, 2)
139
+ s1 = self.temp_1(x).transpose(1, 2)
140
+ s2 = self.temp_2(x).transpose(1, 2)
141
+ s3 = self.temp_3(x).transpose(1, 2)
142
+ frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
143
+ mid_feats = self.midend(frontend_features.transpose(1, 2))
144
+ z = torch.cat(mid_feats, dim=2)
145
+ logits, mean_pool, max_pool = self.backend(z)
146
+ return logits, mean_pool, max_pool
147
+
148
+ @staticmethod
149
+ def preprocess_audio(audio_file, sr=16000):
150
+ # Try librosa first (works well for many formats)
151
+ try:
152
+ audio, file_sr = librosa.load(audio_file, sr=None)
153
+ if len(audio) == 0:
154
+ raise ValueError("Empty audio from librosa")
155
+ except Exception:
156
+ # Fallback to soundfile (better for some MP3s)
157
+ try:
158
+ audio, file_sr = sf.read(audio_file)
159
+ # Convert to mono if stereo
160
+ if len(audio.shape) > 1:
161
+ audio = np.mean(audio, axis=1)
162
+ except Exception as e:
163
+ raise ValueError(f'Could not load audio file {audio_file}: {e}')
164
+
165
+ # Resample to target sample rate if necessary
166
+ if file_sr != sr:
167
+ audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr)
168
+
169
+ if len(audio) == 0:
170
+ raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
171
+
172
+ # Create mel spectrogram
173
+ audio_rep = librosa.feature.melspectrogram(
174
+ y=audio, sr=sr, hop_length=256, n_fft=512, n_mels=96
175
+ ).T
176
+ audio_rep = audio_rep.astype(np.float32)
177
+ audio_rep = np.log10(10000 * audio_rep + 1)
178
+
179
+ return audio_rep
180
+
181
+ def predict_tags(self, audio_file, top_k=5):
182
+ # Use the same batching approach as the original implementation
183
+ # This matches musicnn_torch.py extractor function
184
+
185
+ # Load and preprocess audio (similar to batch_data in musicnn_torch.py)
186
+ audio, file_sr = sf.read(audio_file)
187
+
188
+ # Convert to mono if stereo
189
+ if len(audio.shape) > 1:
190
+ audio = np.mean(audio, axis=1)
191
+
192
+ # Resample to 16000 if necessary
193
+ if file_sr != 16000:
194
+ audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
195
+
196
+ if len(audio) == 0:
197
+ raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
198
+
199
+ # Create mel spectrogram
200
+ audio_rep = librosa.feature.melspectrogram(
201
+ y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
202
+ ).T
203
+ audio_rep = audio_rep.astype(np.float32)
204
+ audio_rep = np.log10(10000 * audio_rep + 1)
205
+
206
+ # Batch the data (same as original implementation)
207
+ n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
208
+ overlap = n_frames # No overlap for simplicity
209
+
210
+ last_frame = audio_rep.shape[0] - n_frames + 1
211
+ batches = []
212
+ if last_frame <= 0:
213
+ # Pad with zeros if audio is too short
214
+ patch = np.zeros((n_frames, 96), dtype=np.float32)
215
+ patch[:audio_rep.shape[0], :] = audio_rep
216
+ batches.append(patch)
217
+ else:
218
+ # Create overlapping windows
219
+ for time_stamp in range(0, last_frame, overlap):
220
+ patch = audio_rep[time_stamp : time_stamp + n_frames, :]
221
+ batches.append(patch)
222
+
223
+ # Convert to tensor and run inference
224
+ batch_tensor = torch.from_numpy(np.stack(batches))
225
+
226
+ all_probs = []
227
+ with torch.no_grad():
228
+ self.eval()
229
+ for i in range(0, len(batches), 1): # Process in batches if needed
230
+ batch_subset = batch_tensor[i:i+1]
231
+ logits, _, _ = self(batch_subset)
232
+ probs = torch.sigmoid(logits).squeeze(0).numpy()
233
+ all_probs.append(probs)
234
+
235
+ # Average probabilities across all windows
236
+ avg_probs = np.mean(all_probs, axis=0)
237
+
238
+ # Get labels based on config
239
+ if self.config.dataset == 'MTT':
240
+ labels = [
241
+ 'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
242
+ 'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
243
+ 'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
244
+ 'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
245
+ 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
246
+ 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
247
+ 'choral'
248
+ ]
249
+ elif self.config.dataset == 'MSD':
250
+ labels = [
251
+ 'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
252
+ '00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
253
+ 'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
254
+ 'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
255
+ 'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
256
+ 'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
257
+ 'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
258
+ ]
259
+ else:
260
+ raise ValueError(f"Unknown dataset: {self.config.dataset}")
261
+
262
+ # Get top k tags
263
+ top_indices = np.argsort(avg_probs)[-top_k:][::-1]
264
+ return [labels[i] for i in top_indices]
265
+
266
+
267
+ def create_musicnn_model(model_type='MTT_musicnn'):
268
+ """
269
+ Factory function to create MusicNN models with different configurations.
270
+
271
+ Args:
272
+ model_type (str): One of 'MTT_musicnn', 'MSD_musicnn', or 'MSD_musicnn_big'
273
+
274
+ Returns:
275
+ MusicNNModel: Configured model instance
276
+ """
277
+ from transformers import AutoConfig
278
+
279
+ # Model configurations
280
+ configs = {
281
+ 'MTT_musicnn': {
282
+ 'num_classes': 50,
283
+ 'mid_filt': 64,
284
+ 'backend_units': 200,
285
+ 'dataset': 'MTT'
286
+ },
287
+ 'MSD_musicnn': {
288
+ 'num_classes': 50,
289
+ 'mid_filt': 64,
290
+ 'backend_units': 200,
291
+ 'dataset': 'MSD'
292
+ },
293
+ 'MSD_musicnn_big': {
294
+ 'num_classes': 50,
295
+ 'mid_filt': 512,
296
+ 'backend_units': 500,
297
+ 'dataset': 'MSD'
298
+ }
299
+ }
300
+
301
+ if model_type not in configs:
302
+ raise ValueError(f"Unknown model type: {model_type}. Choose from: {list(configs.keys())}")
303
+
304
+ # For now, we'll load the default model and modify its config
305
+ # In the future, we could have separate model files for each type
306
+ config = AutoConfig.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
307
+ config.num_classes = configs[model_type]['num_classes']
308
+ config.mid_filt = configs[model_type]['mid_filt']
309
+ config.backend_units = configs[model_type]['backend_units']
310
+ config.dataset = configs[model_type]['dataset']
311
+
312
+ model = MusicNNModel(config)
313
+ return model
musicnn.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import librosa
7
+ from transformers import PretrainedConfig, PreTrainedModel
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ # Suppress warnings
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+
15
+ class MusicNNConfig(PretrainedConfig):
16
+ model_type = 'musicnn'
17
+
18
+ def __init__(
19
+ self,
20
+ num_classes=50,
21
+ mid_filt=64,
22
+ backend_units=200,
23
+ dataset='MTT',
24
+ **kwargs
25
+ ):
26
+ self.num_classes = num_classes
27
+ self.mid_filt = mid_filt
28
+ self.backend_units = backend_units
29
+ self.dataset = dataset
30
+ super().__init__(**kwargs)
31
+
32
+
33
+ # -------------------------
34
+ # Building blocks
35
+ # -------------------------
36
+ class ConvReLUBN(nn.Module):
37
+ def __init__(self, in_ch, out_ch, kernel_size, padding=0):
38
+ super().__init__()
39
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
40
+ self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
41
+
42
+ def forward(self, x):
43
+ return self.bn(F.relu(self.conv(x)))
44
+
45
+
46
+ class TimbralBlock(nn.Module):
47
+ def __init__(self, mel_bins, out_ch):
48
+ super().__init__()
49
+ self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
50
+
51
+ def forward(self, x):
52
+ x = F.pad(x, (0, 0, 3, 3))
53
+ x = self.conv_block(x)
54
+ return torch.max(x, dim=3).values
55
+
56
+
57
+ class TemporalBlock(nn.Module):
58
+ def __init__(self, kernel_size, out_ch):
59
+ super().__init__()
60
+ self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
61
+
62
+ def forward(self, x):
63
+ x = self.conv_block(x)
64
+ return torch.max(x, dim=3).values
65
+
66
+
67
+ class MidEnd(nn.Module):
68
+ def __init__(self, in_ch, num_filt):
69
+ super().__init__()
70
+ self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
71
+ self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
72
+ self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
73
+ self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
74
+ self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
75
+ self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
76
+
77
+ def forward(self, x):
78
+ x = x.transpose(1, 2).unsqueeze(3)
79
+
80
+ x_perm = x.permute(0, 2, 3, 1)
81
+ x1_pad = F.pad(x_perm, (3, 3, 0, 0))
82
+ x1 = x1_pad.permute(0, 2, 3, 1)
83
+ x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
84
+ x1_t = x1.permute(0, 2, 1, 3)
85
+
86
+ x2_perm = x1_t.permute(0, 2, 3, 1)
87
+ x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
88
+ x2 = x2_pad.permute(0, 2, 3, 1)
89
+ x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
90
+ x2_t = x2.permute(0, 2, 1, 3)
91
+ res_conv2 = x2_t + x1_t
92
+
93
+ x3_perm = res_conv2.permute(0, 2, 3, 1)
94
+ x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
95
+ x3 = x3_pad.permute(0, 2, 3, 1)
96
+ x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
97
+ x3_t = x3.permute(0, 2, 1, 3)
98
+ res_conv3 = x3_t + res_conv2
99
+
100
+ return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
101
+
102
+
103
+ class Backend(nn.Module):
104
+ def __init__(self, in_ch, num_classes, hidden):
105
+ super().__init__()
106
+ self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
107
+ self.fc1 = nn.Linear(in_ch * 2, hidden)
108
+ self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
109
+ self.fc2 = nn.Linear(hidden, num_classes)
110
+
111
+ def forward(self, x):
112
+ max_pool = torch.max(x, dim=1).values
113
+ mean_pool = torch.mean(x, dim=1)
114
+ z = torch.stack([max_pool, mean_pool], dim=2)
115
+ z = z.view(z.size(0), -1)
116
+
117
+ z = self.bn_in(z)
118
+ z = F.dropout(z, p=0.5, training=self.training)
119
+ z = self.bn_fc1(F.relu(self.fc1(z)))
120
+ z = F.dropout(z, p=0.5, training=self.training)
121
+
122
+ logits = self.fc2(z)
123
+ return logits, mean_pool, max_pool
124
+
125
+
126
+ class MusicNN(PreTrainedModel, PyTorchModelHubMixin):
127
+ config_class = MusicNNConfig
128
+
129
+ def __init__(self, config):
130
+ super().__init__(config)
131
+ self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
132
+ self.timbral_1 = TimbralBlock(int(0.4 * 96), int(1.6 * 128))
133
+ self.timbral_2 = TimbralBlock(int(0.7 * 96), int(1.6 * 128))
134
+ self.temp_1 = TemporalBlock(128, int(1.6 * 32))
135
+ self.temp_2 = TemporalBlock(64, int(1.6 * 32))
136
+ self.temp_3 = TemporalBlock(32, int(1.6 * 32))
137
+ self.midend = MidEnd(in_ch=561, num_filt=config.mid_filt)
138
+ self.backend = Backend(in_ch=config.mid_filt * 3 + 561, num_classes=config.num_classes, hidden=config.backend_units)
139
+
140
+ def forward(self, x):
141
+ x = x.unsqueeze(1)
142
+ x = self.bn_input(x)
143
+ f74 = self.timbral_1(x).transpose(1, 2)
144
+ f77 = self.timbral_2(x).transpose(1, 2)
145
+ s1 = self.temp_1(x).transpose(1, 2)
146
+ s2 = self.temp_2(x).transpose(1, 2)
147
+ s3 = self.temp_3(x).transpose(1, 2)
148
+ frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
149
+ mid_feats = self.midend(frontend_features.transpose(1, 2))
150
+ z = torch.cat(mid_feats, dim=2)
151
+ logits, mean_pool, max_pool = self.backend(z)
152
+ return logits, mean_pool, max_pool
153
+
154
+ @staticmethod
155
+ def preprocess_audio(audio_file, sr=16000):
156
+ # Try librosa first (works well for many formats)
157
+ try:
158
+ audio, file_sr = librosa.load(audio_file, sr=None)
159
+ if len(audio) == 0:
160
+ raise ValueError("Empty audio from librosa")
161
+ except Exception:
162
+ # Fallback to soundfile (better for some MP3s)
163
+ try:
164
+ audio, file_sr = sf.read(audio_file)
165
+ # Convert to mono if stereo
166
+ if len(audio.shape) > 1:
167
+ audio = np.mean(audio, axis=1)
168
+ except Exception as e:
169
+ raise ValueError(f'Could not load audio file {audio_file}: {e}')
170
+
171
+ # Resample to target sample rate if necessary
172
+ if file_sr != sr:
173
+ audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr)
174
+
175
+ if len(audio) == 0:
176
+ raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
177
+
178
+ # Create mel spectrogram
179
+ audio_rep = librosa.feature.melspectrogram(
180
+ y=audio, sr=sr, hop_length=256, n_fft=512, n_mels=96
181
+ ).T
182
+ audio_rep = audio_rep.astype(np.float32)
183
+ audio_rep = np.log10(10000 * audio_rep + 1)
184
+
185
+ return audio_rep
186
+
187
+ def predict_tags(self, audio_file, top_k=5):
188
+ # Auto-detect device and move model to it
189
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
190
+ self.to(device)
191
+
192
+ # Use the same batching approach as the original implementation
193
+ # This matches musicnn_torch.py extractor function
194
+
195
+ # Load and preprocess audio (similar to batch_data in musicnn_torch.py)
196
+ audio, file_sr = sf.read(audio_file)
197
+
198
+ # Convert to mono if stereo
199
+ if len(audio.shape) > 1:
200
+ audio = np.mean(audio, axis=1)
201
+
202
+ # Resample to 16000 if necessary
203
+ if file_sr != 16000:
204
+ audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
205
+
206
+ if len(audio) == 0:
207
+ raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
208
+
209
+ # Create mel spectrogram
210
+ audio_rep = librosa.feature.melspectrogram(
211
+ y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
212
+ ).T
213
+ audio_rep = audio_rep.astype(np.float32)
214
+ audio_rep = np.log10(10000 * audio_rep + 1)
215
+
216
+ # Batch the data (same as original implementation)
217
+ n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
218
+ overlap = n_frames # No overlap for simplicity
219
+
220
+ last_frame = audio_rep.shape[0] - n_frames + 1
221
+ batches = []
222
+ if last_frame <= 0:
223
+ # Pad with zeros if audio is too short
224
+ patch = np.zeros((n_frames, 96), dtype=np.float32)
225
+ patch[:audio_rep.shape[0], :] = audio_rep
226
+ batches.append(patch)
227
+ else:
228
+ # Create overlapping windows
229
+ for time_stamp in range(0, last_frame, overlap):
230
+ patch = audio_rep[time_stamp : time_stamp + n_frames, :]
231
+ batches.append(patch)
232
+
233
+ # Convert to tensor and run inference
234
+ batch_tensor = torch.from_numpy(np.stack(batches)).to(device)
235
+
236
+ all_probs = []
237
+ with torch.no_grad():
238
+ self.eval()
239
+ for i in range(0, len(batches), 1): # Process in batches if needed
240
+ batch_subset = batch_tensor[i:i+1]
241
+ logits, _, _ = self(batch_subset)
242
+ probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
243
+ all_probs.append(probs)
244
+
245
+ # Average probabilities across all windows
246
+ avg_probs = np.mean(all_probs, axis=0)
247
+
248
+ # Get labels based on config
249
+ if self.config.dataset == 'MTT':
250
+ labels = [
251
+ 'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
252
+ 'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
253
+ 'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
254
+ 'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
255
+ 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
256
+ 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
257
+ 'choral'
258
+ ]
259
+ elif self.config.dataset == 'MSD':
260
+ labels = [
261
+ 'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
262
+ '00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
263
+ 'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
264
+ 'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
265
+ 'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
266
+ 'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
267
+ 'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
268
+ ]
269
+ else:
270
+ raise ValueError(f"Unknown dataset: {self.config.dataset}")
271
+
272
+ # Get top k tags
273
+ top_indices = np.argsort(avg_probs)[-top_k:][::-1]
274
+ return [labels[i] for i in top_indices]
275
+
276
+ def extract_embeddings(self, audio_file, layer=None, pool='mean'):
277
+ """
278
+ Extract embeddings from audio file.
279
+ Args:
280
+ audio_file: path to audio file
281
+ layer: which layer to extract from (ignored for simplicity, uses final embeddings)
282
+ pool: pooling method ('mean', 'max', or 'both')
283
+ Returns:
284
+ embeddings as numpy array
285
+ """
286
+ # Auto-detect device and move model to it
287
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
288
+ self.to(device)
289
+
290
+ # Load and preprocess audio
291
+ audio, file_sr = sf.read(audio_file)
292
+
293
+ # Convert to mono if stereo
294
+ if len(audio.shape) > 1:
295
+ audio = np.mean(audio, axis=1)
296
+
297
+ # Resample to 16000 if necessary
298
+ if file_sr != 16000:
299
+ audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
300
+
301
+ if len(audio) == 0:
302
+ raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
303
+
304
+ # Create mel spectrogram
305
+ audio_rep = librosa.feature.melspectrogram(
306
+ y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
307
+ ).T
308
+ audio_rep = audio_rep.astype(np.float32)
309
+ audio_rep = np.log10(10000 * audio_rep + 1)
310
+
311
+ # Batch the data
312
+ n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
313
+ overlap = n_frames # No overlap
314
+
315
+ last_frame = audio_rep.shape[0] - n_frames + 1
316
+ batches = []
317
+ if last_frame <= 0:
318
+ # Pad with zeros if audio is too short
319
+ patch = np.zeros((n_frames, 96), dtype=np.float32)
320
+ patch[:audio_rep.shape[0], :] = audio_rep
321
+ batches.append(patch)
322
+ else:
323
+ # Create windows
324
+ for time_stamp in range(0, last_frame, overlap):
325
+ patch = audio_rep[time_stamp : time_stamp + n_frames, :]
326
+ batches.append(patch)
327
+
328
+ # Convert to tensor and run inference
329
+ batch_tensor = torch.from_numpy(np.stack(batches)).to(device)
330
+
331
+ all_embeddings = []
332
+ with torch.no_grad():
333
+ self.eval()
334
+ for i in range(0, len(batches), 1):
335
+ batch_subset = batch_tensor[i:i+1]
336
+ logits, mean_pool, max_pool = self(batch_subset)
337
+
338
+ if pool == 'mean':
339
+ embeddings = mean_pool.squeeze(0).cpu().numpy()
340
+ elif pool == 'max':
341
+ embeddings = max_pool.squeeze(0).cpu().numpy()
342
+ elif pool == 'both':
343
+ embeddings = torch.cat([mean_pool, max_pool], dim=1).squeeze(0).cpu().numpy()
344
+ else:
345
+ embeddings = mean_pool.squeeze(0).cpu().numpy() # default to mean
346
+
347
+ all_embeddings.append(embeddings)
348
+
349
+ # Average embeddings across all windows
350
+ avg_embeddings = np.mean(all_embeddings, axis=0)
351
+ return avg_embeddings
352
+
353
+
354
+ # For uploading to Hugging Face Hub
355
+ if __name__ == '__main__':
356
+ import json
357
+ import os
358
+ from huggingface_hub import HfApi
359
+ import shutil
360
+
361
+ # Create the model with MTT config
362
+ config = MusicNNConfig(
363
+ num_classes=50,
364
+ mid_filt=64,
365
+ backend_units=200,
366
+ dataset='MTT'
367
+ )
368
+
369
+ model = MusicNN(config)
370
+
371
+ # Load the weights
372
+ state_dict = torch.load('weights/MTT_musicnn.pt')
373
+ model.load_state_dict(state_dict)
374
+
375
+ # Save and push to Hugging Face
376
+ save_dir = 'musicnn-pytorch'
377
+ os.makedirs(save_dir, exist_ok=True)
378
+
379
+ model.save_pretrained(save_dir)
380
+ shutil.copy('musicnn.py', save_dir)
381
+
382
+ # Create config.json
383
+ config_dict = config.to_dict()
384
+ config_dict.update({
385
+ '_name_or_path': 'oriyonay/musicnn-pytorch',
386
+ 'architectures': ['MusicNN'],
387
+ 'auto_map': {
388
+ 'AutoConfig': 'musicnn.MusicNNConfig',
389
+ 'AutoModel': 'musicnn.MusicNN'
390
+ },
391
+ 'model_type': 'musicnn'
392
+ })
393
+
394
+ with open(os.path.join(save_dir, 'config.json'), 'w') as f:
395
+ json.dump(config_dict, f, indent=4)
396
+
397
+ # Push to Hugging Face
398
+ api = HfApi()
399
+ api.upload_folder(
400
+ folder_path=save_dir,
401
+ repo_id='oriyonay/musicnn-pytorch',
402
+ repo_type='model'
403
+ )
404
+
405
+ print("✅ Model uploaded to Hugging Face!")
406
+ print("Usage: model = MusicNN.from_pretrained('oriyonay/musicnn-pytorch')")
musicnn_torch.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import librosa
7
+ import soundfile as sf
8
+ import warnings
9
+
10
+ # Suppress the PyTorch padding warning and other user warnings
11
+ warnings.filterwarnings('ignore', category=UserWarning)
12
+
13
+ # hyperparams
14
+ SR = 16000
15
+ N_MELS = 96
16
+ FFT_HOP = 256
17
+ FFT_SIZE = 512
18
+
19
+ MTT_LABELS = [
20
+ 'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
21
+ 'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
22
+ 'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
23
+ 'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
24
+ 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
25
+ 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
26
+ 'choral'
27
+ ]
28
+
29
+ MSD_LABELS = [
30
+ 'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
31
+ '00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
32
+ 'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
33
+ 'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
34
+ 'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
35
+ 'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
36
+ 'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
37
+ ]
38
+
39
+
40
+ # -------------------------
41
+ # Building blocks
42
+ # -------------------------
43
+ class ConvReLUBN(nn.Module):
44
+ def __init__(self, in_ch, out_ch, kernel_size, padding=0):
45
+ super().__init__()
46
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
47
+ self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
48
+
49
+ def forward(self, x):
50
+ return self.bn(F.relu(self.conv(x)))
51
+
52
+
53
+ class TimbralBlock(nn.Module):
54
+ def __init__(self, mel_bins, out_ch):
55
+ super().__init__()
56
+ self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
57
+
58
+ def forward(self, x):
59
+ x = F.pad(x, (0, 0, 3, 3))
60
+ x = self.conv_block(x)
61
+ return torch.max(x, dim=3).values
62
+
63
+
64
+ class TemporalBlock(nn.Module):
65
+ def __init__(self, kernel_size, out_ch):
66
+ super().__init__()
67
+ self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
68
+
69
+ def forward(self, x):
70
+ x = self.conv_block(x)
71
+ return torch.max(x, dim=3).values
72
+
73
+
74
+ class MidEnd(nn.Module):
75
+ def __init__(self, in_ch, num_filt):
76
+ super().__init__()
77
+ self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
78
+ self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
79
+ self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
80
+ self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
81
+ self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
82
+ self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
83
+
84
+ def forward(self, x):
85
+ x = x.transpose(1, 2).unsqueeze(3)
86
+
87
+ x_perm = x.permute(0, 2, 3, 1)
88
+ x1_pad = F.pad(x_perm, (3, 3, 0, 0))
89
+ x1 = x1_pad.permute(0, 2, 3, 1)
90
+ x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
91
+ x1_t = x1.permute(0, 2, 1, 3)
92
+
93
+ x2_perm = x1_t.permute(0, 2, 3, 1)
94
+ x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
95
+ x2 = x2_pad.permute(0, 2, 3, 1)
96
+ x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
97
+ x2_t = x2.permute(0, 2, 1, 3)
98
+ res_conv2 = x2_t + x1_t
99
+
100
+ x3_perm = res_conv2.permute(0, 2, 3, 1)
101
+ x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
102
+ x3 = x3_pad.permute(0, 2, 3, 1)
103
+ x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
104
+ x3_t = x3.permute(0, 2, 1, 3)
105
+ res_conv3 = x3_t + res_conv2
106
+
107
+ return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
108
+
109
+
110
+ class Backend(nn.Module):
111
+ def __init__(self, in_ch, num_classes, hidden):
112
+ super().__init__()
113
+ self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
114
+ self.fc1 = nn.Linear(in_ch * 2, hidden)
115
+ self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
116
+ self.fc2 = nn.Linear(hidden, num_classes)
117
+
118
+ def forward(self, x):
119
+ max_pool = torch.max(x, dim=1).values
120
+ mean_pool = torch.mean(x, dim=1)
121
+ z = torch.stack([max_pool, mean_pool], dim=2)
122
+ z = z.view(z.size(0), -1)
123
+
124
+ z = self.bn_in(z)
125
+ z = F.dropout(z, p=0.5, training=self.training)
126
+ z = self.bn_fc1(F.relu(self.fc1(z)))
127
+ z = F.dropout(z, p=0.5, training=self.training)
128
+
129
+ logits = self.fc2(z)
130
+ return logits, mean_pool, max_pool
131
+
132
+
133
+ # -------------------------
134
+ # MusicNN
135
+ # -------------------------
136
+ class MusicNN(nn.Module):
137
+ def __init__(self, num_classes, mid_filt=64, backend_units=200):
138
+ super().__init__()
139
+ self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
140
+ self.timbral_1 = TimbralBlock(int(0.4 * N_MELS), int(1.6 * 128))
141
+ self.timbral_2 = TimbralBlock(int(0.7 * N_MELS), int(1.6 * 128))
142
+ self.temp_1 = TemporalBlock(128, int(1.6 * 32))
143
+ self.temp_2 = TemporalBlock(64, int(1.6 * 32))
144
+ self.temp_3 = TemporalBlock(32, int(1.6 * 32))
145
+ self.midend = MidEnd(in_ch=561, num_filt=mid_filt)
146
+ self.backend = Backend(in_ch=mid_filt * 3 + 561, num_classes=num_classes, hidden=backend_units)
147
+
148
+ def forward(self, x):
149
+ x = x.unsqueeze(1)
150
+ x = self.bn_input(x)
151
+ f74 = self.timbral_1(x).transpose(1, 2)
152
+ f77 = self.timbral_2(x).transpose(1, 2)
153
+ s1 = self.temp_1(x).transpose(1, 2)
154
+ s2 = self.temp_2(x).transpose(1, 2)
155
+ s3 = self.temp_3(x).transpose(1, 2)
156
+ frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
157
+ mid_feats = self.midend(frontend_features.transpose(1, 2))
158
+ z = torch.cat(mid_feats, dim=2)
159
+ logits, mean_pool, max_pool = self.backend(z)
160
+ return logits, mean_pool, max_pool
161
+
162
+
163
+ # inference utils
164
+ def batch_data(audio_file, n_frames, overlap):
165
+ # Use soundfile as it handles MP3 more reliably in some local environments
166
+ audio, sr = sf.read(audio_file)
167
+
168
+ # Convert to mono if stereo
169
+ if len(audio.shape) > 1:
170
+ audio = np.mean(audio, axis=1)
171
+
172
+ # Resample to 16000 if necessary
173
+ if sr != SR:
174
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=SR)
175
+
176
+ if len(audio) == 0:
177
+ raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
178
+
179
+ audio_rep = librosa.feature.melspectrogram(
180
+ y=audio, sr=SR, hop_length=FFT_HOP, n_fft=FFT_SIZE, n_mels=N_MELS
181
+ ).T
182
+ audio_rep = audio_rep.astype(np.float32)
183
+ audio_rep = np.log10(10000 * audio_rep + 1)
184
+
185
+ last_frame = audio_rep.shape[0] - n_frames + 1
186
+ batches = []
187
+ if last_frame <= 0:
188
+ patch = np.zeros((n_frames, N_MELS), dtype=np.float32)
189
+ patch[:audio_rep.shape[0], :] = audio_rep
190
+ batches.append(patch)
191
+ else:
192
+ for time_stamp in range(0, last_frame, overlap):
193
+ patch = audio_rep[time_stamp : time_stamp + n_frames, :]
194
+ batches.append(patch)
195
+
196
+ return np.stack(batches), audio_rep
197
+
198
+
199
+ def extractor(file_name, model='MTT_musicnn', input_length=3, input_overlap=False, device=None):
200
+ # Auto-detect device if not specified
201
+ if device is None:
202
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
203
+
204
+ if 'MTT' in model:
205
+ labels = MTT_LABELS
206
+ config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200}
207
+ elif 'MSD' in model:
208
+ labels = MSD_LABELS
209
+ if 'big' in model:
210
+ config = {'num_classes': 50, 'mid_filt': 512, 'backend_units': 500}
211
+ else:
212
+ config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200}
213
+ else:
214
+ raise ValueError('Model not supported')
215
+
216
+ # Load model
217
+ net = MusicNN(**config)
218
+ weight_path = f'{model}.pt'
219
+ if not os.path.exists(weight_path):
220
+ weight_path = os.path.join('weights', f'{model}.pt')
221
+
222
+ if os.path.exists(weight_path):
223
+ net.load_state_dict(torch.load(weight_path, map_location=device))
224
+ else:
225
+ print(f'Warning: Weights not found at {weight_path}')
226
+
227
+ net.to(device)
228
+ net.eval()
229
+
230
+ # Prep data
231
+ n_frames = librosa.time_to_frames(input_length, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP) + 1
232
+ if not input_overlap:
233
+ overlap = n_frames
234
+ else:
235
+ overlap = librosa.time_to_frames(input_overlap, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP)
236
+
237
+ batch, _ = batch_data(file_name, n_frames, overlap)
238
+ batch_torch = torch.from_numpy(batch).to(device)
239
+
240
+ with torch.no_grad():
241
+ logits, _, _ = net(batch_torch)
242
+ probs = torch.sigmoid(logits).cpu().numpy()
243
+
244
+ return probs, labels
245
+
246
+
247
+ def top_tags(file_name, model='MTT_musicnn', topN=3, device=None):
248
+ # Auto-detect device if not specified
249
+ if device is None:
250
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
251
+
252
+ probs, labels = extractor(file_name, model=model, device=device)
253
+ avg_probs = np.mean(probs, axis=0)
254
+ top_indices = avg_probs.argsort()[-topN:][::-1]
255
+ return [labels[i] for i in top_indices]
weights/MSD_musicnn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6db4c22908da50888d6a259d41980988d3b9cecc5f96fd725ede09166996dd00
3
+ size 3191473
weights/MSD_musicnn_big.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8312eddea265984e0315ecbc87a88b6fe2ab6c341a692741390880a4d1f9abe
3
+ size 31998829
weights/MTT_musicnn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32cb8bc12786302edc7dde58be340082c06559d979bec06615d1035fa2474f8d
3
+ size 3191473