Spaces:
Runtime error
Runtime error
| """Pre-download parquet shards using direct HTTP with concurrent ranged requests. | |
| Bypasses hf_hub_download overhead β just resolves the CDN URL and streams | |
| with concurrent range chunks. Achieves 10+ MB/s (full BW). | |
| Files are placed directly in HF cache structure so streaming=True picks them up. | |
| Usage: python scripts/predownload_shards.py [--shards N] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| import urllib.request | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| # Unbuffered stdout | |
| sys.stdout.reconfigure(line_buffering=True) | |
| sys.stderr.reconfigure(line_buffering=True) | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from prepare_nemotron import _BLEND_REGISTRY | |
| from huggingface_hub import HfApi, hf_hub_url, hf_hub_download | |
| def list_parquet(repo: str, config: str | None, name: str, shards: int, token: str | None) -> list[str]: | |
| api = HfApi(token=token) | |
| files = api.list_repo_files(repo, repo_type="dataset") | |
| parquet = sorted(f for f in files if f.endswith(".parquet")) | |
| effective_cfg = "Nemotron-Pretraining-Code-Concepts" if name == "nemotron-specialized" else config | |
| if effective_cfg is not None: | |
| filtered = [f for f in parquet if f"/{effective_cfg}/" in f or f.startswith(f"{effective_cfg}/")] | |
| if filtered: | |
| parquet = filtered | |
| return parquet[:shards] | |
| def download_one(repo: str, filename: str, token: str | None) -> tuple[str, int, float]: | |
| """Use hf_hub_download β proven to work with -L redirect from curl test.""" | |
| t0 = time.time() | |
| path = hf_hub_download( | |
| repo_id=repo, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=token, | |
| ) | |
| sz = os.path.getsize(path) | |
| return (filename, sz, time.time() - t0) | |
| def download_dataset(name: str, repo: str, config: str | None, shards: int, token: str | None, workers: int = 2) -> tuple[int, float]: | |
| t0 = time.time() | |
| try: | |
| files = list_parquet(repo, config, name, shards, token) | |
| except Exception as e: | |
| print(f"[{name}] list failed: {type(e).__name__}: {e}", flush=True) | |
| return (0, 0.0) | |
| if not files: | |
| print(f"[{name}] no parquet matched β skipped (config={config})", flush=True) | |
| return (0, 0.0) | |
| print(f"[{name}] {len(files)} shards ({workers} concurrent)", flush=True) | |
| total = 0 | |
| with ThreadPoolExecutor(max_workers=workers) as ex: | |
| futs = [ex.submit(download_one, repo, f, token) for f in files] | |
| for fut in as_completed(futs): | |
| try: | |
| fname, sz, elapsed = fut.result() | |
| mbps = sz / 1024**2 / max(elapsed, 0.001) | |
| print(f" OK {fname}: {sz / 1024**2:.0f} MB in {elapsed:.0f}s ({mbps:.1f} MB/s)", flush=True) | |
| total += sz | |
| except Exception as e: | |
| print(f" FAIL: {type(e).__name__}: {str(e)[:100]}", flush=True) | |
| elapsed = time.time() - t0 | |
| print(f"[{name}] {total / 1024**3:.2f} GB in {elapsed:.0f}s ({total / 1024**2 / max(elapsed, 0.001):.1f} MB/s)", flush=True) | |
| return (total, elapsed) | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--shards", type=int, default=2) | |
| ap.add_argument("--concurrent-files", type=int, default=2, help="shards in parallel per dataset") | |
| args = ap.parse_args() | |
| token = os.environ.get("HF_TOKEN") | |
| datasets = list(_BLEND_REGISTRY.items()) | |
| print(f"[predownload] {len(datasets)} datasets Γ {args.shards} shards, {args.concurrent_files} concurrent per dataset", flush=True) | |
| t_start = time.time() | |
| grand_total = 0 | |
| for name, (repo, cfg, _col) in datasets: | |
| total, _ = download_dataset(name, repo, cfg, args.shards, token, workers=args.concurrent_files) | |
| grand_total += total | |
| elapsed = time.time() - t_start | |
| print(f"\n[predownload] DONE β {grand_total / 1024**3:.2f} GB in {elapsed:.0f}s ({grand_total / 1024**2 / max(elapsed, 0.001):.1f} MB/s overall)", flush=True) | |
| if __name__ == "__main__": | |
| main() | |