File size: 3,293 Bytes
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Packed training dataset with deterministic resume support."""

from __future__ import annotations

import random
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Iterator

import torch
from torch.utils.data import IterableDataset

try:
    import pyarrow.parquet as pq
except ImportError:  # pragma: no cover - optional at import time
    pq = None


@dataclass(frozen=True)
class DatasetConfig:
    """Configuration for packing token streams into training batches."""

    shard_paths: tuple[str, ...]
    context_length: int
    split: str = "train"
    seed: int = 42


class PackedDataset(IterableDataset):
    """Iterate packed token sequences with document-boundary masks."""

    def __init__(self, config: DatasetConfig):
        super().__init__()
        self.config = config
        self._skip = 0

    def skip(self, n_batches: int) -> None:
        """Fast-forward the iterator by discarding the first n batches."""
        self._skip = max(0, int(n_batches))

    def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
        skipped = 0
        for batch in self._generate():
            if skipped < self._skip:
                skipped += 1
                continue
            yield batch

    def _generate(self) -> Iterator[dict[str, torch.Tensor]]:
        token_buffer: list[int] = []
        boundary_buffer: list[int] = []
        for row in self._iter_rows():
            tokens = list(row["tokens"])
            if len(tokens) < 2:
                continue
            token_buffer.extend(tokens)
            boundary_buffer.extend([0] * (len(tokens) - 1) + [1])
            while len(token_buffer) >= self.config.context_length + 1:
                window_tokens = token_buffer[: self.config.context_length + 1]
                window_boundaries = boundary_buffer[: self.config.context_length + 1]
                yield pack_sequence(window_tokens, window_boundaries)
                token_buffer = token_buffer[self.config.context_length :]
                boundary_buffer = boundary_buffer[self.config.context_length :]

    def _iter_rows(self) -> Iterator[dict[str, object]]:
        if pq is None:
            raise ImportError("pyarrow is required to read parquet shards.")
        shard_paths = [Path(path) for path in self.config.shard_paths]
        rng = random.Random(self.config.seed)
        shard_paths = shard_paths[:]
        rng.shuffle(shard_paths)
        for path in shard_paths:
            table = pq.read_table(path, columns=["tokens", "split"])
            rows = table.to_pylist()
            for row in rows:
                if row["split"] != self.config.split:
                    continue
                yield row


def pack_sequence(tokens: list[int], boundaries: list[int]) -> dict[str, torch.Tensor]:
    """Turn one packed token window into model-ready tensors."""
    input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
    labels = torch.tensor(tokens[1:], dtype=torch.long)
    loss_mask = torch.ones_like(input_ids, dtype=torch.float32)
    attention_document_mask = torch.tensor(boundaries[:-1], dtype=torch.long)
    return {
        "input_ids": input_ids,
        "labels": labels,
        "loss_mask": loss_mask,
        "document_boundaries": attention_document_mask,
    }