File size: 1,917 Bytes
fb11af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn.functional as F


def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor":
    """
    Converts the sequence lengths to cumulative sequence lengths.

    NOTE: flash attention only accepts int32 cu_seqlens.
    """
    return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)


def culen2len(cu_seqlens: "torch.Tensor") -> "torch.Tensor":
    """
    Converts the cumulative sequence lengths to sequence lengths.
    """
    return cu_seqlens.diff()


def pos2culen(position_ids: "torch.Tensor") -> "torch.Tensor":
    """
    Converts the position ids to cumulative sequence lengths.
    """
    if position_ids.dim() == 3:  # (batch_size, dim, seq_length):
        position_ids = position_ids[:, 0, :]

    position_ids = position_ids.flatten()
    indices_q = torch.arange(position_ids.size(0), dtype=torch.int32, device=position_ids.device)
    return F.pad(indices_q[position_ids == 0], (0, 1), "constant", position_ids.size(0))


def culen2pos(cu_seqlens: "torch.Tensor") -> "torch.Tensor":
    """
    Converts the cumulative sequence lengths to position ids.
    """
    seqlens = culen2len(cu_seqlens).cpu()
    position_ids = torch.cat([torch.arange(length, dtype=torch.long, device=cu_seqlens.device) for length in seqlens])
    return position_ids.unsqueeze(0)