File size: 8,373 Bytes
22741d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | """
Dataset audit — diagnostic tool for HYDRA's pretraining corpus.
Usage:
python scripts/dataset_audit.py # Quick audit
python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts
python scripts/dataset_audit.py --full # Full tokenize of every shard (slow)
Reports:
- Shard count, total disk usage
- Estimated total tokens (character-based + tokenized sample)
- Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target
- Document diversity sample
- Warnings about shard ordering, shuffle, and streaming behavior
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
import pyarrow.parquet as pq
# Resolve repo root so the script works regardless of CWD.
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from prepare import ( # noqa: E402
DATA_DIR,
MAX_SHARD,
TOKENIZER_DIR,
VAL_FILENAME,
VAL_SHARD,
)
TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s
CHARS_PER_TOKEN_HEURISTIC = 4.0
def human_bytes(n: int) -> str:
for unit in ("B", "KB", "MB", "GB", "TB"):
if n < 1024:
return f"{n:.1f}{unit}"
n /= 1024
return f"{n:.1f}PB"
def human_tokens(n: int | float) -> str:
if n >= 1e9:
return f"{n / 1e9:.2f}B"
if n >= 1e6:
return f"{n / 1e6:.1f}M"
if n >= 1e3:
return f"{n / 1e3:.1f}K"
return f"{n:.0f}"
def list_shards() -> tuple[list[Path], Path | None]:
"""Return (train_shards_sorted, val_shard_or_none)."""
if not os.path.isdir(DATA_DIR):
return [], None
all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet"))
val_path = Path(DATA_DIR) / VAL_FILENAME
train = [p for p in all_paths if p.name != VAL_FILENAME]
val = val_path if val_path.exists() else None
return train, val
def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]:
"""Tokenize first N row groups of a shard. Returns (tokens, docs)."""
pf = pq.ParquetFile(shard_path)
tokens = 0
docs = 0
n = min(row_groups, pf.num_row_groups)
for i in range(n):
rg = pf.read_row_group(i)
texts = rg.column("text").to_pylist()
ids = enc.encode_ordinary_batch(texts, num_threads=8)
tokens += sum(len(x) for x in ids)
docs += len(texts)
return tokens, docs, pf.num_row_groups
def main() -> int:
parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus")
parser.add_argument(
"--sample",
type=int,
default=3,
help="Number of shards to tokenize for token-count estimate",
)
parser.add_argument(
"--full",
action="store_true",
help="Tokenize every shard (slow; gives exact total)",
)
args = parser.parse_args()
print("=" * 72)
print("HYDRA corpus audit")
print("=" * 72)
print(f"DATA_DIR: {DATA_DIR}")
print(f"TOKENIZER_DIR: {TOKENIZER_DIR}")
print(f"Source dataset: karpathy/climbmix-400b-shuffle")
print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})")
print()
train_shards, val_shard = list_shards()
if not train_shards:
print("ERROR: no parquet shards found. Run `python prepare.py` first.")
return 1
total_disk = sum(p.stat().st_size for p in train_shards)
val_disk = val_shard.stat().st_size if val_shard else 0
print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})")
print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})")
print(f"Disk (train): {human_bytes(total_disk)}")
print(f"Disk (val): {human_bytes(val_disk)}")
print()
# Character-based pass (fast): count total chars in all shards.
t0 = time.time()
total_chars = 0
total_docs = 0
total_row_groups = 0
for p in train_shards:
pf = pq.ParquetFile(p)
total_row_groups += pf.num_row_groups
total_docs += pf.metadata.num_rows
dt_meta = time.time() - t0
print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s")
print(f"Train documents: {total_docs:,}")
print(f"Row groups: {total_row_groups:,}")
print()
# Tokenizer-based sampling.
try:
import pickle
with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f:
enc = pickle.load(f)
print(f"Tokenizer vocab: {enc.n_vocab}")
except FileNotFoundError:
print("WARNING: tokenizer.pkl not found — skipping tokenized sample.")
enc = None
est_total_tokens = 0
if enc is not None:
if args.full:
sample_shards = train_shards
else:
# Pick shards evenly across the range for a representative sample.
n_sample = min(args.sample, len(train_shards))
if n_sample == 1:
sample_shards = [train_shards[0]]
else:
stride = max(1, len(train_shards) // n_sample)
sample_shards = train_shards[::stride][:n_sample]
t0 = time.time()
sample_tokens = 0
sample_docs = 0
sample_row_groups = 0
sample_shard_row_groups = 0
print(f"Tokenizing sample: {len(sample_shards)} shards ...")
for p in sample_shards:
tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5)
sample_tokens += tok
sample_docs += docs
sample_row_groups += min(5, n_rg)
sample_shard_row_groups += n_rg
dt_tok = time.time() - t0
tokens_per_rg = sample_tokens / max(sample_row_groups, 1)
per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards))
est_total_tokens = per_shard * len(train_shards)
print(
f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, "
f"{sample_tokens:,} tokens) in {dt_tok:.1f}s"
)
print(f" tokens/row_group: {tokens_per_rg:,.0f}")
print(f" tokens/shard: {per_shard:,.0f}")
print(f" tokens/shard: {human_tokens(per_shard)}")
else:
# Fall back to character heuristic.
per_shard_chars = total_disk / max(len(train_shards), 1)
# Parquet compression ratio ~3x for text; decompressed ~3 * file size.
# Chars per token heuristic ≈ 4.
est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC
print()
print("-" * 72)
print("Token budget analysis")
print("-" * 72)
print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} "
f"({est_total_tokens:,.0f})")
print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}")
ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0
if ratio >= 1.0:
print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})")
else:
print(f" Ratio: {ratio:.2f}x INSUFFICIENT — need {1 - ratio:.0%} more")
print()
# Warnings about the dataloader behavior.
print("-" * 72)
print("Dataloader behavior (prepare.py::_document_batches)")
print("-" * 72)
print("+ Infinite streaming: while True around shard list (no StopIteration)")
print("+ Streams per shard, never loads full corpus into RAM")
print("+ BOS-aligned best-fit packing gives document-level buffer shuffling")
print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch")
print("- Row groups / rows WITHIN a shard are read in fixed order")
print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)")
print()
# Quick content diversity peek.
if train_shards:
print("-" * 72)
print("Content sample (shard 0, first 3 docs)")
print("-" * 72)
pf = pq.ParquetFile(train_shards[0])
rg = pf.read_row_group(0)
texts = rg.column("text").to_pylist()
for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]):
if idx < len(texts):
snippet = texts[idx][:160].replace("\n", " ")
print(f" [{i}] len={len(texts[idx])}: {snippet!r}")
print()
print("=" * 72)
print("Done.")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|