| |
|
|
| from typing import Optional |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from fla.utils import input_guard |
|
|
|
|
| def token_shift_ref( |
| x: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| if cu_seqlens is not None: |
| |
| assert x.dim() == 3, "Input must be [B, T, D]" |
| B, T, D = x.shape |
| assert B == 1, "Batch size must be 1 when using cu_seqlens" |
|
|
| result = torch.zeros_like(x) |
| N = cu_seqlens.shape[0] - 1 |
|
|
| for i in range(N): |
| start = cu_seqlens[i].item() |
| end = cu_seqlens[i+1].item() |
| seq_len = end - start |
|
|
| if seq_len <= 1: |
| |
| result[0, start:end] = -x[0, start:end] |
| else: |
| |
| shifted = torch.zeros_like(x[0, start:end]) |
| shifted[1:] = x[0, start:end-1] |
| delta = shifted - x[0, start:end] |
| result[0, start:end] = delta |
|
|
| return result |
| else: |
| time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) |
| shifted = time_shift(x) |
| delta = shifted - x |
| return delta |
|
|
|
|
| @triton.heuristics({ |
| 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
| for num_warps in [2, 4, 8, 16, 32] |
| for num_stages in [1, 2, 3, 4] |
| ], |
| key=['BD'], |
| ) |
| @triton.jit |
| def token_shift_fwd_kernel( |
| x, |
| y, |
| cu_seqlens, |
| T, |
| D: tl.constexpr, |
| BD: tl.constexpr, |
| IS_VARLEN: tl.constexpr, |
| ): |
| i_b, i_t = tl.program_id(0), tl.program_id(1) |
|
|
| if IS_VARLEN: |
| i_n = i_b |
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) |
|
|
| if i_t < bos or i_t >= eos: |
| return |
|
|
| is_first_pos = (i_t - bos == 0) |
| else: |
| is_first_pos = (i_t == 0) |
|
|
| o_d = tl.arange(0, BD) |
| m_d = o_d < D |
|
|
| if IS_VARLEN: |
| base_offset = i_t * D + o_d |
| else: |
| base_offset = i_b * T*D + i_t * D + o_d |
|
|
| b_x = tl.load(x + base_offset, mask=m_d) |
|
|
| if is_first_pos: |
| |
| tl.store(y + base_offset, -b_x, mask=m_d) |
| else: |
| |
| if IS_VARLEN: |
| prev_offset = (i_t - 1) * D + o_d |
| else: |
| prev_offset = i_b * T*D + (i_t-1) * D + o_d |
|
|
| prev_values = tl.load(x + prev_offset, mask=m_d) |
| delta = prev_values - b_x |
| tl.store(y + base_offset, delta, mask=m_d) |
|
|
|
|
| @triton.heuristics({ |
| 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
| for num_warps in [2, 4, 8, 16, 32] |
| for num_stages in [1, 2, 3, 4] |
| ], |
| key=['D'], |
| ) |
| @triton.jit |
| def token_shift_bwd_kernel( |
| dx, |
| dy, |
| cu_seqlens, |
| T, |
| D: tl.constexpr, |
| BD: tl.constexpr, |
| IS_VARLEN: tl.constexpr, |
| ): |
| i_b, i_t = tl.program_id(0), tl.program_id(1) |
| if IS_VARLEN: |
| i_n = i_b |
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) |
|
|
| if i_t < bos or i_t >= eos: |
| return |
|
|
| local_pos = i_t - bos |
| is_last_pos = (local_pos == eos - bos - 1) |
| else: |
| is_last_pos = (i_t == T - 1) |
|
|
| o_d = tl.arange(0, BD) |
| m_d = o_d < D |
|
|
| if IS_VARLEN: |
| base_offset = i_t * D + o_d |
| else: |
| base_offset = i_b * T*D + i_t * D + o_d |
|
|
| b_dy = tl.load(dy + base_offset, mask=m_d) |
|
|
| if is_last_pos: |
| |
| b_dx = -b_dy |
| else: |
| |
| if IS_VARLEN: |
| next_offset = (i_t+1) * D + o_d |
| else: |
| next_offset = i_b * T*D + (i_t+1) * D + o_d |
|
|
| b_dx = -b_dy + tl.load(dy + next_offset, mask=m_d) |
|
|
| tl.store(dx + base_offset, b_dx, mask=m_d) |
|
|
|
|
| def token_shift_fwd( |
| x: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| B, T, D = x.shape |
| N = len(cu_seqlens) - 1 if cu_seqlens is not None else B |
| BD = triton.next_power_of_2(D) |
|
|
| y = torch.empty_like(x) |
|
|
| grid = (N, T) |
| token_shift_fwd_kernel[grid]( |
| x=x, |
| y=y, |
| cu_seqlens=cu_seqlens, |
| T=T, |
| D=D, |
| BD=BD, |
| ) |
|
|
| return y |
|
|
|
|
| def token_shift_bwd( |
| dy: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| B, T, D = dy.shape |
| N = len(cu_seqlens) - 1 if cu_seqlens is not None else B |
| BD = triton.next_power_of_2(D) |
|
|
| dx = torch.empty_like(dy) |
|
|
| grid = (N, T) |
| token_shift_bwd_kernel[grid]( |
| dy=dy, |
| dx=dx, |
| cu_seqlens=cu_seqlens, |
| T=T, |
| D=D, |
| BD=BD, |
| ) |
| return dx |
|
|
|
|
| class TokenShift(torch.autograd.Function): |
|
|
| @staticmethod |
| @input_guard |
| def forward(ctx, x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None): |
| ctx.cu_seqlens = cu_seqlens |
| return token_shift_fwd(x, cu_seqlens) |
|
|
| @staticmethod |
| @input_guard |
| def backward(ctx, dy: torch.Tensor): |
| cu_seqlens = ctx.cu_seqlens |
| dx = token_shift_bwd(dy, cu_seqlens) |
| return dx, None |
|
|
|
|
| def token_shift( |
| x: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None |
| ): |
| """ |
| Implementation of token shift using Triton kernels |
| Args: |
| x: Input tensor of shape [B, T, D] |
| cu_seqlens: Cumulative sequence lengths (optional) |
| Returns: |
| Tensor of same shape as input with token shift applied |
| """ |
| if cu_seqlens is not None: |
| assert x.dim() == 3, "Input must be [B, T, D]" |
| assert x.shape[0] == 1, "Batch size must be 1 when using cu_seqlens" |
|
|
| return TokenShift.apply(x, cu_seqlens) |
|
|