simplexuq-code / src /dgp /discrete_groups.py
anonymous0523ly's picture
Initial anonymous code release
fc329a3 verified
raw
history blame
1.38 kB
"""DGP with discrete heterogeneity aligned to coarse simplex groups."""
import numpy as np
from .base import BaseDGP, DGPSample
from .pure_scale import PureScaleDGP
from ..utils.simplex import aitchison_dist, ilr, ilr_inv
class DiscreteGroupsDGP(PureScaleDGP):
"""Step-function scale heterogeneity based on the predicted top class."""
def __init__(
self,
K: int = 10,
sigma_low: float = 0.08,
sigma_high: float = 0.30,
d_x: int = 5,
easy_classes: int = 5,
):
super().__init__(K=K, sigma_min=sigma_low, c=sigma_high - sigma_low, d_x=d_x)
self.sigma_low = sigma_low
self.sigma_high = sigma_high
self.easy_classes = easy_classes
def _sigma(self, u: np.ndarray) -> np.ndarray:
top_class = np.argmax(u, axis=1)
is_easy = top_class < self.easy_classes
return np.where(is_easy, self.sigma_low, self.sigma_high)
def sample(self, n: int, rng: np.random.Generator) -> DGPSample:
self._init_weights(rng)
X = rng.standard_normal((n, self.d_x))
mu = self._mu(X)
sigma = self._sigma(mu)
Z_mu = ilr(mu)
eps = rng.standard_normal((n, self.K - 1))
Y = ilr_inv(Z_mu + sigma[:, None] * eps, K=self.K)
U = mu
R = aitchison_dist(Y, U)
return DGPSample(X=X, Y=Y, U=U, R=R, sigma_true=sigma)