"""Prediction-space stratification utilities.""" import numpy as np from .simplex import entropy, ilr from sklearn.cluster import KMeans def stratify_by_entropy(U: np.ndarray, n_bins: int = 5) -> np.ndarray: """Bin predictions by Shannon entropy.""" H = entropy(U) return np.digitize(H, np.linspace(H.min(), H.max(), n_bins + 1)[1:-1]) def stratify_by_boundary(U: np.ndarray, n_bins: int = 5) -> np.ndarray: """Bin predictions by proximity to simplex boundary (min component).""" bprox = U.min(axis=-1) return np.digitize(bprox, np.linspace(bprox.min(), bprox.max(), n_bins + 1)[1:-1]) def stratify_by_kmeans(U: np.ndarray, n_clusters: int = 5, seed: int = 42) -> np.ndarray: """K-means clustering in ILR space.""" Z = ilr(U) km = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10) return km.fit_predict(Z) def stratify_by_argmax_group(U: np.ndarray, split_index: int = 5) -> np.ndarray: """Two-group stratification based on whether argmax(U) is before split_index.""" top_class = np.argmax(U, axis=1) return (top_class >= split_index).astype(int) def _quantile_edges(values: np.ndarray, n_bins: int) -> np.ndarray: """Return stable interior quantile edges for binning a scalar score.""" if n_bins <= 1: return np.array([], dtype=float) qs = np.linspace(0.0, 1.0, n_bins + 1)[1:-1] edges = np.quantile(values, qs) return np.unique(edges) def _digitize_fixed(values: np.ndarray, n_bins: int) -> np.ndarray: """Digitize values using globally fixed quantile edges.""" edges = _quantile_edges(values, n_bins) if edges.size == 0: return np.zeros(len(values), dtype=int) return np.digitize(values, edges) def precompute_fixed_strata( U: np.ndarray, method: str, n_strata: int = 5, seed: int = 42, ) -> np.ndarray: """Precompute a fixed stratification on a full cached prediction matrix. This is intended for repeated cal/test splits of the same frozen task, where the stratification rule should remain constant across repetitions. Args: U: full prediction matrix with shape (n, K). method: one of {"entropy", "boundary", "dominant", "kmeans", "random"}. n_strata: target number of strata. For "dominant", this is treated as the number of grouped dominant-component bins; if it is at least K, each dominant component gets its own stratum. seed: random seed for methods that require stochastic initialization. Returns: Integer stratum labels of shape (n,). """ method = method.lower() if method == "entropy": return _digitize_fixed(entropy(U), n_strata) if method == "boundary": return _digitize_fixed(U.min(axis=-1), n_strata) if method == "dominant": top = np.argmax(U, axis=1) k = U.shape[1] n_groups = min(max(int(n_strata), 1), k) if n_groups >= k: return top.astype(int) return np.floor(top * n_groups / k).astype(int) if method == "random": z = ilr(U) rng = np.random.default_rng(seed) direction = rng.normal(size=z.shape[1]) direction /= np.linalg.norm(direction) + 1e-12 score = z @ direction return _digitize_fixed(score, n_strata) if method == "kmeans": z = ilr(U) km = KMeans(n_clusters=n_strata, random_state=seed, n_init=10) return km.fit_predict(z) raise ValueError(f"Unknown fixed stratification method: {method}")