Jackoatmon's picture
Update Feather a10g-large training runtime image
f8ad1c7 verified
"""
Dataset audit — diagnostic tool for HYDRA's pretraining corpus.
Usage:
python scripts/dataset_audit.py # Quick audit
python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts
python scripts/dataset_audit.py --full # Full tokenize of every shard (slow)
Reports:
- Shard count, total disk usage
- Estimated total tokens (character-based + tokenized sample)
- Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target
- Document diversity sample
- Warnings about shard ordering, shuffle, and streaming behavior
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
import pyarrow.parquet as pq
# Resolve repo root so the script works regardless of CWD.
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from prepare import ( # noqa: E402
DATA_DIR,
MAX_SHARD,
TOKENIZER_DIR,
VAL_FILENAME,
VAL_SHARD,
)
TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s
CHARS_PER_TOKEN_HEURISTIC = 4.0
def human_bytes(n: int) -> str:
for unit in ("B", "KB", "MB", "GB", "TB"):
if n < 1024:
return f"{n:.1f}{unit}"
n /= 1024
return f"{n:.1f}PB"
def human_tokens(n: int | float) -> str:
if n >= 1e9:
return f"{n / 1e9:.2f}B"
if n >= 1e6:
return f"{n / 1e6:.1f}M"
if n >= 1e3:
return f"{n / 1e3:.1f}K"
return f"{n:.0f}"
def list_shards() -> tuple[list[Path], Path | None]:
"""Return (train_shards_sorted, val_shard_or_none)."""
if not os.path.isdir(DATA_DIR):
return [], None
all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet"))
val_path = Path(DATA_DIR) / VAL_FILENAME
train = [p for p in all_paths if p.name != VAL_FILENAME]
val = val_path if val_path.exists() else None
return train, val
def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]:
"""Tokenize first N row groups of a shard. Returns (tokens, docs)."""
pf = pq.ParquetFile(shard_path)
tokens = 0
docs = 0
n = min(row_groups, pf.num_row_groups)
for i in range(n):
rg = pf.read_row_group(i)
texts = rg.column("text").to_pylist()
ids = enc.encode_ordinary_batch(texts, num_threads=8)
tokens += sum(len(x) for x in ids)
docs += len(texts)
return tokens, docs, pf.num_row_groups
def main() -> int:
parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus")
parser.add_argument(
"--sample",
type=int,
default=3,
help="Number of shards to tokenize for token-count estimate",
)
parser.add_argument(
"--full",
action="store_true",
help="Tokenize every shard (slow; gives exact total)",
)
args = parser.parse_args()
print("=" * 72)
print("HYDRA corpus audit")
print("=" * 72)
print(f"DATA_DIR: {DATA_DIR}")
print(f"TOKENIZER_DIR: {TOKENIZER_DIR}")
print(f"Source dataset: karpathy/climbmix-400b-shuffle")
print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})")
print()
train_shards, val_shard = list_shards()
if not train_shards:
print("ERROR: no parquet shards found. Run `python prepare.py` first.")
return 1
total_disk = sum(p.stat().st_size for p in train_shards)
val_disk = val_shard.stat().st_size if val_shard else 0
print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})")
print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})")
print(f"Disk (train): {human_bytes(total_disk)}")
print(f"Disk (val): {human_bytes(val_disk)}")
print()
# Character-based pass (fast): count total chars in all shards.
t0 = time.time()
total_chars = 0
total_docs = 0
total_row_groups = 0
for p in train_shards:
pf = pq.ParquetFile(p)
total_row_groups += pf.num_row_groups
total_docs += pf.metadata.num_rows
dt_meta = time.time() - t0
print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s")
print(f"Train documents: {total_docs:,}")
print(f"Row groups: {total_row_groups:,}")
print()
# Tokenizer-based sampling.
try:
import pickle
with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f:
enc = pickle.load(f)
print(f"Tokenizer vocab: {enc.n_vocab}")
except FileNotFoundError:
print("WARNING: tokenizer.pkl not found — skipping tokenized sample.")
enc = None
est_total_tokens = 0
if enc is not None:
if args.full:
sample_shards = train_shards
else:
# Pick shards evenly across the range for a representative sample.
n_sample = min(args.sample, len(train_shards))
if n_sample == 1:
sample_shards = [train_shards[0]]
else:
stride = max(1, len(train_shards) // n_sample)
sample_shards = train_shards[::stride][:n_sample]
t0 = time.time()
sample_tokens = 0
sample_docs = 0
sample_row_groups = 0
sample_shard_row_groups = 0
print(f"Tokenizing sample: {len(sample_shards)} shards ...")
for p in sample_shards:
tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5)
sample_tokens += tok
sample_docs += docs
sample_row_groups += min(5, n_rg)
sample_shard_row_groups += n_rg
dt_tok = time.time() - t0
tokens_per_rg = sample_tokens / max(sample_row_groups, 1)
per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards))
est_total_tokens = per_shard * len(train_shards)
print(
f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, "
f"{sample_tokens:,} tokens) in {dt_tok:.1f}s"
)
print(f" tokens/row_group: {tokens_per_rg:,.0f}")
print(f" tokens/shard: {per_shard:,.0f}")
print(f" tokens/shard: {human_tokens(per_shard)}")
else:
# Fall back to character heuristic.
per_shard_chars = total_disk / max(len(train_shards), 1)
# Parquet compression ratio ~3x for text; decompressed ~3 * file size.
# Chars per token heuristic ≈ 4.
est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC
print()
print("-" * 72)
print("Token budget analysis")
print("-" * 72)
print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} "
f"({est_total_tokens:,.0f})")
print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}")
ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0
if ratio >= 1.0:
print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})")
else:
print(f" Ratio: {ratio:.2f}x INSUFFICIENT — need {1 - ratio:.0%} more")
print()
# Warnings about the dataloader behavior.
print("-" * 72)
print("Dataloader behavior (prepare.py::_document_batches)")
print("-" * 72)
print("+ Infinite streaming: while True around shard list (no StopIteration)")
print("+ Streams per shard, never loads full corpus into RAM")
print("+ BOS-aligned best-fit packing gives document-level buffer shuffling")
print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch")
print("- Row groups / rows WITHIN a shard are read in fixed order")
print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)")
print()
# Quick content diversity peek.
if train_shards:
print("-" * 72)
print("Content sample (shard 0, first 3 docs)")
print("-" * 72)
pf = pq.ParquetFile(train_shards[0])
rg = pf.read_row_group(0)
texts = rg.column("text").to_pylist()
for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]):
if idx < len(texts):
snippet = texts[idx][:160].replace("\n", " ")
print(f" [{i}] len={len(texts[idx])}: {snippet!r}")
print()
print("=" * 72)
print("Done.")
return 0
if __name__ == "__main__":
raise SystemExit(main())