| """Partition (grouped) conformal prediction.""" |
| import numpy as np |
| from .base import ConformalResult |
| from ._split_quantile import split_conformal_quantile |
|
|
|
|
| def partition_conformal( |
| R_cal: np.ndarray, |
| R_test: np.ndarray, |
| alpha: float, |
| strata_cal: np.ndarray, |
| strata_test: np.ndarray, |
| ) -> ConformalResult: |
| """Grouped conformal with per-stratum thresholds. |
| |
| Args: |
| R_cal: calibration residuals (n_cal,) |
| R_test: test residuals (n_test,) |
| alpha: miscoverage level |
| strata_cal: integer group labels for cal (n_cal,) |
| strata_test: integer group labels for test (n_test,) |
| |
| Returns: |
| ConformalResult with per-group coverage guarantee. |
| """ |
| covered = np.zeros(len(R_test), dtype=bool) |
| radius = np.full(len(R_test), np.inf) |
| thresholds = {} |
|
|
| for s in np.unique(strata_test): |
| cal_mask = strata_cal == s |
| test_mask = strata_test == s |
| R_cal_s = R_cal[cal_mask] |
| n_s = len(R_cal_s) |
|
|
| if n_s == 0: |
| |
| covered[test_mask] = True |
| thresholds[int(s)] = np.inf |
| continue |
|
|
| q_s = split_conformal_quantile(R_cal_s, alpha) |
| covered[test_mask] = R_test[test_mask] <= q_s |
| radius[test_mask] = q_s |
| thresholds[int(s)] = float(q_s) |
|
|
| return ConformalResult(covered=covered, radius=radius, threshold=thresholds) |
|
|