feather-runtime / overlay /scripts /predownload_shards.py
Jackoatmon's picture
Update Feather h200 training runtime image
e317e25 verified
"""Pre-download parquet shards using direct HTTP with concurrent ranged requests.
Bypasses hf_hub_download overhead β€” just resolves the CDN URL and streams
with concurrent range chunks. Achieves 10+ MB/s (full BW).
Files are placed directly in HF cache structure so streaming=True picks them up.
Usage: python scripts/predownload_shards.py [--shards N]
"""
from __future__ import annotations
import argparse
import os
import sys
import time
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
# Unbuffered stdout
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prepare_nemotron import _BLEND_REGISTRY
from huggingface_hub import HfApi, hf_hub_url, hf_hub_download
def list_parquet(repo: str, config: str | None, name: str, shards: int, token: str | None) -> list[str]:
api = HfApi(token=token)
files = api.list_repo_files(repo, repo_type="dataset")
parquet = sorted(f for f in files if f.endswith(".parquet"))
effective_cfg = "Nemotron-Pretraining-Code-Concepts" if name == "nemotron-specialized" else config
if effective_cfg is not None:
filtered = [f for f in parquet if f"/{effective_cfg}/" in f or f.startswith(f"{effective_cfg}/")]
if filtered:
parquet = filtered
return parquet[:shards]
def download_one(repo: str, filename: str, token: str | None) -> tuple[str, int, float]:
"""Use hf_hub_download β€” proven to work with -L redirect from curl test."""
t0 = time.time()
path = hf_hub_download(
repo_id=repo,
filename=filename,
repo_type="dataset",
token=token,
)
sz = os.path.getsize(path)
return (filename, sz, time.time() - t0)
def download_dataset(name: str, repo: str, config: str | None, shards: int, token: str | None, workers: int = 2) -> tuple[int, float]:
t0 = time.time()
try:
files = list_parquet(repo, config, name, shards, token)
except Exception as e:
print(f"[{name}] list failed: {type(e).__name__}: {e}", flush=True)
return (0, 0.0)
if not files:
print(f"[{name}] no parquet matched β€” skipped (config={config})", flush=True)
return (0, 0.0)
print(f"[{name}] {len(files)} shards ({workers} concurrent)", flush=True)
total = 0
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = [ex.submit(download_one, repo, f, token) for f in files]
for fut in as_completed(futs):
try:
fname, sz, elapsed = fut.result()
mbps = sz / 1024**2 / max(elapsed, 0.001)
print(f" OK {fname}: {sz / 1024**2:.0f} MB in {elapsed:.0f}s ({mbps:.1f} MB/s)", flush=True)
total += sz
except Exception as e:
print(f" FAIL: {type(e).__name__}: {str(e)[:100]}", flush=True)
elapsed = time.time() - t0
print(f"[{name}] {total / 1024**3:.2f} GB in {elapsed:.0f}s ({total / 1024**2 / max(elapsed, 0.001):.1f} MB/s)", flush=True)
return (total, elapsed)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--shards", type=int, default=2)
ap.add_argument("--concurrent-files", type=int, default=2, help="shards in parallel per dataset")
args = ap.parse_args()
token = os.environ.get("HF_TOKEN")
datasets = list(_BLEND_REGISTRY.items())
print(f"[predownload] {len(datasets)} datasets Γ— {args.shards} shards, {args.concurrent_files} concurrent per dataset", flush=True)
t_start = time.time()
grand_total = 0
for name, (repo, cfg, _col) in datasets:
total, _ = download_dataset(name, repo, cfg, args.shards, token, workers=args.concurrent_files)
grand_total += total
elapsed = time.time() - t_start
print(f"\n[predownload] DONE β€” {grand_total / 1024**3:.2f} GB in {elapsed:.0f}s ({grand_total / 1024**2 / max(elapsed, 0.001):.1f} MB/s overall)", flush=True)
if __name__ == "__main__":
main()