Spaces:
Sleeping
Sleeping
| from .backend import xp | |
| from .deterministic import reduce_tree_fixed | |
| def softmax_canonical(logits, mask=None): | |
| # Stable softmax with deterministic max/sum reductions and mask support. | |
| x = logits.astype(xp.float64, copy=False) | |
| if mask is not None: | |
| x = xp.where(mask.astype(bool), x, -xp.inf) | |
| m = reduce_tree_fixed(x, op="max") | |
| z = xp.exp(x - m) | |
| s = reduce_tree_fixed(z, op="sum") | |
| p = z / s | |
| p = xp.where(xp.isfinite(p), p, 0.0) # handle all -inf case | |
| return p.astype(logits.dtype, copy=False) | |