from __future__ import annotations import argparse import json import sys from datetime import datetime, timezone from pathlib import Path import numpy as np import torch from sklearn.metrics import accuracy_score, log_loss BASE_DIR = Path(__file__).resolve().parent.parent if str(BASE_DIR) not in sys.path: sys.path.insert(0, str(BASE_DIR)) from config import HEAD_CONFIGS, ensure_artifact_dirs from model_runtime import get_head from training.common import load_labeled_rows, write_json def expected_calibration_error(probabilities: np.ndarray, labels: np.ndarray, bins: int = 10) -> float: confidences = probabilities.max(axis=1) predictions = probabilities.argmax(axis=1) accuracies = predictions == labels ece = 0.0 bin_edges = np.linspace(0.0, 1.0, bins + 1) for lower, upper in zip(bin_edges[:-1], bin_edges[1:]): if upper == 1.0: mask = (confidences >= lower) & (confidences <= upper) else: mask = (confidences >= lower) & (confidences < upper) if not np.any(mask): continue bin_accuracy = float(np.mean(accuracies[mask])) bin_confidence = float(np.mean(confidences[mask])) ece += abs(bin_accuracy - bin_confidence) * float(np.mean(mask)) return ece def optimize_temperature(logits: np.ndarray, labels: np.ndarray) -> float: logits_tensor = torch.tensor(logits, dtype=torch.float32) labels_tensor = torch.tensor(labels, dtype=torch.long) temperature = torch.nn.Parameter(torch.ones(1, dtype=torch.float32)) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.LBFGS([temperature], lr=0.01, max_iter=50) def closure(): optimizer.zero_grad() loss = criterion(logits_tensor / temperature.clamp(min=1e-3), labels_tensor) loss.backward() return loss optimizer.step(closure) return max(float(temperature.detach().item()), 1e-3) def select_threshold(confidences: np.ndarray, correct: np.ndarray, target_precision: float, step: float) -> dict: candidates = [] threshold = 0.0 while threshold <= 1.000001: accepted = confidences >= threshold coverage = float(np.mean(accepted)) if coverage > 0: accepted_accuracy = float(np.mean(correct[accepted])) candidates.append( { "threshold": round(float(threshold), 4), "coverage": round(coverage, 4), "accepted_accuracy": round(accepted_accuracy, 4), } ) threshold += step eligible = [candidate for candidate in candidates if candidate["accepted_accuracy"] >= target_precision] if eligible: return max(eligible, key=lambda candidate: (candidate["coverage"], -candidate["threshold"])) return max( candidates, key=lambda candidate: ( candidate["accepted_accuracy"] * candidate["coverage"], candidate["accepted_accuracy"], -candidate["threshold"], ), ) def summarize_threshold(confidences: np.ndarray, correct: np.ndarray, threshold: float) -> dict: accepted = confidences >= threshold coverage = float(np.mean(accepted)) accepted_accuracy = float(np.mean(correct[accepted])) if coverage > 0 else 0.0 return { "threshold": round(float(threshold), 4), "coverage": round(coverage, 4), "accepted_accuracy": round(accepted_accuracy, 4), } def collect_logits(head_name: str, split: str) -> tuple[np.ndarray, np.ndarray]: head = get_head(head_name) config = head.config rows = load_labeled_rows(config.split_paths[split], config.label_field, config.label2id) texts = [row["text"] for row in rows] labels = np.array([row["label"] for row in rows], dtype=np.int64) inputs = head.tokenizer( texts, return_tensors="pt", truncation=True, padding=True, max_length=config.max_length, ) inputs = { key: value for key, value in inputs.items() if key in head.forward_arg_names } with torch.no_grad(): logits = head.model(**inputs).logits.detach().cpu().numpy() return logits, labels def calibrate_head(head_name: str, split: str, step: float) -> dict: head = get_head(head_name) logits, labels = collect_logits(head_name, split) raw_probs = torch.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).numpy() raw_confidences = raw_probs.max(axis=1) raw_preds = raw_probs.argmax(axis=1) raw_correct = raw_preds == labels raw_nll = float(log_loss(labels, raw_probs, labels=list(range(len(head.config.labels))))) optimized_temperature = optimize_temperature(logits, labels) calibrated_probs = torch.softmax(torch.tensor(logits / optimized_temperature, dtype=torch.float32), dim=-1).numpy() calibrated_confidences = calibrated_probs.max(axis=1) calibrated_preds = calibrated_probs.argmax(axis=1) calibrated_correct = calibrated_preds == labels calibrated_nll = float(log_loss(labels, calibrated_probs, labels=list(range(len(head.config.labels))))) temperature = optimized_temperature used_temperature_scaling = calibrated_nll <= raw_nll if not used_temperature_scaling: temperature = 1.0 calibrated_probs = raw_probs calibrated_confidences = raw_confidences calibrated_preds = raw_preds calibrated_correct = raw_correct calibrated_nll = raw_nll selected_threshold_summary = select_threshold( calibrated_confidences, calibrated_correct, target_precision=head.config.target_accept_precision, step=step, ) applied_threshold = max( float(selected_threshold_summary["threshold"]), float(head.config.min_calibrated_confidence_threshold), ) threshold_summary = summarize_threshold(calibrated_confidences, calibrated_correct, applied_threshold) payload = { "head": head_name, "generated_at": datetime.now(timezone.utc).isoformat(), "calibrated": True, "temperature": round(float(temperature), 6), "temperature_scaling_applied": used_temperature_scaling, "optimized_temperature_candidate": round(float(optimized_temperature), 6), "confidence_threshold": threshold_summary["threshold"], "selection_target_precision": head.config.target_accept_precision, "selection_split": split, "minimum_threshold_floor": round(float(head.config.min_calibrated_confidence_threshold), 4), "metrics": { "raw_accuracy": round(float(accuracy_score(labels, raw_preds)), 4), "calibrated_accuracy": round(float(accuracy_score(labels, calibrated_preds)), 4), "raw_negative_log_likelihood": round(raw_nll, 4), "calibrated_negative_log_likelihood": round(calibrated_nll, 4), "raw_expected_calibration_error": round(float(expected_calibration_error(raw_probs, labels)), 4), "calibrated_expected_calibration_error": round( float(expected_calibration_error(calibrated_probs, labels)), 4, ), "mean_raw_confidence": round(float(np.mean(raw_confidences)), 4), "mean_calibrated_confidence": round(float(np.mean(calibrated_confidences)), 4), }, "selected_threshold_before_floor": selected_threshold_summary, "threshold_summary": threshold_summary, } write_json(head.config.calibration_path, payload) return payload def main() -> None: parser = argparse.ArgumentParser(description="Calibrate confidence scores for classifier heads.") parser.add_argument( "--head", choices=["all", *HEAD_CONFIGS.keys()], default="all", help="Which head to calibrate.", ) parser.add_argument( "--split", choices=["val", "test"], default="val", help="Which labeled split to use for calibration fitting.", ) parser.add_argument( "--threshold-step", type=float, default=0.01, help="Grid step for confidence threshold selection.", ) args = parser.parse_args() ensure_artifact_dirs() head_names = list(HEAD_CONFIGS.keys()) if args.head == "all" else [args.head] summary = { head_name: calibrate_head(head_name, split=args.split, step=max(args.threshold_step, 0.001)) for head_name in head_names } print(json.dumps(summary, indent=2)) if __name__ == "__main__": main()