File size: 486 Bytes
493de78
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from .backend import xp, _rng

def sample_canonical(probs, seed, token_idx=0):
    # Counter-based deterministic sampling:
    # a single uniform u in (0,1) per token, then left-search on CDF.
    # Seed should be set globally outside for full runs; kept in signature for clarity.
    u = _rng.random((), dtype=probs.dtype)
    cdf = xp.cumsum(probs, dtype=probs.dtype)
    cdf = xp.minimum(cdf, probs.dtype.type(1.0))
    idx = int(xp.searchsorted(cdf, u, side="left"))
    return idx