File size: 1,917 Bytes
fb11af9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor":
"""
Converts the sequence lengths to cumulative sequence lengths.
NOTE: flash attention only accepts int32 cu_seqlens.
"""
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
def culen2len(cu_seqlens: "torch.Tensor") -> "torch.Tensor":
"""
Converts the cumulative sequence lengths to sequence lengths.
"""
return cu_seqlens.diff()
def pos2culen(position_ids: "torch.Tensor") -> "torch.Tensor":
"""
Converts the position ids to cumulative sequence lengths.
"""
if position_ids.dim() == 3: # (batch_size, dim, seq_length):
position_ids = position_ids[:, 0, :]
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), dtype=torch.int32, device=position_ids.device)
return F.pad(indices_q[position_ids == 0], (0, 1), "constant", position_ids.size(0))
def culen2pos(cu_seqlens: "torch.Tensor") -> "torch.Tensor":
"""
Converts the cumulative sequence lengths to position ids.
"""
seqlens = culen2len(cu_seqlens).cpu()
position_ids = torch.cat([torch.arange(length, dtype=torch.long, device=cu_seqlens.device) for length in seqlens])
return position_ids.unsqueeze(0)
|