Spaces:
Sleeping
Sleeping
| """Linear probing on frozen GPT activations for musical structure signals.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import matplotlib | |
| import numpy as np | |
| import pretty_midi | |
| import torch | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import accuracy_score | |
| from sklearn.model_selection import train_test_split | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _ROOT = _SCRIPT_DIR.parent | |
| if str(_SCRIPT_DIR) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPT_DIR)) | |
| from model import GPT, GPTConfig, default_gpt_config # noqa: E402 | |
| from tokenizer import ID2TOKEN, PHRASE_END, encode # noqa: E402 | |
| N_VOICE_BINS = 8 | |
| def _pick_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| mps = getattr(torch.backends, "mps", None) | |
| if mps is not None and mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def _extract_gpt_config_dict(raw: Dict[str, Any]) -> Dict[str, Any]: | |
| keys = set(GPTConfig.__dataclass_fields__.keys()) | |
| return {k: raw[k] for k in keys if k in raw} | |
| def _load_config_from_sources( | |
| ckpt: Dict[str, Any], config_path: str | |
| ) -> GPTConfig: | |
| cfg = default_gpt_config() | |
| ckpt_cfg = ckpt.get("config") | |
| if isinstance(ckpt_cfg, dict): | |
| for k, v in _extract_gpt_config_dict(ckpt_cfg).items(): | |
| setattr(cfg, k, v) | |
| if config_path: | |
| loaded = json.loads(Path(config_path).read_text()) | |
| if not isinstance(loaded, dict): | |
| raise ValueError("--config must be a JSON object.") | |
| for k, v in _extract_gpt_config_dict(loaded).items(): | |
| setattr(cfg, k, v) | |
| return cfg | |
| def _load_jsb_sequences( | |
| n_files: int, seq_len: int, seed: int | |
| ) -> List[List[int]]: | |
| from music21 import corpus | |
| rng = random.Random(seed) | |
| chorales = list( | |
| corpus.chorales.Iterator( | |
| numberingSystem="bwv", | |
| returnType="stream", | |
| ) | |
| ) | |
| rng.shuffle(chorales) | |
| seqs: List[List[int]] = [] | |
| for score in chorales: | |
| if len(seqs) >= n_files: | |
| break | |
| try: | |
| with tempfile.NamedTemporaryFile( | |
| suffix=".mid", | |
| delete=True, | |
| ) as tmp: | |
| score.write("midi", fp=tmp.name) | |
| pm = pretty_midi.PrettyMIDI(tmp.name) | |
| ids = encode(pm) | |
| if len(ids) < seq_len: | |
| continue | |
| seqs.append(ids[:seq_len]) | |
| except Exception: | |
| continue | |
| return seqs | |
| def _build_labels(ids: List[int]) -> Dict[str, np.ndarray]: | |
| """Build per-position labels from token sequence heuristics.""" | |
| n = len(ids) | |
| beat = np.zeros(n, dtype=np.int64) | |
| pitch_class = np.full(n, -1, dtype=np.int64) | |
| cadence = np.zeros(n, dtype=np.int64) | |
| voice = np.zeros(n, dtype=np.int64) | |
| ts_since_bar = 0 | |
| last_voice = 0 | |
| for i, tid in enumerate(ids): | |
| tok = ID2TOKEN.get(tid, "") | |
| if tok == "BAR_START": | |
| ts_since_bar = 0 | |
| elif tok.startswith("TS"): | |
| ts_since_bar += 1 | |
| elif tok.startswith("V"): | |
| try: | |
| last_voice = int(tok[1:]) | |
| except ValueError: | |
| last_voice = 0 | |
| beat[i] = ts_since_bar % 4 | |
| voice[i] = max(0, min(N_VOICE_BINS - 1, last_voice)) | |
| if tok.startswith("P"): | |
| try: | |
| pitch = int(tok[1:]) | |
| pitch_class[i] = pitch % 12 | |
| except ValueError: | |
| pitch_class[i] = -1 | |
| end = min(n, i + 9) | |
| if PHRASE_END in ids[i + 1:end]: | |
| cadence[i] = 1 | |
| return { | |
| "beat_position": beat, | |
| "pitch_class": pitch_class, | |
| "cadence_soon": cadence, | |
| "voice_bin": voice, | |
| } | |
| def _collect_layer_activations( | |
| model: GPT, | |
| ids: List[int], | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| """Return tensor (L+1, T, D): embedding baseline + block outputs. | |
| Index 0 corresponds to layer -1 (raw token embeddings). | |
| Indices 1..L correspond to transformer block outputs for layers 0..L-1. | |
| """ | |
| activations: Dict[int, torch.Tensor] = {} | |
| hooks = [] | |
| def make_hook(layer_idx: int): | |
| def hook(_module, _inputs, output): | |
| # TransformerBlock forward returns (x, attn_weights). | |
| x = output[0] if isinstance(output, tuple) else output | |
| activations[layer_idx] = x.detach() | |
| return hook | |
| for i, block in enumerate(model.blocks): | |
| hooks.append(block.register_forward_hook(make_hook(i))) | |
| try: | |
| x = torch.tensor([ids], dtype=torch.long, device=device) | |
| emb = model.wte(x).detach()[0].cpu() | |
| _ = model(x) | |
| finally: | |
| for h in hooks: | |
| h.remove() | |
| layer_tensors = [emb] + [ | |
| activations[i][0].cpu() for i in range(model.config.n_layers) | |
| ] | |
| return torch.stack(layer_tensors, dim=0) | |
| def _probe_one_target( | |
| X_by_layer: List[np.ndarray], | |
| y: np.ndarray, | |
| seed: int, | |
| ) -> List[float]: | |
| unique_classes = np.unique(y) | |
| if unique_classes.size < 2: | |
| # Degenerate target on this sample split; all predictions are trivial. | |
| return [1.0 for _ in X_by_layer] | |
| idx = np.arange(len(y)) | |
| train_idx, val_idx = train_test_split( | |
| idx, | |
| test_size=0.2, | |
| random_state=seed, | |
| shuffle=True, | |
| ) | |
| y_train = y[train_idx] | |
| y_val = y[val_idx] | |
| scores: List[float] = [] | |
| for X in X_by_layer: | |
| X_train = X[train_idx] | |
| X_val = X[val_idx] | |
| train_classes = np.unique(y_train) | |
| if train_classes.size < 2: | |
| majority = train_classes[0] | |
| pred = np.full_like(y_val, fill_value=majority) | |
| scores.append(float(accuracy_score(y_val, pred))) | |
| continue | |
| clf = LogisticRegression( | |
| max_iter=1000, | |
| random_state=seed, | |
| solver="lbfgs", | |
| ) | |
| clf.fit(X_train, y_train) | |
| pred = clf.predict(X_val) | |
| scores.append(float(accuracy_score(y_val, pred))) | |
| return scores | |
| def _plot_probe_lines( | |
| results: Dict[str, List[float]], | |
| out_path: Path, | |
| ) -> None: | |
| fig, ax = plt.subplots(figsize=(7.5, 4.5)) | |
| n_levels = len(next(iter(results.values()))) | |
| xs = list(range(n_levels)) | |
| layer_labels = [-1] + list(range(n_levels - 1)) | |
| for name, vals in results.items(): | |
| ax.plot(xs, vals, marker="o", linewidth=2, label=name) | |
| ax.set_xlabel("layer") | |
| ax.set_ylabel("validation accuracy") | |
| ax.set_xticks(xs) | |
| ax.set_xticklabels([str(x) for x in layer_labels]) | |
| ax.set_title("Linear probe accuracy by layer") | |
| ax.grid(alpha=0.25) | |
| ax.legend() | |
| fig.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| def _write_summary_md( | |
| out_path: Path, | |
| results: Dict[str, List[float]], | |
| baselines: Dict[str, float], | |
| ) -> None: | |
| n_levels = len(next(iter(results.values()))) | |
| layer_labels = [-1] + list(range(n_levels - 1)) | |
| layer_cols = " | ".join([f"layer_{x}" for x in layer_labels]) | |
| align_cols = "|".join(["---:"] * len(layer_labels)) | |
| rows = [ | |
| "# Linear probe summary", | |
| "", | |
| "| target | random_baseline | " | |
| f"{layer_cols} | best_layer | best_acc |", | |
| f"|---|---:|{align_cols}|---:|---:|", | |
| ] | |
| for target, vals in results.items(): | |
| best_idx = int(np.argmax(vals)) | |
| best_acc = float(np.max(vals)) | |
| best_layer = layer_labels[best_idx] | |
| per_layer = " | ".join(f"{v:.3f}" for v in vals) | |
| rows.append( | |
| f"| {target} | {baselines[target]:.3f} | {per_layer} | " | |
| f"{best_layer} | {best_acc:.3f} |" | |
| ) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text("\n".join(rows)) | |
| def main() -> None: | |
| p = argparse.ArgumentParser( | |
| description="Linear probing on frozen GPT activations." | |
| ) | |
| p.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=str(_ROOT / "results" / "checkpoints" / "best_model.pt"), | |
| ) | |
| p.add_argument("--config", type=str, default="") | |
| p.add_argument("--n-chorales", type=int, default=20) | |
| p.add_argument("--seq-len", type=int, default=256) | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument( | |
| "--plot-out", | |
| type=str, | |
| default=str(_ROOT / "figures" / "linear_probe_accuracy_by_layer.png"), | |
| ) | |
| p.add_argument( | |
| "--summary-out", | |
| type=str, | |
| default=str(_ROOT / "results" / "linear_probe_summary.md"), | |
| ) | |
| args = p.parse_args() | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| device = _pick_device() | |
| ckpt = torch.load(args.checkpoint, map_location=device, weights_only=True) | |
| cfg = _load_config_from_sources(ckpt, args.config) | |
| model = GPT(cfg).to(device) | |
| state = ( | |
| ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt | |
| ) | |
| model.load_state_dict(state) | |
| model.eval() | |
| seq_len = min(args.seq_len, cfg.block_size) | |
| seqs = _load_jsb_sequences( | |
| args.n_chorales, | |
| seq_len=seq_len, | |
| seed=args.seed, | |
| ) | |
| if not seqs: | |
| raise RuntimeError( | |
| "No JSB chorales loaded; check music21 install/dataset." | |
| ) | |
| n_levels = cfg.n_layers + 1 | |
| X_layers: List[List[np.ndarray]] = [[] for _ in range(n_levels)] | |
| labels_all: Dict[str, List[np.ndarray]] = { | |
| "beat_position": [], | |
| "pitch_class": [], | |
| "cadence_soon": [], | |
| "voice_bin": [], | |
| } | |
| for ids in seqs: | |
| acts = _collect_layer_activations( | |
| model, | |
| ids=ids[:seq_len], | |
| device=device, | |
| ) | |
| # acts shape: (L+1, T, D), where index 0 is layer -1 embedding. | |
| labels = _build_labels(ids[:seq_len]) | |
| for layer_idx in range(n_levels): | |
| X_layers[layer_idx].append(acts[layer_idx].numpy()) | |
| for k in labels_all: | |
| labels_all[k].append(labels[k]) | |
| X_by_layer = [np.concatenate(xlist, axis=0) for xlist in X_layers] | |
| y_targets = {k: np.concatenate(v, axis=0) for k, v in labels_all.items()} | |
| results: Dict[str, List[float]] = {} | |
| for target, y in y_targets.items(): | |
| if target == "pitch_class": | |
| valid = y >= 0 | |
| y_use = y[valid] | |
| X_use = [X[valid] for X in X_by_layer] | |
| else: | |
| y_use = y | |
| X_use = X_by_layer | |
| results[target] = _probe_one_target(X_use, y_use, seed=args.seed) | |
| baselines = { | |
| "beat_position": 0.25, | |
| "pitch_class": 1.0 / 12.0, | |
| "cadence_soon": 0.15, | |
| "voice_bin": 1.0 / N_VOICE_BINS, | |
| } | |
| _plot_probe_lines(results, out_path=Path(args.plot_out)) | |
| _write_summary_md( | |
| Path(args.summary_out), | |
| results=results, | |
| baselines=baselines, | |
| ) | |
| print(f"[probe_linear] chorales_used={len(seqs)} seq_len={seq_len}") | |
| print("[probe_linear] layer index convention: -1=embedding, 0..N-1=blocks") | |
| print(f"[probe_linear] plot -> {args.plot_out}") | |
| print(f"[probe_linear] summary -> {args.summary_out}") | |
| for target, vals in results.items(): | |
| best = max(vals) | |
| print( | |
| f"[probe_linear] {target}: " | |
| f"layers={['%.3f' % v for v in vals]} best={best:.3f} " | |
| f"(baseline≈{baselines[target]:.3f})" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |