| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from fla.ops.utils.op import exp, log |
|
|
|
|
| @triton.heuristics({ |
| 'HAS_SCALE': lambda args: args['scale'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=num_warps) |
| for num_warps in [1, 2, 4, 8, 16, 32] |
| ], |
| key=['D'] |
| ) |
| @triton.jit |
| def logsumexp_fwd_kernel( |
| x, |
| z, |
| scale, |
| D: tl.constexpr, |
| B: tl.constexpr, |
| HAS_SCALE: tl.constexpr |
| ): |
| i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) |
| o_d = i_d * B + tl.arange(0, B) |
| m_d = o_d < D |
|
|
| b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) |
| if HAS_SCALE: |
| b_x = b_x * scale |
| b_m = tl.max(b_x, 0) |
| b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m |
| tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) |
|
|
|
|
| def logsumexp_fwd( |
| x, |
| scale: Optional[float] = None, |
| dtype: Optional[torch.dtype] = None |
| ): |
| r""" |
| Compute the logsumexp of the input tensor over the last dimension. |
| |
| Args: |
| x (Tensor): |
| The input tensor of any shape. |
| scale (Optional[float]): |
| The scale applied to the input tensor. Default: `None`. |
| dtype (Optional[torch.dtype]): |
| The data type of the output tensor. Default: `None`. |
| Returns: |
| Tensor: The logsumexp of the input tensor. |
| """ |
|
|
| shape = x.shape |
| x = x.view(-1, shape[-1]) |
| N, D = x.shape |
| B = min(triton.next_power_of_2(D), 64 * 1024) |
| ND = triton.cdiv(D, B) |
|
|
| z = x.new_empty(N, ND, dtype=torch.float) |
| logsumexp_fwd_kernel[(N, ND)]( |
| x=x, |
| z=z, |
| scale=scale, |
| D=D, |
| B=B |
| ) |
| z = z.logsumexp(-1).view(*shape[:-1]) |
| if dtype is not None and dtype != torch.float: |
| z = z.to(dtype) |
| return z |
|
|