Spaces:
Sleeping
Sleeping
File size: 537 Bytes
493de78 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | from .backend import xp
from .deterministic import reduce_tree_fixed
def softmax_canonical(logits, mask=None):
# Stable softmax with deterministic max/sum reductions and mask support.
x = logits.astype(xp.float64, copy=False)
if mask is not None:
x = xp.where(mask.astype(bool), x, -xp.inf)
m = reduce_tree_fixed(x, op="max")
z = xp.exp(x - m)
s = reduce_tree_fixed(z, op="sum")
p = z / s
p = xp.where(xp.isfinite(p), p, 0.0) # handle all -inf case
return p.astype(logits.dtype, copy=False)
|