File size: 11,586 Bytes
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""Lightning DataModule + IterableDataset for HYDRA pretraining.

Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader
with a standard multiprocessing DataLoader approach.

Design:
  β€’ IterableStreamDataset: each worker opens its own HF streams for the 7-way
    blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and
    yields one row per __next__.
  β€’ HydraDataModule: wraps the dataset with a standard DataLoader using
    num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles
    device transfer.
  β€’ Val stream: deterministic seed 12345, weights match training blend.

The worker RNG is seeded per-worker so the weighted-sampling schedule is
independent across workers (else all workers request the same config at
the same step and prefetching serializes).

Env vars (all preserved from prepare_nemotron):
  HYDRA_SEQ_LEN                  β€” sequence length T (default 512)
  HYDRA_BATCH_SIZE               β€” batch size B (default 1) β€” passed through
                                    to DataLoader
  HYDRA_STREAM_SHUFFLE_BUFFER    β€” HF shuffle buffer (default 2048)
  HYDRA_USE_FULL_BLEND           β€” 7-way blend vs 5-way Nemotron phase
  HYDRA_USE_NEMOTRON             β€” enables streaming path (else shard path)
  HYDRA_FACTUAL_INJECT_RATE      β€” factual doc injection cadence
  HYDRA_NEMOTRON_PHASE           β€” phase1|phase2 (when not full blend)
  HYDRA_DATA_NUM_WORKERS         β€” DataLoader num_workers (default 2)
  HYDRA_DATA_PREFETCH            β€” DataLoader prefetch_factor (default 4)
  HYDRA_DATA_BUFFER              β€” doc_buffer size for best-fit packing
                                    (default 1000)
