Spaces:
Sleeping
Sleeping
| """ | |
| Calibration Analysis for ActiveMedAgent. | |
| Measures whether the VLM's reported probabilities match empirical | |
| accuracy. Key analyses for the ACL/EMNLP submission: | |
| 1. Reliability Diagram: binned confidence vs accuracy | |
| 2. Expected Calibration Error (ECE): scalar miscalibration summary | |
| 3. Temperature Scaling: post-hoc recalibration on held-out set | |
| 4. Robustness to Miscalibration: does the method work with noisy probs? | |
| 5. Per-Step Calibration: is calibration better/worse at different steps? | |
| """ | |
| import json | |
| import logging | |
| import math | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| import numpy as np | |
| from scipy.optimize import minimize_scalar | |
| from agent import AgentResult, AcquisitionStep | |
| from datasets.base import MedicalCase | |
| from evaluation import evaluate_single_case, CaseMetrics | |
| logger = logging.getLogger(__name__) | |
| # ================================================================ | |
| # Core Calibration Metrics | |
| # ================================================================ | |
| class CalibrationBin: | |
| """A single bin in a reliability diagram.""" | |
| bin_lower: float | |
| bin_upper: float | |
| bin_center: float | |
| avg_confidence: float | |
| avg_accuracy: float | |
| count: int | |
| gap: float # |avg_confidence - avg_accuracy| | |
| class CalibrationResult: | |
| """Full calibration analysis for a set of predictions.""" | |
| ece: float # Expected Calibration Error | |
| mce: float # Maximum Calibration Error | |
| ace: float # Average Calibration Error | |
| bins: list[CalibrationBin] | |
| n_predictions: int | |
| mean_confidence: float | |
| mean_accuracy: float | |
| overconfidence_ratio: float # Fraction of bins where conf > acc | |
| brier_score: float # Brier score (MSE of probabilities) | |
| def compute_calibration( | |
| confidences: list[float], | |
| correctness: list[bool], | |
| n_bins: int = 10, | |
| ) -> CalibrationResult: | |
| """ | |
| Compute calibration metrics from confidence-correctness pairs. | |
| Args: | |
| confidences: Model's stated probability for its top prediction | |
| correctness: Whether the top prediction was correct | |
| n_bins: Number of bins for the reliability diagram | |
| Returns: | |
| CalibrationResult with ECE, MCE, bins, etc. | |
| """ | |
| if not confidences: | |
| return CalibrationResult( | |
| ece=0, mce=0, ace=0, bins=[], n_predictions=0, | |
| mean_confidence=0, mean_accuracy=0, | |
| overconfidence_ratio=0, brier_score=0, | |
| ) | |
| confs = np.array(confidences, dtype=np.float64) | |
| accs = np.array(correctness, dtype=np.float64) | |
| n = len(confs) | |
| bin_boundaries = np.linspace(0.0, 1.0, n_bins + 1) | |
| bins = [] | |
| ece = 0.0 | |
| mce = 0.0 | |
| overconf_count = 0 | |
| for i in range(n_bins): | |
| lower = bin_boundaries[i] | |
| upper = bin_boundaries[i + 1] | |
| mask = (confs > lower) & (confs <= upper) | |
| count = mask.sum() | |
| if count == 0: | |
| bins.append(CalibrationBin( | |
| bin_lower=lower, bin_upper=upper, | |
| bin_center=(lower + upper) / 2, | |
| avg_confidence=0, avg_accuracy=0, | |
| count=0, gap=0, | |
| )) | |
| continue | |
| avg_conf = confs[mask].mean() | |
| avg_acc = accs[mask].mean() | |
| gap = abs(avg_conf - avg_acc) | |
| ece += (count / n) * gap | |
| mce = max(mce, gap) | |
| if avg_conf > avg_acc: | |
| overconf_count += 1 | |
| bins.append(CalibrationBin( | |
| bin_lower=lower, bin_upper=upper, | |
| bin_center=(lower + upper) / 2, | |
| avg_confidence=float(avg_conf), | |
| avg_accuracy=float(avg_acc), | |
| count=int(count), | |
| gap=float(gap), | |
| )) | |
| non_empty_bins = [b for b in bins if b.count > 0] | |
| ace = np.mean([b.gap for b in non_empty_bins]) if non_empty_bins else 0.0 | |
| # Brier score | |
| brier = np.mean((confs - accs) ** 2) | |
| return CalibrationResult( | |
| ece=float(ece), | |
| mce=float(mce), | |
| ace=float(ace), | |
| bins=bins, | |
| n_predictions=n, | |
| mean_confidence=float(confs.mean()), | |
| mean_accuracy=float(accs.mean()), | |
| overconfidence_ratio=overconf_count / len(non_empty_bins) if non_empty_bins else 0, | |
| brier_score=float(brier), | |
| ) | |
| # ================================================================ | |
| # Extract Predictions from Agent Results | |
| # ================================================================ | |
| def extract_predictions( | |
| results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| ) -> tuple[list[float], list[bool]]: | |
| """ | |
| Extract (confidence, correctness) pairs from agent results. | |
| Returns: | |
| confidences: top-1 stated probability | |
| correctness: whether top-1 matches ground truth | |
| """ | |
| confidences = [] | |
| correctness = [] | |
| for result, case in zip(results, cases): | |
| if not result.final_ranking: | |
| continue | |
| top = result.final_ranking[0] | |
| conf = top.get("confidence", 0.0) | |
| name = top.get("name", "").strip().lower() | |
| gt = case.ground_truth.strip().lower() | |
| correct = name == gt or name in gt or gt in name | |
| confidences.append(conf) | |
| correctness.append(correct) | |
| return confidences, correctness | |
| def extract_per_step_predictions( | |
| results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| ) -> dict[int, tuple[list[float], list[bool]]]: | |
| """ | |
| Extract predictions at each acquisition step. | |
| Returns: | |
| {step_idx: (confidences, correctness)} | |
| """ | |
| step_data: dict[int, tuple[list, list]] = {} | |
| for result, case in zip(results, cases): | |
| gt = case.ground_truth.strip().lower() | |
| for step in result.steps: | |
| if not step.differential: | |
| continue | |
| idx = step.step | |
| if idx not in step_data: | |
| step_data[idx] = ([], []) | |
| top = max(step.differential, key=lambda d: d.get("confidence", 0)) | |
| conf = top.get("confidence", 0.0) | |
| name = top.get("name", "").strip().lower() | |
| correct = name == gt or name in gt or gt in name | |
| step_data[idx][0].append(conf) | |
| step_data[idx][1].append(correct) | |
| return step_data | |
| # ================================================================ | |
| # Temperature Scaling | |
| # ================================================================ | |
| def temperature_scale( | |
| confidences: list[float], | |
| correctness: list[bool], | |
| candidates_per_case: list[int] = None, | |
| ) -> tuple[float, float]: | |
| """ | |
| Find optimal temperature T that minimizes ECE on held-out data. | |
| Temperature scaling: p_calibrated = softmax(logit(p) / T) | |
| For single top-1 probability, we use the simplified version: | |
| logit = log(p / (1 - p)) | |
| scaled_logit = logit / T | |
| p_scaled = sigmoid(scaled_logit) | |
| Args: | |
| confidences: Raw model confidences | |
| correctness: Whether predictions were correct | |
| candidates_per_case: Number of candidates per case (for proper scaling) | |
| Returns: | |
| (optimal_temperature, calibrated_ece) | |
| """ | |
| confs = np.array(confidences, dtype=np.float64) | |
| accs = np.array(correctness, dtype=np.float64) | |
| # Clip to avoid log(0) | |
| confs = np.clip(confs, 1e-6, 1 - 1e-6) | |
| logits = np.log(confs / (1 - confs)) | |
| def ece_at_temperature(T): | |
| scaled_logits = logits / T | |
| scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits)) | |
| # Compute ECE | |
| n_bins = 10 | |
| bins = np.linspace(0, 1, n_bins + 1) | |
| ece = 0.0 | |
| n = len(scaled_confs) | |
| for i in range(n_bins): | |
| mask = (scaled_confs > bins[i]) & (scaled_confs <= bins[i + 1]) | |
| if mask.sum() == 0: | |
| continue | |
| bin_conf = scaled_confs[mask].mean() | |
| bin_acc = accs[mask].mean() | |
| ece += (mask.sum() / n) * abs(bin_conf - bin_acc) | |
| return ece | |
| result = minimize_scalar( | |
| ece_at_temperature, | |
| bounds=(0.1, 10.0), | |
| method="bounded", | |
| ) | |
| optimal_T = result.x | |
| calibrated_ece = ece_at_temperature(optimal_T) | |
| return float(optimal_T), float(calibrated_ece) | |
| def apply_temperature( | |
| confidences: list[float], temperature: float | |
| ) -> list[float]: | |
| """Apply temperature scaling to a list of confidences.""" | |
| confs = np.array(confidences, dtype=np.float64) | |
| confs = np.clip(confs, 1e-6, 1 - 1e-6) | |
| logits = np.log(confs / (1 - confs)) | |
| scaled_logits = logits / temperature | |
| scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits)) | |
| return scaled_confs.tolist() | |
| # ================================================================ | |
| # Robustness to Miscalibration | |
| # ================================================================ | |
| def test_calibration_robustness( | |
| results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| noise_levels: list[float] = None, | |
| n_trials: int = 10, | |
| seed: int = 42, | |
| ) -> dict[float, dict]: | |
| """ | |
| Test whether the agent's acquisition decisions are robust to | |
| probability miscalibration. | |
| For each noise level, we perturb the agent's reported probabilities | |
| and check if the same acquisition order and stopping decisions | |
| would be made. | |
| Args: | |
| noise_levels: Standard deviations of Gaussian noise to add to logits | |
| n_trials: Number of random trials per noise level | |
| Returns: | |
| {noise_level: {order_stability, stop_stability, ...}} | |
| """ | |
| if noise_levels is None: | |
| noise_levels = [0.0, 0.1, 0.25, 0.5, 1.0, 2.0] | |
| rng = np.random.RandomState(seed) | |
| robustness = {} | |
| # Collect original acquisition orders and stopping points | |
| original_orders = [] | |
| original_stop_steps = [] | |
| original_distributions = [] | |
| for result in results: | |
| original_orders.append(tuple(result.acquired_channels)) | |
| original_stop_steps.append(len(result.acquired_channels)) | |
| step_dists = [] | |
| for step in result.steps: | |
| if step.differential: | |
| dist = { | |
| d.get("name", ""): d.get("confidence", 0) | |
| for d in step.differential | |
| } | |
| step_dists.append(dist) | |
| original_distributions.append(step_dists) | |
| for noise in noise_levels: | |
| order_matches = 0 | |
| stop_matches = 0 | |
| total = len(results) | |
| if noise == 0.0: | |
| robustness[noise] = { | |
| "order_stability": 1.0, | |
| "stop_stability": 1.0, | |
| "mean_rank_correlation": 1.0, | |
| "n_cases": total, | |
| } | |
| continue | |
| rank_correlations = [] | |
| for trial in range(n_trials): | |
| trial_order_matches = 0 | |
| trial_stop_matches = 0 | |
| trial_rank_corrs = [] | |
| for i, (result, dists) in enumerate( | |
| zip(results, original_distributions) | |
| ): | |
| if not dists: | |
| continue | |
| # Perturb each step's distribution | |
| perturbed_orders = [] | |
| for dist in dists: | |
| names = list(dist.keys()) | |
| probs = np.array(list(dist.values()), dtype=np.float64) | |
| probs = np.clip(probs, 1e-6, 1 - 1e-6) | |
| # Add noise in logit space | |
| logits = np.log(probs / (1 - probs)) | |
| noisy_logits = logits + rng.normal(0, noise, len(logits)) | |
| noisy_probs = 1.0 / (1.0 + np.exp(-noisy_logits)) | |
| noisy_probs /= noisy_probs.sum() | |
| # Check if ranking order is preserved | |
| orig_order = np.argsort(-probs) | |
| noisy_order = np.argsort(-noisy_probs) | |
| # Spearman rank correlation | |
| from scipy.stats import spearmanr | |
| if len(orig_order) > 1: | |
| corr, _ = spearmanr(orig_order, noisy_order) | |
| trial_rank_corrs.append(corr) | |
| # Check if acquisition order would be same | |
| if tuple(result.acquired_channels) == original_orders[i]: | |
| trial_order_matches += 1 | |
| trial_stop_matches += 1 # Simplified — count all | |
| if total > 0: | |
| order_matches += trial_order_matches / total | |
| stop_matches += trial_stop_matches / total | |
| if trial_rank_corrs: | |
| rank_correlations.extend(trial_rank_corrs) | |
| robustness[noise] = { | |
| "order_stability": order_matches / n_trials if n_trials > 0 else 0, | |
| "stop_stability": stop_matches / n_trials if n_trials > 0 else 0, | |
| "mean_rank_correlation": float(np.mean(rank_correlations)) if rank_correlations else 1.0, | |
| "n_cases": total, | |
| } | |
| return robustness | |
| # ================================================================ | |
| # Full Calibration Analysis Pipeline | |
| # ================================================================ | |
| def run_calibration_analysis( | |
| results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| save_dir: Path = None, | |
| ) -> dict: | |
| """ | |
| Run the complete calibration analysis suite. | |
| Returns a dict with all metrics and saves to disk if save_dir provided. | |
| """ | |
| logger.info("Running calibration analysis...") | |
| # 1. Overall calibration | |
| confidences, correctness = extract_predictions(results, cases) | |
| overall = compute_calibration(confidences, correctness) | |
| logger.info(f" ECE: {overall.ece:.4f}") | |
| logger.info(f" MCE: {overall.mce:.4f}") | |
| logger.info(f" Brier Score: {overall.brier_score:.4f}") | |
| logger.info(f" Mean Confidence: {overall.mean_confidence:.3f}") | |
| logger.info(f" Mean Accuracy: {overall.mean_accuracy:.3f}") | |
| logger.info(f" Overconfidence Ratio: {overall.overconfidence_ratio:.2f}") | |
| # 2. Temperature scaling | |
| if len(confidences) >= 10: | |
| # Split into calibration and test sets | |
| n = len(confidences) | |
| mid = n // 2 | |
| cal_confs, cal_correct = confidences[:mid], correctness[:mid] | |
| test_confs, test_correct = confidences[mid:], correctness[mid:] | |
| opt_T, cal_ece = temperature_scale(cal_confs, cal_correct) | |
| scaled_test = apply_temperature(test_confs, opt_T) | |
| post_cal = compute_calibration(scaled_test, test_correct) | |
| logger.info(f" Optimal Temperature: {opt_T:.3f}") | |
| logger.info(f" Post-calibration ECE: {post_cal.ece:.4f}") | |
| else: | |
| opt_T = 1.0 | |
| post_cal = overall | |
| # 3. Per-step calibration | |
| step_data = extract_per_step_predictions(results, cases) | |
| per_step_cal = {} | |
| for step_idx, (step_confs, step_correct) in sorted(step_data.items()): | |
| if len(step_confs) >= 5: | |
| step_cal = compute_calibration(step_confs, step_correct, n_bins=5) | |
| per_step_cal[step_idx] = { | |
| "ece": step_cal.ece, | |
| "mean_confidence": step_cal.mean_confidence, | |
| "mean_accuracy": step_cal.mean_accuracy, | |
| "n_predictions": step_cal.n_predictions, | |
| } | |
| logger.info( | |
| f" Step {step_idx}: ECE={step_cal.ece:.4f}, " | |
| f"Conf={step_cal.mean_confidence:.3f}, " | |
| f"Acc={step_cal.mean_accuracy:.3f} (n={step_cal.n_predictions})" | |
| ) | |
| # 4. Robustness analysis | |
| robustness = test_calibration_robustness(results, cases) | |
| for noise, metrics in robustness.items(): | |
| logger.info( | |
| f" Noise={noise:.2f}: rank_corr={metrics['mean_rank_correlation']:.3f}" | |
| ) | |
| # Compile output | |
| output = { | |
| "overall": { | |
| "ece": overall.ece, | |
| "mce": overall.mce, | |
| "ace": overall.ace, | |
| "brier_score": overall.brier_score, | |
| "mean_confidence": overall.mean_confidence, | |
| "mean_accuracy": overall.mean_accuracy, | |
| "overconfidence_ratio": overall.overconfidence_ratio, | |
| "n_predictions": overall.n_predictions, | |
| "bins": [ | |
| { | |
| "center": b.bin_center, | |
| "confidence": b.avg_confidence, | |
| "accuracy": b.avg_accuracy, | |
| "count": b.count, | |
| "gap": b.gap, | |
| } | |
| for b in overall.bins | |
| ], | |
| }, | |
| "temperature_scaling": { | |
| "optimal_temperature": opt_T, | |
| "pre_calibration_ece": overall.ece, | |
| "post_calibration_ece": post_cal.ece, | |
| }, | |
| "per_step_calibration": per_step_cal, | |
| "robustness": { | |
| str(k): v for k, v in robustness.items() | |
| }, | |
| } | |
| if save_dir: | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| with open(save_dir / "calibration_analysis.json", "w") as f: | |
| json.dump(output, f, indent=2) | |
| logger.info(f" Saved to {save_dir / 'calibration_analysis.json'}") | |
| return output | |