| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| |
|
| | from ...ops.utils.op import exp, log |
| |
|
| |
|
| | @triton.heuristics({ |
| | 'HAS_SCALE': lambda args: args['scale'] is not None |
| | }) |
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=num_warps) |
| | for num_warps in [1, 2, 4, 8, 16, 32] |
| | ], |
| | key=['D'] |
| | ) |
| | @triton.jit |
| | def logsumexp_fwd_kernel( |
| | x, |
| | z, |
| | scale, |
| | D: tl.constexpr, |
| | B: tl.constexpr, |
| | HAS_SCALE: tl.constexpr |
| | ): |
| | i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) |
| | o_d = i_d * B + tl.arange(0, B) |
| | m_d = o_d < D |
| |
|
| | b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) |
| | if HAS_SCALE: |
| | b_x = b_x * scale |
| | b_m = tl.max(b_x, 0) |
| | b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m |
| | tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) |
| |
|
| |
|
| | def logsumexp_fwd( |
| | x, |
| | scale: Optional[float] = None, |
| | dtype: Optional[torch.dtype] = None |
| | ): |
| | r""" |
| | Compute the logsumexp of the input tensor over the last dimension. |
| | |
| | Args: |
| | x (Tensor): |
| | The input tensor of any shape. |
| | scale (Optional[float]): |
| | The scale applied to the input tensor. Default: `None`. |
| | dtype (Optional[torch.dtype]): |
| | The data type of the output tensor. Default: `None`. |
| | Returns: |
| | Tensor: The logsumexp of the input tensor. |
| | """ |
| |
|
| | shape = x.shape |
| | x = x.view(-1, shape[-1]) |
| | N, D = x.shape |
| | B = min(triton.next_power_of_2(D), 64 * 1024) |
| | ND = triton.cdiv(D, B) |
| |
|
| | z = x.new_empty(N, ND, dtype=torch.float) |
| | logsumexp_fwd_kernel[(N, ND)]( |
| | x=x, |
| | z=z, |
| | scale=scale, |
| | D=D, |
| | B=B |
| | ) |
| | z = z.logsumexp(-1).view(*shape[:-1]) |
| | if dtype is not None and dtype != torch.float: |
| | z = z.to(dtype) |
| | return z |
| |
|