Spaces:
Sleeping
Sleeping
| """Phrase completion evaluation on JSB chorales. | |
| For each JSB piece, we build (prompt, ground_truth_continuation) pairs by | |
| splitting near the midpoint BAR_END. We then compare: | |
| - model completions | |
| - random-token completions | |
| - ground-truth Bach continuations | |
| using tonal stability, rhythmic regularity, and phrase closure metrics. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import matplotlib | |
| import numpy as np | |
| import pretty_midi | |
| import torch | |
| import torch.nn.functional as F | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _ROOT = _SCRIPT_DIR.parent | |
| if str(_SCRIPT_DIR) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPT_DIR)) | |
| from model import GPT, GPTConfig, default_gpt_config # noqa: E402 | |
| from tokenizer import ( # noqa: E402 | |
| BAR_END, | |
| ID2TOKEN, | |
| PHRASE_END, | |
| VOCAB_SIZE, | |
| encode, | |
| ) | |
| def _pick_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| mps = getattr(torch.backends, "mps", None) | |
| if mps is not None and mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor: | |
| if k <= 0 or k >= logits.size(-1): | |
| return logits | |
| values, _ = torch.topk(logits, k) | |
| threshold = values[:, -1].unsqueeze(-1) | |
| return logits.masked_fill(logits < threshold, float("-inf")) | |
| def _extract_gpt_config_dict(raw: Dict[str, Any]) -> Dict[str, Any]: | |
| keys = set(GPTConfig.__dataclass_fields__.keys()) | |
| return {k: raw[k] for k in keys if k in raw} | |
| def _load_config_from_sources( | |
| ckpt: Dict[str, Any], config_path: str | |
| ) -> GPTConfig: | |
| cfg = default_gpt_config() | |
| ckpt_cfg = ckpt.get("config") | |
| if isinstance(ckpt_cfg, dict): | |
| for k, v in _extract_gpt_config_dict(ckpt_cfg).items(): | |
| setattr(cfg, k, v) | |
| if config_path: | |
| loaded = json.loads(Path(config_path).read_text()) | |
| if not isinstance(loaded, dict): | |
| raise ValueError("--config must be a JSON object.") | |
| for k, v in _extract_gpt_config_dict(loaded).items(): | |
| setattr(cfg, k, v) | |
| return cfg | |
| def _pitch_class_from_id(tid: int) -> int | None: | |
| tok = ID2TOKEN.get(tid, "") | |
| if not tok.startswith("P"): | |
| return None | |
| try: | |
| return int(tok[1:]) % 12 | |
| except ValueError: | |
| return None | |
| def _bar_lengths(ids: List[int]) -> List[int]: | |
| lengths: List[int] = [] | |
| count = 0 | |
| in_bar = False | |
| for tid in ids: | |
| if tid == BAR_END: | |
| if in_bar: | |
| lengths.append(max(1, count)) | |
| in_bar = False | |
| count = 0 | |
| continue | |
| count += 1 | |
| tok = ID2TOKEN.get(tid, "") | |
| if tok == "BAR_START": | |
| in_bar = True | |
| return lengths | |
| def _infer_tonic(prompt_ids: List[int]) -> int: | |
| pcs = [ | |
| pc for tid in prompt_ids if (pc := _pitch_class_from_id(tid)) is not None | |
| ] | |
| if not pcs: | |
| return 0 | |
| counts = np.bincount(np.array(pcs, dtype=np.int64), minlength=12) | |
| return int(np.argmax(counts)) | |
| def _scale_pitch_classes(tonic: int) -> Tuple[set[int], set[int]]: | |
| major = {0, 2, 4, 5, 7, 9, 11} | |
| minor = {0, 2, 3, 5, 7, 8, 10} | |
| maj = {(tonic + x) % 12 for x in major} | |
| minr = {(tonic + x) % 12 for x in minor} | |
| return maj, minr | |
| def tonal_stability(prompt_ids: List[int], gen_ids: List[int]) -> float: | |
| tonic = _infer_tonic(prompt_ids) | |
| maj, minr = _scale_pitch_classes(tonic) | |
| pcs = [pc for tid in gen_ids if (pc := _pitch_class_from_id(tid)) is not None] | |
| if not pcs: | |
| return 0.0 | |
| in_maj = sum(1 for pc in pcs if pc in maj) / len(pcs) | |
| in_min = sum(1 for pc in pcs if pc in minr) / len(pcs) | |
| return float(max(in_maj, in_min)) | |
| def rhythmic_regularity(prompt_ids: List[int], gen_ids: List[int]) -> float: | |
| p_bars = _bar_lengths(prompt_ids) | |
| g_bars = _bar_lengths(gen_ids) | |
| if not p_bars or not g_bars: | |
| return 0.0 | |
| p_len = float(np.mean(p_bars)) | |
| g_len = float(np.mean(g_bars)) | |
| if p_len <= 0: | |
| return 0.0 | |
| score = 1.0 - abs(g_len - p_len) / p_len | |
| return float(max(0.0, min(1.0, score))) | |
| def phrase_closure(prompt_ids: List[int], gen_ids: List[int]) -> float: | |
| if not gen_ids: | |
| return 0.0 | |
| tail = gen_ids[-8:] | |
| tonic = _infer_tonic(prompt_ids) | |
| tonic_present = any(_pitch_class_from_id(tid) == tonic for tid in tail) | |
| if PHRASE_END in tail or BAR_END in tail or tonic_present: | |
| return 1.0 | |
| return 0.0 | |
| def _metrics( | |
| prompt_ids: List[int], continuation_ids: List[int] | |
| ) -> Dict[str, float]: | |
| return { | |
| "tonal_stability": tonal_stability(prompt_ids, continuation_ids), | |
| "rhythmic_regularity": rhythmic_regularity(prompt_ids, continuation_ids), | |
| "phrase_closure": phrase_closure(prompt_ids, continuation_ids), | |
| } | |
| def _build_phrase_pairs( | |
| n_pairs: int, | |
| min_prompt_tokens: int, | |
| min_cont_tokens: int, | |
| seed: int, | |
| ) -> List[Tuple[List[int], List[int]]]: | |
| from music21 import corpus | |
| rng = random.Random(seed) | |
| chorales = list( | |
| corpus.chorales.Iterator( | |
| numberingSystem="bwv", | |
| returnType="stream", | |
| ) | |
| ) | |
| rng.shuffle(chorales) | |
| pairs: List[Tuple[List[int], List[int]]] = [] | |
| for score in chorales: | |
| if len(pairs) >= n_pairs: | |
| break | |
| try: | |
| with tempfile.NamedTemporaryFile( | |
| suffix=".mid", | |
| delete=True, | |
| ) as tmp: | |
| score.write("midi", fp=tmp.name) | |
| pm = pretty_midi.PrettyMIDI(tmp.name) | |
| ids = encode(pm) | |
| except Exception: | |
| continue | |
| bar_end_pos = [i for i, tid in enumerate(ids) if tid == BAR_END] | |
| if len(bar_end_pos) < 2: | |
| continue | |
| split = bar_end_pos[len(bar_end_pos) // 2] | |
| prompt = ids[:split + 1] | |
| continuation = ids[split + 1:] | |
| if ( | |
| len(prompt) < min_prompt_tokens | |
| or len(continuation) < min_cont_tokens | |
| ): | |
| continue | |
| pairs.append((prompt, continuation)) | |
| return pairs | |
| def _generate_continuation( | |
| model: GPT, | |
| prompt_ids: List[int], | |
| n_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| device: torch.device, | |
| ) -> List[int]: | |
| x = torch.tensor([prompt_ids], dtype=torch.long, device=device) | |
| out = x | |
| for _ in range(n_tokens): | |
| ctx = out[:, -model.config.block_size:] | |
| logits = model(ctx)[:, -1, :] / temperature | |
| logits = top_k_filter(logits, top_k) | |
| probs = F.softmax(logits, dim=-1) | |
| nxt = torch.multinomial(probs, num_samples=1) | |
| out = torch.cat([out, nxt], dim=1) | |
| full = out[0].tolist() | |
| return full[len(prompt_ids):] | |
| def _plot_metrics( | |
| system_scores: Dict[str, Dict[str, float]], out_path: Path | |
| ) -> None: | |
| metrics = ["tonal_stability", "rhythmic_regularity", "phrase_closure"] | |
| systems = list(system_scores.keys()) | |
| x = np.arange(len(metrics)) | |
| width = 0.8 / len(systems) | |
| fig, ax = plt.subplots(figsize=(8.0, 4.6)) | |
| for idx, name in enumerate(systems): | |
| offs = x - 0.4 + width * (idx + 0.5) | |
| vals = [system_scores[name][m] for m in metrics] | |
| ax.bar(offs, vals, width=width, label=name) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(["tonal", "rhythmic", "closure"]) | |
| ax.set_ylim(0.0, 1.05) | |
| ax.set_ylabel("score") | |
| ax.set_title("Phrase completion metrics") | |
| ax.grid(alpha=0.25, axis="y") | |
| ax.legend() | |
| fig.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| def _write_summary( | |
| out_path: Path, | |
| n_pairs: int, | |
| n_samples: int, | |
| system_scores: Dict[str, Dict[str, float]], | |
| ) -> None: | |
| rows = [ | |
| "# Phrase completion evaluation", | |
| "", | |
| f"- Pairs evaluated: {n_pairs}", | |
| f"- Model samples per prompt: {n_samples}", | |
| "", | |
| "| system | tonal_stability | rhythmic_regularity | phrase_closure |", | |
| "|---|---:|---:|---:|", | |
| ] | |
| for name, scores in system_scores.items(): | |
| rows.append( | |
| f"| {name} | {scores['tonal_stability']:.3f} | " | |
| f"{scores['rhythmic_regularity']:.3f} | {scores['phrase_closure']:.3f} |" | |
| ) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text("\n".join(rows)) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="JSB phrase completion evaluation." | |
| ) | |
| p.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=str(_ROOT / "results" / "checkpoints" / "best_model.pt"), | |
| ) | |
| p.add_argument("--config", type=str, default="") | |
| p.add_argument("--n-pairs", type=int, default=60) | |
| p.add_argument("--samples-per-prompt", type=int, default=5) | |
| p.add_argument("--temperature", type=float, default=1.0) | |
| p.add_argument("--top-k", type=int, default=40) | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument( | |
| "--plot-out", | |
| type=str, | |
| default=str(_ROOT / "figures" / "phrase_completion_metrics.png"), | |
| ) | |
| p.add_argument( | |
| "--summary-out", | |
| type=str, | |
| default=str(_ROOT / "results" / "phrase_completion_summary.md"), | |
| ) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| device = _pick_device() | |
| ckpt = torch.load(args.checkpoint, map_location=device, weights_only=True) | |
| cfg = _load_config_from_sources(ckpt, args.config) | |
| model = GPT(cfg).to(device) | |
| state = ( | |
| ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt | |
| ) | |
| model.load_state_dict(state) | |
| model.eval() | |
| pairs = _build_phrase_pairs( | |
| n_pairs=args.n_pairs, | |
| min_prompt_tokens=32, | |
| min_cont_tokens=32, | |
| seed=args.seed, | |
| ) | |
| if not pairs: | |
| raise RuntimeError("No phrase pairs extracted from JSB.") | |
| metric_names = ["tonal_stability", "rhythmic_regularity", "phrase_closure"] | |
| systems: Dict[str, Dict[str, List[float]]] = { | |
| "random": {k: [] for k in metric_names}, | |
| "model_t1.0_topk40": {k: [] for k in metric_names}, | |
| "ground_truth": {k: [] for k in metric_names}, | |
| } | |
| rng = random.Random(args.seed + 123) | |
| for prompt_ids, gt_cont in pairs: | |
| target_len = len(gt_cont) | |
| gt_scores = _metrics(prompt_ids, gt_cont) | |
| for k, v in gt_scores.items(): | |
| systems["ground_truth"][k].append(v) | |
| rand_ids = [rng.randrange(VOCAB_SIZE) for _ in range(target_len)] | |
| rand_scores = _metrics(prompt_ids, rand_ids) | |
| for k, v in rand_scores.items(): | |
| systems["random"][k].append(v) | |
| for _ in range(args.samples_per_prompt): | |
| gen_ids = _generate_continuation( | |
| model=model, | |
| prompt_ids=prompt_ids, | |
| n_tokens=target_len, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| device=device, | |
| ) | |
| gen_scores = _metrics(prompt_ids, gen_ids) | |
| for k, v in gen_scores.items(): | |
| systems["model_t1.0_topk40"][k].append(v) | |
| system_means: Dict[str, Dict[str, float]] = {} | |
| for name, metrics in systems.items(): | |
| system_means[name] = { | |
| k: float(np.mean(v)) if v else 0.0 for k, v in metrics.items() | |
| } | |
| _plot_metrics(system_means, out_path=Path(args.plot_out)) | |
| _write_summary( | |
| out_path=Path(args.summary_out), | |
| n_pairs=len(pairs), | |
| n_samples=args.samples_per_prompt, | |
| system_scores=system_means, | |
| ) | |
| print( | |
| f"[probe_completion] pairs={len(pairs)} " | |
| f"samples_per_prompt={args.samples_per_prompt}" | |
| ) | |
| print(f"[probe_completion] plot -> {args.plot_out}") | |
| print(f"[probe_completion] summary -> {args.summary_out}") | |
| for system, scores in system_means.items(): | |
| print( | |
| f"[probe_completion] {system}: " | |
| f"tonal={scores['tonal_stability']:.3f} " | |
| f"rhythmic={scores['rhythmic_regularity']:.3f} " | |
| f"closure={scores['phrase_closure']:.3f}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |