MelodyDeterminism-Demo / sampling.py
Simo76's picture
Upload 9 files
493de78 verified
raw
history blame contribute delete
486 Bytes
from .backend import xp, _rng
def sample_canonical(probs, seed, token_idx=0):
# Counter-based deterministic sampling:
# a single uniform u in (0,1) per token, then left-search on CDF.
# Seed should be set globally outside for full runs; kept in signature for clarity.
u = _rng.random((), dtype=probs.dtype)
cdf = xp.cumsum(probs, dtype=probs.dtype)
cdf = xp.minimum(cdf, probs.dtype.type(1.0))
idx = int(xp.searchsorted(cdf, u, side="left"))
return idx