"""Fast parallel token cache builder. Reads parquet shards DIRECTLY via pyarrow (no HF streaming overhead), tokenizes with multiprocessing.Pool, writes packed (T+1) int32 rows. Uses the pre-downloaded shards in ~/.cache/huggingface/hub/ — no network. Usage: python scripts/build_token_cache.py [--gb 2] [--workers 8] """ from __future__ import annotations import argparse import glob import os import sys import time from pathlib import Path from multiprocessing import Pool sys.stdout.reconfigure(line_buffering=True) import numpy as np import pyarrow.parquet as pq sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from prepare import Tokenizer HF_HUB_CACHE = os.path.expanduser("~/.cache/huggingface/hub") # Which column each dataset uses for text TEXT_COLS: dict[str, list[str]] = { "fineweb-edu": ["text"], "fineweb": ["text"], "stack-v2": ["text", "content"], "nemotron-math": ["text"], "nemotron-specialized": ["text"], "wikipedia": ["text"], "cosmopedia": ["text"], } # Dataset repo → cache dir mapping REPO_DIRS = { "fineweb-edu": "datasets--HuggingFaceFW--fineweb-edu", "fineweb": "datasets--HuggingFaceFW--fineweb", "stack-v2": "datasets--OpenCoder-LLM--opc-fineweb-code-corpus", "nemotron-math": "datasets--nvidia--Nemotron-CC-Math-v1", "nemotron-specialized": "datasets--nvidia--Nemotron-Pretraining-Specialized-v1.1", "wikipedia": "datasets--wikimedia--wikipedia", "cosmopedia": "datasets--HuggingFaceTB--cosmopedia", } def find_parquet_files() -> list[tuple[str, str]]: """Return [(dataset_name, parquet_path), ...] for all cached shards.""" results = [] for name, dirname in REPO_DIRS.items(): base = os.path.join(HF_HUB_CACHE, dirname, "snapshots") if not os.path.isdir(base): continue for snap in os.listdir(base): snap_dir = os.path.join(base, snap) for root, _, files in os.walk(snap_dir): for f in files: if f.endswith(".parquet"): results.append((name, os.path.join(root, f))) return results # Tokenizer loaded once per worker process _WORKER_TOKENIZER = None _WORKER_BOS = None def _worker_init(): global _WORKER_TOKENIZER, _WORKER_BOS _WORKER_TOKENIZER = Tokenizer.from_directory() _WORKER_BOS = _WORKER_TOKENIZER.get_bos_token_id() def _tokenize_batch(args: tuple[list[str], int]) -> list[list[int]]: """Tokenize a batch of text strings. Returns list of token-id lists.""" texts, _ = args return _WORKER_TOKENIZER.encode(texts, prepend=_WORKER_BOS) def iter_text_from_parquet(name: str, path: str, batch_size: int = 512): """Stream text batches from one parquet file.""" cols = TEXT_COLS.get(name, ["text"]) try: pf = pq.ParquetFile(path) except Exception as e: print(f" [skip] {path}: {e}", flush=True) return # Find which column exists schema_names = set(pf.schema_arrow.names) col = next((c for c in cols if c in schema_names), None) if col is None: return for batch in pf.iter_batches(batch_size=batch_size, columns=[col]): texts = batch.column(col).to_pylist() texts = [t for t in texts if t] if texts: yield texts def pack_rows(token_lists: list[list[int]], row_capacity: int) -> np.ndarray: """Pack variable-length token sequences into (N, row_capacity) rows using simple greedy concat.""" rows = [] current = [] for doc in token_lists: if len(current) + len(doc) > row_capacity: # Flush current row (pad with 0) if len(current) >= row_capacity // 2: # skip too-short trailing bits row = current[:row_capacity] if len(row) < row_capacity: row = row + [0] * (row_capacity - len(row)) rows.append(row) # Start new row with this doc (truncate if too long) current = doc[:row_capacity] else: current.extend(doc) # Emit full rows as we fill up while len(current) >= row_capacity: rows.append(current[:row_capacity]) current = current[row_capacity:] if not rows: return np.empty((0, row_capacity), dtype=np.int32) return np.asarray(rows, dtype=np.int32) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--gb", type=float, default=2.0) ap.add_argument("--seq-len", type=int, default=512) ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2)) ap.add_argument("--batch-size", type=int, default=512, help="docs per tokenizer call") args = ap.parse_args() T = args.seq_len row_capacity = T + 1 target_bytes = int(args.gb * 1024**3) target_rows = target_bytes // (row_capacity * 4) # Load tokenizer in main process for vocab size tok = Tokenizer.from_directory() V = tok.get_vocab_size() cache_path = os.path.expanduser( f"~/.cache/autoresearch/packed_tokens_v1_T{T}_V{V}_train.bin" ) tmp_path = cache_path + ".tmp" print(f"[cache-build] target: {args.gb:.1f} GB = {target_rows} rows of (T+1)={row_capacity} int32", flush=True) print(f"[cache-build] workers: {args.workers}", flush=True) parquet_files = find_parquet_files() print(f"[cache-build] found {len(parquet_files)} parquet shards", flush=True) for name, path in parquet_files: sz = os.path.getsize(path) / 1024**2 print(f" [{name}] {path.split('/blobs/')[-1]} ({sz:.0f} MB)", flush=True) if not parquet_files: print("[cache-build] no shards found — run predownload first", flush=True) sys.exit(1) t_start = time.time() rows_written = 0 # Single-batch tokenize function using the pool pool = Pool(processes=args.workers, initializer=_worker_init) pending_batches = [] # batches of texts waiting to be tokenized PENDING_LIMIT = args.workers * 4 def flush_to_tokenize(): """Submit pending batches to pool, write results as they come.""" nonlocal rows_written if not pending_batches: return batch_args = [(b, 0) for b in pending_batches] # Use imap_unordered for streaming results for token_lists in pool.imap_unordered(_tokenize_batch, batch_args, chunksize=1): rows = pack_rows(token_lists, row_capacity) if len(rows) > 0: fout.write(rows.tobytes()) rows_written += len(rows) if rows_written >= target_rows: return if rows_written % 8192 < len(rows): elapsed = time.time() - t_start bw = rows_written * row_capacity * 4 / 1024**3 mbps = bw * 1024 / max(elapsed, 0.001) pct = 100 * rows_written / target_rows print(f" {rows_written:>8} rows {bw:.2f} GB {pct:5.1f}% {mbps:.1f} MB/s t={elapsed:.0f}s", flush=True) pending_batches.clear() with open(tmp_path, "wb") as fout: try: done = False # Round-robin across datasets to get diverse blend iterators = [] for name, path in parquet_files: iterators.append((name, iter_text_from_parquet(name, path, args.batch_size))) while iterators and not done: for i in range(len(iterators) - 1, -1, -1): name, it = iterators[i] try: texts = next(it) except StopIteration: iterators.pop(i) continue pending_batches.append(texts) if len(pending_batches) >= PENDING_LIMIT: flush_to_tokenize() if rows_written >= target_rows: done = True break # Final flush if not done and pending_batches: flush_to_tokenize() finally: pool.close() pool.terminate() pool.join() os.replace(tmp_path, cache_path) elapsed = time.time() - t_start total_bytes = rows_written * row_capacity * 4 print(f"\n[cache-build] DONE — {rows_written} rows, {total_bytes/1024**3:.2f} GB in {elapsed:.0f}s ({total_bytes/1024**2/elapsed:.1f} MB/s)", flush=True) print(f"[cache-build] cache: {cache_path}", flush=True) if __name__ == "__main__": main()