File size: 3,531 Bytes
fc329a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | """Prediction-space stratification utilities."""
import numpy as np
from .simplex import entropy, ilr
from sklearn.cluster import KMeans
def stratify_by_entropy(U: np.ndarray, n_bins: int = 5) -> np.ndarray:
"""Bin predictions by Shannon entropy."""
H = entropy(U)
return np.digitize(H, np.linspace(H.min(), H.max(), n_bins + 1)[1:-1])
def stratify_by_boundary(U: np.ndarray, n_bins: int = 5) -> np.ndarray:
"""Bin predictions by proximity to simplex boundary (min component)."""
bprox = U.min(axis=-1)
return np.digitize(bprox, np.linspace(bprox.min(), bprox.max(), n_bins + 1)[1:-1])
def stratify_by_kmeans(U: np.ndarray, n_clusters: int = 5, seed: int = 42) -> np.ndarray:
"""K-means clustering in ILR space."""
Z = ilr(U)
km = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
return km.fit_predict(Z)
def stratify_by_argmax_group(U: np.ndarray, split_index: int = 5) -> np.ndarray:
"""Two-group stratification based on whether argmax(U) is before split_index."""
top_class = np.argmax(U, axis=1)
return (top_class >= split_index).astype(int)
def _quantile_edges(values: np.ndarray, n_bins: int) -> np.ndarray:
"""Return stable interior quantile edges for binning a scalar score."""
if n_bins <= 1:
return np.array([], dtype=float)
qs = np.linspace(0.0, 1.0, n_bins + 1)[1:-1]
edges = np.quantile(values, qs)
return np.unique(edges)
def _digitize_fixed(values: np.ndarray, n_bins: int) -> np.ndarray:
"""Digitize values using globally fixed quantile edges."""
edges = _quantile_edges(values, n_bins)
if edges.size == 0:
return np.zeros(len(values), dtype=int)
return np.digitize(values, edges)
def precompute_fixed_strata(
U: np.ndarray,
method: str,
n_strata: int = 5,
seed: int = 42,
) -> np.ndarray:
"""Precompute a fixed stratification on a full cached prediction matrix.
This is intended for repeated cal/test splits of the same frozen task, where
the stratification rule should remain constant across repetitions.
Args:
U: full prediction matrix with shape (n, K).
method: one of {"entropy", "boundary", "dominant", "kmeans", "random"}.
n_strata: target number of strata. For "dominant", this is treated as
the number of grouped dominant-component bins; if it is at least K,
each dominant component gets its own stratum.
seed: random seed for methods that require stochastic initialization.
Returns:
Integer stratum labels of shape (n,).
"""
method = method.lower()
if method == "entropy":
return _digitize_fixed(entropy(U), n_strata)
if method == "boundary":
return _digitize_fixed(U.min(axis=-1), n_strata)
if method == "dominant":
top = np.argmax(U, axis=1)
k = U.shape[1]
n_groups = min(max(int(n_strata), 1), k)
if n_groups >= k:
return top.astype(int)
return np.floor(top * n_groups / k).astype(int)
if method == "random":
z = ilr(U)
rng = np.random.default_rng(seed)
direction = rng.normal(size=z.shape[1])
direction /= np.linalg.norm(direction) + 1e-12
score = z @ direction
return _digitize_fixed(score, n_strata)
if method == "kmeans":
z = ilr(U)
km = KMeans(n_clusters=n_strata, random_state=seed, n_init=10)
return km.fit_predict(z)
raise ValueError(f"Unknown fixed stratification method: {method}")
|