Tiny-GPT / prepare_data.py
pragadeeshv23's picture
Upload folder using huggingface_hub
ffc0c0c verified
#!/usr/bin/env python3
"""
prepare_data.py
===============
Build tokenized binary files for training from Cosmopedia using streaming.
Outputs:
data/train.bin
data/val.bin
data/test.bin
Dataset:
HuggingFaceTB/cosmopedia (streaming)
"""
import os
from pathlib import Path
import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm.auto import tqdm
# Local project cache for reproducibility and resume behavior.
os.environ.setdefault("HF_HOME", "./hf_cache")
os.environ.setdefault("HF_DATASETS_CACHE", "./hf_cache/datasets")
os.environ.setdefault("HF_HUB_CACHE", "./hf_cache/hub")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
DATA_DIR = Path("data")
DATA_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR = "./hf_cache"
DATASET_NAME = "HuggingFaceTB/cosmopedia"
DATASET_CONFIG = os.environ.get("DATASET_CONFIG", "stories")
# Stream only first N rows by default, matching your requested pattern.
MAX_EXAMPLES = int(os.environ.get("MAX_EXAMPLES", "1000000"))
# Deterministic split from one stream: 98% train, 1% val, 1% test.
TRAIN_FRAC = float(os.environ.get("TRAIN_FRAC", "0.98"))
VAL_FRAC = float(os.environ.get("VAL_FRAC", "0.01"))
# Flush chunks to disk to keep RAM bounded.
FLUSH_TOKENS = int(os.environ.get("FLUSH_TOKENS", "2000000"))
enc = tiktoken.get_encoding("gpt2")
EOT = enc.eot_token
def extract_text(row: dict) -> str:
"""Extract a usable text field across possible Cosmopedia schemas."""
if "text" in row and isinstance(row["text"], str):
return row["text"].strip()
if "content" in row and isinstance(row["content"], str):
return row["content"].strip()
parts = []
for key in ("prompt", "question", "instruction", "input", "answer", "response", "output"):
val = row.get(key)
if isinstance(val, str) and val.strip():
parts.append(val.strip())
return "\n\n".join(parts).strip()
def encode_text(text: str):
ids = enc.encode_ordinary(text)
ids.append(EOT)
return ids
def flush_tokens(fp, buffer_tokens):
if not buffer_tokens:
return 0
arr = np.asarray(buffer_tokens, dtype=np.uint16)
arr.tofile(fp)
n = int(arr.size)
buffer_tokens.clear()
return n
def pick_split(i: int, total: int) -> str:
train_cut = int(total * TRAIN_FRAC)
val_cut = train_cut + int(total * VAL_FRAC)
if i < train_cut:
return "train"
if i < val_cut:
return "val"
return "test"
if __name__ == "__main__":
print("Loading Cosmopedia (streaming)...")
# This follows your requested style while allowing MAX_EXAMPLES override.
dataset = load_dataset(
DATASET_NAME,
DATASET_CONFIG,
split="train",
streaming=True,
cache_dir=CACHE_DIR,
)
out_paths = {
"train": DATA_DIR / "train.bin",
"val": DATA_DIR / "val.bin",
"test": DATA_DIR / "test.bin",
}
for p in out_paths.values():
if p.exists():
p.unlink()
buffers = {"train": [], "val": [], "test": []}
counts_examples = {"train": 0, "val": 0, "test": 0}
counts_tokens = {"train": 0, "val": 0, "test": 0}
with open(out_paths["train"], "ab") as f_train, open(out_paths["val"], "ab") as f_val, open(out_paths["test"], "ab") as f_test:
fps = {"train": f_train, "val": f_val, "test": f_test}
progress = tqdm(total=MAX_EXAMPLES, desc="Streaming+Encoding", unit="doc")
for i, row in enumerate(dataset):
if i >= MAX_EXAMPLES:
break
text = extract_text(row)
if not text:
progress.update(1)
continue
split = pick_split(i, MAX_EXAMPLES)
toks = encode_text(text)
buffers[split].extend(toks)
counts_examples[split] += 1
# Flush all splits together so val/test are written even if their
# individual buffers never reach FLUSH_TOKENS (they're only 1% each).
if len(buffers["train"]) >= FLUSH_TOKENS:
for s in ("train", "val", "test"):
counts_tokens[s] += flush_tokens(fps[s], buffers[s])
progress.update(1)
progress.close()
for split in ("train", "val", "test"):
counts_tokens[split] += flush_tokens(fps[split], buffers[split])
print("\nDone.")
for split in ("train", "val", "test"):
print(f"{split:>5}: {counts_examples[split]:>10,} docs -> {counts_tokens[split]:>12,} tokens")
print(f"Saved files in: {DATA_DIR.resolve()}")