Commit ·
0bbc70a
0
Parent(s):
Duplicate from oriyonay/musicnn-pytorch
Browse filesCo-authored-by: ori yonay <oriyonay@users.noreply.huggingface.co>
- .gitattributes +35 -0
- README.md +83 -0
- config.json +75 -0
- configuration_musicnn.py +18 -0
- inference.py +20 -0
- model.safetensors +3 -0
- modeling_musicnn.py +313 -0
- musicnn.py +406 -0
- musicnn_torch.py +255 -0
- weights/MSD_musicnn.pt +3 -0
- weights/MSD_musicnn_big.pt +3 -0
- weights/MTT_musicnn.pt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- audio
|
| 5 |
+
- music
|
| 6 |
+
- music-tagging
|
| 7 |
+
- pytorch
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# MusicNN-PyTorch
|
| 11 |
+
|
| 12 |
+
This is a PyTorch reimplementation of the [MusicNN](https://github.com/jordipons/musicnn) library for music audio tagging.
|
| 13 |
+
|
| 14 |
+
It contains the model architecture and converted weights from the original TensorFlow 1.x checkpoints.
|
| 15 |
+
|
| 16 |
+
## Supported Models
|
| 17 |
+
|
| 18 |
+
- `MTT_musicnn`: Trained on MagnaTagATune (50 tags) - **Default model**
|
| 19 |
+
- `MSD_musicnn`: Trained on Million Song Dataset (50 tags)
|
| 20 |
+
- `MSD_musicnn_big`: Larger version trained on MSD (512 filters)
|
| 21 |
+
|
| 22 |
+
## Super Simple Usage (Hugging Face Transformers)
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
from transformers import AutoModel
|
| 26 |
+
|
| 27 |
+
# Load the model (downloads automatically)
|
| 28 |
+
model = AutoModel.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
|
| 29 |
+
|
| 30 |
+
# Use the model
|
| 31 |
+
tags = model.predict_tags("your_audio.mp3", top_k=5)
|
| 32 |
+
print(f"Top 5 tags: {tags}")
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Embeddings (Optional)
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
from transformers import AutoModel
|
| 39 |
+
|
| 40 |
+
model = AutoModel.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
|
| 41 |
+
|
| 42 |
+
# Extract embeddings from any layer
|
| 43 |
+
emb = model.extract_embeddings("your_audio.mp3", layer="penultimate", pool="mean")
|
| 44 |
+
print(emb.shape)
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## Colab Example
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
# Install dependencies
|
| 51 |
+
!pip install transformers torch librosa soundfile
|
| 52 |
+
|
| 53 |
+
# Load with AutoModel
|
| 54 |
+
from transformers import AutoModel
|
| 55 |
+
model = AutoModel.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
|
| 56 |
+
|
| 57 |
+
# Use the model
|
| 58 |
+
tags = model.predict_tags("your_audio.mp3", top_k=5)
|
| 59 |
+
print(tags)
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Traditional Usage
|
| 63 |
+
|
| 64 |
+
If you prefer to download the code manually:
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
from musicnn_torch import top_tags
|
| 68 |
+
|
| 69 |
+
# Get top 5 tags for an audio file
|
| 70 |
+
tags = top_tags('path/to/audio.mp3', model='MTT_musicnn', topN=5)
|
| 71 |
+
print(tags)
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Installation
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
pip install transformers torch librosa soundfile
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Credits
|
| 81 |
+
|
| 82 |
+
Original implementation by [Jordi Pons](https://github.com/jordipons).
|
| 83 |
+
PyTorch port by Gemini.
|
config.json
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_classes": 50,
|
| 3 |
+
"mid_filt": 64,
|
| 4 |
+
"backend_units": 200,
|
| 5 |
+
"dataset": "MTT",
|
| 6 |
+
"return_dict": true,
|
| 7 |
+
"output_hidden_states": false,
|
| 8 |
+
"output_attentions": false,
|
| 9 |
+
"torchscript": false,
|
| 10 |
+
"torch_dtype": "float32",
|
| 11 |
+
"use_bfloat16": false,
|
| 12 |
+
"tf_legacy_loss": false,
|
| 13 |
+
"pruned_heads": {},
|
| 14 |
+
"tie_word_embeddings": true,
|
| 15 |
+
"chunk_size_feed_forward": 0,
|
| 16 |
+
"is_encoder_decoder": false,
|
| 17 |
+
"is_decoder": false,
|
| 18 |
+
"cross_attention_hidden_size": null,
|
| 19 |
+
"add_cross_attention": false,
|
| 20 |
+
"tie_encoder_decoder": false,
|
| 21 |
+
"max_length": 20,
|
| 22 |
+
"min_length": 0,
|
| 23 |
+
"do_sample": false,
|
| 24 |
+
"early_stopping": false,
|
| 25 |
+
"num_beams": 1,
|
| 26 |
+
"num_beam_groups": 1,
|
| 27 |
+
"diversity_penalty": 0.0,
|
| 28 |
+
"temperature": 1.0,
|
| 29 |
+
"top_k": 50,
|
| 30 |
+
"top_p": 1.0,
|
| 31 |
+
"typical_p": 1.0,
|
| 32 |
+
"repetition_penalty": 1.0,
|
| 33 |
+
"length_penalty": 1.0,
|
| 34 |
+
"no_repeat_ngram_size": 0,
|
| 35 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 36 |
+
"bad_words_ids": null,
|
| 37 |
+
"num_return_sequences": 1,
|
| 38 |
+
"output_scores": false,
|
| 39 |
+
"return_dict_in_generate": false,
|
| 40 |
+
"forced_bos_token_id": null,
|
| 41 |
+
"forced_eos_token_id": null,
|
| 42 |
+
"remove_invalid_values": false,
|
| 43 |
+
"exponential_decay_length_penalty": null,
|
| 44 |
+
"suppress_tokens": null,
|
| 45 |
+
"begin_suppress_tokens": null,
|
| 46 |
+
"architectures": [
|
| 47 |
+
"MusicNN"
|
| 48 |
+
],
|
| 49 |
+
"finetuning_task": null,
|
| 50 |
+
"id2label": {
|
| 51 |
+
"0": "LABEL_0",
|
| 52 |
+
"1": "LABEL_1"
|
| 53 |
+
},
|
| 54 |
+
"label2id": {
|
| 55 |
+
"LABEL_0": 0,
|
| 56 |
+
"LABEL_1": 1
|
| 57 |
+
},
|
| 58 |
+
"tokenizer_class": null,
|
| 59 |
+
"prefix": null,
|
| 60 |
+
"bos_token_id": null,
|
| 61 |
+
"pad_token_id": null,
|
| 62 |
+
"eos_token_id": null,
|
| 63 |
+
"sep_token_id": null,
|
| 64 |
+
"decoder_start_token_id": null,
|
| 65 |
+
"task_specific_params": null,
|
| 66 |
+
"problem_type": null,
|
| 67 |
+
"_name_or_path": "oriyonay/musicnn-pytorch",
|
| 68 |
+
"_attn_implementation_autoset": false,
|
| 69 |
+
"transformers_version": "4.48.0",
|
| 70 |
+
"model_type": "musicnn",
|
| 71 |
+
"auto_map": {
|
| 72 |
+
"AutoConfig": "musicnn.MusicNNConfig",
|
| 73 |
+
"AutoModel": "musicnn.MusicNN"
|
| 74 |
+
}
|
| 75 |
+
}
|
configuration_musicnn.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class MusicNNConfig(PretrainedConfig):
|
| 4 |
+
model_type = 'musicnn'
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
num_classes=50,
|
| 9 |
+
mid_filt=64,
|
| 10 |
+
backend_units=200,
|
| 11 |
+
dataset='MTT',
|
| 12 |
+
**kwargs
|
| 13 |
+
):
|
| 14 |
+
self.num_classes = num_classes
|
| 15 |
+
self.mid_filt = mid_filt
|
| 16 |
+
self.backend_units = backend_units
|
| 17 |
+
self.dataset = dataset
|
| 18 |
+
super().__init__(**kwargs)
|
inference.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from musicnn_torch import top_tags
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# Use the absolute paths you provided
|
| 5 |
+
files = [
|
| 6 |
+
'/Users/oriyonay/Desktop/CRAZY BEAT.mp3',
|
| 7 |
+
'/Users/oriyonay/Desktop/burn the stage/bounces/02 the type of girl.mp3',
|
| 8 |
+
'/Users/oriyonay/Desktop/burn the stage/extras/jazzy red roses.mp3'
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
for f in files:
|
| 12 |
+
if os.path.exists(f):
|
| 13 |
+
print(f"\n--- Predicting top tags for {os.path.basename(f)} ---")
|
| 14 |
+
try:
|
| 15 |
+
tags = top_tags(f, model='MTT_musicnn', topN=5)
|
| 16 |
+
print(f"Top 5 tags: {tags}")
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Error processing {f}: {e}")
|
| 19 |
+
else:
|
| 20 |
+
print(f"\nWarning: File not found at {f}")
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc0b9400fcaed6e9ce7fbcfa97ec91e4fcb5f2ab34ca3a0cd6bef4af74753e1a
|
| 3 |
+
size 3175212
|
modeling_musicnn.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import librosa
|
| 7 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 8 |
+
|
| 9 |
+
class MusicNNConfig(PretrainedConfig):
|
| 10 |
+
model_type = 'musicnn'
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
num_classes=50,
|
| 15 |
+
mid_filt=64,
|
| 16 |
+
backend_units=200,
|
| 17 |
+
dataset='MTT',
|
| 18 |
+
**kwargs
|
| 19 |
+
):
|
| 20 |
+
self.num_classes = num_classes
|
| 21 |
+
self.mid_filt = mid_filt
|
| 22 |
+
self.backend_units = backend_units
|
| 23 |
+
self.dataset = dataset
|
| 24 |
+
super().__init__(**kwargs)
|
| 25 |
+
|
| 26 |
+
# -------------------------
|
| 27 |
+
# Building blocks
|
| 28 |
+
# -------------------------
|
| 29 |
+
class ConvReLUBN(nn.Module):
|
| 30 |
+
def __init__(self, in_ch, out_ch, kernel_size, padding=0):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
|
| 33 |
+
self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return self.bn(F.relu(self.conv(x)))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TimbralBlock(nn.Module):
|
| 40 |
+
def __init__(self, mel_bins, out_ch):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = F.pad(x, (0, 0, 3, 3))
|
| 46 |
+
x = self.conv_block(x)
|
| 47 |
+
return torch.max(x, dim=3).values
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TemporalBlock(nn.Module):
|
| 51 |
+
def __init__(self, kernel_size, out_ch):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
x = self.conv_block(x)
|
| 57 |
+
return torch.max(x, dim=3).values
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class MidEnd(nn.Module):
|
| 61 |
+
def __init__(self, in_ch, num_filt):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
|
| 64 |
+
self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 65 |
+
self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
|
| 66 |
+
self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 67 |
+
self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
|
| 68 |
+
self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
x = x.transpose(1, 2).unsqueeze(3)
|
| 72 |
+
|
| 73 |
+
x_perm = x.permute(0, 2, 3, 1)
|
| 74 |
+
x1_pad = F.pad(x_perm, (3, 3, 0, 0))
|
| 75 |
+
x1 = x1_pad.permute(0, 2, 3, 1)
|
| 76 |
+
x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
|
| 77 |
+
x1_t = x1.permute(0, 2, 1, 3)
|
| 78 |
+
|
| 79 |
+
x2_perm = x1_t.permute(0, 2, 3, 1)
|
| 80 |
+
x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
|
| 81 |
+
x2 = x2_pad.permute(0, 2, 3, 1)
|
| 82 |
+
x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
|
| 83 |
+
x2_t = x2.permute(0, 2, 1, 3)
|
| 84 |
+
res_conv2 = x2_t + x1_t
|
| 85 |
+
|
| 86 |
+
x3_perm = res_conv2.permute(0, 2, 3, 1)
|
| 87 |
+
x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
|
| 88 |
+
x3 = x3_pad.permute(0, 2, 3, 1)
|
| 89 |
+
x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
|
| 90 |
+
x3_t = x3.permute(0, 2, 1, 3)
|
| 91 |
+
res_conv3 = x3_t + res_conv2
|
| 92 |
+
|
| 93 |
+
return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Backend(nn.Module):
|
| 97 |
+
def __init__(self, in_ch, num_classes, hidden):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
|
| 100 |
+
self.fc1 = nn.Linear(in_ch * 2, hidden)
|
| 101 |
+
self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
|
| 102 |
+
self.fc2 = nn.Linear(hidden, num_classes)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
max_pool = torch.max(x, dim=1).values
|
| 106 |
+
mean_pool = torch.mean(x, dim=1)
|
| 107 |
+
z = torch.stack([max_pool, mean_pool], dim=2)
|
| 108 |
+
z = z.view(z.size(0), -1)
|
| 109 |
+
|
| 110 |
+
z = self.bn_in(z)
|
| 111 |
+
z = F.dropout(z, p=0.5, training=self.training)
|
| 112 |
+
z = self.bn_fc1(F.relu(self.fc1(z)))
|
| 113 |
+
z = F.dropout(z, p=0.5, training=self.training)
|
| 114 |
+
|
| 115 |
+
logits = self.fc2(z)
|
| 116 |
+
return logits, mean_pool, max_pool
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class MusicNNModel(PreTrainedModel):
|
| 120 |
+
config_class = MusicNNConfig
|
| 121 |
+
|
| 122 |
+
def __init__(self, config):
|
| 123 |
+
super().__init__(config)
|
| 124 |
+
self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
|
| 125 |
+
self.timbral_1 = TimbralBlock(int(0.4 * 96), int(1.6 * 128))
|
| 126 |
+
self.timbral_2 = TimbralBlock(int(0.7 * 96), int(1.6 * 128))
|
| 127 |
+
self.temp_1 = TemporalBlock(128, int(1.6 * 32))
|
| 128 |
+
self.temp_2 = TemporalBlock(64, int(1.6 * 32))
|
| 129 |
+
self.temp_3 = TemporalBlock(32, int(1.6 * 32))
|
| 130 |
+
self.midend = MidEnd(in_ch=561, num_filt=config.mid_filt)
|
| 131 |
+
self.backend = Backend(in_ch=config.mid_filt * 3 + 561, num_classes=config.num_classes, hidden=config.backend_units)
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
# x is [B, T, M]
|
| 135 |
+
x = x.unsqueeze(1)
|
| 136 |
+
x = self.bn_input(x)
|
| 137 |
+
f74 = self.timbral_1(x).transpose(1, 2)
|
| 138 |
+
f77 = self.timbral_2(x).transpose(1, 2)
|
| 139 |
+
s1 = self.temp_1(x).transpose(1, 2)
|
| 140 |
+
s2 = self.temp_2(x).transpose(1, 2)
|
| 141 |
+
s3 = self.temp_3(x).transpose(1, 2)
|
| 142 |
+
frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
|
| 143 |
+
mid_feats = self.midend(frontend_features.transpose(1, 2))
|
| 144 |
+
z = torch.cat(mid_feats, dim=2)
|
| 145 |
+
logits, mean_pool, max_pool = self.backend(z)
|
| 146 |
+
return logits, mean_pool, max_pool
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def preprocess_audio(audio_file, sr=16000):
|
| 150 |
+
# Try librosa first (works well for many formats)
|
| 151 |
+
try:
|
| 152 |
+
audio, file_sr = librosa.load(audio_file, sr=None)
|
| 153 |
+
if len(audio) == 0:
|
| 154 |
+
raise ValueError("Empty audio from librosa")
|
| 155 |
+
except Exception:
|
| 156 |
+
# Fallback to soundfile (better for some MP3s)
|
| 157 |
+
try:
|
| 158 |
+
audio, file_sr = sf.read(audio_file)
|
| 159 |
+
# Convert to mono if stereo
|
| 160 |
+
if len(audio.shape) > 1:
|
| 161 |
+
audio = np.mean(audio, axis=1)
|
| 162 |
+
except Exception as e:
|
| 163 |
+
raise ValueError(f'Could not load audio file {audio_file}: {e}')
|
| 164 |
+
|
| 165 |
+
# Resample to target sample rate if necessary
|
| 166 |
+
if file_sr != sr:
|
| 167 |
+
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr)
|
| 168 |
+
|
| 169 |
+
if len(audio) == 0:
|
| 170 |
+
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
|
| 171 |
+
|
| 172 |
+
# Create mel spectrogram
|
| 173 |
+
audio_rep = librosa.feature.melspectrogram(
|
| 174 |
+
y=audio, sr=sr, hop_length=256, n_fft=512, n_mels=96
|
| 175 |
+
).T
|
| 176 |
+
audio_rep = audio_rep.astype(np.float32)
|
| 177 |
+
audio_rep = np.log10(10000 * audio_rep + 1)
|
| 178 |
+
|
| 179 |
+
return audio_rep
|
| 180 |
+
|
| 181 |
+
def predict_tags(self, audio_file, top_k=5):
|
| 182 |
+
# Use the same batching approach as the original implementation
|
| 183 |
+
# This matches musicnn_torch.py extractor function
|
| 184 |
+
|
| 185 |
+
# Load and preprocess audio (similar to batch_data in musicnn_torch.py)
|
| 186 |
+
audio, file_sr = sf.read(audio_file)
|
| 187 |
+
|
| 188 |
+
# Convert to mono if stereo
|
| 189 |
+
if len(audio.shape) > 1:
|
| 190 |
+
audio = np.mean(audio, axis=1)
|
| 191 |
+
|
| 192 |
+
# Resample to 16000 if necessary
|
| 193 |
+
if file_sr != 16000:
|
| 194 |
+
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
|
| 195 |
+
|
| 196 |
+
if len(audio) == 0:
|
| 197 |
+
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
|
| 198 |
+
|
| 199 |
+
# Create mel spectrogram
|
| 200 |
+
audio_rep = librosa.feature.melspectrogram(
|
| 201 |
+
y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
|
| 202 |
+
).T
|
| 203 |
+
audio_rep = audio_rep.astype(np.float32)
|
| 204 |
+
audio_rep = np.log10(10000 * audio_rep + 1)
|
| 205 |
+
|
| 206 |
+
# Batch the data (same as original implementation)
|
| 207 |
+
n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
|
| 208 |
+
overlap = n_frames # No overlap for simplicity
|
| 209 |
+
|
| 210 |
+
last_frame = audio_rep.shape[0] - n_frames + 1
|
| 211 |
+
batches = []
|
| 212 |
+
if last_frame <= 0:
|
| 213 |
+
# Pad with zeros if audio is too short
|
| 214 |
+
patch = np.zeros((n_frames, 96), dtype=np.float32)
|
| 215 |
+
patch[:audio_rep.shape[0], :] = audio_rep
|
| 216 |
+
batches.append(patch)
|
| 217 |
+
else:
|
| 218 |
+
# Create overlapping windows
|
| 219 |
+
for time_stamp in range(0, last_frame, overlap):
|
| 220 |
+
patch = audio_rep[time_stamp : time_stamp + n_frames, :]
|
| 221 |
+
batches.append(patch)
|
| 222 |
+
|
| 223 |
+
# Convert to tensor and run inference
|
| 224 |
+
batch_tensor = torch.from_numpy(np.stack(batches))
|
| 225 |
+
|
| 226 |
+
all_probs = []
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
self.eval()
|
| 229 |
+
for i in range(0, len(batches), 1): # Process in batches if needed
|
| 230 |
+
batch_subset = batch_tensor[i:i+1]
|
| 231 |
+
logits, _, _ = self(batch_subset)
|
| 232 |
+
probs = torch.sigmoid(logits).squeeze(0).numpy()
|
| 233 |
+
all_probs.append(probs)
|
| 234 |
+
|
| 235 |
+
# Average probabilities across all windows
|
| 236 |
+
avg_probs = np.mean(all_probs, axis=0)
|
| 237 |
+
|
| 238 |
+
# Get labels based on config
|
| 239 |
+
if self.config.dataset == 'MTT':
|
| 240 |
+
labels = [
|
| 241 |
+
'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
|
| 242 |
+
'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
|
| 243 |
+
'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
|
| 244 |
+
'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
|
| 245 |
+
'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
|
| 246 |
+
'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
|
| 247 |
+
'choral'
|
| 248 |
+
]
|
| 249 |
+
elif self.config.dataset == 'MSD':
|
| 250 |
+
labels = [
|
| 251 |
+
'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
|
| 252 |
+
'00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
|
| 253 |
+
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
|
| 254 |
+
'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
|
| 255 |
+
'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
|
| 256 |
+
'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
|
| 257 |
+
'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
|
| 258 |
+
]
|
| 259 |
+
else:
|
| 260 |
+
raise ValueError(f"Unknown dataset: {self.config.dataset}")
|
| 261 |
+
|
| 262 |
+
# Get top k tags
|
| 263 |
+
top_indices = np.argsort(avg_probs)[-top_k:][::-1]
|
| 264 |
+
return [labels[i] for i in top_indices]
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def create_musicnn_model(model_type='MTT_musicnn'):
|
| 268 |
+
"""
|
| 269 |
+
Factory function to create MusicNN models with different configurations.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
model_type (str): One of 'MTT_musicnn', 'MSD_musicnn', or 'MSD_musicnn_big'
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
MusicNNModel: Configured model instance
|
| 276 |
+
"""
|
| 277 |
+
from transformers import AutoConfig
|
| 278 |
+
|
| 279 |
+
# Model configurations
|
| 280 |
+
configs = {
|
| 281 |
+
'MTT_musicnn': {
|
| 282 |
+
'num_classes': 50,
|
| 283 |
+
'mid_filt': 64,
|
| 284 |
+
'backend_units': 200,
|
| 285 |
+
'dataset': 'MTT'
|
| 286 |
+
},
|
| 287 |
+
'MSD_musicnn': {
|
| 288 |
+
'num_classes': 50,
|
| 289 |
+
'mid_filt': 64,
|
| 290 |
+
'backend_units': 200,
|
| 291 |
+
'dataset': 'MSD'
|
| 292 |
+
},
|
| 293 |
+
'MSD_musicnn_big': {
|
| 294 |
+
'num_classes': 50,
|
| 295 |
+
'mid_filt': 512,
|
| 296 |
+
'backend_units': 500,
|
| 297 |
+
'dataset': 'MSD'
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
if model_type not in configs:
|
| 302 |
+
raise ValueError(f"Unknown model type: {model_type}. Choose from: {list(configs.keys())}")
|
| 303 |
+
|
| 304 |
+
# For now, we'll load the default model and modify its config
|
| 305 |
+
# In the future, we could have separate model files for each type
|
| 306 |
+
config = AutoConfig.from_pretrained("oriyonay/musicnn-pytorch", trust_remote_code=True)
|
| 307 |
+
config.num_classes = configs[model_type]['num_classes']
|
| 308 |
+
config.mid_filt = configs[model_type]['mid_filt']
|
| 309 |
+
config.backend_units = configs[model_type]['backend_units']
|
| 310 |
+
config.dataset = configs[model_type]['dataset']
|
| 311 |
+
|
| 312 |
+
model = MusicNNModel(config)
|
| 313 |
+
return model
|
musicnn.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import librosa
|
| 7 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
+
|
| 10 |
+
# Suppress warnings
|
| 11 |
+
import warnings
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MusicNNConfig(PretrainedConfig):
|
| 16 |
+
model_type = 'musicnn'
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
num_classes=50,
|
| 21 |
+
mid_filt=64,
|
| 22 |
+
backend_units=200,
|
| 23 |
+
dataset='MTT',
|
| 24 |
+
**kwargs
|
| 25 |
+
):
|
| 26 |
+
self.num_classes = num_classes
|
| 27 |
+
self.mid_filt = mid_filt
|
| 28 |
+
self.backend_units = backend_units
|
| 29 |
+
self.dataset = dataset
|
| 30 |
+
super().__init__(**kwargs)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# -------------------------
|
| 34 |
+
# Building blocks
|
| 35 |
+
# -------------------------
|
| 36 |
+
class ConvReLUBN(nn.Module):
|
| 37 |
+
def __init__(self, in_ch, out_ch, kernel_size, padding=0):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
|
| 40 |
+
self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
return self.bn(F.relu(self.conv(x)))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TimbralBlock(nn.Module):
|
| 47 |
+
def __init__(self, mel_bins, out_ch):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = F.pad(x, (0, 0, 3, 3))
|
| 53 |
+
x = self.conv_block(x)
|
| 54 |
+
return torch.max(x, dim=3).values
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TemporalBlock(nn.Module):
|
| 58 |
+
def __init__(self, kernel_size, out_ch):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
x = self.conv_block(x)
|
| 64 |
+
return torch.max(x, dim=3).values
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MidEnd(nn.Module):
|
| 68 |
+
def __init__(self, in_ch, num_filt):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
|
| 71 |
+
self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 72 |
+
self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
|
| 73 |
+
self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 74 |
+
self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
|
| 75 |
+
self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
x = x.transpose(1, 2).unsqueeze(3)
|
| 79 |
+
|
| 80 |
+
x_perm = x.permute(0, 2, 3, 1)
|
| 81 |
+
x1_pad = F.pad(x_perm, (3, 3, 0, 0))
|
| 82 |
+
x1 = x1_pad.permute(0, 2, 3, 1)
|
| 83 |
+
x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
|
| 84 |
+
x1_t = x1.permute(0, 2, 1, 3)
|
| 85 |
+
|
| 86 |
+
x2_perm = x1_t.permute(0, 2, 3, 1)
|
| 87 |
+
x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
|
| 88 |
+
x2 = x2_pad.permute(0, 2, 3, 1)
|
| 89 |
+
x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
|
| 90 |
+
x2_t = x2.permute(0, 2, 1, 3)
|
| 91 |
+
res_conv2 = x2_t + x1_t
|
| 92 |
+
|
| 93 |
+
x3_perm = res_conv2.permute(0, 2, 3, 1)
|
| 94 |
+
x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
|
| 95 |
+
x3 = x3_pad.permute(0, 2, 3, 1)
|
| 96 |
+
x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
|
| 97 |
+
x3_t = x3.permute(0, 2, 1, 3)
|
| 98 |
+
res_conv3 = x3_t + res_conv2
|
| 99 |
+
|
| 100 |
+
return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Backend(nn.Module):
|
| 104 |
+
def __init__(self, in_ch, num_classes, hidden):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
|
| 107 |
+
self.fc1 = nn.Linear(in_ch * 2, hidden)
|
| 108 |
+
self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
|
| 109 |
+
self.fc2 = nn.Linear(hidden, num_classes)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
max_pool = torch.max(x, dim=1).values
|
| 113 |
+
mean_pool = torch.mean(x, dim=1)
|
| 114 |
+
z = torch.stack([max_pool, mean_pool], dim=2)
|
| 115 |
+
z = z.view(z.size(0), -1)
|
| 116 |
+
|
| 117 |
+
z = self.bn_in(z)
|
| 118 |
+
z = F.dropout(z, p=0.5, training=self.training)
|
| 119 |
+
z = self.bn_fc1(F.relu(self.fc1(z)))
|
| 120 |
+
z = F.dropout(z, p=0.5, training=self.training)
|
| 121 |
+
|
| 122 |
+
logits = self.fc2(z)
|
| 123 |
+
return logits, mean_pool, max_pool
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class MusicNN(PreTrainedModel, PyTorchModelHubMixin):
|
| 127 |
+
config_class = MusicNNConfig
|
| 128 |
+
|
| 129 |
+
def __init__(self, config):
|
| 130 |
+
super().__init__(config)
|
| 131 |
+
self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
|
| 132 |
+
self.timbral_1 = TimbralBlock(int(0.4 * 96), int(1.6 * 128))
|
| 133 |
+
self.timbral_2 = TimbralBlock(int(0.7 * 96), int(1.6 * 128))
|
| 134 |
+
self.temp_1 = TemporalBlock(128, int(1.6 * 32))
|
| 135 |
+
self.temp_2 = TemporalBlock(64, int(1.6 * 32))
|
| 136 |
+
self.temp_3 = TemporalBlock(32, int(1.6 * 32))
|
| 137 |
+
self.midend = MidEnd(in_ch=561, num_filt=config.mid_filt)
|
| 138 |
+
self.backend = Backend(in_ch=config.mid_filt * 3 + 561, num_classes=config.num_classes, hidden=config.backend_units)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
x = x.unsqueeze(1)
|
| 142 |
+
x = self.bn_input(x)
|
| 143 |
+
f74 = self.timbral_1(x).transpose(1, 2)
|
| 144 |
+
f77 = self.timbral_2(x).transpose(1, 2)
|
| 145 |
+
s1 = self.temp_1(x).transpose(1, 2)
|
| 146 |
+
s2 = self.temp_2(x).transpose(1, 2)
|
| 147 |
+
s3 = self.temp_3(x).transpose(1, 2)
|
| 148 |
+
frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
|
| 149 |
+
mid_feats = self.midend(frontend_features.transpose(1, 2))
|
| 150 |
+
z = torch.cat(mid_feats, dim=2)
|
| 151 |
+
logits, mean_pool, max_pool = self.backend(z)
|
| 152 |
+
return logits, mean_pool, max_pool
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def preprocess_audio(audio_file, sr=16000):
|
| 156 |
+
# Try librosa first (works well for many formats)
|
| 157 |
+
try:
|
| 158 |
+
audio, file_sr = librosa.load(audio_file, sr=None)
|
| 159 |
+
if len(audio) == 0:
|
| 160 |
+
raise ValueError("Empty audio from librosa")
|
| 161 |
+
except Exception:
|
| 162 |
+
# Fallback to soundfile (better for some MP3s)
|
| 163 |
+
try:
|
| 164 |
+
audio, file_sr = sf.read(audio_file)
|
| 165 |
+
# Convert to mono if stereo
|
| 166 |
+
if len(audio.shape) > 1:
|
| 167 |
+
audio = np.mean(audio, axis=1)
|
| 168 |
+
except Exception as e:
|
| 169 |
+
raise ValueError(f'Could not load audio file {audio_file}: {e}')
|
| 170 |
+
|
| 171 |
+
# Resample to target sample rate if necessary
|
| 172 |
+
if file_sr != sr:
|
| 173 |
+
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr)
|
| 174 |
+
|
| 175 |
+
if len(audio) == 0:
|
| 176 |
+
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
|
| 177 |
+
|
| 178 |
+
# Create mel spectrogram
|
| 179 |
+
audio_rep = librosa.feature.melspectrogram(
|
| 180 |
+
y=audio, sr=sr, hop_length=256, n_fft=512, n_mels=96
|
| 181 |
+
).T
|
| 182 |
+
audio_rep = audio_rep.astype(np.float32)
|
| 183 |
+
audio_rep = np.log10(10000 * audio_rep + 1)
|
| 184 |
+
|
| 185 |
+
return audio_rep
|
| 186 |
+
|
| 187 |
+
def predict_tags(self, audio_file, top_k=5):
|
| 188 |
+
# Auto-detect device and move model to it
|
| 189 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 190 |
+
self.to(device)
|
| 191 |
+
|
| 192 |
+
# Use the same batching approach as the original implementation
|
| 193 |
+
# This matches musicnn_torch.py extractor function
|
| 194 |
+
|
| 195 |
+
# Load and preprocess audio (similar to batch_data in musicnn_torch.py)
|
| 196 |
+
audio, file_sr = sf.read(audio_file)
|
| 197 |
+
|
| 198 |
+
# Convert to mono if stereo
|
| 199 |
+
if len(audio.shape) > 1:
|
| 200 |
+
audio = np.mean(audio, axis=1)
|
| 201 |
+
|
| 202 |
+
# Resample to 16000 if necessary
|
| 203 |
+
if file_sr != 16000:
|
| 204 |
+
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
|
| 205 |
+
|
| 206 |
+
if len(audio) == 0:
|
| 207 |
+
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
|
| 208 |
+
|
| 209 |
+
# Create mel spectrogram
|
| 210 |
+
audio_rep = librosa.feature.melspectrogram(
|
| 211 |
+
y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
|
| 212 |
+
).T
|
| 213 |
+
audio_rep = audio_rep.astype(np.float32)
|
| 214 |
+
audio_rep = np.log10(10000 * audio_rep + 1)
|
| 215 |
+
|
| 216 |
+
# Batch the data (same as original implementation)
|
| 217 |
+
n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
|
| 218 |
+
overlap = n_frames # No overlap for simplicity
|
| 219 |
+
|
| 220 |
+
last_frame = audio_rep.shape[0] - n_frames + 1
|
| 221 |
+
batches = []
|
| 222 |
+
if last_frame <= 0:
|
| 223 |
+
# Pad with zeros if audio is too short
|
| 224 |
+
patch = np.zeros((n_frames, 96), dtype=np.float32)
|
| 225 |
+
patch[:audio_rep.shape[0], :] = audio_rep
|
| 226 |
+
batches.append(patch)
|
| 227 |
+
else:
|
| 228 |
+
# Create overlapping windows
|
| 229 |
+
for time_stamp in range(0, last_frame, overlap):
|
| 230 |
+
patch = audio_rep[time_stamp : time_stamp + n_frames, :]
|
| 231 |
+
batches.append(patch)
|
| 232 |
+
|
| 233 |
+
# Convert to tensor and run inference
|
| 234 |
+
batch_tensor = torch.from_numpy(np.stack(batches)).to(device)
|
| 235 |
+
|
| 236 |
+
all_probs = []
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
self.eval()
|
| 239 |
+
for i in range(0, len(batches), 1): # Process in batches if needed
|
| 240 |
+
batch_subset = batch_tensor[i:i+1]
|
| 241 |
+
logits, _, _ = self(batch_subset)
|
| 242 |
+
probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
|
| 243 |
+
all_probs.append(probs)
|
| 244 |
+
|
| 245 |
+
# Average probabilities across all windows
|
| 246 |
+
avg_probs = np.mean(all_probs, axis=0)
|
| 247 |
+
|
| 248 |
+
# Get labels based on config
|
| 249 |
+
if self.config.dataset == 'MTT':
|
| 250 |
+
labels = [
|
| 251 |
+
'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
|
| 252 |
+
'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
|
| 253 |
+
'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
|
| 254 |
+
'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
|
| 255 |
+
'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
|
| 256 |
+
'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
|
| 257 |
+
'choral'
|
| 258 |
+
]
|
| 259 |
+
elif self.config.dataset == 'MSD':
|
| 260 |
+
labels = [
|
| 261 |
+
'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
|
| 262 |
+
'00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
|
| 263 |
+
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
|
| 264 |
+
'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
|
| 265 |
+
'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
|
| 266 |
+
'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
|
| 267 |
+
'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
|
| 268 |
+
]
|
| 269 |
+
else:
|
| 270 |
+
raise ValueError(f"Unknown dataset: {self.config.dataset}")
|
| 271 |
+
|
| 272 |
+
# Get top k tags
|
| 273 |
+
top_indices = np.argsort(avg_probs)[-top_k:][::-1]
|
| 274 |
+
return [labels[i] for i in top_indices]
|
| 275 |
+
|
| 276 |
+
def extract_embeddings(self, audio_file, layer=None, pool='mean'):
|
| 277 |
+
"""
|
| 278 |
+
Extract embeddings from audio file.
|
| 279 |
+
Args:
|
| 280 |
+
audio_file: path to audio file
|
| 281 |
+
layer: which layer to extract from (ignored for simplicity, uses final embeddings)
|
| 282 |
+
pool: pooling method ('mean', 'max', or 'both')
|
| 283 |
+
Returns:
|
| 284 |
+
embeddings as numpy array
|
| 285 |
+
"""
|
| 286 |
+
# Auto-detect device and move model to it
|
| 287 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 288 |
+
self.to(device)
|
| 289 |
+
|
| 290 |
+
# Load and preprocess audio
|
| 291 |
+
audio, file_sr = sf.read(audio_file)
|
| 292 |
+
|
| 293 |
+
# Convert to mono if stereo
|
| 294 |
+
if len(audio.shape) > 1:
|
| 295 |
+
audio = np.mean(audio, axis=1)
|
| 296 |
+
|
| 297 |
+
# Resample to 16000 if necessary
|
| 298 |
+
if file_sr != 16000:
|
| 299 |
+
audio = librosa.resample(audio, orig_sr=file_sr, target_sr=16000)
|
| 300 |
+
|
| 301 |
+
if len(audio) == 0:
|
| 302 |
+
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
|
| 303 |
+
|
| 304 |
+
# Create mel spectrogram
|
| 305 |
+
audio_rep = librosa.feature.melspectrogram(
|
| 306 |
+
y=audio, sr=16000, hop_length=256, n_fft=512, n_mels=96
|
| 307 |
+
).T
|
| 308 |
+
audio_rep = audio_rep.astype(np.float32)
|
| 309 |
+
audio_rep = np.log10(10000 * audio_rep + 1)
|
| 310 |
+
|
| 311 |
+
# Batch the data
|
| 312 |
+
n_frames = 187 # librosa.time_to_frames(3, sr=16000, n_fft=512, hop_length=256) + 1
|
| 313 |
+
overlap = n_frames # No overlap
|
| 314 |
+
|
| 315 |
+
last_frame = audio_rep.shape[0] - n_frames + 1
|
| 316 |
+
batches = []
|
| 317 |
+
if last_frame <= 0:
|
| 318 |
+
# Pad with zeros if audio is too short
|
| 319 |
+
patch = np.zeros((n_frames, 96), dtype=np.float32)
|
| 320 |
+
patch[:audio_rep.shape[0], :] = audio_rep
|
| 321 |
+
batches.append(patch)
|
| 322 |
+
else:
|
| 323 |
+
# Create windows
|
| 324 |
+
for time_stamp in range(0, last_frame, overlap):
|
| 325 |
+
patch = audio_rep[time_stamp : time_stamp + n_frames, :]
|
| 326 |
+
batches.append(patch)
|
| 327 |
+
|
| 328 |
+
# Convert to tensor and run inference
|
| 329 |
+
batch_tensor = torch.from_numpy(np.stack(batches)).to(device)
|
| 330 |
+
|
| 331 |
+
all_embeddings = []
|
| 332 |
+
with torch.no_grad():
|
| 333 |
+
self.eval()
|
| 334 |
+
for i in range(0, len(batches), 1):
|
| 335 |
+
batch_subset = batch_tensor[i:i+1]
|
| 336 |
+
logits, mean_pool, max_pool = self(batch_subset)
|
| 337 |
+
|
| 338 |
+
if pool == 'mean':
|
| 339 |
+
embeddings = mean_pool.squeeze(0).cpu().numpy()
|
| 340 |
+
elif pool == 'max':
|
| 341 |
+
embeddings = max_pool.squeeze(0).cpu().numpy()
|
| 342 |
+
elif pool == 'both':
|
| 343 |
+
embeddings = torch.cat([mean_pool, max_pool], dim=1).squeeze(0).cpu().numpy()
|
| 344 |
+
else:
|
| 345 |
+
embeddings = mean_pool.squeeze(0).cpu().numpy() # default to mean
|
| 346 |
+
|
| 347 |
+
all_embeddings.append(embeddings)
|
| 348 |
+
|
| 349 |
+
# Average embeddings across all windows
|
| 350 |
+
avg_embeddings = np.mean(all_embeddings, axis=0)
|
| 351 |
+
return avg_embeddings
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# For uploading to Hugging Face Hub
|
| 355 |
+
if __name__ == '__main__':
|
| 356 |
+
import json
|
| 357 |
+
import os
|
| 358 |
+
from huggingface_hub import HfApi
|
| 359 |
+
import shutil
|
| 360 |
+
|
| 361 |
+
# Create the model with MTT config
|
| 362 |
+
config = MusicNNConfig(
|
| 363 |
+
num_classes=50,
|
| 364 |
+
mid_filt=64,
|
| 365 |
+
backend_units=200,
|
| 366 |
+
dataset='MTT'
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
model = MusicNN(config)
|
| 370 |
+
|
| 371 |
+
# Load the weights
|
| 372 |
+
state_dict = torch.load('weights/MTT_musicnn.pt')
|
| 373 |
+
model.load_state_dict(state_dict)
|
| 374 |
+
|
| 375 |
+
# Save and push to Hugging Face
|
| 376 |
+
save_dir = 'musicnn-pytorch'
|
| 377 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 378 |
+
|
| 379 |
+
model.save_pretrained(save_dir)
|
| 380 |
+
shutil.copy('musicnn.py', save_dir)
|
| 381 |
+
|
| 382 |
+
# Create config.json
|
| 383 |
+
config_dict = config.to_dict()
|
| 384 |
+
config_dict.update({
|
| 385 |
+
'_name_or_path': 'oriyonay/musicnn-pytorch',
|
| 386 |
+
'architectures': ['MusicNN'],
|
| 387 |
+
'auto_map': {
|
| 388 |
+
'AutoConfig': 'musicnn.MusicNNConfig',
|
| 389 |
+
'AutoModel': 'musicnn.MusicNN'
|
| 390 |
+
},
|
| 391 |
+
'model_type': 'musicnn'
|
| 392 |
+
})
|
| 393 |
+
|
| 394 |
+
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
|
| 395 |
+
json.dump(config_dict, f, indent=4)
|
| 396 |
+
|
| 397 |
+
# Push to Hugging Face
|
| 398 |
+
api = HfApi()
|
| 399 |
+
api.upload_folder(
|
| 400 |
+
folder_path=save_dir,
|
| 401 |
+
repo_id='oriyonay/musicnn-pytorch',
|
| 402 |
+
repo_type='model'
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
print("✅ Model uploaded to Hugging Face!")
|
| 406 |
+
print("Usage: model = MusicNN.from_pretrained('oriyonay/musicnn-pytorch')")
|
musicnn_torch.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import librosa
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# Suppress the PyTorch padding warning and other user warnings
|
| 11 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 12 |
+
|
| 13 |
+
# hyperparams
|
| 14 |
+
SR = 16000
|
| 15 |
+
N_MELS = 96
|
| 16 |
+
FFT_HOP = 256
|
| 17 |
+
FFT_SIZE = 512
|
| 18 |
+
|
| 19 |
+
MTT_LABELS = [
|
| 20 |
+
'guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock',
|
| 21 |
+
'fast', 'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian',
|
| 22 |
+
'opera', 'male', 'singing', 'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet',
|
| 23 |
+
'flute', 'woman', 'male vocal', 'no vocal', 'pop', 'soft', 'sitar', 'solo', 'man',
|
| 24 |
+
'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 'female vocal',
|
| 25 |
+
'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice',
|
| 26 |
+
'choral'
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
MSD_LABELS = [
|
| 30 |
+
'rock', 'pop', 'alternative', 'indie', 'electronic', 'female vocalists', 'dance',
|
| 31 |
+
'00s', 'alternative rock', 'jazz', 'beautiful', 'metal', 'chillout', 'male vocalists',
|
| 32 |
+
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica', '80s', 'folk', '90s',
|
| 33 |
+
'chill', 'instrumental', 'punk', 'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
|
| 34 |
+
'experimental', 'female vocalist', 'guitar', 'Hip-Hop', '70s', 'party', 'country',
|
| 35 |
+
'easy listening', 'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
|
| 36 |
+
'Progressive rock', '60s', 'rnb', 'indie pop', 'sad', 'House', 'happy'
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# -------------------------
|
| 41 |
+
# Building blocks
|
| 42 |
+
# -------------------------
|
| 43 |
+
class ConvReLUBN(nn.Module):
|
| 44 |
+
def __init__(self, in_ch, out_ch, kernel_size, padding=0):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
|
| 47 |
+
self.bn = nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.01)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return self.bn(F.relu(self.conv(x)))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TimbralBlock(nn.Module):
|
| 54 |
+
def __init__(self, mel_bins, out_ch):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(7, mel_bins), padding=0)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
x = F.pad(x, (0, 0, 3, 3))
|
| 60 |
+
x = self.conv_block(x)
|
| 61 |
+
return torch.max(x, dim=3).values
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TemporalBlock(nn.Module):
|
| 65 |
+
def __init__(self, kernel_size, out_ch):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.conv_block = ConvReLUBN(1, out_ch, kernel_size=(kernel_size, 1), padding='same')
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
x = self.conv_block(x)
|
| 71 |
+
return torch.max(x, dim=3).values
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MidEnd(nn.Module):
|
| 75 |
+
def __init__(self, in_ch, num_filt):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.c1_conv = nn.Conv2d(1, num_filt, kernel_size=(7, in_ch), padding=0)
|
| 78 |
+
self.c1_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 79 |
+
self.c2_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
|
| 80 |
+
self.c2_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 81 |
+
self.c3_conv = nn.Conv2d(1, num_filt, kernel_size=(7, num_filt), padding=0)
|
| 82 |
+
self.c3_bn = nn.BatchNorm2d(num_filt, eps=0.001, momentum=0.01)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
x = x.transpose(1, 2).unsqueeze(3)
|
| 86 |
+
|
| 87 |
+
x_perm = x.permute(0, 2, 3, 1)
|
| 88 |
+
x1_pad = F.pad(x_perm, (3, 3, 0, 0))
|
| 89 |
+
x1 = x1_pad.permute(0, 2, 3, 1)
|
| 90 |
+
x1 = self.c1_bn(F.relu(self.c1_conv(x1)))
|
| 91 |
+
x1_t = x1.permute(0, 2, 1, 3)
|
| 92 |
+
|
| 93 |
+
x2_perm = x1_t.permute(0, 2, 3, 1)
|
| 94 |
+
x2_pad = F.pad(x2_perm, (3, 3, 0, 0))
|
| 95 |
+
x2 = x2_pad.permute(0, 2, 3, 1)
|
| 96 |
+
x2 = self.c2_bn(F.relu(self.c2_conv(x2)))
|
| 97 |
+
x2_t = x2.permute(0, 2, 1, 3)
|
| 98 |
+
res_conv2 = x2_t + x1_t
|
| 99 |
+
|
| 100 |
+
x3_perm = res_conv2.permute(0, 2, 3, 1)
|
| 101 |
+
x3_pad = F.pad(x3_perm, (3, 3, 0, 0))
|
| 102 |
+
x3 = x3_pad.permute(0, 2, 3, 1)
|
| 103 |
+
x3 = self.c3_bn(F.relu(self.c3_conv(x3)))
|
| 104 |
+
x3_t = x3.permute(0, 2, 1, 3)
|
| 105 |
+
res_conv3 = x3_t + res_conv2
|
| 106 |
+
|
| 107 |
+
return [x.squeeze(3), x1_t.squeeze(3), res_conv2.squeeze(3), res_conv3.squeeze(3)]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Backend(nn.Module):
|
| 111 |
+
def __init__(self, in_ch, num_classes, hidden):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.bn_in = nn.BatchNorm1d(in_ch * 2, eps=0.001, momentum=0.01)
|
| 114 |
+
self.fc1 = nn.Linear(in_ch * 2, hidden)
|
| 115 |
+
self.bn_fc1 = nn.BatchNorm1d(hidden, eps=0.001, momentum=0.01)
|
| 116 |
+
self.fc2 = nn.Linear(hidden, num_classes)
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
max_pool = torch.max(x, dim=1).values
|
| 120 |
+
mean_pool = torch.mean(x, dim=1)
|
| 121 |
+
z = torch.stack([max_pool, mean_pool], dim=2)
|
| 122 |
+
z = z.view(z.size(0), -1)
|
| 123 |
+
|
| 124 |
+
z = self.bn_in(z)
|
| 125 |
+
z = F.dropout(z, p=0.5, training=self.training)
|
| 126 |
+
z = self.bn_fc1(F.relu(self.fc1(z)))
|
| 127 |
+
z = F.dropout(z, p=0.5, training=self.training)
|
| 128 |
+
|
| 129 |
+
logits = self.fc2(z)
|
| 130 |
+
return logits, mean_pool, max_pool
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# -------------------------
|
| 134 |
+
# MusicNN
|
| 135 |
+
# -------------------------
|
| 136 |
+
class MusicNN(nn.Module):
|
| 137 |
+
def __init__(self, num_classes, mid_filt=64, backend_units=200):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.bn_input = nn.BatchNorm2d(1, eps=0.001, momentum=0.01)
|
| 140 |
+
self.timbral_1 = TimbralBlock(int(0.4 * N_MELS), int(1.6 * 128))
|
| 141 |
+
self.timbral_2 = TimbralBlock(int(0.7 * N_MELS), int(1.6 * 128))
|
| 142 |
+
self.temp_1 = TemporalBlock(128, int(1.6 * 32))
|
| 143 |
+
self.temp_2 = TemporalBlock(64, int(1.6 * 32))
|
| 144 |
+
self.temp_3 = TemporalBlock(32, int(1.6 * 32))
|
| 145 |
+
self.midend = MidEnd(in_ch=561, num_filt=mid_filt)
|
| 146 |
+
self.backend = Backend(in_ch=mid_filt * 3 + 561, num_classes=num_classes, hidden=backend_units)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
x = x.unsqueeze(1)
|
| 150 |
+
x = self.bn_input(x)
|
| 151 |
+
f74 = self.timbral_1(x).transpose(1, 2)
|
| 152 |
+
f77 = self.timbral_2(x).transpose(1, 2)
|
| 153 |
+
s1 = self.temp_1(x).transpose(1, 2)
|
| 154 |
+
s2 = self.temp_2(x).transpose(1, 2)
|
| 155 |
+
s3 = self.temp_3(x).transpose(1, 2)
|
| 156 |
+
frontend_features = torch.cat([f74, f77, s1, s2, s3], dim=2)
|
| 157 |
+
mid_feats = self.midend(frontend_features.transpose(1, 2))
|
| 158 |
+
z = torch.cat(mid_feats, dim=2)
|
| 159 |
+
logits, mean_pool, max_pool = self.backend(z)
|
| 160 |
+
return logits, mean_pool, max_pool
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# inference utils
|
| 164 |
+
def batch_data(audio_file, n_frames, overlap):
|
| 165 |
+
# Use soundfile as it handles MP3 more reliably in some local environments
|
| 166 |
+
audio, sr = sf.read(audio_file)
|
| 167 |
+
|
| 168 |
+
# Convert to mono if stereo
|
| 169 |
+
if len(audio.shape) > 1:
|
| 170 |
+
audio = np.mean(audio, axis=1)
|
| 171 |
+
|
| 172 |
+
# Resample to 16000 if necessary
|
| 173 |
+
if sr != SR:
|
| 174 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=SR)
|
| 175 |
+
|
| 176 |
+
if len(audio) == 0:
|
| 177 |
+
raise ValueError(f'Audio file {audio_file} is empty or could not be loaded.')
|
| 178 |
+
|
| 179 |
+
audio_rep = librosa.feature.melspectrogram(
|
| 180 |
+
y=audio, sr=SR, hop_length=FFT_HOP, n_fft=FFT_SIZE, n_mels=N_MELS
|
| 181 |
+
).T
|
| 182 |
+
audio_rep = audio_rep.astype(np.float32)
|
| 183 |
+
audio_rep = np.log10(10000 * audio_rep + 1)
|
| 184 |
+
|
| 185 |
+
last_frame = audio_rep.shape[0] - n_frames + 1
|
| 186 |
+
batches = []
|
| 187 |
+
if last_frame <= 0:
|
| 188 |
+
patch = np.zeros((n_frames, N_MELS), dtype=np.float32)
|
| 189 |
+
patch[:audio_rep.shape[0], :] = audio_rep
|
| 190 |
+
batches.append(patch)
|
| 191 |
+
else:
|
| 192 |
+
for time_stamp in range(0, last_frame, overlap):
|
| 193 |
+
patch = audio_rep[time_stamp : time_stamp + n_frames, :]
|
| 194 |
+
batches.append(patch)
|
| 195 |
+
|
| 196 |
+
return np.stack(batches), audio_rep
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def extractor(file_name, model='MTT_musicnn', input_length=3, input_overlap=False, device=None):
|
| 200 |
+
# Auto-detect device if not specified
|
| 201 |
+
if device is None:
|
| 202 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 203 |
+
|
| 204 |
+
if 'MTT' in model:
|
| 205 |
+
labels = MTT_LABELS
|
| 206 |
+
config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200}
|
| 207 |
+
elif 'MSD' in model:
|
| 208 |
+
labels = MSD_LABELS
|
| 209 |
+
if 'big' in model:
|
| 210 |
+
config = {'num_classes': 50, 'mid_filt': 512, 'backend_units': 500}
|
| 211 |
+
else:
|
| 212 |
+
config = {'num_classes': 50, 'mid_filt': 64, 'backend_units': 200}
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError('Model not supported')
|
| 215 |
+
|
| 216 |
+
# Load model
|
| 217 |
+
net = MusicNN(**config)
|
| 218 |
+
weight_path = f'{model}.pt'
|
| 219 |
+
if not os.path.exists(weight_path):
|
| 220 |
+
weight_path = os.path.join('weights', f'{model}.pt')
|
| 221 |
+
|
| 222 |
+
if os.path.exists(weight_path):
|
| 223 |
+
net.load_state_dict(torch.load(weight_path, map_location=device))
|
| 224 |
+
else:
|
| 225 |
+
print(f'Warning: Weights not found at {weight_path}')
|
| 226 |
+
|
| 227 |
+
net.to(device)
|
| 228 |
+
net.eval()
|
| 229 |
+
|
| 230 |
+
# Prep data
|
| 231 |
+
n_frames = librosa.time_to_frames(input_length, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP) + 1
|
| 232 |
+
if not input_overlap:
|
| 233 |
+
overlap = n_frames
|
| 234 |
+
else:
|
| 235 |
+
overlap = librosa.time_to_frames(input_overlap, sr=SR, n_fft=FFT_SIZE, hop_length=FFT_HOP)
|
| 236 |
+
|
| 237 |
+
batch, _ = batch_data(file_name, n_frames, overlap)
|
| 238 |
+
batch_torch = torch.from_numpy(batch).to(device)
|
| 239 |
+
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
logits, _, _ = net(batch_torch)
|
| 242 |
+
probs = torch.sigmoid(logits).cpu().numpy()
|
| 243 |
+
|
| 244 |
+
return probs, labels
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def top_tags(file_name, model='MTT_musicnn', topN=3, device=None):
|
| 248 |
+
# Auto-detect device if not specified
|
| 249 |
+
if device is None:
|
| 250 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 251 |
+
|
| 252 |
+
probs, labels = extractor(file_name, model=model, device=device)
|
| 253 |
+
avg_probs = np.mean(probs, axis=0)
|
| 254 |
+
top_indices = avg_probs.argsort()[-topN:][::-1]
|
| 255 |
+
return [labels[i] for i in top_indices]
|
weights/MSD_musicnn.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6db4c22908da50888d6a259d41980988d3b9cecc5f96fd725ede09166996dd00
|
| 3 |
+
size 3191473
|
weights/MSD_musicnn_big.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b8312eddea265984e0315ecbc87a88b6fe2ab6c341a692741390880a4d1f9abe
|
| 3 |
+
size 31998829
|
weights/MTT_musicnn.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32cb8bc12786302edc7dde58be340082c06559d979bec06615d1035fa2474f8d
|
| 3 |
+
size 3191473
|