--- license: mit language: - en tags: - audio - music - embeddings - similarity - contrastive-learning - music-information-retrieval - disentangled-representations pipeline_tag: feature-extraction --- # MERIT — Disentangled Music Similarity Embeddings **MERIT** maps audio to three *disentangled* 128-dimensional unit vectors — one each for **melody**, **rhythm**, and **timbre** similarity. A single frozen [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) backbone feeds three small trained projection heads that each specialize in one musical factor. > Code & training pipeline → [github.com/AMAAI-Lab/MERIT](https://github.com/AMAAI-Lab/MERIT) --- ## Quick Start — Get Embeddings in Minutes No training or dataset required. Download the three pre-trained heads and encode any audio file. ### 1. Install dependencies ```bash pip install torch torchaudio transformers huggingface_hub ``` ### 2. Download pre-trained heads ```bash huggingface-cli download amaai-lab/merit \ head_mel/best_head.pt head_rhy/best_head.pt head_tim/best_head.pt \ --local-dir ./models ``` ### 3. Encode audio and compute similarity ```python import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from transformers import AutoModel, Wav2Vec2FeatureExtractor DEVICE = "cuda" if torch.cuda.is_available() else "cpu" EXTRACT_LAYERS = (3, 4, 5, 6, 23) MODEL_ID = "m-a-p/MERT-v1-330M" # ── Load MERT backbone (shared for all three factors) ────────────────────── processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID, trust_remote_code=True) mert = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(DEVICE).eval() # ── Head architecture ────────────────────────────────────────────────────── class ProjectionHead(nn.Module): def __init__(self, in_dim=5120, hidden_dim=512, out_dim=128): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, out_dim, bias=False), ) def forward(self, x): return F.normalize(self.net(x), dim=-1) def load_head(path): ckpt = torch.load(path, map_location=DEVICE, weights_only=True) head = ProjectionHead(ckpt["in_dim"], ckpt["hidden_dim"], ckpt["out_dim"]) head.load_state_dict(ckpt["state_dict"]) return head.to(DEVICE).eval() head_mel = load_head("models/head_mel/best_head.pt") head_rhy = load_head("models/head_rhy/best_head.pt") head_tim = load_head("models/head_tim/best_head.pt") # ── Audio loading helper ─────────────────────────────────────────────────── def load_audio(path, sr=24_000, max_sec=30): wav, orig_sr = torchaudio.load(path) if orig_sr != sr: wav = torchaudio.functional.resample(wav, orig_sr, sr) wav = wav.mean(0) # stereo → mono wav = wav[: sr * max_sec] # truncate wav = F.pad(wav, (0, sr * max_sec - wav.shape[0])) # zero-pad return wav # ── Encode ───────────────────────────────────────────────────────────────── @torch.no_grad() def get_merit_embeddings(audio_path): """Return (melody, rhythm, timbre) embeddings — each a (1, 128) unit vector.""" wav = load_audio(audio_path) inputs = processor(wav.numpy(), sampling_rate=24_000, return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} out = mert(**inputs, output_hidden_states=True) parts = [out.hidden_states[l].mean(dim=1) for l in EXTRACT_LAYERS] backbone = torch.cat(parts, dim=-1) # (1, 5120) return head_mel(backbone), head_rhy(backbone), head_tim(backbone) # ── Example: compare two tracks ─────────────────────────────────────────── emb_a = get_merit_embeddings("song_a.wav") emb_b = get_merit_embeddings("song_b.wav") melody_sim = (emb_a[0] * emb_b[0]).sum().item() # cosine similarity in [-1, 1] rhythm_sim = (emb_a[1] * emb_b[1]).sum().item() timbre_sim = (emb_a[2] * emb_b[2]).sum().item() print(f"Melody similarity: {melody_sim:.3f}") print(f"Rhythm similarity: {rhythm_sim:.3f}") print(f"Timbre similarity: {timbre_sim:.3f}") ``` > **Batch encoding:** For large collections, use [`encode_folder.py`](https://github.com/AMAAI-Lab/MERIT/blob/main/evaluation/encode_folder.py) to encode an entire directory to a single `.pkl` file — much faster than file-by-file. --- ## Model Architecture ``` MERT-v1-330M (frozen) └─ Layers 3, 4, 5, 6, 23 → mean-pool over time → concat → 5120-dim Per-factor head (three independent heads, trained independently): Linear(5120 → 512) → ReLU → Linear(512 → 128, bias=False) → L2-norm ``` Early MERT layers (3–6) capture timbral/rhythmic features; the later layer (23) carries melodic/pitch content. Each head learns to selectively weight the 5120-dim multi-layer input toward its specific factor. | Training detail | Value | |---|---| | Loss | Circle Loss (γ=10, m=0.2) | | Optimizer | AdamW (lr=1e-3) | | Schedule | Cosine annealing | | Epochs | 200 | | Triplet source | MoisesDB v0.1 + JASCO | --- ## Files | File | Description | |---|---| | `head_mel/best_head.pt` | Melody projection head (~11 MB) | | `head_rhy/best_head.pt` | Rhythm projection head (~11 MB) | | `head_tim/best_head.pt` | Timbre projection head (~11 MB) | --- ## Citation ```bibtex TODO: add after arXiv submission ``` --- ## License [MIT](https://github.com/AMAAI-Lab/MERIT/blob/main/LICENSE)