| """Weighted conformal prediction under test-dependent importance weights.""" |
| import numpy as np |
|
|
| from .base import ConformalResult |
|
|
|
|
| def _weighted_quantile_with_test_mass( |
| values: np.ndarray, |
| weights: np.ndarray, |
| alpha: float, |
| test_weight: float, |
| ) -> float: |
| """Compute the weighted conformal threshold for one test point. |
| |
| The test point contributes mass at +inf, matching the weighted split |
| conformal recipe used under covariate shift. |
| """ |
| values = np.asarray(values, dtype=float) |
| weights = np.asarray(weights, dtype=float) |
|
|
| order = np.argsort(values) |
| sorted_values = values[order] |
| sorted_weights = weights[order] |
|
|
| total_weight = sorted_weights.sum() + test_weight |
| cutoff = (1.0 - alpha) * total_weight |
| csum = np.cumsum(sorted_weights) |
| idx = np.searchsorted(csum, cutoff, side="left") |
| if idx >= len(sorted_values): |
| return np.inf |
| return float(sorted_values[idx]) |
|
|
|
|
| def weighted_conformal( |
| R_cal: np.ndarray, |
| R_test: np.ndarray, |
| alpha: float, |
| weights_cal: np.ndarray, |
| weights_test: np.ndarray | None = None, |
| ) -> ConformalResult: |
| """Weighted split conformal. |
| |
| Args: |
| R_cal: calibration residuals (n_cal,) |
| R_test: test residuals (n_test,) |
| alpha: miscoverage level |
| weights_cal: nonnegative calibration weights |
| weights_test: nonnegative test weights. If omitted, uses ones. |
| |
| Returns: |
| ConformalResult with test-specific weighted thresholds. |
| """ |
| weights_cal = np.asarray(weights_cal, dtype=float) |
| if weights_cal.shape != np.asarray(R_cal).shape: |
| raise ValueError("weights_cal must have the same shape as R_cal") |
| if not np.all(np.isfinite(weights_cal)): |
| raise ValueError("weights_cal must be finite") |
| if np.any(weights_cal < 0): |
| raise ValueError("weights_cal must be nonnegative") |
| if float(weights_cal.sum()) <= 0.0: |
| raise ValueError("weights_cal must have positive total mass") |
|
|
| if weights_test is None: |
| weights_test = np.ones(len(R_test), dtype=float) |
| else: |
| weights_test = np.asarray(weights_test, dtype=float) |
| if weights_test.shape != np.asarray(R_test).shape: |
| raise ValueError("weights_test must have the same shape as R_test") |
| if not np.all(np.isfinite(weights_test)): |
| raise ValueError("weights_test must be finite") |
| if np.any(weights_test < 0): |
| raise ValueError("weights_test must be nonnegative") |
|
|
| radius = np.array([ |
| _weighted_quantile_with_test_mass(R_cal, weights_cal, alpha, w_t) |
| for w_t in weights_test |
| ]) |
| covered = R_test <= radius |
| return ConformalResult(covered=covered, radius=radius, threshold=radius.copy()) |
|
|