"""
from __future__ import annotations

import os
import random
from typing import Iterator

import numpy as np
import torch
import lightning as L
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

import prepare as _prepare
import prepare_nemotron as _p_nemo
from prepare_nemotron import (
    FULL_BLEND_WEIGHTS,
    PHASE1_WEIGHTS,
    PHASE2_WEIGHTS,
    _BLEND_REGISTRY,
    _extract_text,
    _open_stream,
)


# ---------------------------------------------------------------------------
# Worker-local weighted stream. A stripped version of prepare_nemotron's
# _WeightedStream that is constructed inside each worker. Adds worker sharding:
# when num_workers > 1 the RNG is seeded per-worker, so different workers
# sample different config sequences and pull disjoint shard assignments from
# HF's shuffle buffer.
# ---------------------------------------------------------------------------


class _WorkerWeightedStream:
    def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int):
        self.configs = list(weights.keys())
        self.weights = [weights[c] for c in self.configs]
        self.base_seed = base_seed
        self.worker_id = worker_id
        # Each worker opens its own HF streams. _open_stream returns an iter()
        # over a streaming dataset, with an internal shuffle buffer.
        self.streams = {c: _open_stream(c, "train") for c in self.configs}
        # Per-worker RNG so the config-choice trajectory is independent.
        self.rng = random.Random(base_seed + worker_id * 7919)
        self.epoch = 1

        # Lazy-init factual docs (once per worker). The main-process version
        # in prepare_nemotron._WeightedStream reads these on first __next__.
        self._factual_docs: list[str] | None = None
        self._factual_idx = 0
        self._inject_counter = 0
        inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
        self._inject_rate = inject_rate
        if inject_rate > 0:
            factual_path = os.path.join(
                os.path.dirname(os.path.abspath(_p_nemo.__file__)),
                "data", "factual", "facts.txt",
            )
            if os.path.exists(factual_path):
                with open(factual_path) as fh:
                    self._factual_docs = fh.read().strip().split("\n")

    def _reopen(self, config: str) -> None:
        self.streams[config] = _open_stream(config, "train")
        self.epoch += 1

    def __iter__(self):
        return self

    def __next__(self) -> tuple[str, int]:
        # Factual injection (preserves prepare_nemotron cadence).
        if self._inject_rate > 0 and self._factual_docs:
            self._inject_counter += 1
            if self._inject_counter >= self._inject_rate:
                self._inject_counter = 0
                doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
                self._factual_idx += 1
                return doc, self.epoch

        config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
        try:
            row = next(self.streams[config])
        except StopIteration:
            self._reopen(config)
            row = next(self.streams[config])
        return _extract_text(row), self.epoch


# ---------------------------------------------------------------------------
# IterableStreamDataset β€” yields (T+1,) packed rows. No threads. No queues.
# Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks
# rows into batches of shape (B, T+1) and sends them to the main process.
# ---------------------------------------------------------------------------


class IterableStreamDataset(IterableDataset):
    """Streams docs, tokenizes, packs into (T+1,) rows via best-fit.

    Each worker gets its own instance (via fork/spawn) and initializes its
    own HF streams + rustbpe tokenizer + factual injector. The tokenizer
    pickled blob is small (~1 MB) and thread-safe per tiktoken docs.
    """

    def __init__(
        self,
        split: str,
        seq_len: int,
        *,
        base_seed: int = 0,
        doc_buffer_size: int = 1000,
        tokenizer_batch: int = 128,
    ):
        super().__init__()
        assert split in ("train", "val"), split
        self.split = split
        self.seq_len = seq_len
        self.row_capacity = seq_len + 1
        self.base_seed = base_seed
        self.doc_buffer_size = doc_buffer_size
        self.tokenizer_batch = tokenizer_batch

    def _pick_weights(self) -> dict[str, float]:
        if self.split == "val":
            if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
                return FULL_BLEND_WEIGHTS
            return {"Nemotron-Pretraining-Multiple-Choice": 1.0}
        if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
            return FULL_BLEND_WEIGHTS
        phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower()
        return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS

    def __iter__(self) -> Iterator[torch.Tensor]:
        info = get_worker_info()
        worker_id = 0 if info is None else info.id

        # Each worker builds its own tokenizer instance. tiktoken's Encoding
        # object is pickleable and the underlying C++ BPE is thread-safe;
        # per-worker instantiation avoids cross-process sharing headaches.
        tokenizer = _prepare.Tokenizer.from_directory()
        bos = tokenizer.get_bos_token_id()

        # Each worker gets its own weighted HF stream. Seed offset ensures
        # disjoint config-choice trajectories; HF's own shuffle buffer handles
        # shard randomization.
        val_seed = 12345  # deterministic val
        seed = val_seed if self.split == "val" else self.base_seed
        stream = _WorkerWeightedStream(
            self._pick_weights(), base_seed=seed, worker_id=worker_id,
        )

        row_capacity = self.row_capacity
        doc_buffer: list[list[int]] = []
        doc_batch_size = self.tokenizer_batch

        def refill_buffer() -> None:
            # Collect doc_batch_size text strings, then batch-tokenize.
            texts: list[str] = []
            for _ in range(doc_batch_size):
                text, _epoch = next(stream)
                if text:
                    texts.append(text)
            if texts:
                token_lists = tokenizer.encode(texts, prepend=bos)
                doc_buffer.extend(token_lists)

        while True:
            pos = 0
            row = torch.empty(row_capacity, dtype=torch.long)
            while pos < row_capacity:
                while len(doc_buffer) < self.doc_buffer_size:
                    refill_buffer()

                remaining = row_capacity - pos

                # Best-fit packing: largest doc that fully fits.
                best_idx = -1
                best_len = 0
                for i, doc in enumerate(doc_buffer):
                    dlen = len(doc)
                    if dlen <= remaining and dlen > best_len:
                        best_idx = i
                        best_len = dlen

                if best_idx >= 0:
                    doc = doc_buffer.pop(best_idx)
                    row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
                    pos += len(doc)
                else:
                    # No doc fits remaining space β€” crop shortest to fill.
                    shortest_idx = min(
                        range(len(doc_buffer)),
                        key=lambda i: len(doc_buffer[i]),
                    )
                    doc = doc_buffer.pop(shortest_idx)
                    row[pos : pos + remaining] = torch.tensor(
                        doc[:remaining], dtype=torch.long,
                    )
                    pos += remaining

            yield row


# ---------------------------------------------------------------------------
# LightningDataModule
# ---------------------------------------------------------------------------


class HydraDataModule(L.LightningDataModule):
    def __init__(
        self,
        batch_size: int | None = None,
        seq_len: int | None = None,
        num_workers: int | None = None,
        prefetch_factor: int | None = None,
    ):
        super().__init__()
        self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
        self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512"))
        self.num_workers = (
            num_workers
            if num_workers is not None
            else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2"))
        )
        self.prefetch_factor = (
            prefetch_factor
            if prefetch_factor is not None
            else int(os.environ.get("HYDRA_DATA_PREFETCH", "4"))
        )
        self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000"))

    def _make_loader(self, split: str, seed: int) -> DataLoader:
        dataset = IterableStreamDataset(
            split=split,
            seq_len=self.seq_len,
            base_seed=seed,
            doc_buffer_size=self.doc_buffer,
        )
        # num_workers=0 β†’ main-process iteration (useful for debugging). With
        # IterableDataset the DataLoader batches the rows into (B, T+1) via
        # default torch.stack-collate.
        kw: dict = dict(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
        )
        if self.num_workers > 0:
            kw["prefetch_factor"] = self.prefetch_factor
            kw["persistent_workers"] = True
        return DataLoader(**kw)

    def train_dataloader(self) -> DataLoader:
        return self._make_loader("train", seed=0)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader("val", seed=12345)