File size: 12,030 Bytes
a4b5ecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
# Generated by Claude Code — 2026-02-13
"""Conformal prediction for calibrated risk bounds.



Provides distribution-free prediction sets with guaranteed marginal coverage:

    P(true_label ∈ prediction_set) ≥ 1 - alpha



This directly addresses NASA CARA's criticism about uncertainty quantification

in ML-based collision risk assessment. Instead of a single probability, we

output a prediction set (e.g., {LOW, MODERATE}) that provably covers the

true risk tier at the specified confidence level.



Method: Split conformal prediction (Vovk et al. 2005, Lei et al. 2018)

- Calibrate on a held-out set separate from training AND model selection

- Compute nonconformity scores

- Use quantile of calibration scores to construct prediction sets at test time



References:

- Vovk, Gammerman, Shafer (2005) "Algorithmic Learning in a Random World"

- Lei et al. (2018) "Distribution-Free Predictive Inference for Regression"

- Angelopoulos & Bates (2021) "A Gentle Introduction to Conformal Prediction"

"""

import numpy as np
from dataclasses import dataclass


@dataclass
class ConformalResult:
    """Result of conformal prediction for a single example."""
    prediction_set: list[str]      # e.g., ["LOW", "MODERATE"]
    set_size: int                  # |prediction_set|
    risk_prob: float               # raw model probability
    lower_bound: float             # lower probability bound
    upper_bound: float             # upper probability bound


