Coda / src /eval_retrieval.py
Prajanya Gupta
initial deploy
6b7b403
"""Phase 4d retrieval evaluation on held-out validation split."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT = _SCRIPT_DIR.parent
if str(_SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPT_DIR))
from caption_dataloader import build_caption_dataloaders # noqa: E402
from inference_pipeline import ( # noqa: E402
_pick_device,
load_clap,
load_midi_gpt,
)
def _infer_genre_label(caption: str) -> str:
text = caption.lower()
if "jazz" in text or "swing" in text or "bebop" in text:
return "jazz"
if "electronic" in text or "synth" in text or "edm" in text:
return "electronic"
if "classical" in text or "orchestral" in text or "baroque" in text:
return "classical"
if "rock" in text or "guitar" in text or "band" in text:
return "rock"
return "other"
def _ranks_from_similarity(sim: torch.Tensor) -> torch.Tensor:
"""Return 1-indexed rank of correct pair for each row."""
n = sim.size(0)
sorted_idx = torch.argsort(sim, dim=1, descending=True)
labels = torch.arange(n, device=sim.device).unsqueeze(1)
matches = sorted_idx.eq(labels)
rank0 = torch.argmax(matches.to(torch.int64), dim=1)
return rank0 + 1
def _recall_at_k(ranks: torch.Tensor, k: int) -> float:
return float((ranks <= k).float().mean().item())
def _median_rank(ranks: torch.Tensor) -> float:
return float(torch.median(ranks.to(torch.float32)).item())
@torch.no_grad()
def collect_val_embeddings(
clap,
val_loader,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
midi_chunks: List[torch.Tensor] = []
text_chunks: List[torch.Tensor] = []
captions_all: List[str] = []
clap.eval()
clap.text_encoder.eval()
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
captions = batch["captions"]
midi_feat = clap.encode_midi(input_ids, attention_mask)
text_feat = clap.encode_text(captions, device=device)
midi_emb = F.normalize(clap.midi_projection(midi_feat), p=2, dim=-1)
text_emb = F.normalize(clap.text_projection(text_feat), p=2, dim=-1)
midi_chunks.append(midi_emb.cpu())
text_chunks.append(text_emb.cpu())
captions_all.extend(captions)
return (
torch.cat(midi_chunks, dim=0),
torch.cat(text_chunks, dim=0),
captions_all,
)
def evaluate_retrieval(
midi_embs: torch.Tensor,
text_embs: torch.Tensor,
) -> Dict[str, float]:
sim = midi_embs @ text_embs.t()
ranks_m2t = _ranks_from_similarity(sim)
ranks_t2m = _ranks_from_similarity(sim.t())
out: Dict[str, float] = {
"n_val": float(sim.size(0)),
"random_r1": 1.0 / float(sim.size(0)),
"m2t_r1": _recall_at_k(ranks_m2t, 1),
"m2t_r5": _recall_at_k(ranks_m2t, 5),
"m2t_r10": _recall_at_k(ranks_m2t, 10),
"m2t_median_rank": _median_rank(ranks_m2t),
"t2m_r1": _recall_at_k(ranks_t2m, 1),
"t2m_r5": _recall_at_k(ranks_t2m, 5),
"t2m_r10": _recall_at_k(ranks_t2m, 10),
"t2m_median_rank": _median_rank(ranks_t2m),
}
return out
def genre_r1_breakdown(
midi_embs: torch.Tensor,
text_embs: torch.Tensor,
captions: List[str],
top_genres: List[str],
) -> Dict[str, float]:
sim = midi_embs @ text_embs.t()
sorted_idx = torch.argsort(sim, dim=1, descending=True)
labels = torch.arange(sim.size(0)).unsqueeze(1)
top1 = sorted_idx[:, :1]
correct_top1 = top1.eq(labels).squeeze(1)
genres = [_infer_genre_label(c) for c in captions]
out: Dict[str, float] = {}
for g in top_genres:
idxs = [i for i, gg in enumerate(genres) if gg == g]
if not idxs:
out[g] = float("nan")
continue
mask = torch.tensor(idxs, dtype=torch.long)
out[g] = float(correct_top1[mask].float().mean().item())
return out
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Evaluate CLAP retrieval metrics.")
p.add_argument(
"--results-dir",
type=str,
default=str(_ROOT / "results"),
)
p.add_argument(
"--captions-jsonl",
type=str,
default=str(_ROOT / "data" / "captions_llm.jsonl"),
)
p.add_argument(
"--midi-checkpoint",
type=str,
default="",
)
p.add_argument(
"--clap-checkpoint",
type=str,
default="",
)
p.add_argument("--batch-size", type=int, default=64)
p.add_argument("--max-seq-len", type=int, default=512)
p.add_argument("--split-ratio", type=float, default=0.95)
p.add_argument("--seed", type=int, default=17)
p.add_argument("--num-workers", type=int, default=4)
p.add_argument(
"--out-json",
type=str,
default="",
)
return p.parse_args()
def main() -> None:
args = parse_args()
results_dir = Path(args.results_dir)
if not args.midi_checkpoint:
args.midi_checkpoint = str(
results_dir / "checkpoints" / "best_model.pt"
)
if not args.clap_checkpoint:
args.clap_checkpoint = str(
results_dir / "checkpoints_contrastive" / "clap_best.pt"
)
if not args.out_json:
args.out_json = str(results_dir / "retrieval_eval.json")
device = _pick_device()
print(f"[retrieval] device={device}")
_, val_loader, stats = build_caption_dataloaders(
jsonl_path=args.captions_jsonl,
max_seq_len=args.max_seq_len,
batch_size=args.batch_size,
split_ratio=args.split_ratio,
seed=args.seed,
num_workers=args.num_workers,
)
print(
"[retrieval] val split size="
f"{stats.n_val_records} (total={stats.n_total_records})"
)
midi_gpt, _ = load_midi_gpt(Path(args.midi_checkpoint), device=device)
clap, _ = load_clap(
Path(args.clap_checkpoint), midi_gpt=midi_gpt, device=device
)
midi_embs, text_embs, captions = collect_val_embeddings(
clap=clap,
val_loader=val_loader,
device=device,
)
metrics = evaluate_retrieval(midi_embs=midi_embs, text_embs=text_embs)
genre_r1 = genre_r1_breakdown(
midi_embs=midi_embs,
text_embs=text_embs,
captions=captions,
top_genres=["rock", "jazz", "classical", "electronic"],
)
result = {"overall": metrics, "genre_r1": genre_r1}
print(
"[retrieval] random_r1="
f"{metrics['random_r1']:.6f} | "
f"m2t R@1/5/10={metrics['m2t_r1']:.4f}/"
f"{metrics['m2t_r5']:.4f}/{metrics['m2t_r10']:.4f} "
f"medR={metrics['m2t_median_rank']:.1f}"
)
print(
"[retrieval] t2m R@1/5/10="
f"{metrics['t2m_r1']:.4f}/{metrics['t2m_r5']:.4f}/"
f"{metrics['t2m_r10']:.4f} "
f"medR={metrics['t2m_median_rank']:.1f}"
)
print(
"[retrieval] genre R@1 "
+ " ".join(f"{k}:{v:.4f}" for k, v in genre_r1.items())
)
out_path = Path(args.out_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(result, indent=2))
print(f"[retrieval] wrote {out_path}")
if __name__ == "__main__":
main()