File size: 537 Bytes
493de78
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from .backend import xp
from .deterministic import reduce_tree_fixed

def softmax_canonical(logits, mask=None):
    # Stable softmax with deterministic max/sum reductions and mask support.
    x = logits.astype(xp.float64, copy=False)
    if mask is not None:
        x = xp.where(mask.astype(bool), x, -xp.inf)
    m = reduce_tree_fixed(x, op="max")
    z = xp.exp(x - m)
    s = reduce_tree_fixed(z, op="sum")
    p = z / s
    p = xp.where(xp.isfinite(p), p, 0.0)  # handle all -inf case
    return p.astype(logits.dtype, copy=False)