File size: 3,531 Bytes
fc329a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Prediction-space stratification utilities."""
import numpy as np
from .simplex import entropy, ilr
from sklearn.cluster import KMeans


def stratify_by_entropy(U: np.ndarray, n_bins: int = 5) -> np.ndarray:
    """Bin predictions by Shannon entropy."""
    H = entropy(U)
    return np.digitize(H, np.linspace(H.min(), H.max(), n_bins + 1)[1:-1])


def stratify_by_boundary(U: np.ndarray, n_bins: int = 5) -> np.ndarray:
    """Bin predictions by proximity to simplex boundary (min component)."""
    bprox = U.min(axis=-1)
    return np.digitize(bprox, np.linspace(bprox.min(), bprox.max(), n_bins + 1)[1:-1])


def stratify_by_kmeans(U: np.ndarray, n_clusters: int = 5, seed: int = 42) -> np.ndarray:
    """K-means clustering in ILR space."""
    Z = ilr(U)
    km = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
    return km.fit_predict(Z)


def stratify_by_argmax_group(U: np.ndarray, split_index: int = 5) -> np.ndarray:
    """Two-group stratification based on whether argmax(U) is before split_index."""
    top_class = np.argmax(U, axis=1)
    return (top_class >= split_index).astype(int)


def _quantile_edges(values: np.ndarray, n_bins: int) -> np.ndarray:
    """Return stable interior quantile edges for binning a scalar score."""
    if n_bins <= 1:
        return np.array([], dtype=float)
    qs = np.linspace(0.0, 1.0, n_bins + 1)[1:-1]
    edges = np.quantile(values, qs)
    return np.unique(edges)


def _digitize_fixed(values: np.ndarray, n_bins: int) -> np.ndarray:
    """Digitize values using globally fixed quantile edges."""
    edges = _quantile_edges(values, n_bins)
    if edges.size == 0:
        return np.zeros(len(values), dtype=int)
    return np.digitize(values, edges)


def precompute_fixed_strata(
    U: np.ndarray,
    method: str,
    n_strata: int = 5,
    seed: int = 42,
) -> np.ndarray:
    """Precompute a fixed stratification on a full cached prediction matrix.

    This is intended for repeated cal/test splits of the same frozen task, where
    the stratification rule should remain constant across repetitions.

    Args:
        U: full prediction matrix with shape (n, K).
        method: one of {"entropy", "boundary", "dominant", "kmeans", "random"}.
        n_strata: target number of strata. For "dominant", this is treated as
            the number of grouped dominant-component bins; if it is at least K,
            each dominant component gets its own stratum.
        seed: random seed for methods that require stochastic initialization.

    Returns:
        Integer stratum labels of shape (n,).
    """
    method = method.lower()

    if method == "entropy":
        return _digitize_fixed(entropy(U), n_strata)

    if method == "boundary":
        return _digitize_fixed(U.min(axis=-1), n_strata)

    if method == "dominant":
        top = np.argmax(U, axis=1)
        k = U.shape[1]
        n_groups = min(max(int(n_strata), 1), k)
        if n_groups >= k:
            return top.astype(int)
        return np.floor(top * n_groups / k).astype(int)

    if method == "random":
        z = ilr(U)
        rng = np.random.default_rng(seed)
        direction = rng.normal(size=z.shape[1])
        direction /= np.linalg.norm(direction) + 1e-12
        score = z @ direction
        return _digitize_fixed(score, n_strata)

    if method == "kmeans":
        z = ilr(U)
        km = KMeans(n_clusters=n_strata, random_state=seed, n_init=10)
        return km.fit_predict(z)

    raise ValueError(f"Unknown fixed stratification method: {method}")