File size: 8,373 Bytes
22741d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""
Dataset audit — diagnostic tool for HYDRA's pretraining corpus.

Usage:
    python scripts/dataset_audit.py              # Quick audit
    python scripts/dataset_audit.py --sample 10  # Sample 10 shards for token counts
    python scripts/dataset_audit.py --full       # Full tokenize of every shard (slow)

Reports:
- Shard count, total disk usage
- Estimated total tokens (character-based + tokenized sample)
- Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target
- Document diversity sample
- Warnings about shard ordering, shuffle, and streaming behavior
"""
from __future__ import annotations

import argparse
import os
import sys
import time
from pathlib import Path

import pyarrow.parquet as pq

# Resolve repo root so the script works regardless of CWD.
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))

from prepare import (  # noqa: E402
    DATA_DIR,
    MAX_SHARD,
    TOKENIZER_DIR,
    VAL_FILENAME,
    VAL_SHARD,
)

TARGET_TOKENS_12H = 2_800_000_000  # 65k tok/s * 12h * 3600s
CHARS_PER_TOKEN_HEURISTIC = 4.0


def human_bytes(n: int) -> str:
    for unit in ("B", "KB", "MB", "GB", "TB"):
        if n < 1024:
            return f"{n:.1f}{unit}"
        n /= 1024
    return f"{n:.1f}PB"


def human_tokens(n: int | float) -> str:
    if n >= 1e9:
        return f"{n / 1e9:.2f}B"
    if n >= 1e6:
        return f"{n / 1e6:.1f}M"
    if n >= 1e3:
        return f"{n / 1e3:.1f}K"
    return f"{n:.0f}"


def list_shards() -> tuple[list[Path], Path | None]:
    """Return (train_shards_sorted, val_shard_or_none)."""
    if not os.path.isdir(DATA_DIR):
        return [], None
    all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet"))
    val_path = Path(DATA_DIR) / VAL_FILENAME
    train = [p for p in all_paths if p.name != VAL_FILENAME]
    val = val_path if val_path.exists() else None
    return train, val


def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]:
    """Tokenize first N row groups of a shard. Returns (tokens, docs)."""
    pf = pq.ParquetFile(shard_path)
    tokens = 0
    docs = 0
    n = min(row_groups, pf.num_row_groups)
    for i in range(n):
        rg = pf.read_row_group(i)
        texts = rg.column("text").to_pylist()
        ids = enc.encode_ordinary_batch(texts, num_threads=8)
        tokens += sum(len(x) for x in ids)
        docs += len(texts)
    return tokens, docs, pf.num_row_groups


def main() -> int:
    parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus")
    parser.add_argument(
        "--sample",
        type=int,
        default=3,
        help="Number of shards to tokenize for token-count estimate",
    )
    parser.add_argument(
        "--full",
        action="store_true",
        help="Tokenize every shard (slow; gives exact total)",
    )
    args = parser.parse_args()

    print("=" * 72)
    print("HYDRA corpus audit")
    print("=" * 72)
    print(f"DATA_DIR:        {DATA_DIR}")
    print(f"TOKENIZER_DIR:   {TOKENIZER_DIR}")
    print(f"Source dataset:  karpathy/climbmix-400b-shuffle")
    print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})")
    print()

    train_shards, val_shard = list_shards()
    if not train_shards:
        print("ERROR: no parquet shards found. Run `python prepare.py` first.")
        return 1

    total_disk = sum(p.stat().st_size for p in train_shards)
    val_disk = val_shard.stat().st_size if val_shard else 0

    print(f"Train shards:    {len(train_shards)}  ({train_shards[0].name} ... {train_shards[-1].name})")
    print(f"Val shard:       {'present' if val_shard else 'MISSING'}  ({VAL_FILENAME})")
    print(f"Disk (train):    {human_bytes(total_disk)}")
    print(f"Disk (val):      {human_bytes(val_disk)}")
    print()

    # Character-based pass (fast): count total chars in all shards.
    t0 = time.time()
    total_chars = 0
    total_docs = 0
    total_row_groups = 0
    for p in train_shards:
        pf = pq.ParquetFile(p)
        total_row_groups += pf.num_row_groups
        total_docs += pf.metadata.num_rows
    dt_meta = time.time() - t0
    print(f"Metadata scan:   {len(train_shards)} shards in {dt_meta:.1f}s")
    print(f"Train documents: {total_docs:,}")
    print(f"Row groups:      {total_row_groups:,}")
    print()

    # Tokenizer-based sampling.
    try:
        import pickle

        with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f:
            enc = pickle.load(f)
        print(f"Tokenizer vocab: {enc.n_vocab}")
    except FileNotFoundError:
        print("WARNING: tokenizer.pkl not found — skipping tokenized sample.")
        enc = None

    est_total_tokens = 0
    if enc is not None:
        if args.full:
            sample_shards = train_shards
        else:
            # Pick shards evenly across the range for a representative sample.
            n_sample = min(args.sample, len(train_shards))
            if n_sample == 1:
                sample_shards = [train_shards[0]]
            else:
                stride = max(1, len(train_shards) // n_sample)
                sample_shards = train_shards[::stride][:n_sample]

        t0 = time.time()
        sample_tokens = 0
        sample_docs = 0
        sample_row_groups = 0
        sample_shard_row_groups = 0
        print(f"Tokenizing sample: {len(sample_shards)} shards ...")
        for p in sample_shards:
            tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5)
            sample_tokens += tok
            sample_docs += docs
            sample_row_groups += min(5, n_rg)
            sample_shard_row_groups += n_rg
        dt_tok = time.time() - t0

        tokens_per_rg = sample_tokens / max(sample_row_groups, 1)
        per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards))
        est_total_tokens = per_shard * len(train_shards)

        print(
            f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, "
            f"{sample_tokens:,} tokens) in {dt_tok:.1f}s"
        )
        print(f"  tokens/row_group: {tokens_per_rg:,.0f}")
        print(f"  tokens/shard:     {per_shard:,.0f}")
        print(f"  tokens/shard:     {human_tokens(per_shard)}")
    else:
        # Fall back to character heuristic.
        per_shard_chars = total_disk / max(len(train_shards), 1)
        # Parquet compression ratio ~3x for text; decompressed ~3 * file size.
        # Chars per token heuristic ≈ 4.
        est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC

    print()
    print("-" * 72)
    print("Token budget analysis")
    print("-" * 72)
    print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} "
          f"({est_total_tokens:,.0f})")
    print(f"12h @ 65k tok/s target:       {human_tokens(TARGET_TOKENS_12H)}")
    ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0
    if ratio >= 1.0:
        print(f"  Ratio: {ratio:.1f}x  ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})")
    else:
        print(f"  Ratio: {ratio:.2f}x  INSUFFICIENT — need {1 - ratio:.0%} more")
    print()

    # Warnings about the dataloader behavior.
    print("-" * 72)
    print("Dataloader behavior (prepare.py::_document_batches)")
    print("-" * 72)
    print("+ Infinite streaming: while True around shard list (no StopIteration)")
    print("+ Streams per shard, never loads full corpus into RAM")
    print("+ BOS-aligned best-fit packing gives document-level buffer shuffling")
    print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch")
    print("- Row groups / rows WITHIN a shard are read in fixed order")
    print("  (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)")
    print()

    # Quick content diversity peek.
    if train_shards:
        print("-" * 72)
        print("Content sample (shard 0, first 3 docs)")
        print("-" * 72)
        pf = pq.ParquetFile(train_shards[0])
        rg = pf.read_row_group(0)
        texts = rg.column("text").to_pylist()
        for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]):
            if idx < len(texts):
                snippet = texts[idx][:160].replace("\n", " ")
                print(f"  [{i}] len={len(texts[idx])}: {snippet!r}")
        print()

    print("=" * 72)
    print("Done.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())