""" Dataset audit — diagnostic tool for HYDRA's pretraining corpus. Usage: python scripts/dataset_audit.py # Quick audit python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts python scripts/dataset_audit.py --full # Full tokenize of every shard (slow) Reports: - Shard count, total disk usage - Estimated total tokens (character-based + tokenized sample) - Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target - Document diversity sample - Warnings about shard ordering, shuffle, and streaming behavior """ from __future__ import annotations import argparse import os import sys import time from pathlib import Path import pyarrow.parquet as pq # Resolve repo root so the script works regardless of CWD. REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT)) from prepare import ( # noqa: E402 DATA_DIR, MAX_SHARD, TOKENIZER_DIR, VAL_FILENAME, VAL_SHARD, ) TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s CHARS_PER_TOKEN_HEURISTIC = 4.0 def human_bytes(n: int) -> str: for unit in ("B", "KB", "MB", "GB", "TB"): if n < 1024: return f"{n:.1f}{unit}" n /= 1024 return f"{n:.1f}PB" def human_tokens(n: int | float) -> str: if n >= 1e9: return f"{n / 1e9:.2f}B" if n >= 1e6: return f"{n / 1e6:.1f}M" if n >= 1e3: return f"{n / 1e3:.1f}K" return f"{n:.0f}" def list_shards() -> tuple[list[Path], Path | None]: """Return (train_shards_sorted, val_shard_or_none).""" if not os.path.isdir(DATA_DIR): return [], None all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet")) val_path = Path(DATA_DIR) / VAL_FILENAME train = [p for p in all_paths if p.name != VAL_FILENAME] val = val_path if val_path.exists() else None return train, val def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]: """Tokenize first N row groups of a shard. Returns (tokens, docs).""" pf = pq.ParquetFile(shard_path) tokens = 0 docs = 0 n = min(row_groups, pf.num_row_groups) for i in range(n): rg = pf.read_row_group(i) texts = rg.column("text").to_pylist() ids = enc.encode_ordinary_batch(texts, num_threads=8) tokens += sum(len(x) for x in ids) docs += len(texts) return tokens, docs, pf.num_row_groups def main() -> int: parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus") parser.add_argument( "--sample", type=int, default=3, help="Number of shards to tokenize for token-count estimate", ) parser.add_argument( "--full", action="store_true", help="Tokenize every shard (slow; gives exact total)", ) args = parser.parse_args() print("=" * 72) print("HYDRA corpus audit") print("=" * 72) print(f"DATA_DIR: {DATA_DIR}") print(f"TOKENIZER_DIR: {TOKENIZER_DIR}") print(f"Source dataset: karpathy/climbmix-400b-shuffle") print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})") print() train_shards, val_shard = list_shards() if not train_shards: print("ERROR: no parquet shards found. Run `python prepare.py` first.") return 1 total_disk = sum(p.stat().st_size for p in train_shards) val_disk = val_shard.stat().st_size if val_shard else 0 print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})") print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})") print(f"Disk (train): {human_bytes(total_disk)}") print(f"Disk (val): {human_bytes(val_disk)}") print() # Character-based pass (fast): count total chars in all shards. t0 = time.time() total_chars = 0 total_docs = 0 total_row_groups = 0 for p in train_shards: pf = pq.ParquetFile(p) total_row_groups += pf.num_row_groups total_docs += pf.metadata.num_rows dt_meta = time.time() - t0 print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s") print(f"Train documents: {total_docs:,}") print(f"Row groups: {total_row_groups:,}") print() # Tokenizer-based sampling. try: import pickle with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f: enc = pickle.load(f) print(f"Tokenizer vocab: {enc.n_vocab}") except FileNotFoundError: print("WARNING: tokenizer.pkl not found — skipping tokenized sample.") enc = None est_total_tokens = 0 if enc is not None: if args.full: sample_shards = train_shards else: # Pick shards evenly across the range for a representative sample. n_sample = min(args.sample, len(train_shards)) if n_sample == 1: sample_shards = [train_shards[0]] else: stride = max(1, len(train_shards) // n_sample) sample_shards = train_shards[::stride][:n_sample] t0 = time.time() sample_tokens = 0 sample_docs = 0 sample_row_groups = 0 sample_shard_row_groups = 0 print(f"Tokenizing sample: {len(sample_shards)} shards ...") for p in sample_shards: tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5) sample_tokens += tok sample_docs += docs sample_row_groups += min(5, n_rg) sample_shard_row_groups += n_rg dt_tok = time.time() - t0 tokens_per_rg = sample_tokens / max(sample_row_groups, 1) per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards)) est_total_tokens = per_shard * len(train_shards) print( f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, " f"{sample_tokens:,} tokens) in {dt_tok:.1f}s" ) print(f" tokens/row_group: {tokens_per_rg:,.0f}") print(f" tokens/shard: {per_shard:,.0f}") print(f" tokens/shard: {human_tokens(per_shard)}") else: # Fall back to character heuristic. per_shard_chars = total_disk / max(len(train_shards), 1) # Parquet compression ratio ~3x for text; decompressed ~3 * file size. # Chars per token heuristic ≈ 4. est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC print() print("-" * 72) print("Token budget analysis") print("-" * 72) print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} " f"({est_total_tokens:,.0f})") print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}") ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0 if ratio >= 1.0: print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})") else: print(f" Ratio: {ratio:.2f}x INSUFFICIENT — need {1 - ratio:.0%} more") print() # Warnings about the dataloader behavior. print("-" * 72) print("Dataloader behavior (prepare.py::_document_batches)") print("-" * 72) print("+ Infinite streaming: while True around shard list (no StopIteration)") print("+ Streams per shard, never loads full corpus into RAM") print("+ BOS-aligned best-fit packing gives document-level buffer shuffling") print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch") print("- Row groups / rows WITHIN a shard are read in fixed order") print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)") print() # Quick content diversity peek. if train_shards: print("-" * 72) print("Content sample (shard 0, first 3 docs)") print("-" * 72) pf = pq.ParquetFile(train_shards[0]) rg = pf.read_row_group(0) texts = rg.column("text").to_pylist() for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]): if idx < len(texts): snippet = texts[idx][:160].replace("\n", " ") print(f" [{i}] len={len(texts[idx])}: {snippet!r}") print() print("=" * 72) print("Done.") return 0 if __name__ == "__main__": raise SystemExit(main())