agentic-intent-classifier / training /calibrate_confidence.py
manikumargouni's picture
Upload folder using huggingface_hub
0584798 verified
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import accuracy_score, log_loss
BASE_DIR = Path(__file__).resolve().parent.parent
if str(BASE_DIR) not in sys.path:
sys.path.insert(0, str(BASE_DIR))
from config import HEAD_CONFIGS, ensure_artifact_dirs
from model_runtime import get_head
from training.common import load_labeled_rows, write_json
def expected_calibration_error(probabilities: np.ndarray, labels: np.ndarray, bins: int = 10) -> float:
confidences = probabilities.max(axis=1)
predictions = probabilities.argmax(axis=1)
accuracies = predictions == labels
ece = 0.0
bin_edges = np.linspace(0.0, 1.0, bins + 1)
for lower, upper in zip(bin_edges[:-1], bin_edges[1:]):
if upper == 1.0:
mask = (confidences >= lower) & (confidences <= upper)
else:
mask = (confidences >= lower) & (confidences < upper)
if not np.any(mask):
continue
bin_accuracy = float(np.mean(accuracies[mask]))
bin_confidence = float(np.mean(confidences[mask]))
ece += abs(bin_accuracy - bin_confidence) * float(np.mean(mask))
return ece
def optimize_temperature(logits: np.ndarray, labels: np.ndarray) -> float:
logits_tensor = torch.tensor(logits, dtype=torch.float32)
labels_tensor = torch.tensor(labels, dtype=torch.long)
temperature = torch.nn.Parameter(torch.ones(1, dtype=torch.float32))
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.LBFGS([temperature], lr=0.01, max_iter=50)
def closure():
optimizer.zero_grad()
loss = criterion(logits_tensor / temperature.clamp(min=1e-3), labels_tensor)
loss.backward()
return loss
optimizer.step(closure)
return max(float(temperature.detach().item()), 1e-3)
def select_threshold(confidences: np.ndarray, correct: np.ndarray, target_precision: float, step: float) -> dict:
candidates = []
threshold = 0.0
while threshold <= 1.000001:
accepted = confidences >= threshold
coverage = float(np.mean(accepted))
if coverage > 0:
accepted_accuracy = float(np.mean(correct[accepted]))
candidates.append(
{
"threshold": round(float(threshold), 4),
"coverage": round(coverage, 4),
"accepted_accuracy": round(accepted_accuracy, 4),
}
)
threshold += step
eligible = [candidate for candidate in candidates if candidate["accepted_accuracy"] >= target_precision]
if eligible:
return max(eligible, key=lambda candidate: (candidate["coverage"], -candidate["threshold"]))
return max(
candidates,
key=lambda candidate: (
candidate["accepted_accuracy"] * candidate["coverage"],
candidate["accepted_accuracy"],
-candidate["threshold"],
),
)
def summarize_threshold(confidences: np.ndarray, correct: np.ndarray, threshold: float) -> dict:
accepted = confidences >= threshold
coverage = float(np.mean(accepted))
accepted_accuracy = float(np.mean(correct[accepted])) if coverage > 0 else 0.0
return {
"threshold": round(float(threshold), 4),
"coverage": round(coverage, 4),
"accepted_accuracy": round(accepted_accuracy, 4),
}
def collect_logits(head_name: str, split: str) -> tuple[np.ndarray, np.ndarray]:
head = get_head(head_name)
config = head.config
rows = load_labeled_rows(config.split_paths[split], config.label_field, config.label2id)
texts = [row["text"] for row in rows]
labels = np.array([row["label"] for row in rows], dtype=np.int64)
inputs = head.tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=config.max_length,
)
inputs = {
key: value
for key, value in inputs.items()
if key in head.forward_arg_names
}
with torch.no_grad():
logits = head.model(**inputs).logits.detach().cpu().numpy()
return logits, labels
def calibrate_head(head_name: str, split: str, step: float) -> dict:
head = get_head(head_name)
logits, labels = collect_logits(head_name, split)
raw_probs = torch.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).numpy()
raw_confidences = raw_probs.max(axis=1)
raw_preds = raw_probs.argmax(axis=1)
raw_correct = raw_preds == labels
raw_nll = float(log_loss(labels, raw_probs, labels=list(range(len(head.config.labels)))))
optimized_temperature = optimize_temperature(logits, labels)
calibrated_probs = torch.softmax(torch.tensor(logits / optimized_temperature, dtype=torch.float32), dim=-1).numpy()
calibrated_confidences = calibrated_probs.max(axis=1)
calibrated_preds = calibrated_probs.argmax(axis=1)
calibrated_correct = calibrated_preds == labels
calibrated_nll = float(log_loss(labels, calibrated_probs, labels=list(range(len(head.config.labels)))))
temperature = optimized_temperature
used_temperature_scaling = calibrated_nll <= raw_nll
if not used_temperature_scaling:
temperature = 1.0
calibrated_probs = raw_probs
calibrated_confidences = raw_confidences
calibrated_preds = raw_preds
calibrated_correct = raw_correct
calibrated_nll = raw_nll
selected_threshold_summary = select_threshold(
calibrated_confidences,
calibrated_correct,
target_precision=head.config.target_accept_precision,
step=step,
)
applied_threshold = max(
float(selected_threshold_summary["threshold"]),
float(head.config.min_calibrated_confidence_threshold),
)
threshold_summary = summarize_threshold(calibrated_confidences, calibrated_correct, applied_threshold)
payload = {
"head": head_name,
"generated_at": datetime.now(timezone.utc).isoformat(),
"calibrated": True,
"temperature": round(float(temperature), 6),
"temperature_scaling_applied": used_temperature_scaling,
"optimized_temperature_candidate": round(float(optimized_temperature), 6),
"confidence_threshold": threshold_summary["threshold"],
"selection_target_precision": head.config.target_accept_precision,
"selection_split": split,
"minimum_threshold_floor": round(float(head.config.min_calibrated_confidence_threshold), 4),
"metrics": {
"raw_accuracy": round(float(accuracy_score(labels, raw_preds)), 4),
"calibrated_accuracy": round(float(accuracy_score(labels, calibrated_preds)), 4),
"raw_negative_log_likelihood": round(raw_nll, 4),
"calibrated_negative_log_likelihood": round(calibrated_nll, 4),
"raw_expected_calibration_error": round(float(expected_calibration_error(raw_probs, labels)), 4),
"calibrated_expected_calibration_error": round(
float(expected_calibration_error(calibrated_probs, labels)),
4,
),
"mean_raw_confidence": round(float(np.mean(raw_confidences)), 4),
"mean_calibrated_confidence": round(float(np.mean(calibrated_confidences)), 4),
},
"selected_threshold_before_floor": selected_threshold_summary,
"threshold_summary": threshold_summary,
}
write_json(head.config.calibration_path, payload)
return payload
def main() -> None:
parser = argparse.ArgumentParser(description="Calibrate confidence scores for classifier heads.")
parser.add_argument(
"--head",
choices=["all", *HEAD_CONFIGS.keys()],
default="all",
help="Which head to calibrate.",
)
parser.add_argument(
"--split",
choices=["val", "test"],
default="val",
help="Which labeled split to use for calibration fitting.",
)
parser.add_argument(
"--threshold-step",
type=float,
default=0.01,
help="Grid step for confidence threshold selection.",
)
args = parser.parse_args()
ensure_artifact_dirs()
head_names = list(HEAD_CONFIGS.keys()) if args.head == "all" else [args.head]
summary = {
head_name: calibrate_head(head_name, split=args.split, step=max(args.threshold_step, 0.001))
for head_name in head_names
}
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()