msj19's picture
Add files using upload-large-folder tool
e73a905 verified
# -*- coding: utf-8 -*-
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.utils import input_guard
def token_shift_ref(
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None
) -> torch.Tensor:
if cu_seqlens is not None:
# Variable length mode with cu_seqlens
assert x.dim() == 3, "Input must be [B, T, D]"
B, T, D = x.shape
assert B == 1, "Batch size must be 1 when using cu_seqlens"
result = torch.zeros_like(x)
N = cu_seqlens.shape[0] - 1
for i in range(N):
start = cu_seqlens[i].item()
end = cu_seqlens[i+1].item()
seq_len = end - start
if seq_len <= 1:
# For sequences of length 1 or 0, delta is simply -x
result[0, start:end] = -x[0, start:end]
else:
# For longer sequences, handle padding manually
shifted = torch.zeros_like(x[0, start:end])
shifted[1:] = x[0, start:end-1]
delta = shifted - x[0, start:end]
result[0, start:end] = delta
return result
else:
time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
shifted = time_shift(x)
delta = shifted - x
return delta
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8, 16, 32]
for num_stages in [1, 2, 3, 4]
],
key=['BD'],
)
@triton.jit
def token_shift_fwd_kernel(
x,
y,
cu_seqlens,
T,
D: tl.constexpr,
BD: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_b, i_t = tl.program_id(0), tl.program_id(1)
if IS_VARLEN:
i_n = i_b
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
if i_t < bos or i_t >= eos:
return
is_first_pos = (i_t - bos == 0)
else:
is_first_pos = (i_t == 0)
o_d = tl.arange(0, BD)
m_d = o_d < D
if IS_VARLEN:
base_offset = i_t * D + o_d
else:
base_offset = i_b * T*D + i_t * D + o_d
b_x = tl.load(x + base_offset, mask=m_d)
if is_first_pos:
# First position in sequence: delta = -hidden_states
tl.store(y + base_offset, -b_x, mask=m_d)
else:
# Other positions: delta = prev - curr
if IS_VARLEN:
prev_offset = (i_t - 1) * D + o_d
else:
prev_offset = i_b * T*D + (i_t-1) * D + o_d
prev_values = tl.load(x + prev_offset, mask=m_d)
delta = prev_values - b_x
tl.store(y + base_offset, delta, mask=m_d)
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8, 16, 32]
for num_stages in [1, 2, 3, 4]
],
key=['D'],
)
@triton.jit
def token_shift_bwd_kernel(
dx,
dy,
cu_seqlens,
T,
D: tl.constexpr,
BD: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_b, i_t = tl.program_id(0), tl.program_id(1)
if IS_VARLEN:
i_n = i_b
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
if i_t < bos or i_t >= eos:
return
local_pos = i_t - bos
is_last_pos = (local_pos == eos - bos - 1)
else:
is_last_pos = (i_t == T - 1)
o_d = tl.arange(0, BD)
m_d = o_d < D
if IS_VARLEN:
base_offset = i_t * D + o_d
else:
base_offset = i_b * T*D + i_t * D + o_d
b_dy = tl.load(dy + base_offset, mask=m_d)
if is_last_pos:
# Last position: b_dx = -grad_delta[t]
b_dx = -b_dy
else:
# Other positions: b_dx = -grad_delta[t] + grad_delta[t+1]
if IS_VARLEN:
next_offset = (i_t+1) * D + o_d
else:
next_offset = i_b * T*D + (i_t+1) * D + o_d
b_dx = -b_dy + tl.load(dy + next_offset, mask=m_d)
tl.store(dx + base_offset, b_dx, mask=m_d)
def token_shift_fwd(
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, T, D = x.shape
N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
BD = triton.next_power_of_2(D)
y = torch.empty_like(x)
grid = (N, T)
token_shift_fwd_kernel[grid](
x=x,
y=y,
cu_seqlens=cu_seqlens,
T=T,
D=D,
BD=BD,
)
return y
def token_shift_bwd(
dy: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, T, D = dy.shape
N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
BD = triton.next_power_of_2(D)
dx = torch.empty_like(dy)
grid = (N, T)
token_shift_bwd_kernel[grid](
dy=dy,
dx=dx,
cu_seqlens=cu_seqlens,
T=T,
D=D,
BD=BD,
)
return dx
class TokenShift(torch.autograd.Function):
@staticmethod
@input_guard
def forward(ctx, x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None):
ctx.cu_seqlens = cu_seqlens
return token_shift_fwd(x, cu_seqlens)
@staticmethod
@input_guard
def backward(ctx, dy: torch.Tensor):
cu_seqlens = ctx.cu_seqlens
dx = token_shift_bwd(dy, cu_seqlens)
return dx, None
def token_shift(
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None
):
"""
Implementation of token shift using Triton kernels
Args:
x: Input tensor of shape [B, T, D]
cu_seqlens: Cumulative sequence lengths (optional)
Returns:
Tensor of same shape as input with token shift applied
"""
if cu_seqlens is not None:
assert x.dim() == 3, "Input must be [B, T, D]"
assert x.shape[0] == 1, "Batch size must be 1 when using cu_seqlens"
return TokenShift.apply(x, cu_seqlens)