code-review-env-v3 / scripts /generate_calibration_plot.py
Kinchi
v2: calibrated metacognition as RL + inference-time budget + transfer eval
51fd6a7
Raw
History Blame Contribute Delete
10.4 kB
#!/usr/bin/env python3
"""
scripts/generate_calibration_plot.py
=====================================
Produces grpo_output/calibration_plot.png β€” the metacognitive-calibration
hero figure for the paper / blog / video.
The plot has THREE panels:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ A. Confusion β”‚ B. Calibration β”‚ C. Allocation β”‚
β”‚ matrix: β”‚ curve: β”‚ by ground- β”‚
β”‚ predicted band β”‚ |actual βˆ’ pred β”‚ truth label β”‚
β”‚ vs actual band β”‚ midpoint| β”‚ β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Modes
-----
--mode heuristic (default) Build the plot from the heuristic-proxy
policy. This is the figure we ship pre-training to
show what calibration *should* look like.
--mode real Read calibration data from
grpo_output/eval_calibration.json (produced by
eval_baseline.py when run on a model trained with
metacognitive_reward).
Output
------
grpo_output/calibration_plot.png
"""
from __future__ import annotations
import argparse
import json
import random
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
DATA = ROOT / "data" / "cve_training_data.json"
OUT_DIR = ROOT / "grpo_output"
DEFAULT_OUT = OUT_DIR / "calibration_plot.png"
DEFAULT_REAL = OUT_DIR / "eval_calibration.json"
BANDS = ["short", "medium", "long"]
MID = {"short": 40, "medium": 165, "long": 400}
RNG_BANDS = {"short": (0, 80), "medium": (80, 250), "long": (250, 800)}
# ── Heuristic data generator ──────────────────────────────────────────────
def _risk(f: dict, cvss: float) -> float:
feat = f.get("features", [0, 0, 0, 0])
churn, complexity, _, _ = feat
s = 0.4 * (churn / 100.0) + 0.4 * (complexity / 100.0) + 0.2 * (cvss / 10.0)
if f.get("is_test_file"):
s *= 0.4
return s
def _band_for_risk(normalized: float, label: int) -> str:
"""The ORACLE choice β€” what the trained policy *should* predict."""
if label == 1:
return "long" if normalized > 0.4 else "medium"
return "short" if normalized < 0.5 else "medium"
def heuristic_calibration_data(
n_episodes: int = 30, rng: random.Random | None = None,
) -> Dict[str, List]:
"""Generate (predicted_band, actual_length, label) triples by simulating
a metacog policy that emits a band, then thinks for a length sampled
inside that band with realistic noise."""
rng = rng or random.Random(7)
with open(DATA) as fh:
rows = json.load(fh)
groups = defaultdict(list)
for r in rows:
groups[(r["cveId"], r["repo"])].append(r)
eps = []
for (_cve, _repo), files in groups.items():
if any(f["label"] == 1 for f in files):
eps.append(files)
if len(eps) >= n_episodes:
break
pred, actual_len, label = [], [], []
for files in eps:
cvss = files[0].get("cvss", 0.0)
risks = [_risk(f, cvss) for f in files]
rmax = max(risks) if risks else 1.0
rmin = min(risks) if risks else 0.0
for f, r in zip(files, risks):
normalized = (r - rmin) / max(1e-6, rmax - rmin) if rmax > rmin else 0.0
band = _band_for_risk(normalized, f["label"])
lo, hi = RNG_BANDS[band]
# Calibrated: 80% of samples land inside the predicted band
if rng.random() < 0.85:
length = rng.randint(lo + 5, max(lo + 6, hi - 5))
else:
# 15% miscalibration: sample from a neighbouring band
length = rng.randint(20, 600)
pred.append(band)
actual_len.append(length)
label.append(f["label"])
return {"pred": pred, "actual_len": actual_len, "label": label}
# ── Real-mode loader ──────────────────────────────────────────────────────
def real_calibration_data(path: Path) -> Dict[str, List]:
with open(path) as fh:
return json.load(fh)
# ── Plotting ──────────────────────────────────────────────────────────────
def _band_for_length(L: int) -> str:
if L < 80:
return "short"
if L < 250:
return "medium"
return "long"
def plot(data: Dict[str, List], out_path: Path, title_suffix: str) -> None:
pred = data["pred"]
actual = data["actual_len"]
label = data["label"]
actual_band = [_band_for_length(L) for L in actual]
fig, axes = plt.subplots(1, 3, figsize=(17, 5))
fig.suptitle(
"Metacognitive Calibration β€” does the agent know how hard the problem is?"
f" {title_suffix}",
fontsize=14, fontweight="bold", y=1.02,
)
# ── Panel A: confusion matrix predicted vs actual band ───────────────
cm = np.zeros((3, 3), dtype=float)
for p, a in zip(pred, actual_band):
cm[BANDS.index(p), BANDS.index(a)] += 1
cm_norm = cm / cm.sum(axis=1, keepdims=True).clip(min=1)
im = axes[0].imshow(cm_norm, cmap="Blues", vmin=0, vmax=1)
axes[0].set_xticks(range(3))
axes[0].set_yticks(range(3))
axes[0].set_xticklabels(BANDS)
axes[0].set_yticklabels(BANDS)
axes[0].set_xlabel("Actual <think> band")
axes[0].set_ylabel("Predicted band")
axes[0].set_title("A. Calibration confusion\n(diag = perfect calibration)")
for i in range(3):
for j in range(3):
axes[0].text(j, i, f"{cm_norm[i,j]:.2f}", ha="center", va="center",
color="white" if cm_norm[i, j] > 0.5 else "black",
fontsize=10)
fig.colorbar(im, ax=axes[0], fraction=0.045, pad=0.04)
# On-diagonal calibration accuracy (single number)
diag = float(np.trace(cm) / max(1, cm.sum()))
axes[0].text(0.5, -0.18, f"diag accuracy = {diag:.2f}",
transform=axes[0].transAxes,
ha="center", fontsize=11, fontweight="bold",
color="#2c3e50")
# ── Panel B: |actual βˆ’ band midpoint| as calibration error ───────────
errs = [abs(L - MID[p]) for p, L in zip(pred, actual)]
median_err = float(np.median(errs))
axes[1].hist(errs, bins=30, color="#7faecf", edgecolor="white", alpha=0.85)
axes[1].axvline(median_err, color="#a23a30", ls="--", lw=2.0,
label=f"median error = {median_err:.0f} chars")
axes[1].set_xlabel("|actual length βˆ’ predicted-band midpoint| (characters)")
axes[1].set_ylabel("Number of decisions")
axes[1].set_title("B. Calibration error distribution")
axes[1].legend(fontsize=10, loc="upper right")
axes[1].grid(True, alpha=0.25)
# ── Panel C: allocation by ground-truth label ─────────────────────────
by_band_buggy = [BANDS.index(p) for p, lbl in zip(pred, label) if lbl == 1]
by_band_safe = [BANDS.index(p) for p, lbl in zip(pred, label) if lbl == 0]
counts_buggy = [by_band_buggy.count(i) for i in range(3)]
counts_safe = [by_band_safe.count(i) for i in range(3)]
x = np.arange(3)
width = 0.38
axes[2].bar(x - width / 2, counts_safe, width, color="#7faecf", label="safe files",
edgecolor="white")
axes[2].bar(x + width / 2, counts_buggy, width, color="#d6584d", label="buggy files",
edgecolor="white")
axes[2].set_xticks(x)
axes[2].set_xticklabels(BANDS)
axes[2].set_xlabel("Predicted budget band")
axes[2].set_ylabel("Number of decisions")
axes[2].set_title("C. Difficulty awareness β€” who gets 'long'?")
axes[2].legend(fontsize=10)
axes[2].grid(True, alpha=0.25, axis="y")
long_on_bug = counts_buggy[2] / max(1, sum(counts_buggy))
long_on_safe = counts_safe[2] / max(1, sum(counts_safe))
axes[2].text(0.5, -0.18,
f"P(long | buggy) = {long_on_bug:.2f} "
f"P(long | safe) = {long_on_safe:.2f}",
transform=axes[2].transAxes,
ha="center", fontsize=11, fontweight="bold",
color="#2c3e50")
fig.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=100, bbox_inches="tight")
plt.close(fig)
print(f"βœ… Wrote {out_path}")
print(f" diag={diag:.2f} median_err={median_err:.0f} "
f"P(long|buggy)={long_on_bug:.2f} P(long|safe)={long_on_safe:.2f}")
# ── Main ──────────────────────────────────────────────────────────────────
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--mode", choices=["heuristic", "real"], default="heuristic")
ap.add_argument("--data", default=str(DEFAULT_REAL))
ap.add_argument("--out", default=str(DEFAULT_OUT))
ap.add_argument("--seed", type=int, default=7)
args = ap.parse_args()
if args.mode == "real":
path = Path(args.data)
if not path.exists():
print(f"❌ {path} not found, falling back to heuristic.", file=sys.stderr)
args.mode = "heuristic"
else:
data = real_calibration_data(path)
plot(data, Path(args.out),
title_suffix="(real trained-model calibration)")
return
rng = random.Random(args.seed)
data = heuristic_calibration_data(n_episodes=40, rng=rng)
plot(data, Path(args.out),
title_suffix="(heuristic proxy β€” replace with real traces post-training)")
if __name__ == "__main__":
main()