Spaces:
Runtime error
Runtime error
| """Lightning DataModule + IterableDataset for HYDRA pretraining. | |
| Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader | |
| with a standard multiprocessing DataLoader approach. | |
| Design: | |
| • IterableStreamDataset: each worker opens its own HF streams for the 7-way | |
| blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and | |
| yields one row per __next__. | |
| • HydraDataModule: wraps the dataset with a standard DataLoader using | |
| num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles | |
| device transfer. | |
| • Val stream: deterministic seed 12345, weights match training blend. | |
| The worker RNG is seeded per-worker so the weighted-sampling schedule is | |
| independent across workers (else all workers request the same config at | |
| the same step and prefetching serializes). | |
| Env vars (all preserved from prepare_nemotron): | |
| HYDRA_SEQ_LEN — sequence length T (default 512) | |
| HYDRA_BATCH_SIZE — batch size B (default 1) — passed through | |
| to DataLoader | |
| HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048) | |
| HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase | |
| HYDRA_USE_NEMOTRON — enables streaming path (else shard path) | |
| HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence | |
| HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend) | |
| HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2) | |
| HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4) | |
| HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing | |
| (default 1000) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import random | |
| from typing import Iterator | |
| import numpy as np | |
| import torch | |
| import lightning as L | |
| from torch.utils.data import DataLoader, IterableDataset, get_worker_info | |
| import prepare as _prepare | |
| import prepare_nemotron as _p_nemo | |
| from prepare_nemotron import ( | |
| FULL_BLEND_WEIGHTS, | |
| PHASE1_WEIGHTS, | |
| PHASE2_WEIGHTS, | |
| _BLEND_REGISTRY, | |
| _extract_text, | |
| _open_stream, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Worker-local weighted stream. A stripped version of prepare_nemotron's | |
| # _WeightedStream that is constructed inside each worker. Adds worker sharding: | |
| # when num_workers > 1 the RNG is seeded per-worker, so different workers | |
| # sample different config sequences and pull disjoint shard assignments from | |
| # HF's shuffle buffer. | |
| # --------------------------------------------------------------------------- | |
| class _WorkerWeightedStream: | |
| def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int): | |
| self.configs = list(weights.keys()) | |
| self.weights = [weights[c] for c in self.configs] | |
| self.base_seed = base_seed | |
| self.worker_id = worker_id | |
| # Each worker opens its own HF streams. _open_stream returns an iter() | |
| # over a streaming dataset, with an internal shuffle buffer. | |
| self.streams = {c: _open_stream(c, "train") for c in self.configs} | |
| # Per-worker RNG so the config-choice trajectory is independent. | |
| self.rng = random.Random(base_seed + worker_id * 7919) | |
| self.epoch = 1 | |
| # Lazy-init factual docs (once per worker). The main-process version | |
| # in prepare_nemotron._WeightedStream reads these on first __next__. | |
| self._factual_docs: list[str] | None = None | |
| self._factual_idx = 0 | |
| self._inject_counter = 0 | |
| inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50")) | |
| self._inject_rate = inject_rate | |
| if inject_rate > 0: | |
| factual_path = os.path.join( | |
| os.path.dirname(os.path.abspath(_p_nemo.__file__)), | |
| "data", "factual", "facts.txt", | |
| ) | |
| if os.path.exists(factual_path): | |
| with open(factual_path) as fh: | |
| self._factual_docs = fh.read().strip().split("\n") | |
| def _reopen(self, config: str) -> None: | |
| self.streams[config] = _open_stream(config, "train") | |
| self.epoch += 1 | |
| def __iter__(self): | |
| return self | |
| def __next__(self) -> tuple[str, int]: | |
| # Factual injection (preserves prepare_nemotron cadence). | |
| if self._inject_rate > 0 and self._factual_docs: | |
| self._inject_counter += 1 | |
| if self._inject_counter >= self._inject_rate: | |
| self._inject_counter = 0 | |
| doc = self._factual_docs[self._factual_idx % len(self._factual_docs)] | |
| self._factual_idx += 1 | |
| return doc, self.epoch | |
| config = self.rng.choices(self.configs, weights=self.weights, k=1)[0] | |
| try: | |
| row = next(self.streams[config]) | |
| except StopIteration: | |
| self._reopen(config) | |
| row = next(self.streams[config]) | |
| return _extract_text(row), self.epoch | |
| # --------------------------------------------------------------------------- | |
| # IterableStreamDataset — yields (T+1,) packed rows. No threads. No queues. | |
| # Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks | |
| # rows into batches of shape (B, T+1) and sends them to the main process. | |
| # --------------------------------------------------------------------------- | |
| class IterableStreamDataset(IterableDataset): | |
| """Streams docs, tokenizes, packs into (T+1,) rows via best-fit. | |
| Each worker gets its own instance (via fork/spawn) and initializes its | |
| own HF streams + rustbpe tokenizer + factual injector. The tokenizer | |
| pickled blob is small (~1 MB) and thread-safe per tiktoken docs. | |
| """ | |
| def __init__( | |
| self, | |
| split: str, | |
| seq_len: int, | |
| *, | |
| base_seed: int = 0, | |
| doc_buffer_size: int = 1000, | |
| tokenizer_batch: int = 128, | |
| ): | |
| super().__init__() | |
| assert split in ("train", "val"), split | |
| self.split = split | |
| self.seq_len = seq_len | |
| self.row_capacity = seq_len + 1 | |
| self.base_seed = base_seed | |
| self.doc_buffer_size = doc_buffer_size | |
| self.tokenizer_batch = tokenizer_batch | |
| def _pick_weights(self) -> dict[str, float]: | |
| if self.split == "val": | |
| if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1": | |
| return FULL_BLEND_WEIGHTS | |
| return {"Nemotron-Pretraining-Multiple-Choice": 1.0} | |
| if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1": | |
| return FULL_BLEND_WEIGHTS | |
| phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower() | |
| return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS | |
| def __iter__(self) -> Iterator[torch.Tensor]: | |
| info = get_worker_info() | |
| worker_id = 0 if info is None else info.id | |
| # Each worker builds its own tokenizer instance. tiktoken's Encoding | |
| # object is pickleable and the underlying C++ BPE is thread-safe; | |
| # per-worker instantiation avoids cross-process sharing headaches. | |
| tokenizer = _prepare.Tokenizer.from_directory() | |
| bos = tokenizer.get_bos_token_id() | |
| # Each worker gets its own weighted HF stream. Seed offset ensures | |
| # disjoint config-choice trajectories; HF's own shuffle buffer handles | |
| # shard randomization. | |
| val_seed = 12345 # deterministic val | |
| seed = val_seed if self.split == "val" else self.base_seed | |
| stream = _WorkerWeightedStream( | |
| self._pick_weights(), base_seed=seed, worker_id=worker_id, | |
| ) | |
| row_capacity = self.row_capacity | |
| doc_buffer: list[list[int]] = [] | |
| doc_batch_size = self.tokenizer_batch | |
| def refill_buffer() -> None: | |
| # Collect doc_batch_size text strings, then batch-tokenize. | |
| texts: list[str] = [] | |
| for _ in range(doc_batch_size): | |
| text, _epoch = next(stream) | |
| if text: | |
| texts.append(text) | |
| if texts: | |
| token_lists = tokenizer.encode(texts, prepend=bos) | |
| doc_buffer.extend(token_lists) | |
| while True: | |
| pos = 0 | |
| row = torch.empty(row_capacity, dtype=torch.long) | |
| while pos < row_capacity: | |
| while len(doc_buffer) < self.doc_buffer_size: | |
| refill_buffer() | |
| remaining = row_capacity - pos | |
| # Best-fit packing: largest doc that fully fits. | |
| best_idx = -1 | |
| best_len = 0 | |
| for i, doc in enumerate(doc_buffer): | |
| dlen = len(doc) | |
| if dlen <= remaining and dlen > best_len: | |
| best_idx = i | |
| best_len = dlen | |
| if best_idx >= 0: | |
| doc = doc_buffer.pop(best_idx) | |
| row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long) | |
| pos += len(doc) | |
| else: | |
| # No doc fits remaining space — crop shortest to fill. | |
| shortest_idx = min( | |
| range(len(doc_buffer)), | |
| key=lambda i: len(doc_buffer[i]), | |
| ) | |
| doc = doc_buffer.pop(shortest_idx) | |
| row[pos : pos + remaining] = torch.tensor( | |
| doc[:remaining], dtype=torch.long, | |
| ) | |
| pos += remaining | |
| yield row | |
| # --------------------------------------------------------------------------- | |
| # LightningDataModule | |
| # --------------------------------------------------------------------------- | |
| class HydraDataModule(L.LightningDataModule): | |
| def __init__( | |
| self, | |
| batch_size: int | None = None, | |
| seq_len: int | None = None, | |
| num_workers: int | None = None, | |
| prefetch_factor: int | None = None, | |
| ): | |
| super().__init__() | |
| self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1")) | |
| self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512")) | |
| self.num_workers = ( | |
| num_workers | |
| if num_workers is not None | |
| else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2")) | |
| ) | |
| self.prefetch_factor = ( | |
| prefetch_factor | |
| if prefetch_factor is not None | |
| else int(os.environ.get("HYDRA_DATA_PREFETCH", "4")) | |
| ) | |
| self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000")) | |
| def _make_loader(self, split: str, seed: int) -> DataLoader: | |
| dataset = IterableStreamDataset( | |
| split=split, | |
| seq_len=self.seq_len, | |
| base_seed=seed, | |
| doc_buffer_size=self.doc_buffer, | |
| ) | |
| # num_workers=0 → main-process iteration (useful for debugging). With | |
| # IterableDataset the DataLoader batches the rows into (B, T+1) via | |
| # default torch.stack-collate. | |
| kw: dict = dict( | |
| dataset=dataset, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| if self.num_workers > 0: | |
| kw["prefetch_factor"] = self.prefetch_factor | |
| kw["persistent_workers"] = True | |
| return DataLoader(**kw) | |
| def train_dataloader(self) -> DataLoader: | |
| return self._make_loader("train", seed=0) | |
| def val_dataloader(self) -> DataLoader: | |
| return self._make_loader("val", seed=12345) | |