""" eval/plot_reliability.py — Publication-quality reliability / calibration diagrams. Usage: python eval/plot_reliability.py # uses baseline_results.json python eval/plot_reliability.py --results path/to.json # custom results file python eval/plot_reliability.py --prefix after_rl # changes output file prefix Comparison helper (importable): from eval.plot_reliability import plot_comparison plot_comparison("eval/baseline_results.json", "eval/after_rl_results.json") """ import argparse import json import os from pathlib import Path from typing import Optional import matplotlib matplotlib.use("Agg") # non-interactive backend — works without a display import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np import seaborn as sns # --------------------------------------------------------------------------- # Styling # --------------------------------------------------------------------------- sns.set_theme(style="whitegrid", context="paper", font_scale=1.2) PALETTE = { "bar": "#4C72B0", # Seaborn blue "perfect": "#DD8452", # Warm orange diagonal "gap_pos": "#55A868", # Green — overconfidentundershoot "gap_neg": "#C44E52", # Red — underconfident "bg": "#F8F9FA", } BIN_EDGES = np.linspace(0.0, 1.0, 11) # 10 bins: [0,.1), [.1,.2), …, [.9,1] BIN_CENTRES = (BIN_EDGES[:-1] + BIN_EDGES[1:]) / 2 BIN_WIDTH = 0.09 # slightly narrower than the 0.1 spacing for readability # --------------------------------------------------------------------------- # Core: build calibration bins from a flat list of (confidence, correct) pairs # --------------------------------------------------------------------------- def build_bins(confidences: list, correctness: list) -> dict: """Return per-bin stats used by the reliability diagram.""" conf = np.array(confidences, dtype=float) corr = np.array(correctness, dtype=float) bin_acc, bin_conf, bin_count = [], [], [] for lo, hi in zip(BIN_EDGES[:-1], BIN_EDGES[1:]): mask = (conf >= lo) & (conf < hi) # last bin is inclusive on the right if hi == 1.0: mask = (conf >= lo) & (conf <= hi) n = mask.sum() bin_count.append(int(n)) bin_acc.append(float(corr[mask].mean()) if n > 0 else np.nan) bin_conf.append(float(conf[mask].mean()) if n > 0 else np.nan) return { "bin_acc": np.array(bin_acc), "bin_conf": np.array(bin_conf), "bin_count": np.array(bin_count), } def compute_ece_from_bins(bins: dict) -> float: counts = bins["bin_count"] accs = bins["bin_acc"] confs = bins["bin_conf"] total = counts.sum() if total == 0: return float("nan") ece = 0.0 for n, a, c in zip(counts, accs, confs): if n > 0 and not np.isnan(a): ece += (n / total) * abs(a - c) return ece # --------------------------------------------------------------------------- # Single reliability diagram # --------------------------------------------------------------------------- def _draw_reliability( ax: plt.Axes, bins: dict, ece: float, title: str, show_counts: bool = True, ): """Draw a reliability diagram onto ax.""" acc = bins["bin_acc"] count = bins["bin_count"] # ── background & reference diagonal ───────────────────────────────────── ax.set_facecolor(PALETTE["bg"]) ax.plot([0, 1], [0, 1], "--", color=PALETTE["perfect"], linewidth=1.8, label="Perfect calibration", zorder=3) # ── gap fill (miscalibration region) ──────────────────────────────────── for i, (c, a, n) in enumerate(zip(BIN_CENTRES, acc, count)): if np.isnan(a) or n == 0: continue lo, hi = min(c, a), max(c, a) color = PALETTE["gap_neg"] if a < c else PALETTE["gap_pos"] ax.bar(c, hi - lo, bottom=lo, width=BIN_WIDTH * 0.98, color=color, alpha=0.25, zorder=1) # ── accuracy bars ──────────────────────────────────────────────────────── valid = ~np.isnan(acc) bars = ax.bar( BIN_CENTRES[valid], acc[valid], width=BIN_WIDTH, color=PALETTE["bar"], edgecolor="white", linewidth=0.6, label="Observed accuracy", alpha=0.85, zorder=2, ) # ── count annotations on bars ──────────────────────────────────────────── if show_counts: for bar, n in zip(bars, count[valid]): h = bar.get_height() if h > 0.05: ax.text(bar.get_x() + bar.get_width() / 2, h / 2, f"n={n}", ha="center", va="center", fontsize=7.5, color="white", fontweight="bold", zorder=5) # ── ECE badge ──────────────────────────────────────────────────────────── ece_text = f"ECE = {ece:.4f}" if not np.isnan(ece) else "ECE = n/a" ax.text(0.97, 0.04, ece_text, transform=ax.transAxes, ha="right", va="bottom", fontsize=11, fontweight="bold", bbox=dict(boxstyle="round,pad=0.35", fc="white", ec="#CCCCCC", alpha=0.9)) # ── axes formatting ─────────────────────────────────────────────────────── ax.set_xlim(-0.02, 1.02) ax.set_ylim(-0.02, 1.12) ax.set_xlabel("Confidence (predicted probability)", fontsize=11) ax.set_ylabel("Accuracy (fraction correct)", fontsize=11) ax.set_title(title, fontsize=13, fontweight="bold", pad=10) ax.set_xticks(BIN_EDGES) ax.set_xticklabels([f"{e:.1f}" for e in BIN_EDGES], fontsize=8, rotation=45) ax.legend(loc="upper left", fontsize=9, framealpha=0.9) # --------------------------------------------------------------------------- # High-level: single-results plotting # --------------------------------------------------------------------------- def extract_pairs(samples: list) -> tuple: """Return (confidences, correctness) from a list of sample dicts.""" confs, corrs = [], [] for s in samples: if s.get("confidence") is not None and s.get("correct") is not None: confs.append(float(s["confidence"])) corrs.append(1 if s["correct"] else 0) return confs, corrs def plot_domain( domain: str, conditions: dict, out_dir: Path, prefix: str = "baseline", ): """One 2×3 sub-grid showing each difficulty level + aggregate for a domain.""" domain_conditions = { k: v for k, v in conditions.items() if k.startswith(domain + "_") } if not domain_conditions: print(f" ! No conditions found for domain '{domain}' — skipping.") return # collect aggregate pairs for the domain agg_conf, agg_corr = [], [] per_diff = {} for key, cond in sorted(domain_conditions.items()): diff = int(key.split("_")[-1]) cf, cr = extract_pairs(cond.get("samples", [])) per_diff[diff] = (cf, cr) agg_conf.extend(cf) agg_corr.extend(cr) n_diffs = len(per_diff) # layout: difficulties in a row + 1 aggregate panel ncols = n_diffs + 1 fig, axes = plt.subplots(1, ncols, figsize=(4 * ncols, 4.5)) fig.patch.set_facecolor("white") for idx, (diff, (cf, cr)) in enumerate(sorted(per_diff.items())): bins = build_bins(cf, cr) ece = compute_ece_from_bins(bins) _draw_reliability( axes[idx], bins, ece, title=f"{domain.capitalize()} — Difficulty {diff}", ) # aggregate panel (last) agg_bins = build_bins(agg_conf, agg_corr) agg_ece = compute_ece_from_bins(agg_bins) _draw_reliability( axes[-1], agg_bins, agg_ece, title=f"{domain.capitalize()} — All Difficulties", ) fig.suptitle( f"Baseline Calibration - {domain.capitalize()}", fontsize=15, fontweight="bold", y=1.02, ) fig.tight_layout() out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / f"{prefix}_{domain}.png" fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f" ✓ Saved {out_path}") def plot_overall( conditions: dict, overall: dict, out_dir: Path, prefix: str = "baseline", ): """Aggregate diagram across ALL domains and a per-domain summary bar.""" all_conf, all_corr = [], [] domain_eces = {} domain_accs = {} for domain in ["math", "code", "logic"]: d_conf, d_corr = [], [] for key, cond in conditions.items(): if key.startswith(domain + "_"): cf, cr = extract_pairs(cond.get("samples", [])) d_conf.extend(cf) d_corr.extend(cr) all_conf.extend(cf) all_corr.extend(cr) if d_conf: bins = build_bins(d_conf, d_corr) domain_eces[domain] = compute_ece_from_bins(bins) domain_accs[domain] = float(np.mean(d_corr)) if d_corr else 0.0 fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(14, 5)) fig.patch.set_facecolor("white") # ── left: overall reliability diagram ─────────────────────────────────── overall_bins = build_bins(all_conf, all_corr) overall_ece = compute_ece_from_bins(overall_bins) _draw_reliability(ax_left, overall_bins, overall_ece, title="Overall Reliability Diagram") # ── right: per-domain ECE + accuracy summary bar chart ────────────────── domains = list(domain_eces.keys()) x = np.arange(len(domains)) width = 0.35 ece_vals = [domain_eces[d] for d in domains] acc_vals = [domain_accs[d] for d in domains] bars1 = ax_right.bar(x - width / 2, acc_vals, width, label="Accuracy", color="#4C72B0", alpha=0.85, edgecolor="white") bars2 = ax_right.bar(x + width / 2, ece_vals, width, label="ECE (lower=better)", color="#C44E52", alpha=0.85, edgecolor="white") for bar, val in zip(list(bars1) + list(bars2), acc_vals + ece_vals): ax_right.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f"{val:.2f}", ha="center", fontsize=9, fontweight="bold") ax_right.set_xticks(x) ax_right.set_xticklabels([d.capitalize() for d in domains], fontsize=11) ax_right.set_ylim(0, 1.15) ax_right.set_ylabel("Score", fontsize=11) ax_right.set_title("Per-Domain Summary: Accuracy vs ECE", fontsize=13, fontweight="bold") ax_right.legend(fontsize=10) ax_right.set_facecolor(PALETTE["bg"]) fig.suptitle("Baseline Calibration - Overall", fontsize=15, fontweight="bold", y=1.02) fig.tight_layout() out_path = out_dir / f"{prefix}_overall.png" fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f" ✓ Saved {out_path}") # --------------------------------------------------------------------------- # Comparison: before vs after RL/SFT # --------------------------------------------------------------------------- def _collect_conditions_any_schema(results: dict) -> dict: """Return a flat ``{condition_key: condition_dict}`` regardless of which schema the JSON uses. Supported shapes: * ``baseline_eval.py`` → ``{"conditions": {"math_1": {...}, ...}}`` * ``full_eval.py`` → ``{"in_distribution": {...}, "ood": {...}}`` (each value is a ``{condition_key: cond}`` map) Keys from in-distribution conditions stay verbatim (``math_1`` / ``code_3`` / …); OOD condition keys are prefixed with ``ood_`` so they don't collide with the math/code/logic namespace. """ out: dict = {} if isinstance(results.get("conditions"), dict): out.update(results["conditions"]) if isinstance(results.get("in_distribution"), dict): out.update(results["in_distribution"]) if isinstance(results.get("ood"), dict): for k, v in results["ood"].items(): out[f"ood_{k}"] = v return out def plot_comparison( before_path: str, after_path: str, out_dir: Optional[str] = None, output_path: Optional[str] = None, label_before: str = "Before Training", label_after: str = "After Training", ): """ Generate side-by-side overall reliability diagrams from two result JSON files. Typically called after RL training to visualise improvement. Robust to both ``baseline_results.json`` (top-level ``conditions``) and ``full_results.json`` (top-level ``in_distribution`` + ``ood``) schemas — samples from any/all sections are aggregated for the diagram. Args: before_path: Baseline JSON. after_path: Trained JSON. out_dir: Optional directory; ignored if ``output_path`` is given. output_path: Explicit PNG path. Takes precedence over ``out_dir``. label_before/label_after: Panel titles. Returns: str: Path of the saved PNG. """ with open(before_path) as f: before = json.load(f) with open(after_path) as f: after = json.load(f) def _collect(results): confs, corrs = [], [] for cond in _collect_conditions_any_schema(results).values(): cf, cr = extract_pairs(cond.get("samples", [])) confs.extend(cf) corrs.extend(cr) return confs, corrs b_conf, b_corr = _collect(before) a_conf, a_corr = _collect(after) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5)) fig.patch.set_facecolor("white") b_bins = build_bins(b_conf, b_corr) a_bins = build_bins(a_conf, a_corr) _draw_reliability(ax1, b_bins, compute_ece_from_bins(b_bins), title=label_before, show_counts=False) _draw_reliability(ax2, a_bins, compute_ece_from_bins(a_bins), title=label_after, show_counts=False) b_ece = compute_ece_from_bins(b_bins) a_ece = compute_ece_from_bins(a_bins) delta = b_ece - a_ece sign = "↓" if delta > 0 else "↑" fig.suptitle( f"Calibration Comparison | ECE: {b_ece:.4f} → {a_ece:.4f} ({sign} {abs(delta):.4f})", fontsize=13, fontweight="bold", y=1.02, ) fig.tight_layout() if output_path is not None: out_path = Path(output_path) out_path.parent.mkdir(parents=True, exist_ok=True) else: _out_dir = Path(out_dir) if out_dir else Path(before_path).parent / "plots" _out_dir.mkdir(parents=True, exist_ok=True) out_path = _out_dir / "comparison.png" fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f"Comparison plot saved to {out_path}") return str(out_path) # --------------------------------------------------------------------------- # CLI entry-point # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Generate reliability / calibration diagrams from eval results." ) parser.add_argument( "--results", default="eval/baseline_results.json", help="Path to the results JSON (default: eval/baseline_results.json)", ) parser.add_argument( "--out-dir", default="eval/plots", help="Output directory for PNG files (default: eval/plots)", ) parser.add_argument( "--prefix", default="baseline", help="Filename prefix for output PNGs (default: baseline)", ) args = parser.parse_args() results_path = Path(args.results) if not results_path.exists(): print(f"ERROR: {results_path} not found. Run baseline_eval.py first.") return with open(results_path) as f: data = json.load(f) conditions = _collect_conditions_any_schema(data) if not conditions: print(f"ERROR: {results_path} has no `conditions`, `in_distribution`, " f"or `ood` sections to plot.") return overall = data.get("overall", {}) out_dir = Path(args.out_dir) print(f"\nGenerating reliability diagrams from: {results_path}") print(f"Output directory: {out_dir}\n") domains = sorted({k.rsplit("_", 1)[0] for k in conditions}) for domain in domains: print(f"Domain: {domain}") plot_domain(domain, conditions, out_dir, prefix=args.prefix) print("\nOverall:") plot_overall(conditions, overall, out_dir, prefix=args.prefix) print(f"\nDone! {len(domains) + 1} plots saved to {out_dir}/") if __name__ == "__main__": main()