| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
|
|
| from ...utils import tensor_cache |
|
|
|
|
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=num_warps) |
| for num_warps in [4, 8, 16, 32] |
| ], |
| key=['B'], |
| ) |
| @triton.jit |
| def prepare_position_ids_kernel( |
| y, |
| cu_seqlens, |
| B: tl.constexpr |
| ): |
| i_n = tl.program_id(0) |
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) |
| T = eos - bos |
|
|
| o = tl.arange(0, B) |
| for i in range(0, tl.cdiv(T, B) * B, B): |
| o_i = o + i |
| tl.store(y + bos + o_i, o_i, o_i < T) |
|
|
|
|
| @tensor_cache |
| def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: |
| return cu_seqlens[1:] - cu_seqlens[:-1] |
|
|
|
|
| @tensor_cache |
| def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor: |
| return mask.sum(dim=-1, dtype=torch.int32) |
|
|
|
|
| @tensor_cache |
| def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor: |
| return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0)) |
|
|
|
|
| @tensor_cache |
| def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: |
| return torch.cat([ |
| torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) |
| for n in prepare_lens(cu_seqlens).unbind() |
| ]) |
|
|
|
|
| @tensor_cache |
| def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: |
| return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 |
|
|
|
|
| @tensor_cache |
| def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: |
| position_ids = prepare_position_ids(cu_seqlens) |
| return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) |
|
|
|
|
| @tensor_cache |
| def prepare_chunk_indices( |
| cu_seqlens: torch.LongTensor, |
| chunk_size: int |
| ) -> torch.LongTensor: |
| indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) |
| return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) |
|
|
|
|
| @tensor_cache |
| def prepare_chunk_offsets( |
| cu_seqlens: torch.LongTensor, |
| chunk_size: int |
| ) -> torch.LongTensor: |
| return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) |
|
|