Spaces:
Runtime error
Runtime error
| """Fast parallel token cache builder. | |
| Reads parquet shards DIRECTLY via pyarrow (no HF streaming overhead), | |
| tokenizes with multiprocessing.Pool, writes packed (T+1) int32 rows. | |
| Uses the pre-downloaded shards in ~/.cache/huggingface/hub/ — no network. | |
| Usage: python scripts/build_token_cache.py [--gb 2] [--workers 8] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import glob | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from multiprocessing import Pool | |
| sys.stdout.reconfigure(line_buffering=True) | |
| import numpy as np | |
| import pyarrow.parquet as pq | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from prepare import Tokenizer | |
| HF_HUB_CACHE = os.path.expanduser("~/.cache/huggingface/hub") | |
| # Which column each dataset uses for text | |
| TEXT_COLS: dict[str, list[str]] = { | |
| "fineweb-edu": ["text"], | |
| "fineweb": ["text"], | |
| "stack-v2": ["text", "content"], | |
| "nemotron-math": ["text"], | |
| "nemotron-specialized": ["text"], | |
| "wikipedia": ["text"], | |
| "cosmopedia": ["text"], | |
| } | |
| # Dataset repo → cache dir mapping | |
| REPO_DIRS = { | |
| "fineweb-edu": "datasets--HuggingFaceFW--fineweb-edu", | |
| "fineweb": "datasets--HuggingFaceFW--fineweb", | |
| "stack-v2": "datasets--OpenCoder-LLM--opc-fineweb-code-corpus", | |
| "nemotron-math": "datasets--nvidia--Nemotron-CC-Math-v1", | |
| "nemotron-specialized": "datasets--nvidia--Nemotron-Pretraining-Specialized-v1.1", | |
| "wikipedia": "datasets--wikimedia--wikipedia", | |
| "cosmopedia": "datasets--HuggingFaceTB--cosmopedia", | |
| } | |
| def find_parquet_files() -> list[tuple[str, str]]: | |
| """Return [(dataset_name, parquet_path), ...] for all cached shards.""" | |
| results = [] | |
| for name, dirname in REPO_DIRS.items(): | |
| base = os.path.join(HF_HUB_CACHE, dirname, "snapshots") | |
| if not os.path.isdir(base): | |
| continue | |
| for snap in os.listdir(base): | |
| snap_dir = os.path.join(base, snap) | |
| for root, _, files in os.walk(snap_dir): | |
| for f in files: | |
| if f.endswith(".parquet"): | |
| results.append((name, os.path.join(root, f))) | |
| return results | |
| # Tokenizer loaded once per worker process | |
| _WORKER_TOKENIZER = None | |
| _WORKER_BOS = None | |
| def _worker_init(): | |
| global _WORKER_TOKENIZER, _WORKER_BOS | |
| _WORKER_TOKENIZER = Tokenizer.from_directory() | |
| _WORKER_BOS = _WORKER_TOKENIZER.get_bos_token_id() | |
| def _tokenize_batch(args: tuple[list[str], int]) -> list[list[int]]: | |
| """Tokenize a batch of text strings. Returns list of token-id lists.""" | |
| texts, _ = args | |
| return _WORKER_TOKENIZER.encode(texts, prepend=_WORKER_BOS) | |
| def iter_text_from_parquet(name: str, path: str, batch_size: int = 512): | |
| """Stream text batches from one parquet file.""" | |
| cols = TEXT_COLS.get(name, ["text"]) | |
| try: | |
| pf = pq.ParquetFile(path) | |
| except Exception as e: | |
| print(f" [skip] {path}: {e}", flush=True) | |
| return | |
| # Find which column exists | |
| schema_names = set(pf.schema_arrow.names) | |
| col = next((c for c in cols if c in schema_names), None) | |
| if col is None: | |
| return | |
| for batch in pf.iter_batches(batch_size=batch_size, columns=[col]): | |
| texts = batch.column(col).to_pylist() | |
| texts = [t for t in texts if t] | |
| if texts: | |
| yield texts | |
| def pack_rows(token_lists: list[list[int]], row_capacity: int) -> np.ndarray: | |
| """Pack variable-length token sequences into (N, row_capacity) rows using simple greedy concat.""" | |
| rows = [] | |
| current = [] | |
| for doc in token_lists: | |
| if len(current) + len(doc) > row_capacity: | |
| # Flush current row (pad with 0) | |
| if len(current) >= row_capacity // 2: # skip too-short trailing bits | |
| row = current[:row_capacity] | |
| if len(row) < row_capacity: | |
| row = row + [0] * (row_capacity - len(row)) | |
| rows.append(row) | |
| # Start new row with this doc (truncate if too long) | |
| current = doc[:row_capacity] | |
| else: | |
| current.extend(doc) | |
| # Emit full rows as we fill up | |
| while len(current) >= row_capacity: | |
| rows.append(current[:row_capacity]) | |
| current = current[row_capacity:] | |
| if not rows: | |
| return np.empty((0, row_capacity), dtype=np.int32) | |
| return np.asarray(rows, dtype=np.int32) | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--gb", type=float, default=2.0) | |
| ap.add_argument("--seq-len", type=int, default=512) | |
| ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2)) | |
| ap.add_argument("--batch-size", type=int, default=512, help="docs per tokenizer call") | |
| args = ap.parse_args() | |
| T = args.seq_len | |
| row_capacity = T + 1 | |
| target_bytes = int(args.gb * 1024**3) | |
| target_rows = target_bytes // (row_capacity * 4) | |
| # Load tokenizer in main process for vocab size | |
| tok = Tokenizer.from_directory() | |
| V = tok.get_vocab_size() | |
| cache_path = os.path.expanduser( | |
| f"~/.cache/autoresearch/packed_tokens_v1_T{T}_V{V}_train.bin" | |
| ) | |
| tmp_path = cache_path + ".tmp" | |
| print(f"[cache-build] target: {args.gb:.1f} GB = {target_rows} rows of (T+1)={row_capacity} int32", flush=True) | |
| print(f"[cache-build] workers: {args.workers}", flush=True) | |
| parquet_files = find_parquet_files() | |
| print(f"[cache-build] found {len(parquet_files)} parquet shards", flush=True) | |
| for name, path in parquet_files: | |
| sz = os.path.getsize(path) / 1024**2 | |
| print(f" [{name}] {path.split('/blobs/')[-1]} ({sz:.0f} MB)", flush=True) | |
| if not parquet_files: | |
| print("[cache-build] no shards found — run predownload first", flush=True) | |
| sys.exit(1) | |
| t_start = time.time() | |
| rows_written = 0 | |
| # Single-batch tokenize function using the pool | |
| pool = Pool(processes=args.workers, initializer=_worker_init) | |
| pending_batches = [] # batches of texts waiting to be tokenized | |
| PENDING_LIMIT = args.workers * 4 | |
| def flush_to_tokenize(): | |
| """Submit pending batches to pool, write results as they come.""" | |
| nonlocal rows_written | |
| if not pending_batches: | |
| return | |
| batch_args = [(b, 0) for b in pending_batches] | |
| # Use imap_unordered for streaming results | |
| for token_lists in pool.imap_unordered(_tokenize_batch, batch_args, chunksize=1): | |
| rows = pack_rows(token_lists, row_capacity) | |
| if len(rows) > 0: | |
| fout.write(rows.tobytes()) | |
| rows_written += len(rows) | |
| if rows_written >= target_rows: | |
| return | |
| if rows_written % 8192 < len(rows): | |
| elapsed = time.time() - t_start | |
| bw = rows_written * row_capacity * 4 / 1024**3 | |
| mbps = bw * 1024 / max(elapsed, 0.001) | |
| pct = 100 * rows_written / target_rows | |
| print(f" {rows_written:>8} rows {bw:.2f} GB {pct:5.1f}% {mbps:.1f} MB/s t={elapsed:.0f}s", flush=True) | |
| pending_batches.clear() | |
| with open(tmp_path, "wb") as fout: | |
| try: | |
| done = False | |
| # Round-robin across datasets to get diverse blend | |
| iterators = [] | |
| for name, path in parquet_files: | |
| iterators.append((name, iter_text_from_parquet(name, path, args.batch_size))) | |
| while iterators and not done: | |
| for i in range(len(iterators) - 1, -1, -1): | |
| name, it = iterators[i] | |
| try: | |
| texts = next(it) | |
| except StopIteration: | |
| iterators.pop(i) | |
| continue | |
| pending_batches.append(texts) | |
| if len(pending_batches) >= PENDING_LIMIT: | |
| flush_to_tokenize() | |
| if rows_written >= target_rows: | |
| done = True | |
| break | |
| # Final flush | |
| if not done and pending_batches: | |
| flush_to_tokenize() | |
| finally: | |
| pool.close() | |
| pool.terminate() | |
| pool.join() | |
| os.replace(tmp_path, cache_path) | |
| elapsed = time.time() - t_start | |
| total_bytes = rows_written * row_capacity * 4 | |
| print(f"\n[cache-build] DONE — {rows_written} rows, {total_bytes/1024**3:.2f} GB in {elapsed:.0f}s ({total_bytes/1024**2/elapsed:.1f} MB/s)", flush=True) | |
| print(f"[cache-build] cache: {cache_path}", flush=True) | |
| if __name__ == "__main__": | |
| main() | |