Coda / src /probe_completion.py
Prajanya Gupta
initial deploy
6b7b403
"""Phrase completion evaluation on JSB chorales.
For each JSB piece, we build (prompt, ground_truth_continuation) pairs by
splitting near the midpoint BAR_END. We then compare:
- model completions
- random-token completions
- ground-truth Bach continuations
using tonal stability, rhythmic regularity, and phrase closure metrics.
"""
from __future__ import annotations
import argparse
import json
import random
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Tuple
import matplotlib
import numpy as np
import pretty_midi
import torch
import torch.nn.functional as F
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT = _SCRIPT_DIR.parent
if str(_SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPT_DIR))
from model import GPT, GPTConfig, default_gpt_config # noqa: E402
from tokenizer import ( # noqa: E402
BAR_END,
ID2TOKEN,
PHRASE_END,
VOCAB_SIZE,
encode,
)
def _pick_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
mps = getattr(torch.backends, "mps", None)
if mps is not None and mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor:
if k <= 0 or k >= logits.size(-1):
return logits
values, _ = torch.topk(logits, k)
threshold = values[:, -1].unsqueeze(-1)
return logits.masked_fill(logits < threshold, float("-inf"))
def _extract_gpt_config_dict(raw: Dict[str, Any]) -> Dict[str, Any]:
keys = set(GPTConfig.__dataclass_fields__.keys())
return {k: raw[k] for k in keys if k in raw}
def _load_config_from_sources(
ckpt: Dict[str, Any], config_path: str
) -> GPTConfig:
cfg = default_gpt_config()
ckpt_cfg = ckpt.get("config")
if isinstance(ckpt_cfg, dict):
for k, v in _extract_gpt_config_dict(ckpt_cfg).items():
setattr(cfg, k, v)
if config_path:
loaded = json.loads(Path(config_path).read_text())
if not isinstance(loaded, dict):
raise ValueError("--config must be a JSON object.")
for k, v in _extract_gpt_config_dict(loaded).items():
setattr(cfg, k, v)
return cfg
def _pitch_class_from_id(tid: int) -> int | None:
tok = ID2TOKEN.get(tid, "")
if not tok.startswith("P"):
return None
try:
return int(tok[1:]) % 12
except ValueError:
return None
def _bar_lengths(ids: List[int]) -> List[int]:
lengths: List[int] = []
count = 0
in_bar = False
for tid in ids:
if tid == BAR_END:
if in_bar:
lengths.append(max(1, count))
in_bar = False
count = 0
continue
count += 1
tok = ID2TOKEN.get(tid, "")
if tok == "BAR_START":
in_bar = True
return lengths
def _infer_tonic(prompt_ids: List[int]) -> int:
pcs = [
pc for tid in prompt_ids if (pc := _pitch_class_from_id(tid)) is not None
]
if not pcs:
return 0
counts = np.bincount(np.array(pcs, dtype=np.int64), minlength=12)
return int(np.argmax(counts))
def _scale_pitch_classes(tonic: int) -> Tuple[set[int], set[int]]:
major = {0, 2, 4, 5, 7, 9, 11}
minor = {0, 2, 3, 5, 7, 8, 10}
maj = {(tonic + x) % 12 for x in major}
minr = {(tonic + x) % 12 for x in minor}
return maj, minr
def tonal_stability(prompt_ids: List[int], gen_ids: List[int]) -> float:
tonic = _infer_tonic(prompt_ids)
maj, minr = _scale_pitch_classes(tonic)
pcs = [pc for tid in gen_ids if (pc := _pitch_class_from_id(tid)) is not None]
if not pcs:
return 0.0
in_maj = sum(1 for pc in pcs if pc in maj) / len(pcs)
in_min = sum(1 for pc in pcs if pc in minr) / len(pcs)
return float(max(in_maj, in_min))
def rhythmic_regularity(prompt_ids: List[int], gen_ids: List[int]) -> float:
p_bars = _bar_lengths(prompt_ids)
g_bars = _bar_lengths(gen_ids)
if not p_bars or not g_bars:
return 0.0
p_len = float(np.mean(p_bars))
g_len = float(np.mean(g_bars))
if p_len <= 0:
return 0.0
score = 1.0 - abs(g_len - p_len) / p_len
return float(max(0.0, min(1.0, score)))
def phrase_closure(prompt_ids: List[int], gen_ids: List[int]) -> float:
if not gen_ids:
return 0.0
tail = gen_ids[-8:]
tonic = _infer_tonic(prompt_ids)
tonic_present = any(_pitch_class_from_id(tid) == tonic for tid in tail)
if PHRASE_END in tail or BAR_END in tail or tonic_present:
return 1.0
return 0.0
def _metrics(
prompt_ids: List[int], continuation_ids: List[int]
) -> Dict[str, float]:
return {
"tonal_stability": tonal_stability(prompt_ids, continuation_ids),
"rhythmic_regularity": rhythmic_regularity(prompt_ids, continuation_ids),
"phrase_closure": phrase_closure(prompt_ids, continuation_ids),
}
def _build_phrase_pairs(
n_pairs: int,
min_prompt_tokens: int,
min_cont_tokens: int,
seed: int,
) -> List[Tuple[List[int], List[int]]]:
from music21 import corpus
rng = random.Random(seed)
chorales = list(
corpus.chorales.Iterator(
numberingSystem="bwv",
returnType="stream",
)
)
rng.shuffle(chorales)
pairs: List[Tuple[List[int], List[int]]] = []
for score in chorales:
if len(pairs) >= n_pairs:
break
try:
with tempfile.NamedTemporaryFile(
suffix=".mid",
delete=True,
) as tmp:
score.write("midi", fp=tmp.name)
pm = pretty_midi.PrettyMIDI(tmp.name)
ids = encode(pm)
except Exception:
continue
bar_end_pos = [i for i, tid in enumerate(ids) if tid == BAR_END]
if len(bar_end_pos) < 2:
continue
split = bar_end_pos[len(bar_end_pos) // 2]
prompt = ids[:split + 1]
continuation = ids[split + 1:]
if (
len(prompt) < min_prompt_tokens
or len(continuation) < min_cont_tokens
):
continue
pairs.append((prompt, continuation))
return pairs
@torch.no_grad()
def _generate_continuation(
model: GPT,
prompt_ids: List[int],
n_tokens: int,
temperature: float,
top_k: int,
device: torch.device,
) -> List[int]:
x = torch.tensor([prompt_ids], dtype=torch.long, device=device)
out = x
for _ in range(n_tokens):
ctx = out[:, -model.config.block_size:]
logits = model(ctx)[:, -1, :] / temperature
logits = top_k_filter(logits, top_k)
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
out = torch.cat([out, nxt], dim=1)
full = out[0].tolist()
return full[len(prompt_ids):]
def _plot_metrics(
system_scores: Dict[str, Dict[str, float]], out_path: Path
) -> None:
metrics = ["tonal_stability", "rhythmic_regularity", "phrase_closure"]
systems = list(system_scores.keys())
x = np.arange(len(metrics))
width = 0.8 / len(systems)
fig, ax = plt.subplots(figsize=(8.0, 4.6))
for idx, name in enumerate(systems):
offs = x - 0.4 + width * (idx + 0.5)
vals = [system_scores[name][m] for m in metrics]
ax.bar(offs, vals, width=width, label=name)
ax.set_xticks(x)
ax.set_xticklabels(["tonal", "rhythmic", "closure"])
ax.set_ylim(0.0, 1.05)
ax.set_ylabel("score")
ax.set_title("Phrase completion metrics")
ax.grid(alpha=0.25, axis="y")
ax.legend()
fig.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=160)
plt.close(fig)
def _write_summary(
out_path: Path,
n_pairs: int,
n_samples: int,
system_scores: Dict[str, Dict[str, float]],
) -> None:
rows = [
"# Phrase completion evaluation",
"",
f"- Pairs evaluated: {n_pairs}",
f"- Model samples per prompt: {n_samples}",
"",
"| system | tonal_stability | rhythmic_regularity | phrase_closure |",
"|---|---:|---:|---:|",
]
for name, scores in system_scores.items():
rows.append(
f"| {name} | {scores['tonal_stability']:.3f} | "
f"{scores['rhythmic_regularity']:.3f} | {scores['phrase_closure']:.3f} |"
)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("\n".join(rows))
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="JSB phrase completion evaluation."
)
p.add_argument(
"--checkpoint",
type=str,
default=str(_ROOT / "results" / "checkpoints" / "best_model.pt"),
)
p.add_argument("--config", type=str, default="")
p.add_argument("--n-pairs", type=int, default=60)
p.add_argument("--samples-per-prompt", type=int, default=5)
p.add_argument("--temperature", type=float, default=1.0)
p.add_argument("--top-k", type=int, default=40)
p.add_argument("--seed", type=int, default=42)
p.add_argument(
"--plot-out",
type=str,
default=str(_ROOT / "figures" / "phrase_completion_metrics.png"),
)
p.add_argument(
"--summary-out",
type=str,
default=str(_ROOT / "results" / "phrase_completion_summary.md"),
)
return p.parse_args()
@torch.no_grad()
def main() -> None:
args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device = _pick_device()
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=True)
cfg = _load_config_from_sources(ckpt, args.config)
model = GPT(cfg).to(device)
state = (
ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
)
model.load_state_dict(state)
model.eval()
pairs = _build_phrase_pairs(
n_pairs=args.n_pairs,
min_prompt_tokens=32,
min_cont_tokens=32,
seed=args.seed,
)
if not pairs:
raise RuntimeError("No phrase pairs extracted from JSB.")
metric_names = ["tonal_stability", "rhythmic_regularity", "phrase_closure"]
systems: Dict[str, Dict[str, List[float]]] = {
"random": {k: [] for k in metric_names},
"model_t1.0_topk40": {k: [] for k in metric_names},
"ground_truth": {k: [] for k in metric_names},
}
rng = random.Random(args.seed + 123)
for prompt_ids, gt_cont in pairs:
target_len = len(gt_cont)
gt_scores = _metrics(prompt_ids, gt_cont)
for k, v in gt_scores.items():
systems["ground_truth"][k].append(v)
rand_ids = [rng.randrange(VOCAB_SIZE) for _ in range(target_len)]
rand_scores = _metrics(prompt_ids, rand_ids)
for k, v in rand_scores.items():
systems["random"][k].append(v)
for _ in range(args.samples_per_prompt):
gen_ids = _generate_continuation(
model=model,
prompt_ids=prompt_ids,
n_tokens=target_len,
temperature=args.temperature,
top_k=args.top_k,
device=device,
)
gen_scores = _metrics(prompt_ids, gen_ids)
for k, v in gen_scores.items():
systems["model_t1.0_topk40"][k].append(v)
system_means: Dict[str, Dict[str, float]] = {}
for name, metrics in systems.items():
system_means[name] = {
k: float(np.mean(v)) if v else 0.0 for k, v in metrics.items()
}
_plot_metrics(system_means, out_path=Path(args.plot_out))
_write_summary(
out_path=Path(args.summary_out),
n_pairs=len(pairs),
n_samples=args.samples_per_prompt,
system_scores=system_means,
)
print(
f"[probe_completion] pairs={len(pairs)} "
f"samples_per_prompt={args.samples_per_prompt}"
)
print(f"[probe_completion] plot -> {args.plot_out}")
print(f"[probe_completion] summary -> {args.summary_out}")
for system, scores in system_means.items():
print(
f"[probe_completion] {system}: "
f"tonal={scores['tonal_stability']:.3f} "
f"rhythmic={scores['rhythmic_regularity']:.3f} "
f"closure={scores['phrase_closure']:.3f}"
)
if __name__ == "__main__":
main()