class ConformalPredictor:
    """Split conformal prediction for binary risk classification.



    Workflow:

    1. Train model on training set

    2. Select model (early stopping) on validation set

    3. calibrate() on a SEPARATE calibration set (held out from validation)

    4. predict() on test data with coverage guarantee



    The calibration set must NOT be used for training or model selection,

    otherwise the coverage guarantee is invalidated.

    """

    # Risk tiers with thresholds
    TIERS = {
        "LOW": (0.0, 0.10),
        "MODERATE": (0.10, 0.40),
        "HIGH": (0.40, 0.70),
        "CRITICAL": (0.70, 1.0),
    }

    def __init__(self):
        self.quantile_lower = None  # q_hat for lower bound
        self.quantile_upper = None  # q_hat for upper bound
        self.alpha = None
        self.n_cal = 0
        self.is_calibrated = False

    def calibrate(

        self,

        cal_probs: np.ndarray,

        cal_labels: np.ndarray,

        alpha: float = 0.10,

    ) -> dict:
        """Calibrate conformal predictor on held-out calibration set.



        Args:

            cal_probs: Model predicted probabilities on calibration set, shape (n,)

            cal_labels: True binary labels on calibration set, shape (n,)

            alpha: Desired miscoverage rate. 1-alpha = coverage level.

                   alpha=0.10 → 90% coverage guarantee.



        Returns:

            Calibration summary dict with quantiles and statistics

        """
        n = len(cal_probs)
        if n < 10:
            raise ValueError(f"Calibration set too small: {n} examples (need >= 10)")

        self.alpha = alpha
        self.n_cal = n

        # Nonconformity score: how "wrong" is the model on each calibration example?
        # For binary classification with probabilities:
        #   score = 1 - P(true class)
        # High score = model is wrong/uncertain
        scores = np.where(
            cal_labels == 1,
            1.0 - cal_probs,    # positive: score = 1 - P(positive)
            cal_probs,           # negative: score = P(positive) = 1 - P(negative)
        )

        # Conformal quantile: includes finite-sample correction
        # q_hat = ceil((n+1)(1-alpha))/n -th quantile of scores
        adjusted_level = np.ceil((n + 1) * (1 - alpha)) / n
        adjusted_level = min(adjusted_level, 1.0)
        self.q_hat = float(np.quantile(scores, adjusted_level))

        # For prediction intervals on the probability itself:
        # We also compute quantiles for constructing upper/lower prob bounds
        # Using calibration residuals: |P(positive) - is_positive|
        residuals = np.abs(cal_probs - cal_labels.astype(float))
        self.q_residual = float(np.quantile(residuals, adjusted_level))

        self.is_calibrated = True

        # Report calibration statistics
        empirical_coverage = np.mean(scores <= self.q_hat)

        summary = {
            "alpha": alpha,
            "target_coverage": 1 - alpha,
            "n_calibration": n,
            "q_hat": self.q_hat,
            "q_residual": self.q_residual,
            "empirical_coverage_cal": float(empirical_coverage),
            "mean_score": float(scores.mean()),
            "median_score": float(np.median(scores)),
            "cal_pos_rate": float(cal_labels.mean()),
        }

        print(f"  Conformal calibration (alpha={alpha}):")
        print(f"    Calibration set: {n} examples ({cal_labels.sum():.0f} positive)")
        print(f"    q_hat (nonconformity): {self.q_hat:.4f}")
        print(f"    q_residual:         {self.q_residual:.4f}")
        print(f"    Empirical coverage (cal): {empirical_coverage:.4f}")

        return summary

    def predict(self, test_probs: np.ndarray) -> list[ConformalResult]:
        """Produce conformal prediction sets for test examples.



        For each test example, returns:

        - Prediction set: set of risk tiers that could contain the true risk

        - Probability bounds: [lower, upper] interval on the true probability



        Coverage guarantee: P(true_tier ∈ prediction_set) ≥ 1 - alpha

        """
        if not self.is_calibrated:
            raise RuntimeError("Must call calibrate() before predict()")

        results = []
        for p in test_probs:
            # Probability bounds from residual quantile
            lower = max(0.0, p - self.q_residual)
            upper = min(1.0, p + self.q_residual)

            # Prediction set: all tiers that overlap with [lower, upper]
            pred_set = []
            for tier_name, (tier_lo, tier_hi) in self.TIERS.items():
                if lower < tier_hi and upper > tier_lo:
                    pred_set.append(tier_name)

            results.append(ConformalResult(
                prediction_set=pred_set,
                set_size=len(pred_set),
                risk_prob=float(p),
                lower_bound=lower,
                upper_bound=upper,
            ))

        return results

    def evaluate(

        self,

        test_probs: np.ndarray,

        test_labels: np.ndarray,

    ) -> dict:
        """Evaluate conformal prediction on test set.



        Reports:

        - Marginal coverage: fraction of test examples where true label

          falls within prediction set

        - Average set size: how informative are the predictions

        - Coverage by tier: per-tier coverage (conditional coverage)

        - Efficiency: 1 - (avg_set_size / n_tiers)

        """
        if not self.is_calibrated:
            raise RuntimeError("Must call calibrate() before evaluate()")

        results = self.predict(test_probs)

        # Map labels to tiers for coverage check
        def label_to_tier(prob: float) -> str:
            for tier_name, (lo, hi) in self.TIERS.items():
                if lo <= prob < hi:
                    return tier_name
            return "CRITICAL"  # prob == 1.0

        # True "tier" based on actual probability (binary: 0 or 1)
        true_tiers = [label_to_tier(float(l)) for l in test_labels]

        # Marginal coverage: does the prediction set contain the true tier?
        covered = [
            true_tier in result.prediction_set
            for true_tier, result in zip(true_tiers, results)
        ]
        marginal_coverage = np.mean(covered)

        # Average set size
        set_sizes = [r.set_size for r in results]
        avg_set_size = np.mean(set_sizes)

        # Coverage by true label value
        pos_mask = test_labels == 1
        neg_mask = test_labels == 0
        pos_coverage = np.mean([c for c, m in zip(covered, pos_mask) if m]) if pos_mask.sum() > 0 else 0.0
        neg_coverage = np.mean([c for c, m in zip(covered, neg_mask) if m]) if neg_mask.sum() > 0 else 0.0

        # Set size distribution
        size_counts = {}
        for s in set_sizes:
            size_counts[s] = size_counts.get(s, 0) + 1

        # Efficiency: lower set sizes = more informative
        efficiency = 1.0 - (avg_set_size / len(self.TIERS))

        # Interval width statistics
        widths = [r.upper_bound - r.lower_bound for r in results]

        metrics = {
            "alpha": self.alpha,
            "target_coverage": 1 - self.alpha,
            "marginal_coverage": float(marginal_coverage),
            "coverage_guarantee_met": bool(marginal_coverage >= (1 - self.alpha - 0.01)),
            "avg_set_size": float(avg_set_size),
            "efficiency": float(efficiency),
            "positive_coverage": float(pos_coverage),
            "negative_coverage": float(neg_coverage),
            "set_size_distribution": {str(k): v for k, v in sorted(size_counts.items())},
            "n_test": len(test_labels),
            "mean_interval_width": float(np.mean(widths)),
            "median_interval_width": float(np.median(widths)),
        }

        print(f"\n  Conformal Prediction Evaluation (alpha={self.alpha}):")
        print(f"    Target coverage:   {1 - self.alpha:.1%}")
        print(f"    Marginal coverage: {marginal_coverage:.1%} "
              f"{'OK' if metrics['coverage_guarantee_met'] else 'VIOLATION'}")
        print(f"    Positive coverage: {pos_coverage:.1%}")
        print(f"    Negative coverage: {neg_coverage:.1%}")
        print(f"    Avg set size:      {avg_set_size:.2f} / {len(self.TIERS)} tiers")
        print(f"    Efficiency:        {efficiency:.1%}")
        print(f"    Mean interval:     [{np.mean([r.lower_bound for r in results]):.3f}, "
              f"{np.mean([r.upper_bound for r in results]):.3f}]")
        print(f"    Set size dist:     {size_counts}")

        return metrics

    def save_state(self) -> dict:
        """Serialize calibration state for checkpoint saving."""
        if not self.is_calibrated:
            return {"is_calibrated": False}
        return {
            "is_calibrated": True,
            "alpha": self.alpha,
            "q_hat": self.q_hat,
            "q_residual": self.q_residual,
            "n_cal": self.n_cal,
            "tiers": {k: list(v) for k, v in self.TIERS.items()},
        }

    @classmethod
    def from_state(cls, state: dict) -> "ConformalPredictor":
        """Restore from serialized state."""
        obj = cls()
        if state.get("is_calibrated", False):
            obj.alpha = state["alpha"]
            obj.q_hat = state["q_hat"]
            obj.q_residual = state["q_residual"]
            obj.n_cal = state["n_cal"]
            obj.is_calibrated = True
        return obj


def run_conformal_at_multiple_levels(

    cal_probs: np.ndarray,

    cal_labels: np.ndarray,

    test_probs: np.ndarray,

    test_labels: np.ndarray,

    alphas: list[float] = None,

) -> dict:
    """Run conformal prediction at multiple coverage levels.



    Useful for reporting: "at 90% coverage, avg set size = X;

    at 95%, avg set size = Y; at 99%, avg set size = Z"

    """
    if alphas is None:
        alphas = [0.01, 0.05, 0.10, 0.20]

    all_results = {}
    for alpha in alphas:
        cp = ConformalPredictor()
        cp.calibrate(cal_probs, cal_labels, alpha=alpha)
        eval_metrics = cp.evaluate(test_probs, test_labels)
        all_results[f"alpha_{alpha}"] = {
            "conformal_metrics": eval_metrics,
            "conformal_state": cp.save_state(),
        }

    return all_results