icarus112's picture
Upload folder using huggingface_hub
c383594 verified
"""Lightning DataModule + IterableDataset for HYDRA pretraining.
Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader
with a standard multiprocessing DataLoader approach.
Design:
• IterableStreamDataset: each worker opens its own HF streams for the 7-way
blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and
yields one row per __next__.
• HydraDataModule: wraps the dataset with a standard DataLoader using
num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles
device transfer.
• Val stream: deterministic seed 12345, weights match training blend.
The worker RNG is seeded per-worker so the weighted-sampling schedule is
independent across workers (else all workers request the same config at
the same step and prefetching serializes).
Env vars (all preserved from prepare_nemotron):
HYDRA_SEQ_LEN — sequence length T (default 512)
HYDRA_BATCH_SIZE — batch size B (default 1) — passed through
to DataLoader
HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048)
HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase
HYDRA_USE_NEMOTRON — enables streaming path (else shard path)
HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence
HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend)
HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2)
HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4)
HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing
(default 1000)
"""
from __future__ import annotations
import os
import random
from typing import Iterator
import numpy as np
import torch
import lightning as L
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
import prepare as _prepare
import prepare_nemotron as _p_nemo
from prepare_nemotron import (
FULL_BLEND_WEIGHTS,
PHASE1_WEIGHTS,
PHASE2_WEIGHTS,
_BLEND_REGISTRY,
_extract_text,
_open_stream,
)
# ---------------------------------------------------------------------------
# Worker-local weighted stream. A stripped version of prepare_nemotron's
# _WeightedStream that is constructed inside each worker. Adds worker sharding:
# when num_workers > 1 the RNG is seeded per-worker, so different workers
# sample different config sequences and pull disjoint shard assignments from
# HF's shuffle buffer.
# ---------------------------------------------------------------------------
class _WorkerWeightedStream:
def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int):
self.configs = list(weights.keys())
self.weights = [weights[c] for c in self.configs]
self.base_seed = base_seed
self.worker_id = worker_id
# Each worker opens its own HF streams. _open_stream returns an iter()
# over a streaming dataset, with an internal shuffle buffer.
self.streams = {c: _open_stream(c, "train") for c in self.configs}
# Per-worker RNG so the config-choice trajectory is independent.
self.rng = random.Random(base_seed + worker_id * 7919)
self.epoch = 1
# Lazy-init factual docs (once per worker). The main-process version
# in prepare_nemotron._WeightedStream reads these on first __next__.
self._factual_docs: list[str] | None = None
self._factual_idx = 0
self._inject_counter = 0
inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
self._inject_rate = inject_rate
if inject_rate > 0:
factual_path = os.path.join(
os.path.dirname(os.path.abspath(_p_nemo.__file__)),
"data", "factual", "facts.txt",
)
if os.path.exists(factual_path):
with open(factual_path) as fh:
self._factual_docs = fh.read().strip().split("\n")
def _reopen(self, config: str) -> None:
self.streams[config] = _open_stream(config, "train")
self.epoch += 1
def __iter__(self):
return self
def __next__(self) -> tuple[str, int]:
# Factual injection (preserves prepare_nemotron cadence).
if self._inject_rate > 0 and self._factual_docs:
self._inject_counter += 1
if self._inject_counter >= self._inject_rate:
self._inject_counter = 0
doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
self._factual_idx += 1
return doc, self.epoch
config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
try:
row = next(self.streams[config])
except StopIteration:
self._reopen(config)
row = next(self.streams[config])
return _extract_text(row), self.epoch
# ---------------------------------------------------------------------------
# IterableStreamDataset — yields (T+1,) packed rows. No threads. No queues.
# Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks
# rows into batches of shape (B, T+1) and sends them to the main process.
# ---------------------------------------------------------------------------
class IterableStreamDataset(IterableDataset):
"""Streams docs, tokenizes, packs into (T+1,) rows via best-fit.
Each worker gets its own instance (via fork/spawn) and initializes its
own HF streams + rustbpe tokenizer + factual injector. The tokenizer
pickled blob is small (~1 MB) and thread-safe per tiktoken docs.
"""
def __init__(
self,
split: str,
seq_len: int,
*,
base_seed: int = 0,
doc_buffer_size: int = 1000,
tokenizer_batch: int = 128,
):
super().__init__()
assert split in ("train", "val"), split
self.split = split
self.seq_len = seq_len
self.row_capacity = seq_len + 1
self.base_seed = base_seed
self.doc_buffer_size = doc_buffer_size
self.tokenizer_batch = tokenizer_batch
def _pick_weights(self) -> dict[str, float]:
if self.split == "val":
if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
return FULL_BLEND_WEIGHTS
return {"Nemotron-Pretraining-Multiple-Choice": 1.0}
if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
return FULL_BLEND_WEIGHTS
phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower()
return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
def __iter__(self) -> Iterator[torch.Tensor]:
info = get_worker_info()
worker_id = 0 if info is None else info.id
# Each worker builds its own tokenizer instance. tiktoken's Encoding
# object is pickleable and the underlying C++ BPE is thread-safe;
# per-worker instantiation avoids cross-process sharing headaches.
tokenizer = _prepare.Tokenizer.from_directory()
bos = tokenizer.get_bos_token_id()
# Each worker gets its own weighted HF stream. Seed offset ensures
# disjoint config-choice trajectories; HF's own shuffle buffer handles
# shard randomization.
val_seed = 12345 # deterministic val
seed = val_seed if self.split == "val" else self.base_seed
stream = _WorkerWeightedStream(
self._pick_weights(), base_seed=seed, worker_id=worker_id,
)
row_capacity = self.row_capacity
doc_buffer: list[list[int]] = []
doc_batch_size = self.tokenizer_batch
def refill_buffer() -> None:
# Collect doc_batch_size text strings, then batch-tokenize.
texts: list[str] = []
for _ in range(doc_batch_size):
text, _epoch = next(stream)
if text:
texts.append(text)
if texts:
token_lists = tokenizer.encode(texts, prepend=bos)
doc_buffer.extend(token_lists)
while True:
pos = 0
row = torch.empty(row_capacity, dtype=torch.long)
while pos < row_capacity:
while len(doc_buffer) < self.doc_buffer_size:
refill_buffer()
remaining = row_capacity - pos
# Best-fit packing: largest doc that fully fits.
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
dlen = len(doc)
if dlen <= remaining and dlen > best_len:
best_idx = i
best_len = dlen
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
pos += len(doc)
else:
# No doc fits remaining space — crop shortest to fill.
shortest_idx = min(
range(len(doc_buffer)),
key=lambda i: len(doc_buffer[i]),
)
doc = doc_buffer.pop(shortest_idx)
row[pos : pos + remaining] = torch.tensor(
doc[:remaining], dtype=torch.long,
)
pos += remaining
yield row
# ---------------------------------------------------------------------------
# LightningDataModule
# ---------------------------------------------------------------------------
class HydraDataModule(L.LightningDataModule):
def __init__(
self,
batch_size: int | None = None,
seq_len: int | None = None,
num_workers: int | None = None,
prefetch_factor: int | None = None,
):
super().__init__()
self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512"))
self.num_workers = (
num_workers
if num_workers is not None
else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2"))
)
self.prefetch_factor = (
prefetch_factor
if prefetch_factor is not None
else int(os.environ.get("HYDRA_DATA_PREFETCH", "4"))
)
self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000"))
def _make_loader(self, split: str, seed: int) -> DataLoader:
dataset = IterableStreamDataset(
split=split,
seq_len=self.seq_len,
base_seed=seed,
doc_buffer_size=self.doc_buffer,
)
# num_workers=0 → main-process iteration (useful for debugging). With
# IterableDataset the DataLoader batches the rows into (B, T+1) via
# default torch.stack-collate.
kw: dict = dict(
dataset=dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
)
if self.num_workers > 0:
kw["prefetch_factor"] = self.prefetch_factor
kw["persistent_workers"] = True
return DataLoader(**kw)
def train_dataloader(self) -> DataLoader:
return self._make_loader("train", seed=0)
def val_dataloader(self) -> DataLoader:
return self._make_loader("val", seed=12345)