|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
|
|
|
from fla.utils import tensor_cache |
|
|
|
|
|
|
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({}, num_warps=num_warps) |
|
|
for num_warps in [4, 8, 16, 32] |
|
|
], |
|
|
key=['B'], |
|
|
) |
|
|
@triton.jit |
|
|
def prepare_position_ids_kernel( |
|
|
y, |
|
|
offsets, |
|
|
B: tl.constexpr |
|
|
): |
|
|
i_n = tl.program_id(0) |
|
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
|
T = eos - bos |
|
|
|
|
|
o = tl.arange(0, B) |
|
|
for i in range(0, tl.cdiv(T, B) * B, B): |
|
|
o_i = o + i |
|
|
tl.store(y + bos + o_i, o_i, o_i < T) |
|
|
|
|
|
|
|
|
@tensor_cache |
|
|
def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor: |
|
|
return offsets[1:] - offsets[:-1] |
|
|
|
|
|
|
|
|
@tensor_cache |
|
|
def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor: |
|
|
return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()]) |
|
|
|
|
|
|
|
|
@tensor_cache |
|
|
def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor: |
|
|
return position_ids.eq(0).cumsum(0) - 1 |
|
|
|
|
|
|
|
|
@tensor_cache |
|
|
def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor: |
|
|
position_ids = prepare_position_ids(offsets) |
|
|
return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets) |
|
|
|
|
|
|
|
|
@tensor_cache |
|
|
def prepare_chunk_indices( |
|
|
offsets: torch.LongTensor, |
|
|
chunk_size: int |
|
|
) -> torch.LongTensor: |
|
|
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()]) |
|
|
return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets) |
|
|
|
|
|
|
|
|
@tensor_cache |
|
|
def prepare_chunk_offsets( |
|
|
offsets: torch.LongTensor, |
|
|
chunk_size: int |
|
|
) -> torch.LongTensor: |
|
|
return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1) |
|
|
|