| --- |
| license: mit |
| language: |
| - en |
| tags: |
| - audio |
| - music |
| - embeddings |
| - similarity |
| - contrastive-learning |
| - music-information-retrieval |
| - disentangled-representations |
| pipeline_tag: feature-extraction |
| --- |
| |
| # MERIT β Disentangled Music Similarity Embeddings |
|
|
| **MERIT** maps audio to three *disentangled* 128-dimensional unit vectors β one each for **melody**, **rhythm**, and **timbre** similarity. A single frozen [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) backbone feeds three small trained projection heads that each specialize in one musical factor. |
|
|
| > Code & training pipeline β [github.com/AMAAI-Lab/MERIT](https://github.com/AMAAI-Lab/MERIT) |
|
|
| --- |
|
|
| ## Quick Start β Get Embeddings in Minutes |
|
|
| No training or dataset required. Download the three pre-trained heads and encode any audio file. |
|
|
| ### 1. Install dependencies |
|
|
| ```bash |
| pip install torch torchaudio transformers huggingface_hub |
| ``` |
|
|
| ### 2. Download pre-trained heads |
|
|
| ```bash |
| huggingface-cli download amaai-lab/merit \ |
| head_mel/best_head.pt head_rhy/best_head.pt head_tim/best_head.pt \ |
| --local-dir ./models |
| ``` |
|
|
| ### 3. Encode audio and compute similarity |
|
|
| ```python |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchaudio |
| from transformers import AutoModel, Wav2Vec2FeatureExtractor |
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| EXTRACT_LAYERS = (3, 4, 5, 6, 23) |
| MODEL_ID = "m-a-p/MERT-v1-330M" |
| |
| # ββ Load MERT backbone (shared for all three factors) ββββββββββββββββββββββ |
| processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID, trust_remote_code=True) |
| mert = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(DEVICE).eval() |
| |
| |
| # ββ Head architecture ββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| class ProjectionHead(nn.Module): |
| def __init__(self, in_dim=5120, hidden_dim=512, out_dim=128): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, hidden_dim), |
| nn.ReLU(inplace=True), |
| nn.Linear(hidden_dim, out_dim, bias=False), |
| ) |
| |
| def forward(self, x): |
| return F.normalize(self.net(x), dim=-1) |
| |
| |
| def load_head(path): |
| ckpt = torch.load(path, map_location=DEVICE, weights_only=True) |
| head = ProjectionHead(ckpt["in_dim"], ckpt["hidden_dim"], ckpt["out_dim"]) |
| head.load_state_dict(ckpt["state_dict"]) |
| return head.to(DEVICE).eval() |
| |
| |
| head_mel = load_head("models/head_mel/best_head.pt") |
| head_rhy = load_head("models/head_rhy/best_head.pt") |
| head_tim = load_head("models/head_tim/best_head.pt") |
| |
| |
| # ββ Audio loading helper βββββββββββββββββββββββββββββββββββββββββββββββββββ |
| def load_audio(path, sr=24_000, max_sec=30): |
| wav, orig_sr = torchaudio.load(path) |
| if orig_sr != sr: |
| wav = torchaudio.functional.resample(wav, orig_sr, sr) |
| wav = wav.mean(0) # stereo β mono |
| wav = wav[: sr * max_sec] # truncate |
| wav = F.pad(wav, (0, sr * max_sec - wav.shape[0])) # zero-pad |
| return wav |
| |
| |
| # ββ Encode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| @torch.no_grad() |
| def get_merit_embeddings(audio_path): |
| """Return (melody, rhythm, timbre) embeddings β each a (1, 128) unit vector.""" |
| wav = load_audio(audio_path) |
| inputs = processor(wav.numpy(), sampling_rate=24_000, return_tensors="pt") |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
| out = mert(**inputs, output_hidden_states=True) |
| parts = [out.hidden_states[l].mean(dim=1) for l in EXTRACT_LAYERS] |
| backbone = torch.cat(parts, dim=-1) # (1, 5120) |
| return head_mel(backbone), head_rhy(backbone), head_tim(backbone) |
| |
| |
| # ββ Example: compare two tracks βββββββββββββββββββββββββββββββββββββββββββ |
| emb_a = get_merit_embeddings("song_a.wav") |
| emb_b = get_merit_embeddings("song_b.wav") |
| |
| melody_sim = (emb_a[0] * emb_b[0]).sum().item() # cosine similarity in [-1, 1] |
| rhythm_sim = (emb_a[1] * emb_b[1]).sum().item() |
| timbre_sim = (emb_a[2] * emb_b[2]).sum().item() |
| |
| print(f"Melody similarity: {melody_sim:.3f}") |
| print(f"Rhythm similarity: {rhythm_sim:.3f}") |
| print(f"Timbre similarity: {timbre_sim:.3f}") |
| ``` |
|
|
| > **Batch encoding:** For large collections, use [`encode_folder.py`](https://github.com/AMAAI-Lab/MERIT/blob/main/evaluation/encode_folder.py) to encode an entire directory to a single `.pkl` file β much faster than file-by-file. |
|
|
| --- |
|
|
| ## Model Architecture |
|
|
| ``` |
| MERT-v1-330M (frozen) |
| ββ Layers 3, 4, 5, 6, 23 β mean-pool over time β concat β 5120-dim |
| |
| Per-factor head (three independent heads, trained independently): |
| Linear(5120 β 512) β ReLU β Linear(512 β 128, bias=False) β L2-norm |
| ``` |
|
|
| Early MERT layers (3β6) capture timbral/rhythmic features; the later layer (23) carries melodic/pitch content. Each head learns to selectively weight the 5120-dim multi-layer input toward its specific factor. |
|
|
| | Training detail | Value | |
| |---|---| |
| | Loss | Circle Loss (Ξ³=10, m=0.2) | |
| | Optimizer | AdamW (lr=1e-3) | |
| | Schedule | Cosine annealing | |
| | Epochs | 200 | |
| | Triplet source | MoisesDB v0.1 + JASCO | |
|
|
| --- |
|
|
| ## Files |
|
|
| | File | Description | |
| |---|---| |
| | `head_mel/best_head.pt` | Melody projection head (~11 MB) | |
| | `head_rhy/best_head.pt` | Rhythm projection head (~11 MB) | |
| | `head_tim/best_head.pt` | Timbre projection head (~11 MB) | |
|
|
| --- |
|
|
| ## Citation |
|
|
| ```bibtex |
| TODO: add after arXiv submission |
| ``` |
|
|
| --- |
|
|
| ## License |
|
|
| [MIT](https://github.com/AMAAI-Lab/MERIT/blob/main/LICENSE) |
|
|