Spaces:
Runtime error
Runtime error
| """ | |
| Dataset audit — diagnostic tool for HYDRA's pretraining corpus. | |
| Usage: | |
| python scripts/dataset_audit.py # Quick audit | |
| python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts | |
| python scripts/dataset_audit.py --full # Full tokenize of every shard (slow) | |
| Reports: | |
| - Shard count, total disk usage | |
| - Estimated total tokens (character-based + tokenized sample) | |
| - Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target | |
| - Document diversity sample | |
| - Warnings about shard ordering, shuffle, and streaming behavior | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import pyarrow.parquet as pq | |
| # Resolve repo root so the script works regardless of CWD. | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from prepare import ( # noqa: E402 | |
| DATA_DIR, | |
| MAX_SHARD, | |
| TOKENIZER_DIR, | |
| VAL_FILENAME, | |
| VAL_SHARD, | |
| ) | |
| TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s | |
| CHARS_PER_TOKEN_HEURISTIC = 4.0 | |
| def human_bytes(n: int) -> str: | |
| for unit in ("B", "KB", "MB", "GB", "TB"): | |
| if n < 1024: | |
| return f"{n:.1f}{unit}" | |
| n /= 1024 | |
| return f"{n:.1f}PB" | |
| def human_tokens(n: int | float) -> str: | |
| if n >= 1e9: | |
| return f"{n / 1e9:.2f}B" | |
| if n >= 1e6: | |
| return f"{n / 1e6:.1f}M" | |
| if n >= 1e3: | |
| return f"{n / 1e3:.1f}K" | |
| return f"{n:.0f}" | |
| def list_shards() -> tuple[list[Path], Path | None]: | |
| """Return (train_shards_sorted, val_shard_or_none).""" | |
| if not os.path.isdir(DATA_DIR): | |
| return [], None | |
| all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet")) | |
| val_path = Path(DATA_DIR) / VAL_FILENAME | |
| train = [p for p in all_paths if p.name != VAL_FILENAME] | |
| val = val_path if val_path.exists() else None | |
| return train, val | |
| def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]: | |
| """Tokenize first N row groups of a shard. Returns (tokens, docs).""" | |
| pf = pq.ParquetFile(shard_path) | |
| tokens = 0 | |
| docs = 0 | |
| n = min(row_groups, pf.num_row_groups) | |
| for i in range(n): | |
| rg = pf.read_row_group(i) | |
| texts = rg.column("text").to_pylist() | |
| ids = enc.encode_ordinary_batch(texts, num_threads=8) | |
| tokens += sum(len(x) for x in ids) | |
| docs += len(texts) | |
| return tokens, docs, pf.num_row_groups | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus") | |
| parser.add_argument( | |
| "--sample", | |
| type=int, | |
| default=3, | |
| help="Number of shards to tokenize for token-count estimate", | |
| ) | |
| parser.add_argument( | |
| "--full", | |
| action="store_true", | |
| help="Tokenize every shard (slow; gives exact total)", | |
| ) | |
| args = parser.parse_args() | |
| print("=" * 72) | |
| print("HYDRA corpus audit") | |
| print("=" * 72) | |
| print(f"DATA_DIR: {DATA_DIR}") | |
| print(f"TOKENIZER_DIR: {TOKENIZER_DIR}") | |
| print(f"Source dataset: karpathy/climbmix-400b-shuffle") | |
| print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})") | |
| print() | |
| train_shards, val_shard = list_shards() | |
| if not train_shards: | |
| print("ERROR: no parquet shards found. Run `python prepare.py` first.") | |
| return 1 | |
| total_disk = sum(p.stat().st_size for p in train_shards) | |
| val_disk = val_shard.stat().st_size if val_shard else 0 | |
| print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})") | |
| print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})") | |
| print(f"Disk (train): {human_bytes(total_disk)}") | |
| print(f"Disk (val): {human_bytes(val_disk)}") | |
| print() | |
| # Character-based pass (fast): count total chars in all shards. | |
| t0 = time.time() | |
| total_chars = 0 | |
| total_docs = 0 | |
| total_row_groups = 0 | |
| for p in train_shards: | |
| pf = pq.ParquetFile(p) | |
| total_row_groups += pf.num_row_groups | |
| total_docs += pf.metadata.num_rows | |
| dt_meta = time.time() - t0 | |
| print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s") | |
| print(f"Train documents: {total_docs:,}") | |
| print(f"Row groups: {total_row_groups:,}") | |
| print() | |
| # Tokenizer-based sampling. | |
| try: | |
| import pickle | |
| with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f: | |
| enc = pickle.load(f) | |
| print(f"Tokenizer vocab: {enc.n_vocab}") | |
| except FileNotFoundError: | |
| print("WARNING: tokenizer.pkl not found — skipping tokenized sample.") | |
| enc = None | |
| est_total_tokens = 0 | |
| if enc is not None: | |
| if args.full: | |
| sample_shards = train_shards | |
| else: | |
| # Pick shards evenly across the range for a representative sample. | |
| n_sample = min(args.sample, len(train_shards)) | |
| if n_sample == 1: | |
| sample_shards = [train_shards[0]] | |
| else: | |
| stride = max(1, len(train_shards) // n_sample) | |
| sample_shards = train_shards[::stride][:n_sample] | |
| t0 = time.time() | |
| sample_tokens = 0 | |
| sample_docs = 0 | |
| sample_row_groups = 0 | |
| sample_shard_row_groups = 0 | |
| print(f"Tokenizing sample: {len(sample_shards)} shards ...") | |
| for p in sample_shards: | |
| tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5) | |
| sample_tokens += tok | |
| sample_docs += docs | |
| sample_row_groups += min(5, n_rg) | |
| sample_shard_row_groups += n_rg | |
| dt_tok = time.time() - t0 | |
| tokens_per_rg = sample_tokens / max(sample_row_groups, 1) | |
| per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards)) | |
| est_total_tokens = per_shard * len(train_shards) | |
| print( | |
| f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, " | |
| f"{sample_tokens:,} tokens) in {dt_tok:.1f}s" | |
| ) | |
| print(f" tokens/row_group: {tokens_per_rg:,.0f}") | |
| print(f" tokens/shard: {per_shard:,.0f}") | |
| print(f" tokens/shard: {human_tokens(per_shard)}") | |
| else: | |
| # Fall back to character heuristic. | |
| per_shard_chars = total_disk / max(len(train_shards), 1) | |
| # Parquet compression ratio ~3x for text; decompressed ~3 * file size. | |
| # Chars per token heuristic ≈ 4. | |
| est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC | |
| print() | |
| print("-" * 72) | |
| print("Token budget analysis") | |
| print("-" * 72) | |
| print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} " | |
| f"({est_total_tokens:,.0f})") | |
| print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}") | |
| ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0 | |
| if ratio >= 1.0: | |
| print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})") | |
| else: | |
| print(f" Ratio: {ratio:.2f}x INSUFFICIENT — need {1 - ratio:.0%} more") | |
| print() | |
| # Warnings about the dataloader behavior. | |
| print("-" * 72) | |
| print("Dataloader behavior (prepare.py::_document_batches)") | |
| print("-" * 72) | |
| print("+ Infinite streaming: while True around shard list (no StopIteration)") | |
| print("+ Streams per shard, never loads full corpus into RAM") | |
| print("+ BOS-aligned best-fit packing gives document-level buffer shuffling") | |
| print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch") | |
| print("- Row groups / rows WITHIN a shard are read in fixed order") | |
| print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)") | |
| print() | |
| # Quick content diversity peek. | |
| if train_shards: | |
| print("-" * 72) | |
| print("Content sample (shard 0, first 3 docs)") | |
| print("-" * 72) | |
| pf = pq.ParquetFile(train_shards[0]) | |
| rg = pf.read_row_group(0) | |
| texts = rg.column("text").to_pylist() | |
| for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]): | |
| if idx < len(texts): | |
| snippet = texts[idx][:160].replace("\n", " ") | |
| print(f" [{i}] len={len(texts[idx])}: {snippet!r}") | |
| print() | |
| print("=" * 72) | |
| print("Done.") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |