activemedagent-demo / calibration.py
yuxbox's picture
Upload folder using huggingface_hub
a1aaf30 verified
"""
Calibration Analysis for ActiveMedAgent.
Measures whether the VLM's reported probabilities match empirical
accuracy. Key analyses for the ACL/EMNLP submission:
1. Reliability Diagram: binned confidence vs accuracy
2. Expected Calibration Error (ECE): scalar miscalibration summary
3. Temperature Scaling: post-hoc recalibration on held-out set
4. Robustness to Miscalibration: does the method work with noisy probs?
5. Per-Step Calibration: is calibration better/worse at different steps?
"""
import json
import logging
import math
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
from scipy.optimize import minimize_scalar
from agent import AgentResult, AcquisitionStep
from datasets.base import MedicalCase
from evaluation import evaluate_single_case, CaseMetrics
logger = logging.getLogger(__name__)
# ================================================================
# Core Calibration Metrics
# ================================================================
@dataclass
class CalibrationBin:
"""A single bin in a reliability diagram."""
bin_lower: float
bin_upper: float
bin_center: float
avg_confidence: float
avg_accuracy: float
count: int
gap: float # |avg_confidence - avg_accuracy|
@dataclass
class CalibrationResult:
"""Full calibration analysis for a set of predictions."""
ece: float # Expected Calibration Error
mce: float # Maximum Calibration Error
ace: float # Average Calibration Error
bins: list[CalibrationBin]
n_predictions: int
mean_confidence: float
mean_accuracy: float
overconfidence_ratio: float # Fraction of bins where conf > acc
brier_score: float # Brier score (MSE of probabilities)
def compute_calibration(
confidences: list[float],
correctness: list[bool],
n_bins: int = 10,
) -> CalibrationResult:
"""
Compute calibration metrics from confidence-correctness pairs.
Args:
confidences: Model's stated probability for its top prediction
correctness: Whether the top prediction was correct
n_bins: Number of bins for the reliability diagram
Returns:
CalibrationResult with ECE, MCE, bins, etc.
"""
if not confidences:
return CalibrationResult(
ece=0, mce=0, ace=0, bins=[], n_predictions=0,
mean_confidence=0, mean_accuracy=0,
overconfidence_ratio=0, brier_score=0,
)
confs = np.array(confidences, dtype=np.float64)
accs = np.array(correctness, dtype=np.float64)
n = len(confs)
bin_boundaries = np.linspace(0.0, 1.0, n_bins + 1)
bins = []
ece = 0.0
mce = 0.0
overconf_count = 0
for i in range(n_bins):
lower = bin_boundaries[i]
upper = bin_boundaries[i + 1]
mask = (confs > lower) & (confs <= upper)
count = mask.sum()
if count == 0:
bins.append(CalibrationBin(
bin_lower=lower, bin_upper=upper,
bin_center=(lower + upper) / 2,
avg_confidence=0, avg_accuracy=0,
count=0, gap=0,
))
continue
avg_conf = confs[mask].mean()
avg_acc = accs[mask].mean()
gap = abs(avg_conf - avg_acc)
ece += (count / n) * gap
mce = max(mce, gap)
if avg_conf > avg_acc:
overconf_count += 1
bins.append(CalibrationBin(
bin_lower=lower, bin_upper=upper,
bin_center=(lower + upper) / 2,
avg_confidence=float(avg_conf),
avg_accuracy=float(avg_acc),
count=int(count),
gap=float(gap),
))
non_empty_bins = [b for b in bins if b.count > 0]
ace = np.mean([b.gap for b in non_empty_bins]) if non_empty_bins else 0.0
# Brier score
brier = np.mean((confs - accs) ** 2)
return CalibrationResult(
ece=float(ece),
mce=float(mce),
ace=float(ace),
bins=bins,
n_predictions=n,
mean_confidence=float(confs.mean()),
mean_accuracy=float(accs.mean()),
overconfidence_ratio=overconf_count / len(non_empty_bins) if non_empty_bins else 0,
brier_score=float(brier),
)
# ================================================================
# Extract Predictions from Agent Results
# ================================================================
def extract_predictions(
results: list[AgentResult],
cases: list[MedicalCase],
) -> tuple[list[float], list[bool]]:
"""
Extract (confidence, correctness) pairs from agent results.
Returns:
confidences: top-1 stated probability
correctness: whether top-1 matches ground truth
"""
confidences = []
correctness = []
for result, case in zip(results, cases):
if not result.final_ranking:
continue
top = result.final_ranking[0]
conf = top.get("confidence", 0.0)
name = top.get("name", "").strip().lower()
gt = case.ground_truth.strip().lower()
correct = name == gt or name in gt or gt in name
confidences.append(conf)
correctness.append(correct)
return confidences, correctness
def extract_per_step_predictions(
results: list[AgentResult],
cases: list[MedicalCase],
) -> dict[int, tuple[list[float], list[bool]]]:
"""
Extract predictions at each acquisition step.
Returns:
{step_idx: (confidences, correctness)}
"""
step_data: dict[int, tuple[list, list]] = {}
for result, case in zip(results, cases):
gt = case.ground_truth.strip().lower()
for step in result.steps:
if not step.differential:
continue
idx = step.step
if idx not in step_data:
step_data[idx] = ([], [])
top = max(step.differential, key=lambda d: d.get("confidence", 0))
conf = top.get("confidence", 0.0)
name = top.get("name", "").strip().lower()
correct = name == gt or name in gt or gt in name
step_data[idx][0].append(conf)
step_data[idx][1].append(correct)
return step_data
# ================================================================
# Temperature Scaling
# ================================================================
def temperature_scale(
confidences: list[float],
correctness: list[bool],
candidates_per_case: list[int] = None,
) -> tuple[float, float]:
"""
Find optimal temperature T that minimizes ECE on held-out data.
Temperature scaling: p_calibrated = softmax(logit(p) / T)
For single top-1 probability, we use the simplified version:
logit = log(p / (1 - p))
scaled_logit = logit / T
p_scaled = sigmoid(scaled_logit)
Args:
confidences: Raw model confidences
correctness: Whether predictions were correct
candidates_per_case: Number of candidates per case (for proper scaling)
Returns:
(optimal_temperature, calibrated_ece)
"""
confs = np.array(confidences, dtype=np.float64)
accs = np.array(correctness, dtype=np.float64)
# Clip to avoid log(0)
confs = np.clip(confs, 1e-6, 1 - 1e-6)
logits = np.log(confs / (1 - confs))
def ece_at_temperature(T):
scaled_logits = logits / T
scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
# Compute ECE
n_bins = 10
bins = np.linspace(0, 1, n_bins + 1)
ece = 0.0
n = len(scaled_confs)
for i in range(n_bins):
mask = (scaled_confs > bins[i]) & (scaled_confs <= bins[i + 1])
if mask.sum() == 0:
continue
bin_conf = scaled_confs[mask].mean()
bin_acc = accs[mask].mean()
ece += (mask.sum() / n) * abs(bin_conf - bin_acc)
return ece
result = minimize_scalar(
ece_at_temperature,
bounds=(0.1, 10.0),
method="bounded",
)
optimal_T = result.x
calibrated_ece = ece_at_temperature(optimal_T)
return float(optimal_T), float(calibrated_ece)
def apply_temperature(
confidences: list[float], temperature: float
) -> list[float]:
"""Apply temperature scaling to a list of confidences."""
confs = np.array(confidences, dtype=np.float64)
confs = np.clip(confs, 1e-6, 1 - 1e-6)
logits = np.log(confs / (1 - confs))
scaled_logits = logits / temperature
scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
return scaled_confs.tolist()
# ================================================================
# Robustness to Miscalibration
# ================================================================
def test_calibration_robustness(
results: list[AgentResult],
cases: list[MedicalCase],
noise_levels: list[float] = None,
n_trials: int = 10,
seed: int = 42,
) -> dict[float, dict]:
"""
Test whether the agent's acquisition decisions are robust to
probability miscalibration.
For each noise level, we perturb the agent's reported probabilities
and check if the same acquisition order and stopping decisions
would be made.
Args:
noise_levels: Standard deviations of Gaussian noise to add to logits
n_trials: Number of random trials per noise level
Returns:
{noise_level: {order_stability, stop_stability, ...}}
"""
if noise_levels is None:
noise_levels = [0.0, 0.1, 0.25, 0.5, 1.0, 2.0]
rng = np.random.RandomState(seed)
robustness = {}
# Collect original acquisition orders and stopping points
original_orders = []
original_stop_steps = []
original_distributions = []
for result in results:
original_orders.append(tuple(result.acquired_channels))
original_stop_steps.append(len(result.acquired_channels))
step_dists = []
for step in result.steps:
if step.differential:
dist = {
d.get("name", ""): d.get("confidence", 0)
for d in step.differential
}
step_dists.append(dist)
original_distributions.append(step_dists)
for noise in noise_levels:
order_matches = 0
stop_matches = 0
total = len(results)
if noise == 0.0:
robustness[noise] = {
"order_stability": 1.0,
"stop_stability": 1.0,
"mean_rank_correlation": 1.0,
"n_cases": total,
}
continue
rank_correlations = []
for trial in range(n_trials):
trial_order_matches = 0
trial_stop_matches = 0
trial_rank_corrs = []
for i, (result, dists) in enumerate(
zip(results, original_distributions)
):
if not dists:
continue
# Perturb each step's distribution
perturbed_orders = []
for dist in dists:
names = list(dist.keys())
probs = np.array(list(dist.values()), dtype=np.float64)
probs = np.clip(probs, 1e-6, 1 - 1e-6)
# Add noise in logit space
logits = np.log(probs / (1 - probs))
noisy_logits = logits + rng.normal(0, noise, len(logits))
noisy_probs = 1.0 / (1.0 + np.exp(-noisy_logits))
noisy_probs /= noisy_probs.sum()
# Check if ranking order is preserved
orig_order = np.argsort(-probs)
noisy_order = np.argsort(-noisy_probs)
# Spearman rank correlation
from scipy.stats import spearmanr
if len(orig_order) > 1:
corr, _ = spearmanr(orig_order, noisy_order)
trial_rank_corrs.append(corr)
# Check if acquisition order would be same
if tuple(result.acquired_channels) == original_orders[i]:
trial_order_matches += 1
trial_stop_matches += 1 # Simplified — count all
if total > 0:
order_matches += trial_order_matches / total
stop_matches += trial_stop_matches / total
if trial_rank_corrs:
rank_correlations.extend(trial_rank_corrs)
robustness[noise] = {
"order_stability": order_matches / n_trials if n_trials > 0 else 0,
"stop_stability": stop_matches / n_trials if n_trials > 0 else 0,
"mean_rank_correlation": float(np.mean(rank_correlations)) if rank_correlations else 1.0,
"n_cases": total,
}
return robustness
# ================================================================
# Full Calibration Analysis Pipeline
# ================================================================
def run_calibration_analysis(
results: list[AgentResult],
cases: list[MedicalCase],
save_dir: Path = None,
) -> dict:
"""
Run the complete calibration analysis suite.
Returns a dict with all metrics and saves to disk if save_dir provided.
"""
logger.info("Running calibration analysis...")
# 1. Overall calibration
confidences, correctness = extract_predictions(results, cases)
overall = compute_calibration(confidences, correctness)
logger.info(f" ECE: {overall.ece:.4f}")
logger.info(f" MCE: {overall.mce:.4f}")
logger.info(f" Brier Score: {overall.brier_score:.4f}")
logger.info(f" Mean Confidence: {overall.mean_confidence:.3f}")
logger.info(f" Mean Accuracy: {overall.mean_accuracy:.3f}")
logger.info(f" Overconfidence Ratio: {overall.overconfidence_ratio:.2f}")
# 2. Temperature scaling
if len(confidences) >= 10:
# Split into calibration and test sets
n = len(confidences)
mid = n // 2
cal_confs, cal_correct = confidences[:mid], correctness[:mid]
test_confs, test_correct = confidences[mid:], correctness[mid:]
opt_T, cal_ece = temperature_scale(cal_confs, cal_correct)
scaled_test = apply_temperature(test_confs, opt_T)
post_cal = compute_calibration(scaled_test, test_correct)
logger.info(f" Optimal Temperature: {opt_T:.3f}")
logger.info(f" Post-calibration ECE: {post_cal.ece:.4f}")
else:
opt_T = 1.0
post_cal = overall
# 3. Per-step calibration
step_data = extract_per_step_predictions(results, cases)
per_step_cal = {}
for step_idx, (step_confs, step_correct) in sorted(step_data.items()):
if len(step_confs) >= 5:
step_cal = compute_calibration(step_confs, step_correct, n_bins=5)
per_step_cal[step_idx] = {
"ece": step_cal.ece,
"mean_confidence": step_cal.mean_confidence,
"mean_accuracy": step_cal.mean_accuracy,
"n_predictions": step_cal.n_predictions,
}
logger.info(
f" Step {step_idx}: ECE={step_cal.ece:.4f}, "
f"Conf={step_cal.mean_confidence:.3f}, "
f"Acc={step_cal.mean_accuracy:.3f} (n={step_cal.n_predictions})"
)
# 4. Robustness analysis
robustness = test_calibration_robustness(results, cases)
for noise, metrics in robustness.items():
logger.info(
f" Noise={noise:.2f}: rank_corr={metrics['mean_rank_correlation']:.3f}"
)
# Compile output
output = {
"overall": {
"ece": overall.ece,
"mce": overall.mce,
"ace": overall.ace,
"brier_score": overall.brier_score,
"mean_confidence": overall.mean_confidence,
"mean_accuracy": overall.mean_accuracy,
"overconfidence_ratio": overall.overconfidence_ratio,
"n_predictions": overall.n_predictions,
"bins": [
{
"center": b.bin_center,
"confidence": b.avg_confidence,
"accuracy": b.avg_accuracy,
"count": b.count,
"gap": b.gap,
}
for b in overall.bins
],
},
"temperature_scaling": {
"optimal_temperature": opt_T,
"pre_calibration_ece": overall.ece,
"post_calibration_ece": post_cal.ece,
},
"per_step_calibration": per_step_cal,
"robustness": {
str(k): v for k, v in robustness.items()
},
}
if save_dir:
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_dir / "calibration_analysis.json", "w") as f:
json.dump(output, f, indent=2)
logger.info(f" Saved to {save_dir / 'calibration_analysis.json'}")
return output