"""Lightning DataModule + IterableDataset for HYDRA pretraining. Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader with a standard multiprocessing DataLoader approach. Design: • IterableStreamDataset: each worker opens its own HF streams for the 7-way blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and yields one row per __next__. • HydraDataModule: wraps the dataset with a standard DataLoader using num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles device transfer. • Val stream: deterministic seed 12345, weights match training blend. The worker RNG is seeded per-worker so the weighted-sampling schedule is independent across workers (else all workers request the same config at the same step and prefetching serializes). Env vars (all preserved from prepare_nemotron): HYDRA_SEQ_LEN — sequence length T (default 512) HYDRA_BATCH_SIZE — batch size B (default 1) — passed through to DataLoader HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048) HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase HYDRA_USE_NEMOTRON — enables streaming path (else shard path) HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend) HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2) HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4) HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing (default 1000) """ from __future__ import annotations import os import random from typing import Iterator import numpy as np import torch import lightning as L from torch.utils.data import DataLoader, IterableDataset, get_worker_info import prepare as _prepare import prepare_nemotron as _p_nemo from prepare_nemotron import ( FULL_BLEND_WEIGHTS, PHASE1_WEIGHTS, PHASE2_WEIGHTS, _BLEND_REGISTRY, _extract_text, _open_stream, ) # --------------------------------------------------------------------------- # Worker-local weighted stream. A stripped version of prepare_nemotron's # _WeightedStream that is constructed inside each worker. Adds worker sharding: # when num_workers > 1 the RNG is seeded per-worker, so different workers # sample different config sequences and pull disjoint shard assignments from # HF's shuffle buffer. # --------------------------------------------------------------------------- class _WorkerWeightedStream: def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int): self.configs = list(weights.keys()) self.weights = [weights[c] for c in self.configs] self.base_seed = base_seed self.worker_id = worker_id # Each worker opens its own HF streams. _open_stream returns an iter() # over a streaming dataset, with an internal shuffle buffer. self.streams = {c: _open_stream(c, "train") for c in self.configs} # Per-worker RNG so the config-choice trajectory is independent. self.rng = random.Random(base_seed + worker_id * 7919) self.epoch = 1 # Lazy-init factual docs (once per worker). The main-process version # in prepare_nemotron._WeightedStream reads these on first __next__. self._factual_docs: list[str] | None = None self._factual_idx = 0 self._inject_counter = 0 inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50")) self._inject_rate = inject_rate if inject_rate > 0: factual_path = os.path.join( os.path.dirname(os.path.abspath(_p_nemo.__file__)), "data", "factual", "facts.txt", ) if os.path.exists(factual_path): with open(factual_path) as fh: self._factual_docs = fh.read().strip().split("\n") def _reopen(self, config: str) -> None: self.streams[config] = _open_stream(config, "train") self.epoch += 1 def __iter__(self): return self def __next__(self) -> tuple[str, int]: # Factual injection (preserves prepare_nemotron cadence). if self._inject_rate > 0 and self._factual_docs: self._inject_counter += 1 if self._inject_counter >= self._inject_rate: self._inject_counter = 0 doc = self._factual_docs[self._factual_idx % len(self._factual_docs)] self._factual_idx += 1 return doc, self.epoch config = self.rng.choices(self.configs, weights=self.weights, k=1)[0] try: row = next(self.streams[config]) except StopIteration: self._reopen(config) row = next(self.streams[config]) return _extract_text(row), self.epoch # --------------------------------------------------------------------------- # IterableStreamDataset — yields (T+1,) packed rows. No threads. No queues. # Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks # rows into batches of shape (B, T+1) and sends them to the main process. # --------------------------------------------------------------------------- class IterableStreamDataset(IterableDataset): """Streams docs, tokenizes, packs into (T+1,) rows via best-fit. Each worker gets its own instance (via fork/spawn) and initializes its own HF streams + rustbpe tokenizer + factual injector. The tokenizer pickled blob is small (~1 MB) and thread-safe per tiktoken docs. """ def __init__( self, split: str, seq_len: int, *, base_seed: int = 0, doc_buffer_size: int = 1000, tokenizer_batch: int = 128, ): super().__init__() assert split in ("train", "val"), split self.split = split self.seq_len = seq_len self.row_capacity = seq_len + 1 self.base_seed = base_seed self.doc_buffer_size = doc_buffer_size self.tokenizer_batch = tokenizer_batch def _pick_weights(self) -> dict[str, float]: if self.split == "val": if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1": return FULL_BLEND_WEIGHTS return {"Nemotron-Pretraining-Multiple-Choice": 1.0} if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1": return FULL_BLEND_WEIGHTS phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower() return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS def __iter__(self) -> Iterator[torch.Tensor]: info = get_worker_info() worker_id = 0 if info is None else info.id # Each worker builds its own tokenizer instance. tiktoken's Encoding # object is pickleable and the underlying C++ BPE is thread-safe; # per-worker instantiation avoids cross-process sharing headaches. tokenizer = _prepare.Tokenizer.from_directory() bos = tokenizer.get_bos_token_id() # Each worker gets its own weighted HF stream. Seed offset ensures # disjoint config-choice trajectories; HF's own shuffle buffer handles # shard randomization. val_seed = 12345 # deterministic val seed = val_seed if self.split == "val" else self.base_seed stream = _WorkerWeightedStream( self._pick_weights(), base_seed=seed, worker_id=worker_id, ) row_capacity = self.row_capacity doc_buffer: list[list[int]] = [] doc_batch_size = self.tokenizer_batch def refill_buffer() -> None: # Collect doc_batch_size text strings, then batch-tokenize. texts: list[str] = [] for _ in range(doc_batch_size): text, _epoch = next(stream) if text: texts.append(text) if texts: token_lists = tokenizer.encode(texts, prepend=bos) doc_buffer.extend(token_lists) while True: pos = 0 row = torch.empty(row_capacity, dtype=torch.long) while pos < row_capacity: while len(doc_buffer) < self.doc_buffer_size: refill_buffer() remaining = row_capacity - pos # Best-fit packing: largest doc that fully fits. best_idx = -1 best_len = 0 for i, doc in enumerate(doc_buffer): dlen = len(doc) if dlen <= remaining and dlen > best_len: best_idx = i best_len = dlen if best_idx >= 0: doc = doc_buffer.pop(best_idx) row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long) pos += len(doc) else: # No doc fits remaining space — crop shortest to fill. shortest_idx = min( range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]), ) doc = doc_buffer.pop(shortest_idx) row[pos : pos + remaining] = torch.tensor( doc[:remaining], dtype=torch.long, ) pos += remaining yield row # --------------------------------------------------------------------------- # LightningDataModule # --------------------------------------------------------------------------- class HydraDataModule(L.LightningDataModule): def __init__( self, batch_size: int | None = None, seq_len: int | None = None, num_workers: int | None = None, prefetch_factor: int | None = None, ): super().__init__() self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1")) self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512")) self.num_workers = ( num_workers if num_workers is not None else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2")) ) self.prefetch_factor = ( prefetch_factor if prefetch_factor is not None else int(os.environ.get("HYDRA_DATA_PREFETCH", "4")) ) self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000")) def _make_loader(self, split: str, seed: int) -> DataLoader: dataset = IterableStreamDataset( split=split, seq_len=self.seq_len, base_seed=seed, doc_buffer_size=self.doc_buffer, ) # num_workers=0 → main-process iteration (useful for debugging). With # IterableDataset the DataLoader batches the rows into (B, T+1) via # default torch.stack-collate. kw: dict = dict( dataset=dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, drop_last=True, ) if self.num_workers > 0: kw["prefetch_factor"] = self.prefetch_factor kw["persistent_workers"] = True return DataLoader(**kw) def train_dataloader(self) -> DataLoader: return self._make_loader("train", seed=0) def val_dataloader(self) -> DataLoader: return self._make_loader("val", seed=12345)