File size: 1,444 Bytes
fc329a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """Partition (grouped) conformal prediction."""
import numpy as np
from .base import ConformalResult
from ._split_quantile import split_conformal_quantile
def partition_conformal(
R_cal: np.ndarray,
R_test: np.ndarray,
alpha: float,
strata_cal: np.ndarray,
strata_test: np.ndarray,
) -> ConformalResult:
"""Grouped conformal with per-stratum thresholds.
Args:
R_cal: calibration residuals (n_cal,)
R_test: test residuals (n_test,)
alpha: miscoverage level
strata_cal: integer group labels for cal (n_cal,)
strata_test: integer group labels for test (n_test,)
Returns:
ConformalResult with per-group coverage guarantee.
"""
covered = np.zeros(len(R_test), dtype=bool)
radius = np.full(len(R_test), np.inf)
thresholds = {}
for s in np.unique(strata_test):
cal_mask = strata_cal == s
test_mask = strata_test == s
R_cal_s = R_cal[cal_mask]
n_s = len(R_cal_s)
if n_s == 0:
# No cal points in this group — cover conservatively
covered[test_mask] = True
thresholds[int(s)] = np.inf
continue
q_s = split_conformal_quantile(R_cal_s, alpha)
covered[test_mask] = R_test[test_mask] <= q_s
radius[test_mask] = q_s
thresholds[int(s)] = float(q_s)
return ConformalResult(covered=covered, radius=radius, threshold=thresholds)
|