File size: 9,089 Bytes
355eea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from __future__ import annotations

import json
from pathlib import Path
from typing import Iterator

import numpy as np
import torch
from torch.utils.data import Dataset, IterableDataset

from sllm.utils import ensure_dir


class TokenShardWriter:
    def __init__(self, output_dir: str | Path, prefix: str, shard_size_tokens: int) -> None:
        self.output_dir = ensure_dir(output_dir)
        self.prefix = prefix
        self.shard_size_tokens = shard_size_tokens
        self.buffer: list[int] = []
        self.shard_index = 0
        self.shards: list[dict] = []

    def add_tokens(self, tokens: list[int]) -> None:
        self.buffer.extend(tokens)
        while len(self.buffer) >= self.shard_size_tokens:
            chunk = self.buffer[: self.shard_size_tokens]
            self.buffer = self.buffer[self.shard_size_tokens :]
            self._write_chunk(chunk)

    def finalize(self) -> list[dict]:
        if self.buffer:
            self._write_chunk(self.buffer)
            self.buffer = []
        manifest_path = self.output_dir / f"{self.prefix}_manifest.json"
        with manifest_path.open("w", encoding="utf-8") as handle:
            json.dump(self.shards, handle, indent=2, ensure_ascii=False)
        return self.shards

    def _write_chunk(self, chunk: list[int]) -> None:
        shard_name = f"{self.prefix}_{self.shard_index:05d}.bin"
        shard_path = self.output_dir / shard_name
        array = np.asarray(chunk, dtype=np.uint16)
        with shard_path.open("wb") as handle:
            array.tofile(handle)
        self.shards.append(
            {
                "path": shard_name,
                "num_tokens": int(array.shape[0]),
                "dtype": "uint16",
            }
        )
        self.shard_index += 1


class SFTShardWriter:
    def __init__(self, output_dir: str | Path, prefix: str, seq_len: int) -> None:
        self.output_dir = ensure_dir(output_dir)
        self.prefix = prefix
        self.seq_len = seq_len
        self.num_examples = 0
        self.input_path = self.output_dir / f"{self.prefix}_input_ids.bin"
        self.label_path = self.output_dir / f"{self.prefix}_labels.bin"
        self.input_handle = self.input_path.open("wb")
        self.label_handle = self.label_path.open("wb")

    def add_example(self, input_ids: list[int], labels: list[int]) -> None:
        if len(input_ids) != self.seq_len or len(labels) != self.seq_len:
            raise ValueError("Packed SFT example must match fixed seq_len.")
        np.asarray(input_ids, dtype=np.uint16).tofile(self.input_handle)
        np.asarray(labels, dtype=np.int32).tofile(self.label_handle)
        self.num_examples += 1

    def finalize(self) -> dict:
        self.input_handle.close()
        self.label_handle.close()
        if self.num_examples == 0:
            raise RuntimeError("No SFT examples were written.")
        metadata = {
            "num_examples": self.num_examples,
            "seq_len": self.seq_len,
            "input_ids_path": self.input_path.name,
            "labels_path": self.label_path.name,
        }
        with (self.output_dir / f"{self.prefix}_metadata.json").open("w", encoding="utf-8") as handle:
            json.dump(metadata, handle, indent=2, ensure_ascii=False)
        return metadata


def load_shard_manifest(data_dir: str | Path, split: str) -> list[dict]:
    data_dir = Path(data_dir)
    manifest_paths = sorted(data_dir.glob(f"{split}_manifest.json"))
    if not manifest_paths and data_dir.name == split:
        manifest_paths = sorted(data_dir.glob("*_manifest.json"))
    if not manifest_paths and (data_dir / split).exists():
        manifest_paths = sorted((data_dir / split).glob("*_manifest.json"))
    if not manifest_paths:
        raise FileNotFoundError(f"Shard manifest not found in {data_dir}.")

    shards: list[dict] = []
    for manifest_path in manifest_paths:
        with manifest_path.open("r", encoding="utf-8") as handle:
            items = json.load(handle)
        for item in items:
            item["absolute_path"] = str((manifest_path.parent / item["path"]).resolve())
            shards.append(item)
    if not shards:
        raise RuntimeError(f"Shard manifest {manifest_paths} is empty.")
    return shards


