nsa-117m-byte / nsa /core /collate.py
seconds-0's picture
NSA 117M initial export
4303959 verified
from __future__ import annotations
from typing import List, Tuple
import torch
def collate_token_batch(
sequences: List[List[int]],
*,
pad_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Collate token id sequences (var-length) into padded tensors and masks with label shift.
Args:
sequences: list of token id lists
pad_id: id used for padding
Returns:
input_ids: [B,S_max]
labels: [B,S_max] (next-token labels; last position masked out)
attn_mask: [B,S_max] (True for valid tokens)
loss_mask: [B,S_max] (True for positions to include in loss)
lengths: [B]
cu_seqlens:[B+1] cumulative lengths
"""
B = len(sequences)
lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.int32)
S_max = int(lengths.max().item()) if B > 0 else 0
input_ids = torch.full((B, S_max), pad_id, dtype=torch.long)
labels = torch.full((B, S_max), pad_id, dtype=torch.long)
attn_mask = torch.zeros((B, S_max), dtype=torch.bool)
loss_mask = torch.zeros((B, S_max), dtype=torch.bool)
for b, seq in enumerate(sequences):
L = len(seq)
if L == 0:
continue
input_ids[b, :L] = torch.tensor(seq, dtype=torch.long)
attn_mask[b, :L] = True
# next-token labels (shifted left by 1), last token has no next label
labels[b, : L - 1] = input_ids[b, 1:L]
loss_mask[b, : L - 1] = True
# cu_seqlens for varlen APIs
cu = torch.zeros((B + 1,), dtype=torch.int32)
cu[1:] = torch.cumsum(lengths, dim=0)
return input_ids, labels, attn_mask, loss_mask, lengths, cu