from .backend import xp from .deterministic import reduce_tree_fixed def softmax_canonical(logits, mask=None): # Stable softmax with deterministic max/sum reductions and mask support. x = logits.astype(xp.float64, copy=False) if mask is not None: x = xp.where(mask.astype(bool), x, -xp.inf) m = reduce_tree_fixed(x, op="max") z = xp.exp(x - m) s = reduce_tree_fixed(z, op="sum") p = z / s p = xp.where(xp.isfinite(p), p, 0.0) # handle all -inf case return p.astype(logits.dtype, copy=False)