File size: 5,413 Bytes
b3d361d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Dataset classes for LLM training.

TextDataset:    Sliding window (stride 1) over a memory-mapped uint16 binary file.
PackedDataset: Non-overlapping windows (stride = seq_len) over the same file format.
"""

from __future__ import annotations

from pathlib import Path
from typing import Tuple, Union

import numpy as np
import torch
from torch.utils.data import Dataset


class TextDataset(Dataset):
    """
    Sliding-window dataset over a memory-mapped numpy uint16 binary token file.

    Each sample is a (input_ids, targets) pair of length seq_len, where
    targets is input_ids shifted by one position.  Windows overlap by
    (seq_len - 1) tokens, i.e. stride = 1.

    Args:
        data_path: Path to the .bin file produced by data/prepare.py.
        seq_len:   Number of tokens per sample (context length).
    """

    def __init__(self, data_path: Union[str, Path], seq_len: int) -> None:
        super().__init__()
        self.seq_len = seq_len
        path = Path(data_path)
        if not path.exists():
            raise FileNotFoundError(f"Data file not found: {path}")
        # Memory-map for zero-copy random access.
        self.data: np.ndarray = np.memmap(path, dtype="uint16", mode="r")
        # Hint OS to preload entire file into page cache (2.2TB RAM available)
        import mmap as _mmap
        try:
            self.data._mmap.madvise(_mmap.MADV_SEQUENTIAL)
        except (AttributeError, OSError):
            pass  # madvise not available on all platforms
        if len(self.data) < seq_len + 1:
            raise ValueError(
                f"Data file has only {len(self.data)} tokens, "
                f"need at least {seq_len + 1}."
            )

    def __len__(self) -> int:
        # Each window needs seq_len tokens plus one extra for the target shift.
        return len(self.data) - self.seq_len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Slice from the memmap (returns a uint16 numpy view).
        chunk = self.data[idx : idx + self.seq_len + 1]
        # Cast to int32 (not int64) to halve CPU worker memory usage:
        # uint16 (2 B) → int32 (4 B) instead of uint16 → int64 (8 B, 4× bloat).
        # int32 is sufficient for vocab_size=64000 (max token id 65535 fits in int32).
        # The int32→int64 (long) promotion happens on GPU inside _step(), for free.
        chunk = torch.from_numpy(chunk.astype(np.int32))
        input_ids = chunk[:-1]   # [seq_len]
        targets   = chunk[1:]    # [seq_len]
        return input_ids, targets


class PackedDataset(Dataset):
    """
    Non-overlapping packed dataset over a memory-mapped uint16 binary token file.

    Intended for data that has already been packed (documents concatenated with
    EOS tokens).  Windows do not overlap; stride = seq_len.

    The target sequence is shifted by one token relative to input_ids.  Because
    the last token of a window shares its target with the *first* token of the
    next window, the final target position is filled with -1 (the standard
    ``ignore_index`` for ``nn.CrossEntropyLoss``).

    Args:
        data_path: Path to the .bin file produced by data/prepare.py.
        seq_len:   Number of tokens per sample (context length).
    """

    def __init__(self, data_path: Union[str, Path], seq_len: int) -> None:
        super().__init__()
        self.seq_len = seq_len
        path = Path(data_path)
        if not path.exists():
            raise FileNotFoundError(f"Data file not found: {path}")
        self.data: np.ndarray = np.memmap(path, dtype="uint16", mode="r")
        # Optimize mmap for shuffled random access pattern (DistributedSampler)
        import mmap as _mmap
        try:
            self.data._mmap.madvise(_mmap.MADV_RANDOM)    # disable kernel read-ahead (random access)
            self.data._mmap.madvise(_mmap.MADV_WILLNEED)  # async prefault into page cache
        except (AttributeError, OSError):
            pass
        if len(self.data) < seq_len:
            raise ValueError(
                f"Data file has only {len(self.data)} tokens, "
                f"need at least {seq_len}."
            )

    def __len__(self) -> int:
        return len(self.data) // self.seq_len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        start = idx * self.seq_len
        end   = start + self.seq_len

        # Cast to int32 (not int64) to halve CPU worker memory usage.
        # int32 is sufficient for vocab_size=64000; int32→long promotion on GPU.
        input_ids = torch.from_numpy(
            self.data[start:end].astype(np.int32)
        )  # [seq_len]

        # Targets are shifted by one.  If end < len(data) we can read the
        # extra token normally; otherwise pad the last position with -1.
        if end < len(self.data):
            targets = torch.from_numpy(
                self.data[start + 1 : end + 1].astype(np.int32)
            )  # [seq_len]
        else:
            # Last window: all but the final position can be computed.
            # Use int32 for the filled portion; -1 fits in int32.
            targets = torch.full((self.seq_len,), fill_value=-1, dtype=torch.int32)
            if end - start - 1 > 0:
                targets[: self.seq_len - 1] = torch.from_numpy(
                    self.data[start + 1 : end].astype(np.int32)
                )

        return input_ids, targets