HONEST-RL-Calibrator / eval /plot_reliability.py
Rushhaabhhh's picture
HONEST-RL-Calibrator-v0
3040767 verified
"""
eval/plot_reliability.py — Publication-quality reliability / calibration diagrams.
Usage:
python eval/plot_reliability.py # uses baseline_results.json
python eval/plot_reliability.py --results path/to.json # custom results file
python eval/plot_reliability.py --prefix after_rl # changes output file prefix
Comparison helper (importable):
from eval.plot_reliability import plot_comparison
plot_comparison("eval/baseline_results.json", "eval/after_rl_results.json")
"""
import argparse
import json
import os
from pathlib import Path
from typing import Optional
import matplotlib
matplotlib.use("Agg") # non-interactive backend — works without a display
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import seaborn as sns
# ---------------------------------------------------------------------------
# Styling
# ---------------------------------------------------------------------------
sns.set_theme(style="whitegrid", context="paper", font_scale=1.2)
PALETTE = {
"bar": "#4C72B0", # Seaborn blue
"perfect": "#DD8452", # Warm orange diagonal
"gap_pos": "#55A868", # Green — overconfidentundershoot
"gap_neg": "#C44E52", # Red — underconfident
"bg": "#F8F9FA",
}
BIN_EDGES = np.linspace(0.0, 1.0, 11) # 10 bins: [0,.1), [.1,.2), …, [.9,1]
BIN_CENTRES = (BIN_EDGES[:-1] + BIN_EDGES[1:]) / 2
BIN_WIDTH = 0.09 # slightly narrower than the 0.1 spacing for readability
# ---------------------------------------------------------------------------
# Core: build calibration bins from a flat list of (confidence, correct) pairs
# ---------------------------------------------------------------------------
def build_bins(confidences: list, correctness: list) -> dict:
"""Return per-bin stats used by the reliability diagram."""
conf = np.array(confidences, dtype=float)
corr = np.array(correctness, dtype=float)
bin_acc, bin_conf, bin_count = [], [], []
for lo, hi in zip(BIN_EDGES[:-1], BIN_EDGES[1:]):
mask = (conf >= lo) & (conf < hi)
# last bin is inclusive on the right
if hi == 1.0:
mask = (conf >= lo) & (conf <= hi)
n = mask.sum()
bin_count.append(int(n))
bin_acc.append(float(corr[mask].mean()) if n > 0 else np.nan)
bin_conf.append(float(conf[mask].mean()) if n > 0 else np.nan)
return {
"bin_acc": np.array(bin_acc),
"bin_conf": np.array(bin_conf),
"bin_count": np.array(bin_count),
}
def compute_ece_from_bins(bins: dict) -> float:
counts = bins["bin_count"]
accs = bins["bin_acc"]
confs = bins["bin_conf"]
total = counts.sum()
if total == 0:
return float("nan")
ece = 0.0
for n, a, c in zip(counts, accs, confs):
if n > 0 and not np.isnan(a):
ece += (n / total) * abs(a - c)
return ece
# ---------------------------------------------------------------------------
# Single reliability diagram
# ---------------------------------------------------------------------------
def _draw_reliability(
ax: plt.Axes,
bins: dict,
ece: float,
title: str,
show_counts: bool = True,
):
"""Draw a reliability diagram onto ax."""
acc = bins["bin_acc"]
count = bins["bin_count"]
# ── background & reference diagonal ─────────────────────────────────────
ax.set_facecolor(PALETTE["bg"])
ax.plot([0, 1], [0, 1], "--", color=PALETTE["perfect"],
linewidth=1.8, label="Perfect calibration", zorder=3)
# ── gap fill (miscalibration region) ────────────────────────────────────
for i, (c, a, n) in enumerate(zip(BIN_CENTRES, acc, count)):
if np.isnan(a) or n == 0:
continue
lo, hi = min(c, a), max(c, a)
color = PALETTE["gap_neg"] if a < c else PALETTE["gap_pos"]
ax.bar(c, hi - lo, bottom=lo, width=BIN_WIDTH * 0.98,
color=color, alpha=0.25, zorder=1)
# ── accuracy bars ────────────────────────────────────────────────────────
valid = ~np.isnan(acc)
bars = ax.bar(
BIN_CENTRES[valid], acc[valid],
width=BIN_WIDTH, color=PALETTE["bar"],
edgecolor="white", linewidth=0.6,
label="Observed accuracy", alpha=0.85, zorder=2,
)
# ── count annotations on bars ────────────────────────────────────────────
if show_counts:
for bar, n in zip(bars, count[valid]):
h = bar.get_height()
if h > 0.05:
ax.text(bar.get_x() + bar.get_width() / 2,
h / 2, f"n={n}",
ha="center", va="center",
fontsize=7.5, color="white", fontweight="bold", zorder=5)
# ── ECE badge ────────────────────────────────────────────────────────────
ece_text = f"ECE = {ece:.4f}" if not np.isnan(ece) else "ECE = n/a"
ax.text(0.97, 0.04, ece_text,
transform=ax.transAxes,
ha="right", va="bottom",
fontsize=11, fontweight="bold",
bbox=dict(boxstyle="round,pad=0.35", fc="white", ec="#CCCCCC", alpha=0.9))
# ── axes formatting ───────────────────────────────────────────────────────
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.12)
ax.set_xlabel("Confidence (predicted probability)", fontsize=11)
ax.set_ylabel("Accuracy (fraction correct)", fontsize=11)
ax.set_title(title, fontsize=13, fontweight="bold", pad=10)
ax.set_xticks(BIN_EDGES)
ax.set_xticklabels([f"{e:.1f}" for e in BIN_EDGES], fontsize=8, rotation=45)
ax.legend(loc="upper left", fontsize=9, framealpha=0.9)
# ---------------------------------------------------------------------------
# High-level: single-results plotting
# ---------------------------------------------------------------------------
def extract_pairs(samples: list) -> tuple:
"""Return (confidences, correctness) from a list of sample dicts."""
confs, corrs = [], []
for s in samples:
if s.get("confidence") is not None and s.get("correct") is not None:
confs.append(float(s["confidence"]))
corrs.append(1 if s["correct"] else 0)
return confs, corrs
def plot_domain(
domain: str,
conditions: dict,
out_dir: Path,
prefix: str = "baseline",
):
"""One 2×3 sub-grid showing each difficulty level + aggregate for a domain."""
domain_conditions = {
k: v for k, v in conditions.items() if k.startswith(domain + "_")
}
if not domain_conditions:
print(f" ! No conditions found for domain '{domain}' — skipping.")
return
# collect aggregate pairs for the domain
agg_conf, agg_corr = [], []
per_diff = {}
for key, cond in sorted(domain_conditions.items()):
diff = int(key.split("_")[-1])
cf, cr = extract_pairs(cond.get("samples", []))
per_diff[diff] = (cf, cr)
agg_conf.extend(cf)
agg_corr.extend(cr)
n_diffs = len(per_diff)
# layout: difficulties in a row + 1 aggregate panel
ncols = n_diffs + 1
fig, axes = plt.subplots(1, ncols, figsize=(4 * ncols, 4.5))
fig.patch.set_facecolor("white")
for idx, (diff, (cf, cr)) in enumerate(sorted(per_diff.items())):
bins = build_bins(cf, cr)
ece = compute_ece_from_bins(bins)
_draw_reliability(
axes[idx], bins, ece,
title=f"{domain.capitalize()} — Difficulty {diff}",
)
# aggregate panel (last)
agg_bins = build_bins(agg_conf, agg_corr)
agg_ece = compute_ece_from_bins(agg_bins)
_draw_reliability(
axes[-1], agg_bins, agg_ece,
title=f"{domain.capitalize()} — All Difficulties",
)
fig.suptitle(
f"Baseline Calibration - {domain.capitalize()}",
fontsize=15, fontweight="bold", y=1.02,
)
fig.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{prefix}_{domain}.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" ✓ Saved {out_path}")
def plot_overall(
conditions: dict,
overall: dict,
out_dir: Path,
prefix: str = "baseline",
):
"""Aggregate diagram across ALL domains and a per-domain summary bar."""
all_conf, all_corr = [], []
domain_eces = {}
domain_accs = {}
for domain in ["math", "code", "logic"]:
d_conf, d_corr = [], []
for key, cond in conditions.items():
if key.startswith(domain + "_"):
cf, cr = extract_pairs(cond.get("samples", []))
d_conf.extend(cf)
d_corr.extend(cr)
all_conf.extend(cf)
all_corr.extend(cr)
if d_conf:
bins = build_bins(d_conf, d_corr)
domain_eces[domain] = compute_ece_from_bins(bins)
domain_accs[domain] = float(np.mean(d_corr)) if d_corr else 0.0
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor("white")
# ── left: overall reliability diagram ───────────────────────────────────
overall_bins = build_bins(all_conf, all_corr)
overall_ece = compute_ece_from_bins(overall_bins)
_draw_reliability(ax_left, overall_bins, overall_ece,
title="Overall Reliability Diagram")
# ── right: per-domain ECE + accuracy summary bar chart ──────────────────
domains = list(domain_eces.keys())
x = np.arange(len(domains))
width = 0.35
ece_vals = [domain_eces[d] for d in domains]
acc_vals = [domain_accs[d] for d in domains]
bars1 = ax_right.bar(x - width / 2, acc_vals, width,
label="Accuracy", color="#4C72B0", alpha=0.85, edgecolor="white")
bars2 = ax_right.bar(x + width / 2, ece_vals, width,
label="ECE (lower=better)", color="#C44E52", alpha=0.85, edgecolor="white")
for bar, val in zip(list(bars1) + list(bars2), acc_vals + ece_vals):
ax_right.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.01, f"{val:.2f}",
ha="center", fontsize=9, fontweight="bold")
ax_right.set_xticks(x)
ax_right.set_xticklabels([d.capitalize() for d in domains], fontsize=11)
ax_right.set_ylim(0, 1.15)
ax_right.set_ylabel("Score", fontsize=11)
ax_right.set_title("Per-Domain Summary: Accuracy vs ECE", fontsize=13,
fontweight="bold")
ax_right.legend(fontsize=10)
ax_right.set_facecolor(PALETTE["bg"])
fig.suptitle("Baseline Calibration - Overall", fontsize=15,
fontweight="bold", y=1.02)
fig.tight_layout()
out_path = out_dir / f"{prefix}_overall.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" ✓ Saved {out_path}")
# ---------------------------------------------------------------------------
# Comparison: before vs after RL/SFT
# ---------------------------------------------------------------------------
def _collect_conditions_any_schema(results: dict) -> dict:
"""Return a flat ``{condition_key: condition_dict}`` regardless of which
schema the JSON uses.
Supported shapes:
* ``baseline_eval.py`` → ``{"conditions": {"math_1": {...}, ...}}``
* ``full_eval.py`` → ``{"in_distribution": {...}, "ood": {...}}``
(each value is a ``{condition_key: cond}`` map)
Keys from in-distribution conditions stay verbatim
(``math_1`` / ``code_3`` / …); OOD condition keys are prefixed with
``ood_`` so they don't collide with the math/code/logic namespace.
"""
out: dict = {}
if isinstance(results.get("conditions"), dict):
out.update(results["conditions"])
if isinstance(results.get("in_distribution"), dict):
out.update(results["in_distribution"])
if isinstance(results.get("ood"), dict):
for k, v in results["ood"].items():
out[f"ood_{k}"] = v
return out
def plot_comparison(
before_path: str,
after_path: str,
out_dir: Optional[str] = None,
output_path: Optional[str] = None,
label_before: str = "Before Training",
label_after: str = "After Training",
):
"""
Generate side-by-side overall reliability diagrams from two result JSON files.
Typically called after RL training to visualise improvement.
Robust to both ``baseline_results.json`` (top-level ``conditions``) and
``full_results.json`` (top-level ``in_distribution`` + ``ood``) schemas —
samples from any/all sections are aggregated for the diagram.
Args:
before_path: Baseline JSON.
after_path: Trained JSON.
out_dir: Optional directory; ignored if ``output_path`` is given.
output_path: Explicit PNG path. Takes precedence over ``out_dir``.
label_before/label_after: Panel titles.
Returns:
str: Path of the saved PNG.
"""
with open(before_path) as f:
before = json.load(f)
with open(after_path) as f:
after = json.load(f)
def _collect(results):
confs, corrs = [], []
for cond in _collect_conditions_any_schema(results).values():
cf, cr = extract_pairs(cond.get("samples", []))
confs.extend(cf)
corrs.extend(cr)
return confs, corrs
b_conf, b_corr = _collect(before)
a_conf, a_corr = _collect(after)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))
fig.patch.set_facecolor("white")
b_bins = build_bins(b_conf, b_corr)
a_bins = build_bins(a_conf, a_corr)
_draw_reliability(ax1, b_bins, compute_ece_from_bins(b_bins),
title=label_before, show_counts=False)
_draw_reliability(ax2, a_bins, compute_ece_from_bins(a_bins),
title=label_after, show_counts=False)
b_ece = compute_ece_from_bins(b_bins)
a_ece = compute_ece_from_bins(a_bins)
delta = b_ece - a_ece
sign = "↓" if delta > 0 else "↑"
fig.suptitle(
f"Calibration Comparison | ECE: {b_ece:.4f}{a_ece:.4f} ({sign} {abs(delta):.4f})",
fontsize=13, fontweight="bold", y=1.02,
)
fig.tight_layout()
if output_path is not None:
out_path = Path(output_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
else:
_out_dir = Path(out_dir) if out_dir else Path(before_path).parent / "plots"
_out_dir.mkdir(parents=True, exist_ok=True)
out_path = _out_dir / "comparison.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"Comparison plot saved to {out_path}")
return str(out_path)
# ---------------------------------------------------------------------------
# CLI entry-point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Generate reliability / calibration diagrams from eval results."
)
parser.add_argument(
"--results", default="eval/baseline_results.json",
help="Path to the results JSON (default: eval/baseline_results.json)",
)
parser.add_argument(
"--out-dir", default="eval/plots",
help="Output directory for PNG files (default: eval/plots)",
)
parser.add_argument(
"--prefix", default="baseline",
help="Filename prefix for output PNGs (default: baseline)",
)
args = parser.parse_args()
results_path = Path(args.results)
if not results_path.exists():
print(f"ERROR: {results_path} not found. Run baseline_eval.py first.")
return
with open(results_path) as f:
data = json.load(f)
conditions = _collect_conditions_any_schema(data)
if not conditions:
print(f"ERROR: {results_path} has no `conditions`, `in_distribution`, "
f"or `ood` sections to plot.")
return
overall = data.get("overall", {})
out_dir = Path(args.out_dir)
print(f"\nGenerating reliability diagrams from: {results_path}")
print(f"Output directory: {out_dir}\n")
domains = sorted({k.rsplit("_", 1)[0] for k in conditions})
for domain in domains:
print(f"Domain: {domain}")
plot_domain(domain, conditions, out_dir, prefix=args.prefix)
print("\nOverall:")
plot_overall(conditions, overall, out_dir, prefix=args.prefix)
print(f"\nDone! {len(domains) + 1} plots saved to {out_dir}/")
if __name__ == "__main__":
main()