"""DGP with discrete heterogeneity aligned to coarse simplex groups.""" import numpy as np from .base import BaseDGP, DGPSample from .pure_scale import PureScaleDGP from ..utils.simplex import aitchison_dist, ilr, ilr_inv class DiscreteGroupsDGP(PureScaleDGP): """Step-function scale heterogeneity based on the predicted top class.""" def __init__( self, K: int = 10, sigma_low: float = 0.08, sigma_high: float = 0.30, d_x: int = 5, easy_classes: int = 5, ): super().__init__(K=K, sigma_min=sigma_low, c=sigma_high - sigma_low, d_x=d_x) self.sigma_low = sigma_low self.sigma_high = sigma_high self.easy_classes = easy_classes def _sigma(self, u: np.ndarray) -> np.ndarray: top_class = np.argmax(u, axis=1) is_easy = top_class < self.easy_classes return np.where(is_easy, self.sigma_low, self.sigma_high) def sample(self, n: int, rng: np.random.Generator) -> DGPSample: self._init_weights(rng) X = rng.standard_normal((n, self.d_x)) mu = self._mu(X) sigma = self._sigma(mu) Z_mu = ilr(mu) eps = rng.standard_normal((n, self.K - 1)) Y = ilr_inv(Z_mu + sigma[:, None] * eps, K=self.K) U = mu R = aitchison_dist(Y, U) return DGPSample(X=X, Y=Y, U=U, R=R, sigma_true=sigma)