slm / dataset_overlap.py
Lomesh7777's picture
Upload folder using huggingface_hub
1bbe1a8 verified
"""
dataset_overlap.py
Per-stage PyTorch datasets + val set builders with REPLAY support to prevent catastrophic forgetting.
Each stage has:
- A train dataset (tokenized, chunked to seq_len, streamed or cached)
- Multiple sources: current stage + replay buffer from previous stages
- A val dataset (fixed held-out split, loaded once into memory)
Val sets for ALL three stages are always available so the training loop
can log cross-stage val losses at every eval step.
"""
import os
import pickle
import random
import numpy as np
from pathlib import Path
from typing import Iterator, Optional
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from tokenizers import Tokenizer
# ─── Token-level chunking ─────────────────────────────────────────────────────
def tokenize_and_chunk(
text_iter : Iterator[str],
tokenizer : Tokenizer,
seq_len : int,
max_tokens : Optional[int] = None,
bos_id : int = 2, # [BOS] index in special_tokens
eos_id : int = 3, # [EOS]
) -> list[list[int]]:
"""
Streams text, tokenizes, concatenates with BOS/EOS, then chunks into
non-overlapping windows of seq_len+1 tokens (input + target shift).
Returns list of token id lists, each of length seq_len+1.
"""
buffer = []
chunks = []
tokens_seen = 0
with tqdm(total=max_tokens, unit="tok", desc="Tokenizing & chunking", disable=(max_tokens is None)) as pbar:
for text in text_iter:
if not text or not text.strip():
continue
ids = [bos_id] + tokenizer.encode(text).ids + [eos_id]
buffer.extend(ids)
while len(buffer) >= seq_len + 1:
chunks.append(buffer[: seq_len + 1])
buffer = buffer[seq_len + 1 :]
tokens_seen += seq_len + 1
pbar.update(seq_len + 1)
if max_tokens and tokens_seen >= max_tokens:
pbar.close()
return chunks
return chunks
# ─── Cached val set builder ───────────────────────────────────────────────────
def build_val_set(
stage : int,
tokenizer : Tokenizer,
seq_len : int,
n_docs : int,
cache_dir : str = "cache",
) -> list[list[int]]:
"""
Loads or builds + caches the val set for a given stage.
Always uses the same held-out seed so val sets are reproducible.
"""
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, f"val_stage{stage}_seq{seq_len}.pkl")
if os.path.exists(cache_path):
print(f"[dataset] Loading cached val_stage{stage} from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)
print(f"[dataset] Building val_stage{stage} ({n_docs} docs)...")
text_iter = _val_iter(stage, n_docs)
chunks = tokenize_and_chunk(text_iter, tokenizer, seq_len)
with open(cache_path, "wb") as f:
pickle.dump(chunks, f)
print(f"[dataset] val_stage{stage} saved: {len(chunks)} chunks β†’ {cache_path}")
return chunks
def _val_iter(stage: int, n_docs: int) -> Iterator[str]:
"""Returns a deterministic held-out slice for each stage."""
from datasets import load_dataset
rng = random.Random(42) # fixed seed for reproducibility
if stage == 0:
ds = load_dataset("roneneldan/TinyStories", split="validation", streaming=True)
for i, ex in enumerate(ds):
if i >= n_docs: break
yield ex["text"]
elif stage == 1:
# Mix of SimpleWiki + BabyLM
ds = load_dataset("wikimedia/wikipedia", "20231101.simple",
split="train", streaming=True)
count = 0
for ex in ds:
if rng.random() < 0.05: # ~5% holdout sample
yield ex["text"]
count += 1
if count >= n_docs: break
elif stage == 2:
ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT",
split="train", streaming=True)
count = 0
for ex in ds:
if ex.get("score", 0) >= 3 and rng.random() < 0.01:
yield ex["text"]
count += 1
if count >= n_docs: break
# ─── In-memory val dataset ────────────────────────────────────────────────────
class ChunkedTokenDataset(Dataset):
"""Simple dataset from a list of token id chunks."""
def __init__(self, chunks: list[list[int]]):
self.chunks = chunks
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
chunk = torch.tensor(self.chunks[idx], dtype=torch.long)
x = chunk[:-1] # input
y = chunk[1:] # target (shifted by 1)
return x, y
# ─── Streaming train dataset with REPLAY ──────────────────────────────────────
class StreamingStageDataset(Dataset):
"""
Pre-tokenizes + chunks a stage's training data into memory.
Supports REPLAY: mixes current stage data with previous stage data
to prevent catastrophic forgetting.
Call build() before using in a DataLoader.
"""
def __init__(self):
self.chunks: list[list[int]] = []
def build(
self,
stage : int,
tokenizer : Tokenizer,
seq_len : int,
max_tokens : int,
cache_dir : str = "cache",
replay_ratio: float = 0.0, # % of tokens from previous stage(s)
):
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, f"train_stage{stage}_seq{seq_len}_replay{replay_ratio:.1f}.pkl")
if os.path.exists(cache_path):
print(f"[dataset] Loading cached train_stage{stage} (replay={replay_ratio:.1f})")
with open(cache_path, "rb") as f:
self.chunks = pickle.load(f)
print(f"[dataset] Loaded {len(self.chunks):,} chunks")
return self
print(f"[dataset] Building train_stage{stage} (max_tokens={max_tokens:,}, replay={replay_ratio:.1f})...")
text_iter = _train_iter(stage, replay_ratio=replay_ratio)
self.chunks = tokenize_and_chunk(text_iter, tokenizer, seq_len, max_tokens)
random.shuffle(self.chunks) # shuffle once before caching
with open(cache_path, "wb") as f:
pickle.dump(self.chunks, f)
print(f"[dataset] train_stage{stage}: {len(self.chunks):,} chunks β†’ {cache_path}")
return self
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
chunk = torch.tensor(self.chunks[idx], dtype=torch.long)
return chunk[:-1], chunk[1:]
def _train_iter(stage: int, replay_ratio: float = 0.0) -> Iterator[str]:
"""
Yields text from current stage, optionally interleaved with previous stage(s)
based on replay_ratio.
replay_ratio: float in [0, 1]
- 0.0: only current stage
- 0.2: 20% previous stage(s), 80% current
- 1.0: 100% previous stage(s) (pathological)
"""
from datasets import load_dataset
rng = random.Random(42) # fixed seed for reproducibility
# Get current stage source(s)
if stage == 0:
current_sources = [iter(load_dataset("roneneldan/TinyStories", split="train", streaming=True))]
elif stage == 1:
wiki = load_dataset("wikimedia/wikipedia", "20231101.simple", split="train", streaming=True)
try:
baby = load_dataset("babylm/babylm_10M", split="train", streaming=True)
current_sources = [iter(wiki), iter(baby)]
except Exception:
current_sources = [iter(wiki)]
elif stage == 2:
ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True)
current_sources = [iter(ds)]
else:
raise ValueError(f"Unknown stage: {stage}")
# Get replay source(s) if replay_ratio > 0
replay_sources = []
if replay_ratio > 0 and stage > 0:
if stage >= 1:
# Add TinyStories (stage 0 data)
replay_sources.append(iter(load_dataset("roneneldan/TinyStories", split="train", streaming=True)))
if stage >= 2:
# Add SimpleWiki + BabyLM (stage 1 data)
wiki = load_dataset("wikimedia/wikipedia", "20231101.simple", split="train", streaming=True)
try:
baby = load_dataset("babylm/babylm_10M", split="train", streaming=True)
replay_sources.extend([iter(wiki), iter(baby)])
except Exception:
replay_sources.append(iter(wiki))
# Interleave: use rng.random() to decide whether to draw from replay or current
current_active = list(range(len(current_sources)))
replay_active = list(range(len(replay_sources)))
while current_active or replay_active:
# Choose which pool to draw from based on replay_ratio
if replay_active and rng.random() < replay_ratio:
idx = rng.choice(replay_active)
try:
ex = next(replay_sources[idx])
text = ex.get("text") or ex.get("sentence") or ""
if text: yield text
except StopIteration:
replay_active.remove(idx)
elif current_active:
idx = rng.choice(current_active)
try:
ex = next(current_sources[idx])
text = ex.get("text") or ex.get("sentence") or ""
if text: yield text
except StopIteration:
current_active.remove(idx)
else:
break
# ─── DataLoader factory ───────────────────────────────────────────────────────
def make_dataloader(dataset: Dataset, batch_size: int, shuffle: bool = True) -> DataLoader:
return DataLoader(
dataset,
batch_size = batch_size,
shuffle = shuffle,
num_workers = 2,
pin_memory = True,
)
# ─── Load all three val sets ──────────────────────────────────────────────────
def load_all_val_sets(
tokenizer : Tokenizer,
cache_dir : str = "cache",
) -> dict:
"""
Returns a dict of three pre-built val DataLoaders:
{"s0": loader, "s1": loader, "s2": loader}
Used in train.py to log all cross-stage val losses.
"""
configs = {
"s0": (0, 256, 5000),
"s1": (1, 384, 3000),
"s2": (2, 512, 2000),
}
loaders = {}
for key, (stage, seq_len, n_docs) in configs.items():
chunks = build_val_set(stage, tokenizer, seq_len, n_docs, cache_dir)
ds = ChunkedTokenDataset(chunks)
loaders[key] = DataLoader(ds, batch_size=16, shuffle=False, num_workers=2)
return loaders