File size: 1,444 Bytes
fc329a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Partition (grouped) conformal prediction."""
import numpy as np
from .base import ConformalResult
from ._split_quantile import split_conformal_quantile


def partition_conformal(
    R_cal: np.ndarray,
    R_test: np.ndarray,
    alpha: float,
    strata_cal: np.ndarray,
    strata_test: np.ndarray,
) -> ConformalResult:
    """Grouped conformal with per-stratum thresholds.

    Args:
        R_cal: calibration residuals (n_cal,)
        R_test: test residuals (n_test,)
        alpha: miscoverage level
        strata_cal: integer group labels for cal (n_cal,)
        strata_test: integer group labels for test (n_test,)

    Returns:
        ConformalResult with per-group coverage guarantee.
    """
    covered = np.zeros(len(R_test), dtype=bool)
    radius = np.full(len(R_test), np.inf)
    thresholds = {}

    for s in np.unique(strata_test):
        cal_mask = strata_cal == s
        test_mask = strata_test == s
        R_cal_s = R_cal[cal_mask]
        n_s = len(R_cal_s)

        if n_s == 0:
            # No cal points in this group — cover conservatively
            covered[test_mask] = True
            thresholds[int(s)] = np.inf
            continue

        q_s = split_conformal_quantile(R_cal_s, alpha)
        covered[test_mask] = R_test[test_mask] <= q_s
        radius[test_mask] = q_s
        thresholds[int(s)] = float(q_s)

    return ConformalResult(covered=covered, radius=radius, threshold=thresholds)