msj19's picture
Add files using upload-large-folder tool
ccefec1 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Code adapted from https://github.com/mayank31398/cute-kernels
from typing import Optional
import torch
import triton
import triton.language as tl
from ...ops.utils.index import prepare_lens
from ...utils import input_guard
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [4, 8, 16, 32]
],
key=['D', 'PADDING_SIDE', 'PACK']
)
@triton.jit
def packunpack_sequence_kernel(
x,
y,
cu_seqlens,
S,
D,
BD: tl.constexpr,
PADDING_SIDE: tl.constexpr,
PACK: tl.constexpr,
):
i_d, i_s, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
T = eos - bos
if PADDING_SIDE == 'left':
NP = S - T
if i_s < NP:
return
i_t = bos + (i_s - NP)
else:
if i_s >= T:
return
i_t = bos + i_s
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
if PACK:
b_x = tl.load(x + (i_b * S + i_s) * D + o_d, mask=mask)
tl.store(y + i_t * D + o_d, b_x, mask=mask)
else:
b_x = tl.load(x + i_t * D + o_d, mask=mask)
tl.store(y + (i_b * S + i_s) * D + o_d, b_x, mask=mask)
def pack_sequence_fwdbwd(
x: torch.Tensor,
cu_seqlens: torch.Tensor,
padding_side: str,
) -> torch.Tensor:
B, S = x.shape[:2]
D = x.numel() // (B * S)
BD = min(triton.next_power_of_2(D), 4096)
ND = triton.cdiv(D, BD)
y = torch.empty(cu_seqlens[-1].item(), *x.shape[2:], device=x.device, dtype=x.dtype)
packunpack_sequence_kernel[ND, S, B](
x=x,
y=y,
cu_seqlens=cu_seqlens,
S=S,
D=D,
BD=BD,
PADDING_SIDE=padding_side,
PACK=True,
)
return y
def unpack_sequence_fwdbwd(
x: torch.Tensor,
cu_seqlens: torch.Tensor,
padding_side: str,
desired_shape: torch.Size,
) -> torch.Tensor:
if desired_shape is None:
desired_shape = (len(cu_seqlens) - 1, prepare_lens(cu_seqlens).max().item(), *x.shape[1:])
y = torch.zeros(desired_shape, device=x.device, dtype=x.dtype)
B, S = y.shape[:2]
D = y.numel() // (B * S)
BD = min(triton.next_power_of_2(D), 4096)
ND = triton.cdiv(D, BD)
packunpack_sequence_kernel[ND, S, B](
x=x,
y=y,
cu_seqlens=cu_seqlens,
S=S,
D=D,
BD=BD,
PADDING_SIDE=padding_side,
PACK=False,
)
return y
class PackSequenceFunction(torch.autograd.Function):
@staticmethod
@input_guard
def forward(
ctx,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
padding_side: str,
) -> torch.Tensor:
assert padding_side in ['left', 'right']
assert x.ndim >= 2
ctx.cu_seqlens = cu_seqlens
ctx.padding_side = padding_side
ctx.desired_shape = x.shape
y = pack_sequence_fwdbwd(
x=x,
cu_seqlens=cu_seqlens,
padding_side=padding_side,
)
return y
@staticmethod
@input_guard
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
dx = unpack_sequence_fwdbwd(
x=dy,
cu_seqlens=ctx.cu_seqlens,
padding_side=ctx.padding_side,
desired_shape=ctx.desired_shape,
)
return dx, *[None] * 10
class UnpackSequenceFunction(torch.autograd.Function):
@staticmethod
@input_guard
def forward(
ctx,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
padding_side: str,
desired_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
assert padding_side in ['left', 'right']
assert x.ndim >= 2
if desired_shape is not None:
assert desired_shape[0] == cu_seqlens.shape[0] - 1
assert desired_shape[2:] == x.shape[1:]
ctx.cu_seqlens = cu_seqlens
ctx.padding_side = padding_side
y = unpack_sequence_fwdbwd(
x=x,
cu_seqlens=cu_seqlens,
padding_side=padding_side,
desired_shape=desired_shape,
)
return y
@staticmethod
@input_guard
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
dx = pack_sequence_fwdbwd(
x=dy,
cu_seqlens=ctx.cu_seqlens,
padding_side=ctx.padding_side,
)
return dx, None, None, None
def pack_sequence(
x: torch.Tensor,
cu_seqlens: torch.Tensor,
padding_side: str = 'left'
) -> torch.Tensor:
return PackSequenceFunction.apply(
x,
cu_seqlens,
padding_side,
)
def unpack_sequence(
x: torch.Tensor,
cu_seqlens: torch.Tensor,
padding_side: str = 'left',
desired_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
return UnpackSequenceFunction.apply(
x,
cu_seqlens,
padding_side,
desired_shape,
)