"""Standalone calibration evaluator for a frozen HF text classifier. Computes ECE (Expected Calibration Error) and a few helpful supporting stats. Example: python eval_calibration.py --model_dir probert_model --csv training_data/probert_training_20260131_004706.csv # Or use auto-detection for ProBERT: python eval_calibration.py --probert CSV requirements: - text column (default: text) - label column (default: label). Can be string labels (uses model.config.label2id) or integer ids. """ from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import pandas as pd import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer def _softmax_np(x: np.ndarray, axis: int = -1) -> np.ndarray: x = x - np.max(x, axis=axis, keepdims=True) ex = np.exp(x) return ex / np.sum(ex, axis=axis, keepdims=True) def ece_score(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15) -> float: """Expected Calibration Error (ECE) with equal-width bins on max-prob confidence.""" probs = np.asarray(probs) labels = np.asarray(labels) conf = probs.max(axis=1) preds = probs.argmax(axis=1) acc = (preds == labels).astype(np.float32) bins = np.linspace(0.0, 1.0, n_bins + 1) ece = 0.0 for i in range(n_bins): lo, hi = bins[i], bins[i + 1] mask = (conf > lo) & (conf <= hi) if not np.any(mask): continue ece += float(np.abs(acc[mask].mean() - conf[mask].mean()) * mask.mean()) return float(ece) def nll_score(probs: np.ndarray, labels: np.ndarray) -> float: probs = np.asarray(probs) labels = np.asarray(labels) p_true = probs[np.arange(len(labels)), labels] return float(-np.log(np.clip(p_true, 1e-12, 1.0)).mean()) def infer_label_id(series: pd.Series, label2id: dict | None) -> np.ndarray: if pd.api.types.is_integer_dtype(series) or pd.api.types.is_bool_dtype(series): return series.astype(int).to_numpy() # Try numeric strings try: return series.astype(int).to_numpy() except Exception: pass if not label2id: raise ValueError( "Labels look non-numeric, but model has no label2id mapping. " "Pass integer labels in the CSV or use a model with label2id configured." ) unknown = sorted(set(series.astype(str)) - set(label2id.keys())) if unknown: raise ValueError(f"Unknown labels not in model.config.label2id: {unknown[:10]}") return series.astype(str).map(label2id).astype(int).to_numpy() def run(args: argparse.Namespace) -> dict: device = ( torch.device("cuda") if args.device == "auto" and torch.cuda.is_available() else torch.device("cpu") if args.device == "auto" else torch.device(args.device) ) model_dir = Path(args.model_dir) csv_path = Path(args.csv) tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) model = AutoModelForSequenceClassification.from_pretrained(str(model_dir)) model.to(device) model.eval() df = pd.read_csv(csv_path) if args.text_col not in df.columns: raise ValueError(f"Missing text column '{args.text_col}' in CSV") if args.label_col not in df.columns: raise ValueError(f"Missing label column '{args.label_col}' in CSV") label2id = getattr(getattr(model, "config", None), "label2id", None) labels = infer_label_id(df[args.label_col], label2id) texts = df[args.text_col].astype(str).tolist() logits_chunks: list[np.ndarray] = [] with torch.no_grad(): for start in range(0, len(texts), args.batch_size): batch = texts[start : start + args.batch_size] enc = tokenizer( batch, truncation=True, max_length=args.max_length, padding=True, return_tensors="pt", ) enc = {k: v.to(device) for k, v in enc.items()} out = model(**enc) logits_chunks.append(out.logits.detach().cpu().numpy()) logits = np.concatenate(logits_chunks, axis=0) probs = _softmax_np(logits, axis=1) preds = probs.argmax(axis=1) conf = probs.max(axis=1) wrong = preds != labels result: dict = { "n": int(len(labels)), "accuracy": float((preds == labels).mean()), "mean_conf": float(conf.mean()), "nll": nll_score(probs, labels), f"ece_{args.n_bins}": ece_score(probs, labels, n_bins=args.n_bins), "wrong_count": int(wrong.sum()), "max_wrong_conf": float(conf[wrong].max()) if wrong.any() else 0.0, } for t in args.thresholds: key = str(t).replace(".", "_") result[f"coverage_at_conf_ge_{key}"] = float((conf >= t).mean()) result[f"wrong_at_conf_ge_{key}"] = float(((wrong) & (conf >= t)).mean()) return result def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--model_dir", default="probert_model", help="Path to saved HF model directory") p.add_argument("--csv", help="CSV with text + label (auto-detects latest in training_data/ if not provided)") p.add_argument("--probert", action="store_true", help="ProBERT mode: auto-detect model + latest training CSV") p.add_argument("--text_col", default="text") p.add_argument("--label_col", default="label") p.add_argument("--batch_size", type=int, default=32) p.add_argument("--max_length", type=int, default=128) p.add_argument("--n_bins", type=int, default=15) p.add_argument( "--thresholds", type=float, nargs="*", default=[0.7, 0.8, 0.9], help="Confidence thresholds for coverage/wrong-rate", ) p.add_argument( "--device", default="auto", choices=["auto", "cpu", "cuda"], help="auto uses cuda if available", ) p.add_argument("--out_json", default="", help="Optional path to write metrics JSON") return p.parse_args() def main() -> None: args = parse_args() # ProBERT mode: auto-detect model + latest CSV if args.probert: # Find ProBERT root (where probert_model/ should be) script_dir = Path(__file__).parent probert_root = script_dir.parent.parent # Go up from training_data/1.0 to ProBERT/ args.model_dir = str(probert_root / "probert_model") if not args.csv: # Look in multiple locations for training CSVs search_dirs = [ probert_root / "training_data", probert_root / "training_data" / "1.0", script_dir, # Same directory as script ] csv_files = [] for search_dir in search_dirs: if search_dir.exists(): csv_files.extend(search_dir.glob("probert_training_*.csv")) if not csv_files: raise ValueError(f"No training CSV found. Searched: {[str(d) for d in search_dirs]}") args.csv = str(max(csv_files, key=lambda p: p.stat().st_mtime)) print(f"🔍 Auto-detected CSV: {args.csv}") # Validate inputs if not args.csv: raise ValueError("Must provide --csv or use --probert mode") if not Path(args.model_dir).exists(): raise ValueError(f"Model directory not found: {args.model_dir}") result = run(args) print("\n" + "="*70) print("CALIBRATION METRICS") print("="*70) print(json.dumps(result, indent=2)) # ProBERT-specific interpretation if args.probert or "probert" in args.model_dir.lower(): print("\n" + "="*70) print("PROBERT CALIBRATION SUMMARY") print("="*70) ece = result.get(f"ece_{args.n_bins}", 0.0) acc = result['accuracy'] mean_conf = result['mean_conf'] # Determine if over or underconfident conf_gap = acc - mean_conf if conf_gap > 0.10: confidence_type = "UNDERCONFIDENT" gap_msg = f"Model is {conf_gap*100:.1f}% less confident than it should be (conservative)" elif conf_gap < -0.10: confidence_type = "OVERCONFIDENT" gap_msg = f"Model is {abs(conf_gap)*100:.1f}% more confident than accuracy justifies (risky)" else: confidence_type = "WELL-CALIBRATED" gap_msg = "Confidence matches accuracy closely" if ece <= 0.05: verdict = "✅ EXCELLENT - Very well calibrated" elif ece <= 0.10: verdict = "✅ GOOD - Acceptable calibration" elif ece <= 0.15: verdict = "⚠️ MODERATE - Some miscalibration" else: verdict = f"⚠️ HIGH ECE - {confidence_type}" print(f"\nECE ({args.n_bins} bins): {ece:.4f}") print(f"Verdict: {verdict}") print(f"\n{gap_msg}") print(f"Accuracy: {acc:.3f} | Mean Confidence: {mean_conf:.3f} | Gap: {conf_gap:+.3f}") print(f"NLL (log loss): {result['nll']:.4f}") print(f"\nWrong predictions: {result['wrong_count']}/{result['n']}") print(f"Max confidence on wrong: {result['max_wrong_conf']:.3f}") # Analyze high-confidence errors high_conf_wrong = result.get('wrong_at_conf_ge_0_8', 0.0) if high_conf_wrong == 0.0: print("\n✅ SAFETY: No errors at confidence ≥ 0.8 (high-confidence predictions are trustworthy)") else: print(f"\n⚠️ RISK: {high_conf_wrong*100:.1f}% wrong predictions at conf ≥ 0.8") print(f"\nCoverage at confidence thresholds:") for t in args.thresholds: key = str(t).replace(".", "_") cov = result[f"coverage_at_conf_ge_{key}"] wrong = result[f"wrong_at_conf_ge_{key}"] print(f" ≥{t:.1f}: {cov*100:.1f}% coverage, {wrong*100:.1f}% wrong") if args.out_json: out_path = Path(args.out_json) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(result, indent=2) + "\n", encoding="utf-8") print(f"\n💾 Saved to: {args.out_json}") if __name__ == "__main__": main()