panacea-api / src /evaluation /conformal.py
DTanzillo's picture
Upload folder using huggingface_hub
a4b5ecb verified
# Generated by Claude Code — 2026-02-13
"""Conformal prediction for calibrated risk bounds.
Provides distribution-free prediction sets with guaranteed marginal coverage:
P(true_label ∈ prediction_set) ≥ 1 - alpha
This directly addresses NASA CARA's criticism about uncertainty quantification
in ML-based collision risk assessment. Instead of a single probability, we
output a prediction set (e.g., {LOW, MODERATE}) that provably covers the
true risk tier at the specified confidence level.
Method: Split conformal prediction (Vovk et al. 2005, Lei et al. 2018)
- Calibrate on a held-out set separate from training AND model selection
- Compute nonconformity scores
- Use quantile of calibration scores to construct prediction sets at test time
References:
- Vovk, Gammerman, Shafer (2005) "Algorithmic Learning in a Random World"
- Lei et al. (2018) "Distribution-Free Predictive Inference for Regression"
- Angelopoulos & Bates (2021) "A Gentle Introduction to Conformal Prediction"
"""
import numpy as np
from dataclasses import dataclass
@dataclass
class ConformalResult:
"""Result of conformal prediction for a single example."""
prediction_set: list[str] # e.g., ["LOW", "MODERATE"]
set_size: int # |prediction_set|
risk_prob: float # raw model probability
lower_bound: float # lower probability bound
upper_bound: float # upper probability bound
class ConformalPredictor:
"""Split conformal prediction for binary risk classification.
Workflow:
1. Train model on training set
2. Select model (early stopping) on validation set
3. calibrate() on a SEPARATE calibration set (held out from validation)
4. predict() on test data with coverage guarantee
The calibration set must NOT be used for training or model selection,
otherwise the coverage guarantee is invalidated.
"""
# Risk tiers with thresholds
TIERS = {
"LOW": (0.0, 0.10),
"MODERATE": (0.10, 0.40),
"HIGH": (0.40, 0.70),
"CRITICAL": (0.70, 1.0),
}
def __init__(self):
self.quantile_lower = None # q_hat for lower bound
self.quantile_upper = None # q_hat for upper bound
self.alpha = None
self.n_cal = 0
self.is_calibrated = False
def calibrate(
self,
cal_probs: np.ndarray,
cal_labels: np.ndarray,
alpha: float = 0.10,
) -> dict:
"""Calibrate conformal predictor on held-out calibration set.
Args:
cal_probs: Model predicted probabilities on calibration set, shape (n,)
cal_labels: True binary labels on calibration set, shape (n,)
alpha: Desired miscoverage rate. 1-alpha = coverage level.
alpha=0.10 → 90% coverage guarantee.
Returns:
Calibration summary dict with quantiles and statistics
"""
n = len(cal_probs)
if n < 10:
raise ValueError(f"Calibration set too small: {n} examples (need >= 10)")
self.alpha = alpha
self.n_cal = n
# Nonconformity score: how "wrong" is the model on each calibration example?
# For binary classification with probabilities:
# score = 1 - P(true class)
# High score = model is wrong/uncertain
scores = np.where(
cal_labels == 1,
1.0 - cal_probs, # positive: score = 1 - P(positive)
cal_probs, # negative: score = P(positive) = 1 - P(negative)
)
# Conformal quantile: includes finite-sample correction
# q_hat = ceil((n+1)(1-alpha))/n -th quantile of scores
adjusted_level = np.ceil((n + 1) * (1 - alpha)) / n
adjusted_level = min(adjusted_level, 1.0)
self.q_hat = float(np.quantile(scores, adjusted_level))
# For prediction intervals on the probability itself:
# We also compute quantiles for constructing upper/lower prob bounds
# Using calibration residuals: |P(positive) - is_positive|
residuals = np.abs(cal_probs - cal_labels.astype(float))
self.q_residual = float(np.quantile(residuals, adjusted_level))
self.is_calibrated = True
# Report calibration statistics
empirical_coverage = np.mean(scores <= self.q_hat)
summary = {
"alpha": alpha,
"target_coverage": 1 - alpha,
"n_calibration": n,
"q_hat": self.q_hat,
"q_residual": self.q_residual,
"empirical_coverage_cal": float(empirical_coverage),
"mean_score": float(scores.mean()),
"median_score": float(np.median(scores)),
"cal_pos_rate": float(cal_labels.mean()),
}
print(f" Conformal calibration (alpha={alpha}):")
print(f" Calibration set: {n} examples ({cal_labels.sum():.0f} positive)")
print(f" q_hat (nonconformity): {self.q_hat:.4f}")
print(f" q_residual: {self.q_residual:.4f}")
print(f" Empirical coverage (cal): {empirical_coverage:.4f}")
return summary
def predict(self, test_probs: np.ndarray) -> list[ConformalResult]:
"""Produce conformal prediction sets for test examples.
For each test example, returns:
- Prediction set: set of risk tiers that could contain the true risk
- Probability bounds: [lower, upper] interval on the true probability
Coverage guarantee: P(true_tier ∈ prediction_set) ≥ 1 - alpha
"""
if not self.is_calibrated:
raise RuntimeError("Must call calibrate() before predict()")
results = []
for p in test_probs:
# Probability bounds from residual quantile
lower = max(0.0, p - self.q_residual)
upper = min(1.0, p + self.q_residual)
# Prediction set: all tiers that overlap with [lower, upper]
pred_set = []
for tier_name, (tier_lo, tier_hi) in self.TIERS.items():
if lower < tier_hi and upper > tier_lo:
pred_set.append(tier_name)
results.append(ConformalResult(
prediction_set=pred_set,
set_size=len(pred_set),
risk_prob=float(p),
lower_bound=lower,
upper_bound=upper,
))
return results
def evaluate(
self,
test_probs: np.ndarray,
test_labels: np.ndarray,
) -> dict:
"""Evaluate conformal prediction on test set.
Reports:
- Marginal coverage: fraction of test examples where true label
falls within prediction set
- Average set size: how informative are the predictions
- Coverage by tier: per-tier coverage (conditional coverage)
- Efficiency: 1 - (avg_set_size / n_tiers)
"""
if not self.is_calibrated:
raise RuntimeError("Must call calibrate() before evaluate()")
results = self.predict(test_probs)
# Map labels to tiers for coverage check
def label_to_tier(prob: float) -> str:
for tier_name, (lo, hi) in self.TIERS.items():
if lo <= prob < hi:
return tier_name
return "CRITICAL" # prob == 1.0
# True "tier" based on actual probability (binary: 0 or 1)
true_tiers = [label_to_tier(float(l)) for l in test_labels]
# Marginal coverage: does the prediction set contain the true tier?
covered = [
true_tier in result.prediction_set
for true_tier, result in zip(true_tiers, results)
]
marginal_coverage = np.mean(covered)
# Average set size
set_sizes = [r.set_size for r in results]
avg_set_size = np.mean(set_sizes)
# Coverage by true label value
pos_mask = test_labels == 1
neg_mask = test_labels == 0
pos_coverage = np.mean([c for c, m in zip(covered, pos_mask) if m]) if pos_mask.sum() > 0 else 0.0
neg_coverage = np.mean([c for c, m in zip(covered, neg_mask) if m]) if neg_mask.sum() > 0 else 0.0
# Set size distribution
size_counts = {}
for s in set_sizes:
size_counts[s] = size_counts.get(s, 0) + 1
# Efficiency: lower set sizes = more informative
efficiency = 1.0 - (avg_set_size / len(self.TIERS))
# Interval width statistics
widths = [r.upper_bound - r.lower_bound for r in results]
metrics = {
"alpha": self.alpha,
"target_coverage": 1 - self.alpha,
"marginal_coverage": float(marginal_coverage),
"coverage_guarantee_met": bool(marginal_coverage >= (1 - self.alpha - 0.01)),
"avg_set_size": float(avg_set_size),
"efficiency": float(efficiency),
"positive_coverage": float(pos_coverage),
"negative_coverage": float(neg_coverage),
"set_size_distribution": {str(k): v for k, v in sorted(size_counts.items())},
"n_test": len(test_labels),
"mean_interval_width": float(np.mean(widths)),
"median_interval_width": float(np.median(widths)),
}
print(f"\n Conformal Prediction Evaluation (alpha={self.alpha}):")
print(f" Target coverage: {1 - self.alpha:.1%}")
print(f" Marginal coverage: {marginal_coverage:.1%} "
f"{'OK' if metrics['coverage_guarantee_met'] else 'VIOLATION'}")
print(f" Positive coverage: {pos_coverage:.1%}")
print(f" Negative coverage: {neg_coverage:.1%}")
print(f" Avg set size: {avg_set_size:.2f} / {len(self.TIERS)} tiers")
print(f" Efficiency: {efficiency:.1%}")
print(f" Mean interval: [{np.mean([r.lower_bound for r in results]):.3f}, "
f"{np.mean([r.upper_bound for r in results]):.3f}]")
print(f" Set size dist: {size_counts}")
return metrics
def save_state(self) -> dict:
"""Serialize calibration state for checkpoint saving."""
if not self.is_calibrated:
return {"is_calibrated": False}
return {
"is_calibrated": True,
"alpha": self.alpha,
"q_hat": self.q_hat,
"q_residual": self.q_residual,
"n_cal": self.n_cal,
"tiers": {k: list(v) for k, v in self.TIERS.items()},
}
@classmethod
def from_state(cls, state: dict) -> "ConformalPredictor":
"""Restore from serialized state."""
obj = cls()
if state.get("is_calibrated", False):
obj.alpha = state["alpha"]
obj.q_hat = state["q_hat"]
obj.q_residual = state["q_residual"]
obj.n_cal = state["n_cal"]
obj.is_calibrated = True
return obj
def run_conformal_at_multiple_levels(
cal_probs: np.ndarray,
cal_labels: np.ndarray,
test_probs: np.ndarray,
test_labels: np.ndarray,
alphas: list[float] = None,
) -> dict:
"""Run conformal prediction at multiple coverage levels.
Useful for reporting: "at 90% coverage, avg set size = X;
at 95%, avg set size = Y; at 99%, avg set size = Z"
"""
if alphas is None:
alphas = [0.01, 0.05, 0.10, 0.20]
all_results = {}
for alpha in alphas:
cp = ConformalPredictor()
cp.calibrate(cal_probs, cal_labels, alpha=alpha)
eval_metrics = cp.evaluate(test_probs, test_labels)
all_results[f"alpha_{alpha}"] = {
"conformal_metrics": eval_metrics,
"conformal_state": cp.save_state(),
}
return all_results