MERIT β Disentangled Music Similarity Embeddings
MERIT maps audio to three disentangled 128-dimensional unit vectors β one each for melody, rhythm, and timbre similarity. A single frozen MERT-v1-330M backbone feeds three small trained projection heads that each specialize in one musical factor.
Code & training pipeline β github.com/AMAAI-Lab/MERIT
Quick Start β Get Embeddings in Minutes
No training or dataset required. Download the three pre-trained heads and encode any audio file.
1. Install dependencies
pip install torch torchaudio transformers huggingface_hub
2. Download pre-trained heads
huggingface-cli download amaai-lab/merit \
head_mel/best_head.pt head_rhy/best_head.pt head_tim/best_head.pt \
--local-dir ./models
3. Encode audio and compute similarity
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import AutoModel, Wav2Vec2FeatureExtractor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EXTRACT_LAYERS = (3, 4, 5, 6, 23)
MODEL_ID = "m-a-p/MERT-v1-330M"
# ββ Load MERT backbone (shared for all three factors) ββββββββββββββββββββββ
processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID, trust_remote_code=True)
mert = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(DEVICE).eval()
# ββ Head architecture ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ProjectionHead(nn.Module):
def __init__(self, in_dim=5120, hidden_dim=512, out_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim, bias=False),
)
def forward(self, x):
return F.normalize(self.net(x), dim=-1)
def load_head(path):
ckpt = torch.load(path, map_location=DEVICE, weights_only=True)
head = ProjectionHead(ckpt["in_dim"], ckpt["hidden_dim"], ckpt["out_dim"])
head.load_state_dict(ckpt["state_dict"])
return head.to(DEVICE).eval()
head_mel = load_head("models/head_mel/best_head.pt")
head_rhy = load_head("models/head_rhy/best_head.pt")
head_tim = load_head("models/head_tim/best_head.pt")
# ββ Audio loading helper βββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_audio(path, sr=24_000, max_sec=30):
wav, orig_sr = torchaudio.load(path)
if orig_sr != sr:
wav = torchaudio.functional.resample(wav, orig_sr, sr)
wav = wav.mean(0) # stereo β mono
wav = wav[: sr * max_sec] # truncate
wav = F.pad(wav, (0, sr * max_sec - wav.shape[0])) # zero-pad
return wav
# ββ Encode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@torch.no_grad()
def get_merit_embeddings(audio_path):
"""Return (melody, rhythm, timbre) embeddings β each a (1, 128) unit vector."""
wav = load_audio(audio_path)
inputs = processor(wav.numpy(), sampling_rate=24_000, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
out = mert(**inputs, output_hidden_states=True)
parts = [out.hidden_states[l].mean(dim=1) for l in EXTRACT_LAYERS]
backbone = torch.cat(parts, dim=-1) # (1, 5120)
return head_mel(backbone), head_rhy(backbone), head_tim(backbone)
# ββ Example: compare two tracks βββββββββββββββββββββββββββββββββββββββββββ
emb_a = get_merit_embeddings("song_a.wav")
emb_b = get_merit_embeddings("song_b.wav")
melody_sim = (emb_a[0] * emb_b[0]).sum().item() # cosine similarity in [-1, 1]
rhythm_sim = (emb_a[1] * emb_b[1]).sum().item()
timbre_sim = (emb_a[2] * emb_b[2]).sum().item()
print(f"Melody similarity: {melody_sim:.3f}")
print(f"Rhythm similarity: {rhythm_sim:.3f}")
print(f"Timbre similarity: {timbre_sim:.3f}")
Batch encoding: For large collections, use
encode_folder.pyto encode an entire directory to a single.pklfile β much faster than file-by-file.
Model Architecture
MERT-v1-330M (frozen)
ββ Layers 3, 4, 5, 6, 23 β mean-pool over time β concat β 5120-dim
Per-factor head (three independent heads, trained independently):
Linear(5120 β 512) β ReLU β Linear(512 β 128, bias=False) β L2-norm
Early MERT layers (3β6) capture timbral/rhythmic features; the later layer (23) carries melodic/pitch content. Each head learns to selectively weight the 5120-dim multi-layer input toward its specific factor.
| Training detail | Value |
|---|---|
| Loss | Circle Loss (Ξ³=10, m=0.2) |
| Optimizer | AdamW (lr=1e-3) |
| Schedule | Cosine annealing |
| Epochs | 200 |
| Triplet source | MoisesDB v0.1 + JASCO |
Files
| File | Description |
|---|---|
head_mel/best_head.pt |
Melody projection head (~11 MB) |
head_rhy/best_head.pt |
Rhythm projection head (~11 MB) |
head_tim/best_head.pt |
Timbre projection head (~11 MB) |
Citation
TODO: add after arXiv submission