| """Lightning DataModule + IterableDataset for HYDRA pretraining. |
| |
| Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader |
| with a standard multiprocessing DataLoader approach. |
| |
| Design: |
| • IterableStreamDataset: each worker opens its own HF streams for the 7-way |
| blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and |
| yields one row per __next__. |
| • HydraDataModule: wraps the dataset with a standard DataLoader using |
| num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles |
| device transfer. |
| • Val stream: deterministic seed 12345, weights match training blend. |
| |
| The worker RNG is seeded per-worker so the weighted-sampling schedule is |
| independent across workers (else all workers request the same config at |
| the same step and prefetching serializes). |
| |
| Env vars (all preserved from prepare_nemotron): |
| HYDRA_SEQ_LEN — sequence length T (default 512) |
| HYDRA_BATCH_SIZE — batch size B (default 1) — passed through |
| to DataLoader |
| HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048) |
| HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase |
| HYDRA_USE_NEMOTRON — enables streaming path (else shard path) |
| HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence |
| HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend) |
| HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2) |
| HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4) |
| HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing |
| (default 1000) |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import random |
| from typing import Iterator |
|
|
| import numpy as np |
| import torch |
| import lightning as L |
| from torch.utils.data import DataLoader, IterableDataset, get_worker_info |
|
|
| import prepare as _prepare |
| import prepare_nemotron as _p_nemo |
| from prepare_nemotron import ( |
| FULL_BLEND_WEIGHTS, |
| PHASE1_WEIGHTS, |
| PHASE2_WEIGHTS, |
| _BLEND_REGISTRY, |
| _extract_text, |
| _open_stream, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class _WorkerWeightedStream: |
| def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int): |
| self.configs = list(weights.keys()) |
| self.weights = [weights[c] for c in self.configs] |
| self.base_seed = base_seed |
| self.worker_id = worker_id |
| |
| |
| self.streams = {c: _open_stream(c, "train") for c in self.configs} |
| |
| self.rng = random.Random(base_seed + worker_id * 7919) |
| self.epoch = 1 |
|
|
| |
| |
| self._factual_docs: list[str] | None = None |
| self._factual_idx = 0 |
| self._inject_counter = 0 |
| inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50")) |
| self._inject_rate = inject_rate |
| if inject_rate > 0: |
| factual_path = os.path.join( |
| os.path.dirname(os.path.abspath(_p_nemo.__file__)), |
| "data", "factual", "facts.txt", |
| ) |
| if os.path.exists(factual_path): |
| with open(factual_path) as fh: |
| self._factual_docs = fh.read().strip().split("\n") |
|
|
| def _reopen(self, config: str) -> None: |
| self.streams[config] = _open_stream(config, "train") |
| self.epoch += 1 |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self) -> tuple[str, int]: |
| |
| if self._inject_rate > 0 and self._factual_docs: |
| self._inject_counter += 1 |
| if self._inject_counter >= self._inject_rate: |
| self._inject_counter = 0 |
| doc = self._factual_docs[self._factual_idx % len(self._factual_docs)] |
| self._factual_idx += 1 |
| return doc, self.epoch |
|
|
| config = self.rng.choices(self.configs, weights=self.weights, k=1)[0] |
| try: |
| row = next(self.streams[config]) |
| except StopIteration: |
| self._reopen(config) |
| row = next(self.streams[config]) |
| return _extract_text(row), self.epoch |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class IterableStreamDataset(IterableDataset): |
| """Streams docs, tokenizes, packs into (T+1,) rows via best-fit. |
| |
| Each worker gets its own instance (via fork/spawn) and initializes its |
| own HF streams + rustbpe tokenizer + factual injector. The tokenizer |
| pickled blob is small (~1 MB) and thread-safe per tiktoken docs. |
| """ |
|
|
| def __init__( |
| self, |
| split: str, |
| seq_len: int, |
| *, |
| base_seed: int = 0, |
| doc_buffer_size: int = 1000, |
| tokenizer_batch: int = 128, |
| ): |
| super().__init__() |
| assert split in ("train", "val"), split |
| self.split = split |
| self.seq_len = seq_len |
| self.row_capacity = seq_len + 1 |
| self.base_seed = base_seed |
| self.doc_buffer_size = doc_buffer_size |
| self.tokenizer_batch = tokenizer_batch |
|
|
| def _pick_weights(self) -> dict[str, float]: |
| if self.split == "val": |
| if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1": |
| return FULL_BLEND_WEIGHTS |
| return {"Nemotron-Pretraining-Multiple-Choice": 1.0} |
| if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1": |
| return FULL_BLEND_WEIGHTS |
| phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower() |
| return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS |
|
|
| def __iter__(self) -> Iterator[torch.Tensor]: |
| info = get_worker_info() |
| worker_id = 0 if info is None else info.id |
|
|
| |
| |
| |
| tokenizer = _prepare.Tokenizer.from_directory() |
| bos = tokenizer.get_bos_token_id() |
|
|
| |
| |
| |
| val_seed = 12345 |
| seed = val_seed if self.split == "val" else self.base_seed |
| stream = _WorkerWeightedStream( |
| self._pick_weights(), base_seed=seed, worker_id=worker_id, |
| ) |
|
|
| row_capacity = self.row_capacity |
| doc_buffer: list[list[int]] = [] |
| doc_batch_size = self.tokenizer_batch |
|
|
| def refill_buffer() -> None: |
| |
| texts: list[str] = [] |
| for _ in range(doc_batch_size): |
| text, _epoch = next(stream) |
| if text: |
| texts.append(text) |
| if texts: |
| token_lists = tokenizer.encode(texts, prepend=bos) |
| doc_buffer.extend(token_lists) |
|
|
| while True: |
| pos = 0 |
| row = torch.empty(row_capacity, dtype=torch.long) |
| while pos < row_capacity: |
| while len(doc_buffer) < self.doc_buffer_size: |
| refill_buffer() |
|
|
| remaining = row_capacity - pos |
|
|
| |
| best_idx = -1 |
| best_len = 0 |
| for i, doc in enumerate(doc_buffer): |
| dlen = len(doc) |
| if dlen <= remaining and dlen > best_len: |
| best_idx = i |
| best_len = dlen |
|
|
| if best_idx >= 0: |
| doc = doc_buffer.pop(best_idx) |
| row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long) |
| pos += len(doc) |
| else: |
| |
| shortest_idx = min( |
| range(len(doc_buffer)), |
| key=lambda i: len(doc_buffer[i]), |
| ) |
| doc = doc_buffer.pop(shortest_idx) |
| row[pos : pos + remaining] = torch.tensor( |
| doc[:remaining], dtype=torch.long, |
| ) |
| pos += remaining |
|
|
| yield row |
|
|
|
|
| |
| |
| |
|
|
|
|
| class HydraDataModule(L.LightningDataModule): |
| def __init__( |
| self, |
| batch_size: int | None = None, |
| seq_len: int | None = None, |
| num_workers: int | None = None, |
| prefetch_factor: int | None = None, |
| ): |
| super().__init__() |
| self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1")) |
| self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512")) |
| self.num_workers = ( |
| num_workers |
| if num_workers is not None |
| else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2")) |
| ) |
| self.prefetch_factor = ( |
| prefetch_factor |
| if prefetch_factor is not None |
| else int(os.environ.get("HYDRA_DATA_PREFETCH", "4")) |
| ) |
| self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000")) |
|
|
| def _make_loader(self, split: str, seed: int) -> DataLoader: |
| dataset = IterableStreamDataset( |
| split=split, |
| seq_len=self.seq_len, |
| base_seed=seed, |
| doc_buffer_size=self.doc_buffer, |
| ) |
| |
| |
| |
| kw: dict = dict( |
| dataset=dataset, |
| batch_size=self.batch_size, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| drop_last=True, |
| ) |
| if self.num_workers > 0: |
| kw["prefetch_factor"] = self.prefetch_factor |
| kw["persistent_workers"] = True |
| return DataLoader(**kw) |
|
|
| def train_dataloader(self) -> DataLoader: |
| return self._make_loader("train", seed=0) |
|
|
| def val_dataloader(self) -> DataLoader: |
| return self._make_loader("val", seed=12345) |
|
|