| """Packed training dataset with deterministic resume support.""" |
|
|
| from __future__ import annotations |
|
|
| import random |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Iterable, Iterator |
|
|
| import torch |
| from torch.utils.data import IterableDataset |
|
|
| try: |
| import pyarrow.parquet as pq |
| except ImportError: |
| pq = None |
|
|
|
|
| @dataclass(frozen=True) |
| class DatasetConfig: |
| """Configuration for packing token streams into training batches.""" |
|
|
| shard_paths: tuple[str, ...] |
| context_length: int |
| split: str = "train" |
| seed: int = 42 |
|
|
|
|
| class PackedDataset(IterableDataset): |
| """Iterate packed token sequences with document-boundary masks.""" |
|
|
| def __init__(self, config: DatasetConfig): |
| super().__init__() |
| self.config = config |
| self._skip = 0 |
|
|
| def skip(self, n_batches: int) -> None: |
| """Fast-forward the iterator by discarding the first n batches.""" |
| self._skip = max(0, int(n_batches)) |
|
|
| def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: |
| skipped = 0 |
| for batch in self._generate(): |
| if skipped < self._skip: |
| skipped += 1 |
| continue |
| yield batch |
|
|
| def _generate(self) -> Iterator[dict[str, torch.Tensor]]: |
| token_buffer: list[int] = [] |
| boundary_buffer: list[int] = [] |
| for row in self._iter_rows(): |
| tokens = list(row["tokens"]) |
| if len(tokens) < 2: |
| continue |
| token_buffer.extend(tokens) |
| boundary_buffer.extend([0] * (len(tokens) - 1) + [1]) |
| while len(token_buffer) >= self.config.context_length + 1: |
| window_tokens = token_buffer[: self.config.context_length + 1] |
| window_boundaries = boundary_buffer[: self.config.context_length + 1] |
| yield pack_sequence(window_tokens, window_boundaries) |
| token_buffer = token_buffer[self.config.context_length :] |
| boundary_buffer = boundary_buffer[self.config.context_length :] |
|
|
| def _iter_rows(self) -> Iterator[dict[str, object]]: |
| if pq is None: |
| raise ImportError("pyarrow is required to read parquet shards.") |
| shard_paths = [Path(path) for path in self.config.shard_paths] |
| rng = random.Random(self.config.seed) |
| shard_paths = shard_paths[:] |
| rng.shuffle(shard_paths) |
| for path in shard_paths: |
| table = pq.read_table(path, columns=["tokens", "split"]) |
| rows = table.to_pylist() |
| for row in rows: |
| if row["split"] != self.config.split: |
| continue |
| yield row |
|
|
|
|
| def pack_sequence(tokens: list[int], boundaries: list[int]) -> dict[str, torch.Tensor]: |
| """Turn one packed token window into model-ready tensors.""" |
| input_ids = torch.tensor(tokens[:-1], dtype=torch.long) |
| labels = torch.tensor(tokens[1:], dtype=torch.long) |
| loss_mask = torch.ones_like(input_ids, dtype=torch.float32) |
| attention_document_mask = torch.tensor(boundaries[:-1], dtype=torch.long) |
| return { |
| "input_ids": input_ids, |
| "labels": labels, |
| "loss_mask": loss_mask, |
| "document_boundaries": attention_document_mask, |
| } |
|
|