| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| |
|
| | from fla.ops.utils.op import exp |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=1), |
| | triton.Config({}, num_warps=2), |
| | triton.Config({}, num_warps=4), |
| | triton.Config({}, num_warps=8), |
| | triton.Config({}, num_warps=16), |
| | triton.Config({}, num_warps=32) |
| | ], |
| | key=['D'] |
| | ) |
| | @triton.jit |
| | def softmax_fwd_kernel( |
| | x, |
| | p, |
| | D: tl.constexpr, |
| | B: tl.constexpr |
| | ): |
| | i_n = tl.program_id(0) |
| | o_d = tl.arange(0, B) |
| | m_d = o_d < D |
| |
|
| | b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) |
| | b_m = tl.max(b_x, 0) |
| | b_x = exp(b_x - b_m) |
| | b_p = b_x / tl.sum(b_x, 0) |
| |
|
| | tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=1), |
| | triton.Config({}, num_warps=2), |
| | triton.Config({}, num_warps=4), |
| | triton.Config({}, num_warps=8), |
| | triton.Config({}, num_warps=16), |
| | triton.Config({}, num_warps=32) |
| | ], |
| | key=['D'] |
| | ) |
| | @triton.jit |
| | def softmax_bwd_kernel( |
| | p, |
| | dp, |
| | ds, |
| | D: tl.constexpr, |
| | B: tl.constexpr |
| | ): |
| | i_n = tl.program_id(0) |
| | o_d = tl.arange(0, B) |
| | m_d = o_d < D |
| |
|
| | b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.) |
| | b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.) |
| | b_pp = tl.sum(b_p * b_dp, 0) |
| | b_ds = b_p * b_dp - b_p * b_pp |
| | tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d) |
| |
|
| |
|
| | def softmax_fwd( |
| | x: torch.Tensor, |
| | dtype: Optional[torch.dtype] = torch.float |
| | ) -> torch.Tensor: |
| | shape = x.shape |
| | x = x.view(-1, x.shape[-1]) |
| |
|
| | N, D = x.shape |
| | B = triton.next_power_of_2(D) |
| |
|
| | p = torch.empty_like(x, dtype=dtype) |
| | softmax_fwd_kernel[(N,)]( |
| | x=x, |
| | p=p, |
| | D=D, |
| | B=B |
| | ) |
| | return p.view(*shape) |
| |
|
| |
|
| | def softmax_bwd( |
| | p: torch.Tensor, |
| | dp: torch.Tensor, |
| | dtype: Optional[torch.dtype] = torch.float |
| | ) -> torch.Tensor: |
| | shape = p.shape |
| | p = p.view(-1, p.shape[-1]) |
| | ds = torch.empty_like(p, dtype=dtype) |
| |
|
| | N, D = p.shape |
| | B = triton.next_power_of_2(D) |
| | softmax_bwd_kernel[(N,)]( |
| | p=p, |
| | dp=dp, |
| | ds=ds, |
| | D=D, |
| | B=B |
| | ) |
| | return ds.view(*shape) |
| |
|