Spaces:
Sleeping
Sleeping
| """Phase 4d retrieval evaluation on held-out validation split.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _ROOT = _SCRIPT_DIR.parent | |
| if str(_SCRIPT_DIR) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPT_DIR)) | |
| from caption_dataloader import build_caption_dataloaders # noqa: E402 | |
| from inference_pipeline import ( # noqa: E402 | |
| _pick_device, | |
| load_clap, | |
| load_midi_gpt, | |
| ) | |
| def _infer_genre_label(caption: str) -> str: | |
| text = caption.lower() | |
| if "jazz" in text or "swing" in text or "bebop" in text: | |
| return "jazz" | |
| if "electronic" in text or "synth" in text or "edm" in text: | |
| return "electronic" | |
| if "classical" in text or "orchestral" in text or "baroque" in text: | |
| return "classical" | |
| if "rock" in text or "guitar" in text or "band" in text: | |
| return "rock" | |
| return "other" | |
| def _ranks_from_similarity(sim: torch.Tensor) -> torch.Tensor: | |
| """Return 1-indexed rank of correct pair for each row.""" | |
| n = sim.size(0) | |
| sorted_idx = torch.argsort(sim, dim=1, descending=True) | |
| labels = torch.arange(n, device=sim.device).unsqueeze(1) | |
| matches = sorted_idx.eq(labels) | |
| rank0 = torch.argmax(matches.to(torch.int64), dim=1) | |
| return rank0 + 1 | |
| def _recall_at_k(ranks: torch.Tensor, k: int) -> float: | |
| return float((ranks <= k).float().mean().item()) | |
| def _median_rank(ranks: torch.Tensor) -> float: | |
| return float(torch.median(ranks.to(torch.float32)).item()) | |
| def collect_val_embeddings( | |
| clap, | |
| val_loader, | |
| device: torch.device, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: | |
| midi_chunks: List[torch.Tensor] = [] | |
| text_chunks: List[torch.Tensor] = [] | |
| captions_all: List[str] = [] | |
| clap.eval() | |
| clap.text_encoder.eval() | |
| for batch in val_loader: | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| captions = batch["captions"] | |
| midi_feat = clap.encode_midi(input_ids, attention_mask) | |
| text_feat = clap.encode_text(captions, device=device) | |
| midi_emb = F.normalize(clap.midi_projection(midi_feat), p=2, dim=-1) | |
| text_emb = F.normalize(clap.text_projection(text_feat), p=2, dim=-1) | |
| midi_chunks.append(midi_emb.cpu()) | |
| text_chunks.append(text_emb.cpu()) | |
| captions_all.extend(captions) | |
| return ( | |
| torch.cat(midi_chunks, dim=0), | |
| torch.cat(text_chunks, dim=0), | |
| captions_all, | |
| ) | |
| def evaluate_retrieval( | |
| midi_embs: torch.Tensor, | |
| text_embs: torch.Tensor, | |
| ) -> Dict[str, float]: | |
| sim = midi_embs @ text_embs.t() | |
| ranks_m2t = _ranks_from_similarity(sim) | |
| ranks_t2m = _ranks_from_similarity(sim.t()) | |
| out: Dict[str, float] = { | |
| "n_val": float(sim.size(0)), | |
| "random_r1": 1.0 / float(sim.size(0)), | |
| "m2t_r1": _recall_at_k(ranks_m2t, 1), | |
| "m2t_r5": _recall_at_k(ranks_m2t, 5), | |
| "m2t_r10": _recall_at_k(ranks_m2t, 10), | |
| "m2t_median_rank": _median_rank(ranks_m2t), | |
| "t2m_r1": _recall_at_k(ranks_t2m, 1), | |
| "t2m_r5": _recall_at_k(ranks_t2m, 5), | |
| "t2m_r10": _recall_at_k(ranks_t2m, 10), | |
| "t2m_median_rank": _median_rank(ranks_t2m), | |
| } | |
| return out | |
| def genre_r1_breakdown( | |
| midi_embs: torch.Tensor, | |
| text_embs: torch.Tensor, | |
| captions: List[str], | |
| top_genres: List[str], | |
| ) -> Dict[str, float]: | |
| sim = midi_embs @ text_embs.t() | |
| sorted_idx = torch.argsort(sim, dim=1, descending=True) | |
| labels = torch.arange(sim.size(0)).unsqueeze(1) | |
| top1 = sorted_idx[:, :1] | |
| correct_top1 = top1.eq(labels).squeeze(1) | |
| genres = [_infer_genre_label(c) for c in captions] | |
| out: Dict[str, float] = {} | |
| for g in top_genres: | |
| idxs = [i for i, gg in enumerate(genres) if gg == g] | |
| if not idxs: | |
| out[g] = float("nan") | |
| continue | |
| mask = torch.tensor(idxs, dtype=torch.long) | |
| out[g] = float(correct_top1[mask].float().mean().item()) | |
| return out | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Evaluate CLAP retrieval metrics.") | |
| p.add_argument( | |
| "--results-dir", | |
| type=str, | |
| default=str(_ROOT / "results"), | |
| ) | |
| p.add_argument( | |
| "--captions-jsonl", | |
| type=str, | |
| default=str(_ROOT / "data" / "captions_llm.jsonl"), | |
| ) | |
| p.add_argument( | |
| "--midi-checkpoint", | |
| type=str, | |
| default="", | |
| ) | |
| p.add_argument( | |
| "--clap-checkpoint", | |
| type=str, | |
| default="", | |
| ) | |
| p.add_argument("--batch-size", type=int, default=64) | |
| p.add_argument("--max-seq-len", type=int, default=512) | |
| p.add_argument("--split-ratio", type=float, default=0.95) | |
| p.add_argument("--seed", type=int, default=17) | |
| p.add_argument("--num-workers", type=int, default=4) | |
| p.add_argument( | |
| "--out-json", | |
| type=str, | |
| default="", | |
| ) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| results_dir = Path(args.results_dir) | |
| if not args.midi_checkpoint: | |
| args.midi_checkpoint = str( | |
| results_dir / "checkpoints" / "best_model.pt" | |
| ) | |
| if not args.clap_checkpoint: | |
| args.clap_checkpoint = str( | |
| results_dir / "checkpoints_contrastive" / "clap_best.pt" | |
| ) | |
| if not args.out_json: | |
| args.out_json = str(results_dir / "retrieval_eval.json") | |
| device = _pick_device() | |
| print(f"[retrieval] device={device}") | |
| _, val_loader, stats = build_caption_dataloaders( | |
| jsonl_path=args.captions_jsonl, | |
| max_seq_len=args.max_seq_len, | |
| batch_size=args.batch_size, | |
| split_ratio=args.split_ratio, | |
| seed=args.seed, | |
| num_workers=args.num_workers, | |
| ) | |
| print( | |
| "[retrieval] val split size=" | |
| f"{stats.n_val_records} (total={stats.n_total_records})" | |
| ) | |
| midi_gpt, _ = load_midi_gpt(Path(args.midi_checkpoint), device=device) | |
| clap, _ = load_clap( | |
| Path(args.clap_checkpoint), midi_gpt=midi_gpt, device=device | |
| ) | |
| midi_embs, text_embs, captions = collect_val_embeddings( | |
| clap=clap, | |
| val_loader=val_loader, | |
| device=device, | |
| ) | |
| metrics = evaluate_retrieval(midi_embs=midi_embs, text_embs=text_embs) | |
| genre_r1 = genre_r1_breakdown( | |
| midi_embs=midi_embs, | |
| text_embs=text_embs, | |
| captions=captions, | |
| top_genres=["rock", "jazz", "classical", "electronic"], | |
| ) | |
| result = {"overall": metrics, "genre_r1": genre_r1} | |
| print( | |
| "[retrieval] random_r1=" | |
| f"{metrics['random_r1']:.6f} | " | |
| f"m2t R@1/5/10={metrics['m2t_r1']:.4f}/" | |
| f"{metrics['m2t_r5']:.4f}/{metrics['m2t_r10']:.4f} " | |
| f"medR={metrics['m2t_median_rank']:.1f}" | |
| ) | |
| print( | |
| "[retrieval] t2m R@1/5/10=" | |
| f"{metrics['t2m_r1']:.4f}/{metrics['t2m_r5']:.4f}/" | |
| f"{metrics['t2m_r10']:.4f} " | |
| f"medR={metrics['t2m_median_rank']:.1f}" | |
| ) | |
| print( | |
| "[retrieval] genre R@1 " | |
| + " ".join(f"{k}:{v:.4f}" for k, v in genre_r1.items()) | |
| ) | |
| out_path = Path(args.out_json) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text(json.dumps(result, indent=2)) | |
| print(f"[retrieval] wrote {out_path}") | |
| if __name__ == "__main__": | |
| main() | |