| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": |
| """ |
| Converts the sequence lengths to cumulative sequence lengths. |
| |
| NOTE: flash attention only accepts int32 cu_seqlens. |
| """ |
| return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32) |
|
|
|
|
| def culen2len(cu_seqlens: "torch.Tensor") -> "torch.Tensor": |
| """ |
| Converts the cumulative sequence lengths to sequence lengths. |
| """ |
| return cu_seqlens.diff() |
|
|
|
|
| def pos2culen(position_ids: "torch.Tensor") -> "torch.Tensor": |
| """ |
| Converts the position ids to cumulative sequence lengths. |
| """ |
| if position_ids.dim() == 3: |
| position_ids = position_ids[:, 0, :] |
|
|
| position_ids = position_ids.flatten() |
| indices_q = torch.arange(position_ids.size(0), dtype=torch.int32, device=position_ids.device) |
| return F.pad(indices_q[position_ids == 0], (0, 1), "constant", position_ids.size(0)) |
|
|
|
|
| def culen2pos(cu_seqlens: "torch.Tensor") -> "torch.Tensor": |
| """ |
| Converts the cumulative sequence lengths to position ids. |
| """ |
| seqlens = culen2len(cu_seqlens).cpu() |
| position_ids = torch.cat([torch.arange(length, dtype=torch.long, device=cu_seqlens.device) for length in seqlens]) |
| return position_ids.unsqueeze(0) |
|
|