| """ |
| dataset_overlap.py |
| Per-stage PyTorch datasets + val set builders with REPLAY support to prevent catastrophic forgetting. |
| |
| Each stage has: |
| - A train dataset (tokenized, chunked to seq_len, streamed or cached) |
| - Multiple sources: current stage + replay buffer from previous stages |
| - A val dataset (fixed held-out split, loaded once into memory) |
| |
| Val sets for ALL three stages are always available so the training loop |
| can log cross-stage val losses at every eval step. |
| """ |
|
|
| import os |
| import pickle |
| import random |
| import numpy as np |
| from pathlib import Path |
| from typing import Iterator, Optional |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
|
|
| from tokenizers import Tokenizer |
|
|
|
|
| |
|
|
| def tokenize_and_chunk( |
| text_iter : Iterator[str], |
| tokenizer : Tokenizer, |
| seq_len : int, |
| max_tokens : Optional[int] = None, |
| bos_id : int = 2, |
| eos_id : int = 3, |
| ) -> list[list[int]]: |
| """ |
| Streams text, tokenizes, concatenates with BOS/EOS, then chunks into |
| non-overlapping windows of seq_len+1 tokens (input + target shift). |
| Returns list of token id lists, each of length seq_len+1. |
| """ |
| buffer = [] |
| chunks = [] |
| tokens_seen = 0 |
|
|
| with tqdm(total=max_tokens, unit="tok", desc="Tokenizing & chunking", disable=(max_tokens is None)) as pbar: |
| for text in text_iter: |
| if not text or not text.strip(): |
| continue |
| ids = [bos_id] + tokenizer.encode(text).ids + [eos_id] |
| buffer.extend(ids) |
|
|
| while len(buffer) >= seq_len + 1: |
| chunks.append(buffer[: seq_len + 1]) |
| buffer = buffer[seq_len + 1 :] |
| tokens_seen += seq_len + 1 |
| pbar.update(seq_len + 1) |
| if max_tokens and tokens_seen >= max_tokens: |
| pbar.close() |
| return chunks |
|
|
| return chunks |
|
|
|
|
| |
|
|
| def build_val_set( |
| stage : int, |
| tokenizer : Tokenizer, |
| seq_len : int, |
| n_docs : int, |
| cache_dir : str = "cache", |
| ) -> list[list[int]]: |
| """ |
| Loads or builds + caches the val set for a given stage. |
| Always uses the same held-out seed so val sets are reproducible. |
| """ |
| os.makedirs(cache_dir, exist_ok=True) |
| cache_path = os.path.join(cache_dir, f"val_stage{stage}_seq{seq_len}.pkl") |
|
|
| if os.path.exists(cache_path): |
| print(f"[dataset] Loading cached val_stage{stage} from {cache_path}") |
| with open(cache_path, "rb") as f: |
| return pickle.load(f) |
|
|
| print(f"[dataset] Building val_stage{stage} ({n_docs} docs)...") |
| text_iter = _val_iter(stage, n_docs) |
| chunks = tokenize_and_chunk(text_iter, tokenizer, seq_len) |
|
|
| with open(cache_path, "wb") as f: |
| pickle.dump(chunks, f) |
| print(f"[dataset] val_stage{stage} saved: {len(chunks)} chunks β {cache_path}") |
| return chunks |
|
|
|
|
| def _val_iter(stage: int, n_docs: int) -> Iterator[str]: |
| """Returns a deterministic held-out slice for each stage.""" |
| from datasets import load_dataset |
| rng = random.Random(42) |
|
|
| if stage == 0: |
| ds = load_dataset("roneneldan/TinyStories", split="validation", streaming=True) |
| for i, ex in enumerate(ds): |
| if i >= n_docs: break |
| yield ex["text"] |
|
|
| elif stage == 1: |
| |
| ds = load_dataset("wikimedia/wikipedia", "20231101.simple", |
| split="train", streaming=True) |
| count = 0 |
| for ex in ds: |
| if rng.random() < 0.05: |
| yield ex["text"] |
| count += 1 |
| if count >= n_docs: break |
|
|
| elif stage == 2: |
| ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", |
| split="train", streaming=True) |
| count = 0 |
| for ex in ds: |
| if ex.get("score", 0) >= 3 and rng.random() < 0.01: |
| yield ex["text"] |
| count += 1 |
| if count >= n_docs: break |
|
|
|
|
| |
|
|
| class ChunkedTokenDataset(Dataset): |
| """Simple dataset from a list of token id chunks.""" |
| def __init__(self, chunks: list[list[int]]): |
| self.chunks = chunks |
|
|
| def __len__(self): |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx): |
| chunk = torch.tensor(self.chunks[idx], dtype=torch.long) |
| x = chunk[:-1] |
| y = chunk[1:] |
| return x, y |
|
|
|
|
| |
|
|
| class StreamingStageDataset(Dataset): |
| """ |
| Pre-tokenizes + chunks a stage's training data into memory. |
| Supports REPLAY: mixes current stage data with previous stage data |
| to prevent catastrophic forgetting. |
| |
| Call build() before using in a DataLoader. |
| """ |
| def __init__(self): |
| self.chunks: list[list[int]] = [] |
|
|
| def build( |
| self, |
| stage : int, |
| tokenizer : Tokenizer, |
| seq_len : int, |
| max_tokens : int, |
| cache_dir : str = "cache", |
| replay_ratio: float = 0.0, |
| ): |
| os.makedirs(cache_dir, exist_ok=True) |
| cache_path = os.path.join(cache_dir, f"train_stage{stage}_seq{seq_len}_replay{replay_ratio:.1f}.pkl") |
|
|
| if os.path.exists(cache_path): |
| print(f"[dataset] Loading cached train_stage{stage} (replay={replay_ratio:.1f})") |
| with open(cache_path, "rb") as f: |
| self.chunks = pickle.load(f) |
| print(f"[dataset] Loaded {len(self.chunks):,} chunks") |
| return self |
|
|
| print(f"[dataset] Building train_stage{stage} (max_tokens={max_tokens:,}, replay={replay_ratio:.1f})...") |
| text_iter = _train_iter(stage, replay_ratio=replay_ratio) |
| self.chunks = tokenize_and_chunk(text_iter, tokenizer, seq_len, max_tokens) |
| random.shuffle(self.chunks) |
|
|
| with open(cache_path, "wb") as f: |
| pickle.dump(self.chunks, f) |
| print(f"[dataset] train_stage{stage}: {len(self.chunks):,} chunks β {cache_path}") |
| return self |
|
|
| def __len__(self): |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx): |
| chunk = torch.tensor(self.chunks[idx], dtype=torch.long) |
| return chunk[:-1], chunk[1:] |
|
|
|
|
| def _train_iter(stage: int, replay_ratio: float = 0.0) -> Iterator[str]: |
| """ |
| Yields text from current stage, optionally interleaved with previous stage(s) |
| based on replay_ratio. |
| |
| replay_ratio: float in [0, 1] |
| - 0.0: only current stage |
| - 0.2: 20% previous stage(s), 80% current |
| - 1.0: 100% previous stage(s) (pathological) |
| """ |
| from datasets import load_dataset |
|
|
| rng = random.Random(42) |
| |
| |
| if stage == 0: |
| current_sources = [iter(load_dataset("roneneldan/TinyStories", split="train", streaming=True))] |
| |
| elif stage == 1: |
| wiki = load_dataset("wikimedia/wikipedia", "20231101.simple", split="train", streaming=True) |
| try: |
| baby = load_dataset("babylm/babylm_10M", split="train", streaming=True) |
| current_sources = [iter(wiki), iter(baby)] |
| except Exception: |
| current_sources = [iter(wiki)] |
| |
| elif stage == 2: |
| ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True) |
| current_sources = [iter(ds)] |
| |
| else: |
| raise ValueError(f"Unknown stage: {stage}") |
| |
| |
| replay_sources = [] |
| if replay_ratio > 0 and stage > 0: |
| if stage >= 1: |
| |
| replay_sources.append(iter(load_dataset("roneneldan/TinyStories", split="train", streaming=True))) |
| if stage >= 2: |
| |
| wiki = load_dataset("wikimedia/wikipedia", "20231101.simple", split="train", streaming=True) |
| try: |
| baby = load_dataset("babylm/babylm_10M", split="train", streaming=True) |
| replay_sources.extend([iter(wiki), iter(baby)]) |
| except Exception: |
| replay_sources.append(iter(wiki)) |
| |
| |
| current_active = list(range(len(current_sources))) |
| replay_active = list(range(len(replay_sources))) |
| |
| while current_active or replay_active: |
| |
| if replay_active and rng.random() < replay_ratio: |
| idx = rng.choice(replay_active) |
| try: |
| ex = next(replay_sources[idx]) |
| text = ex.get("text") or ex.get("sentence") or "" |
| if text: yield text |
| except StopIteration: |
| replay_active.remove(idx) |
| elif current_active: |
| idx = rng.choice(current_active) |
| try: |
| ex = next(current_sources[idx]) |
| text = ex.get("text") or ex.get("sentence") or "" |
| if text: yield text |
| except StopIteration: |
| current_active.remove(idx) |
| else: |
| break |
|
|
|
|
| |
|
|
| def make_dataloader(dataset: Dataset, batch_size: int, shuffle: bool = True) -> DataLoader: |
| return DataLoader( |
| dataset, |
| batch_size = batch_size, |
| shuffle = shuffle, |
| num_workers = 2, |
| pin_memory = True, |
| ) |
|
|
|
|
| |
|
|
| def load_all_val_sets( |
| tokenizer : Tokenizer, |
| cache_dir : str = "cache", |
| ) -> dict: |
| """ |
| Returns a dict of three pre-built val DataLoaders: |
| {"s0": loader, "s1": loader, "s2": loader} |
| Used in train.py to log all cross-stage val losses. |
| """ |
| configs = { |
| "s0": (0, 256, 5000), |
| "s1": (1, 384, 3000), |
| "s2": (2, 512, 2000), |
| } |
| loaders = {} |
| for key, (stage, seq_len, n_docs) in configs.items(): |
| chunks = build_val_set(stage, tokenizer, seq_len, n_docs, cache_dir) |
| ds = ChunkedTokenDataset(chunks) |
| loaders[key] = DataLoader(ds, batch_size=16, shuffle=False, num_workers=2) |
| return loaders |
|
|