Simo76's picture
Upload 9 files
493de78 verified
raw
history blame contribute delete
537 Bytes
from .backend import xp
from .deterministic import reduce_tree_fixed
def softmax_canonical(logits, mask=None):
# Stable softmax with deterministic max/sum reductions and mask support.
x = logits.astype(xp.float64, copy=False)
if mask is not None:
x = xp.where(mask.astype(bool), x, -xp.inf)
m = reduce_tree_fixed(x, op="max")
z = xp.exp(x - m)
s = reduce_tree_fixed(z, op="sum")
p = z / s
p = xp.where(xp.isfinite(p), p, 0.0) # handle all -inf case
return p.astype(logits.dtype, copy=False)