|
|
from __future__ import annotations |
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def collate_token_batch( |
|
|
sequences: List[List[int]], |
|
|
*, |
|
|
pad_id: int = 0, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Collate token id sequences (var-length) into padded tensors and masks with label shift. |
|
|
|
|
|
Args: |
|
|
sequences: list of token id lists |
|
|
pad_id: id used for padding |
|
|
Returns: |
|
|
input_ids: [B,S_max] |
|
|
labels: [B,S_max] (next-token labels; last position masked out) |
|
|
attn_mask: [B,S_max] (True for valid tokens) |
|
|
loss_mask: [B,S_max] (True for positions to include in loss) |
|
|
lengths: [B] |
|
|
cu_seqlens:[B+1] cumulative lengths |
|
|
""" |
|
|
B = len(sequences) |
|
|
lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.int32) |
|
|
S_max = int(lengths.max().item()) if B > 0 else 0 |
|
|
input_ids = torch.full((B, S_max), pad_id, dtype=torch.long) |
|
|
labels = torch.full((B, S_max), pad_id, dtype=torch.long) |
|
|
attn_mask = torch.zeros((B, S_max), dtype=torch.bool) |
|
|
loss_mask = torch.zeros((B, S_max), dtype=torch.bool) |
|
|
for b, seq in enumerate(sequences): |
|
|
L = len(seq) |
|
|
if L == 0: |
|
|
continue |
|
|
input_ids[b, :L] = torch.tensor(seq, dtype=torch.long) |
|
|
attn_mask[b, :L] = True |
|
|
|
|
|
labels[b, : L - 1] = input_ids[b, 1:L] |
|
|
loss_mask[b, : L - 1] = True |
|
|
|
|
|
cu = torch.zeros((B + 1,), dtype=torch.int32) |
|
|
cu[1:] = torch.cumsum(lengths, dim=0) |
|
|
return input_ids, labels, attn_mask, loss_mask, lengths, cu |
|
|
|