merit / README.md
elchico1990's picture
Update README.md
94eac6f verified
---
license: mit
language:
- en
tags:
- audio
- music
- embeddings
- similarity
- contrastive-learning
- music-information-retrieval
- disentangled-representations
pipeline_tag: feature-extraction
---
# MERIT β€” Disentangled Music Similarity Embeddings
**MERIT** maps audio to three *disentangled* 128-dimensional unit vectors β€” one each for **melody**, **rhythm**, and **timbre** similarity. A single frozen [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) backbone feeds three small trained projection heads that each specialize in one musical factor.
> Code & training pipeline β†’ [github.com/AMAAI-Lab/MERIT](https://github.com/AMAAI-Lab/MERIT)
---
## Quick Start β€” Get Embeddings in Minutes
No training or dataset required. Download the three pre-trained heads and encode any audio file.
### 1. Install dependencies
```bash
pip install torch torchaudio transformers huggingface_hub
```
### 2. Download pre-trained heads
```bash
huggingface-cli download amaai-lab/merit \
head_mel/best_head.pt head_rhy/best_head.pt head_tim/best_head.pt \
--local-dir ./models
```
### 3. Encode audio and compute similarity
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import AutoModel, Wav2Vec2FeatureExtractor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EXTRACT_LAYERS = (3, 4, 5, 6, 23)
MODEL_ID = "m-a-p/MERT-v1-330M"
# ── Load MERT backbone (shared for all three factors) ──────────────────────
processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID, trust_remote_code=True)
mert = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(DEVICE).eval()
# ── Head architecture ──────────────────────────────────────────────────────
class ProjectionHead(nn.Module):
def __init__(self, in_dim=5120, hidden_dim=512, out_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim, bias=False),
)
def forward(self, x):
return F.normalize(self.net(x), dim=-1)
def load_head(path):
ckpt = torch.load(path, map_location=DEVICE, weights_only=True)
head = ProjectionHead(ckpt["in_dim"], ckpt["hidden_dim"], ckpt["out_dim"])
head.load_state_dict(ckpt["state_dict"])
return head.to(DEVICE).eval()
head_mel = load_head("models/head_mel/best_head.pt")
head_rhy = load_head("models/head_rhy/best_head.pt")
head_tim = load_head("models/head_tim/best_head.pt")
# ── Audio loading helper ───────────────────────────────────────────────────
def load_audio(path, sr=24_000, max_sec=30):
wav, orig_sr = torchaudio.load(path)
if orig_sr != sr:
wav = torchaudio.functional.resample(wav, orig_sr, sr)
wav = wav.mean(0) # stereo β†’ mono
wav = wav[: sr * max_sec] # truncate
wav = F.pad(wav, (0, sr * max_sec - wav.shape[0])) # zero-pad
return wav
# ── Encode ─────────────────────────────────────────────────────────────────
@torch.no_grad()
def get_merit_embeddings(audio_path):
"""Return (melody, rhythm, timbre) embeddings β€” each a (1, 128) unit vector."""
wav = load_audio(audio_path)
inputs = processor(wav.numpy(), sampling_rate=24_000, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
out = mert(**inputs, output_hidden_states=True)
parts = [out.hidden_states[l].mean(dim=1) for l in EXTRACT_LAYERS]
backbone = torch.cat(parts, dim=-1) # (1, 5120)
return head_mel(backbone), head_rhy(backbone), head_tim(backbone)
# ── Example: compare two tracks ───────────────────────────────────────────
emb_a = get_merit_embeddings("song_a.wav")
emb_b = get_merit_embeddings("song_b.wav")
melody_sim = (emb_a[0] * emb_b[0]).sum().item() # cosine similarity in [-1, 1]
rhythm_sim = (emb_a[1] * emb_b[1]).sum().item()
timbre_sim = (emb_a[2] * emb_b[2]).sum().item()
print(f"Melody similarity: {melody_sim:.3f}")
print(f"Rhythm similarity: {rhythm_sim:.3f}")
print(f"Timbre similarity: {timbre_sim:.3f}")
```
> **Batch encoding:** For large collections, use [`encode_folder.py`](https://github.com/AMAAI-Lab/MERIT/blob/main/evaluation/encode_folder.py) to encode an entire directory to a single `.pkl` file β€” much faster than file-by-file.
---
## Model Architecture
```
MERT-v1-330M (frozen)
└─ Layers 3, 4, 5, 6, 23 β†’ mean-pool over time β†’ concat β†’ 5120-dim
Per-factor head (three independent heads, trained independently):
Linear(5120 β†’ 512) β†’ ReLU β†’ Linear(512 β†’ 128, bias=False) β†’ L2-norm
```
Early MERT layers (3–6) capture timbral/rhythmic features; the later layer (23) carries melodic/pitch content. Each head learns to selectively weight the 5120-dim multi-layer input toward its specific factor.
| Training detail | Value |
|---|---|
| Loss | Circle Loss (Ξ³=10, m=0.2) |
| Optimizer | AdamW (lr=1e-3) |
| Schedule | Cosine annealing |
| Epochs | 200 |
| Triplet source | MoisesDB v0.1 + JASCO |
---
## Files
| File | Description |
|---|---|
| `head_mel/best_head.pt` | Melody projection head (~11 MB) |
| `head_rhy/best_head.pt` | Rhythm projection head (~11 MB) |
| `head_tim/best_head.pt` | Timbre projection head (~11 MB) |
---
## Citation
```bibtex
TODO: add after arXiv submission
```
---
## License
[MIT](https://github.com/AMAAI-Lab/MERIT/blob/main/LICENSE)