ProBERT-1.0 / eval_calibration.py
collapseindex's picture
uploaded repro script and dataset
ead9fb1 verified
"""Standalone calibration evaluator for a frozen HF text classifier.
Computes ECE (Expected Calibration Error) and a few helpful supporting stats.
Example:
python eval_calibration.py --model_dir probert_model --csv training_data/probert_training_20260131_004706.csv
# Or use auto-detection for ProBERT:
python eval_calibration.py --probert
CSV requirements:
- text column (default: text)
- label column (default: label). Can be string labels (uses model.config.label2id)
or integer ids.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def _softmax_np(x: np.ndarray, axis: int = -1) -> np.ndarray:
x = x - np.max(x, axis=axis, keepdims=True)
ex = np.exp(x)
return ex / np.sum(ex, axis=axis, keepdims=True)
def ece_score(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15) -> float:
"""Expected Calibration Error (ECE) with equal-width bins on max-prob confidence."""
probs = np.asarray(probs)
labels = np.asarray(labels)
conf = probs.max(axis=1)
preds = probs.argmax(axis=1)
acc = (preds == labels).astype(np.float32)
bins = np.linspace(0.0, 1.0, n_bins + 1)
ece = 0.0
for i in range(n_bins):
lo, hi = bins[i], bins[i + 1]
mask = (conf > lo) & (conf <= hi)
if not np.any(mask):
continue
ece += float(np.abs(acc[mask].mean() - conf[mask].mean()) * mask.mean())
return float(ece)
def nll_score(probs: np.ndarray, labels: np.ndarray) -> float:
probs = np.asarray(probs)
labels = np.asarray(labels)
p_true = probs[np.arange(len(labels)), labels]
return float(-np.log(np.clip(p_true, 1e-12, 1.0)).mean())
def infer_label_id(series: pd.Series, label2id: dict | None) -> np.ndarray:
if pd.api.types.is_integer_dtype(series) or pd.api.types.is_bool_dtype(series):
return series.astype(int).to_numpy()
# Try numeric strings
try:
return series.astype(int).to_numpy()
except Exception:
pass
if not label2id:
raise ValueError(
"Labels look non-numeric, but model has no label2id mapping. "
"Pass integer labels in the CSV or use a model with label2id configured."
)
unknown = sorted(set(series.astype(str)) - set(label2id.keys()))
if unknown:
raise ValueError(f"Unknown labels not in model.config.label2id: {unknown[:10]}")
return series.astype(str).map(label2id).astype(int).to_numpy()
def run(args: argparse.Namespace) -> dict:
device = (
torch.device("cuda")
if args.device == "auto" and torch.cuda.is_available()
else torch.device("cpu")
if args.device == "auto"
else torch.device(args.device)
)
model_dir = Path(args.model_dir)
csv_path = Path(args.csv)
tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))
model.to(device)
model.eval()
df = pd.read_csv(csv_path)
if args.text_col not in df.columns:
raise ValueError(f"Missing text column '{args.text_col}' in CSV")
if args.label_col not in df.columns:
raise ValueError(f"Missing label column '{args.label_col}' in CSV")
label2id = getattr(getattr(model, "config", None), "label2id", None)
labels = infer_label_id(df[args.label_col], label2id)
texts = df[args.text_col].astype(str).tolist()
logits_chunks: list[np.ndarray] = []
with torch.no_grad():
for start in range(0, len(texts), args.batch_size):
batch = texts[start : start + args.batch_size]
enc = tokenizer(
batch,
truncation=True,
max_length=args.max_length,
padding=True,
return_tensors="pt",
)
enc = {k: v.to(device) for k, v in enc.items()}
out = model(**enc)
logits_chunks.append(out.logits.detach().cpu().numpy())
logits = np.concatenate(logits_chunks, axis=0)
probs = _softmax_np(logits, axis=1)
preds = probs.argmax(axis=1)
conf = probs.max(axis=1)
wrong = preds != labels
result: dict = {
"n": int(len(labels)),
"accuracy": float((preds == labels).mean()),
"mean_conf": float(conf.mean()),
"nll": nll_score(probs, labels),
f"ece_{args.n_bins}": ece_score(probs, labels, n_bins=args.n_bins),
"wrong_count": int(wrong.sum()),
"max_wrong_conf": float(conf[wrong].max()) if wrong.any() else 0.0,
}
for t in args.thresholds:
key = str(t).replace(".", "_")
result[f"coverage_at_conf_ge_{key}"] = float((conf >= t).mean())
result[f"wrong_at_conf_ge_{key}"] = float(((wrong) & (conf >= t)).mean())
return result
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--model_dir", default="probert_model", help="Path to saved HF model directory")
p.add_argument("--csv", help="CSV with text + label (auto-detects latest in training_data/ if not provided)")
p.add_argument("--probert", action="store_true", help="ProBERT mode: auto-detect model + latest training CSV")
p.add_argument("--text_col", default="text")
p.add_argument("--label_col", default="label")
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--max_length", type=int, default=128)
p.add_argument("--n_bins", type=int, default=15)
p.add_argument(
"--thresholds",
type=float,
nargs="*",
default=[0.7, 0.8, 0.9],
help="Confidence thresholds for coverage/wrong-rate",
)
p.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="auto uses cuda if available",
)
p.add_argument("--out_json", default="", help="Optional path to write metrics JSON")
return p.parse_args()
def main() -> None:
args = parse_args()
# ProBERT mode: auto-detect model + latest CSV
if args.probert:
# Find ProBERT root (where probert_model/ should be)
script_dir = Path(__file__).parent
probert_root = script_dir.parent.parent # Go up from training_data/1.0 to ProBERT/
args.model_dir = str(probert_root / "probert_model")
if not args.csv:
# Look in multiple locations for training CSVs
search_dirs = [
probert_root / "training_data",
probert_root / "training_data" / "1.0",
script_dir, # Same directory as script
]
csv_files = []
for search_dir in search_dirs:
if search_dir.exists():
csv_files.extend(search_dir.glob("probert_training_*.csv"))
if not csv_files:
raise ValueError(f"No training CSV found. Searched: {[str(d) for d in search_dirs]}")
args.csv = str(max(csv_files, key=lambda p: p.stat().st_mtime))
print(f"🔍 Auto-detected CSV: {args.csv}")
# Validate inputs
if not args.csv:
raise ValueError("Must provide --csv or use --probert mode")
if not Path(args.model_dir).exists():
raise ValueError(f"Model directory not found: {args.model_dir}")
result = run(args)
print("\n" + "="*70)
print("CALIBRATION METRICS")
print("="*70)
print(json.dumps(result, indent=2))
# ProBERT-specific interpretation
if args.probert or "probert" in args.model_dir.lower():
print("\n" + "="*70)
print("PROBERT CALIBRATION SUMMARY")
print("="*70)
ece = result.get(f"ece_{args.n_bins}", 0.0)
acc = result['accuracy']
mean_conf = result['mean_conf']
# Determine if over or underconfident
conf_gap = acc - mean_conf
if conf_gap > 0.10:
confidence_type = "UNDERCONFIDENT"
gap_msg = f"Model is {conf_gap*100:.1f}% less confident than it should be (conservative)"
elif conf_gap < -0.10:
confidence_type = "OVERCONFIDENT"
gap_msg = f"Model is {abs(conf_gap)*100:.1f}% more confident than accuracy justifies (risky)"
else:
confidence_type = "WELL-CALIBRATED"
gap_msg = "Confidence matches accuracy closely"
if ece <= 0.05:
verdict = "✅ EXCELLENT - Very well calibrated"
elif ece <= 0.10:
verdict = "✅ GOOD - Acceptable calibration"
elif ece <= 0.15:
verdict = "⚠️ MODERATE - Some miscalibration"
else:
verdict = f"⚠️ HIGH ECE - {confidence_type}"
print(f"\nECE ({args.n_bins} bins): {ece:.4f}")
print(f"Verdict: {verdict}")
print(f"\n{gap_msg}")
print(f"Accuracy: {acc:.3f} | Mean Confidence: {mean_conf:.3f} | Gap: {conf_gap:+.3f}")
print(f"NLL (log loss): {result['nll']:.4f}")
print(f"\nWrong predictions: {result['wrong_count']}/{result['n']}")
print(f"Max confidence on wrong: {result['max_wrong_conf']:.3f}")
# Analyze high-confidence errors
high_conf_wrong = result.get('wrong_at_conf_ge_0_8', 0.0)
if high_conf_wrong == 0.0:
print("\n✅ SAFETY: No errors at confidence ≥ 0.8 (high-confidence predictions are trustworthy)")
else:
print(f"\n⚠️ RISK: {high_conf_wrong*100:.1f}% wrong predictions at conf ≥ 0.8")
print(f"\nCoverage at confidence thresholds:")
for t in args.thresholds:
key = str(t).replace(".", "_")
cov = result[f"coverage_at_conf_ge_{key}"]
wrong = result[f"wrong_at_conf_ge_{key}"]
print(f" ≥{t:.1f}: {cov*100:.1f}% coverage, {wrong*100:.1f}% wrong")
if args.out_json:
out_path = Path(args.out_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(result, indent=2) + "\n", encoding="utf-8")
print(f"\n💾 Saved to: {args.out_json}")
if __name__ == "__main__":
main()