class RandomTokenDataset(IterableDataset):
    def __init__(
        self,
        data_dir: str | Path,
        split: str,
        seq_len: int,
        seed: int = 42,
    ) -> None:
        super().__init__()
        self.seq_len = seq_len
        self.seed = seed
        self.shards = load_shard_manifest(data_dir, split)
        self.arrays = [np.memmap(item["absolute_path"], dtype=np.uint16, mode="r") for item in self.shards]
        capacities = [max(0, int(item["num_tokens"]) - seq_len - 1) for item in self.shards]
        valid_pairs = [(item, array, capacity) for item, array, capacity in zip(self.shards, self.arrays, capacities) if capacity > 0]
        if not valid_pairs:
            raise RuntimeError("No shard contains enough tokens for the selected sequence length.")
        self.shards, self.arrays, capacities = map(list, zip(*valid_pairs))
        weights = np.asarray(capacities, dtype=np.float64)
        self.probabilities = weights / weights.sum()
        self.capacities = capacities

    def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        rng = np.random.default_rng(self.seed + worker_id)

        while True:
            shard_index = int(rng.choice(len(self.arrays), p=self.probabilities))
            capacity = self.capacities[shard_index]
            start = int(rng.integers(0, capacity))
            array = self.arrays[shard_index]
            window = np.asarray(array[start : start + self.seq_len + 1], dtype=np.int64)
            input_ids = torch.from_numpy(window[:-1].copy()).long()
            labels = torch.from_numpy(window[1:].copy()).long()
            attention_mask = torch.ones(self.seq_len, dtype=torch.long)
            yield {
                "input_ids": input_ids,
                "labels": labels,
                "attention_mask": attention_mask,
            }


class SequentialEvalDataset(IterableDataset):
    def __init__(
        self,
        data_dir: str | Path,
        split: str,
        seq_len: int,
        max_batches: int,
    ) -> None:
        super().__init__()
        self.seq_len = seq_len
        self.max_batches = max_batches
        self.shards = load_shard_manifest(data_dir, split)
        self.arrays = [np.memmap(item["absolute_path"], dtype=np.uint16, mode="r") for item in self.shards]

    def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
        yielded = 0
        for array in self.arrays:
            max_start = len(array) - self.seq_len - 1
            if max_start <= 0:
                continue
            for start in range(0, max_start, self.seq_len):
                if yielded >= self.max_batches:
                    return
                window = np.asarray(array[start : start + self.seq_len + 1], dtype=np.int64)
                if len(window) < self.seq_len + 1:
                    break
                input_ids = torch.from_numpy(window[:-1].copy()).long()
                labels = torch.from_numpy(window[1:].copy()).long()
                attention_mask = torch.ones(self.seq_len, dtype=torch.long)
                yield {
                    "input_ids": input_ids,
                    "labels": labels,
                    "attention_mask": attention_mask,
                }
                yielded += 1


class FixedSFTDataset(Dataset):
    def __init__(self, dataset_dir: str | Path, split: str) -> None:
        dataset_dir = Path(dataset_dir)
        metadata_path = dataset_dir / f"{split}_metadata.json"
        if not metadata_path.exists():
            raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
        with metadata_path.open("r", encoding="utf-8") as handle:
            metadata = json.load(handle)

        self.seq_len = int(metadata["seq_len"])
        self.num_examples = int(metadata["num_examples"])
        self.input_ids = np.memmap(
            dataset_dir / metadata["input_ids_path"],
            dtype=np.uint16,
            mode="r",
            shape=(self.num_examples, self.seq_len),
        )
        self.labels = np.memmap(
            dataset_dir / metadata["labels_path"],
            dtype=np.int32,
            mode="r",
            shape=(self.num_examples, self.seq_len),
        )

    def __len__(self) -> int:
        return self.num_examples

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        input_ids = torch.from_numpy(np.asarray(self.input_ids[index], dtype=np.int64).copy()).long()
        labels = torch.from_numpy(np.asarray(self.labels[index], dtype=np.int64).copy()).long()
        attention_mask = (input_ids != 0).long()
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }