from .backend import xp, _rng def sample_canonical(probs, seed, token_idx=0): # Counter-based deterministic sampling: # a single uniform u in (0,1) per token, then left-search on CDF. # Seed should be set globally outside for full runs; kept in signature for clarity. u = _rng.random((), dtype=probs.dtype) cdf = xp.cumsum(probs, dtype=probs.dtype) cdf = xp.minimum(cdf, probs.dtype.type(1.0)) idx = int(xp.searchsorted(cdf, u, side="left")) return idx