Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| """ | |
| Fetch additional training shards from karpathy/climbmix-400b-shuffle. | |
| The repo already has ~500 shards (~31B tokens). This script is a | |
| resumable, parallel downloader for cases where more shards are needed | |
| (e.g., multi-day training, experiments requiring fresh-unseen data, | |
| or when we want to split the corpus across processes). | |
| Usage: | |
| # Fetch shards up to index 600 (total cap) | |
| python scripts/fetch_corpus.py --target-shards 600 | |
| # Fetch a specific range | |
| python scripts/fetch_corpus.py --start 500 --end 800 | |
| # Dry-run (list what would be downloaded) | |
| python scripts/fetch_corpus.py --target-shards 600 --dry-run | |
| Notes: | |
| - Safe to run while training is active; only writes files not touched | |
| by the training process. | |
| - Resumable: skips shards already on disk. | |
| - Downloads to the same DATA_DIR used by prepare.py so they're picked | |
| up on next training launch. | |
| """ | |
| import argparse | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| import requests | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from prepare import BASE_URL, DATA_DIR, MAX_SHARD, VAL_SHARD # noqa: E402 | |
| def human_bytes(n: int) -> str: | |
| for unit in ("B", "KB", "MB", "GB", "TB"): | |
| if n < 1024: | |
| return f"{n:.1f}{unit}" | |
| n /= 1024 | |
| return f"{n:.1f}PB" | |
| def download_one( | |
| index: int, data_dir: str, timeout: int = 30, max_attempts: int = 5 | |
| ) -> tuple[int, bool, int, str]: | |
| """ | |
| Download a single parquet shard. Resumable + retry with exponential backoff. | |
| Returns (index, success, bytes_written, message). | |
| """ | |
| filename = f"shard_{index:05d}.parquet" | |
| filepath = os.path.join(data_dir, filename) | |
| tmp_path = filepath + ".tmp" | |
| if os.path.exists(filepath): | |
| return index, True, 0, "already-present" | |
| url = f"{BASE_URL}/{filename}" | |
| for attempt in range(1, max_attempts + 1): | |
| try: | |
| with requests.get(url, stream=True, timeout=timeout) as r: | |
| r.raise_for_status() | |
| bytes_written = 0 | |
| with open(tmp_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1 << 20): | |
| if chunk: | |
| f.write(chunk) | |
| bytes_written += len(chunk) | |
| os.rename(tmp_path, filepath) | |
| return index, True, bytes_written, f"ok (attempt {attempt})" | |
| except (requests.RequestException, OSError) as e: | |
| # Clean up partial file. | |
| for p in (tmp_path, filepath): | |
| if os.path.exists(p): | |
| try: | |
| os.remove(p) | |
| except OSError: | |
| pass | |
| if attempt < max_attempts: | |
| wait = 2 ** attempt | |
| time.sleep(wait) | |
| continue | |
| return index, False, 0, f"failed after {max_attempts} attempts: {e}" | |
| return index, False, 0, "unknown failure" | |
| def check_disk_space(required_bytes: int, data_dir: str) -> tuple[bool, int]: | |
| """Ensure we have at least required_bytes + 10% headroom free.""" | |
| os.makedirs(data_dir, exist_ok=True) | |
| stats = shutil.disk_usage(data_dir) | |
| headroom = int(required_bytes * 1.1) | |
| return stats.free >= headroom, stats.free | |
| def main() -> int: | |
| parser = argparse.ArgumentParser( | |
| description="Fetch additional climbmix-400b-shuffle shards" | |
| ) | |
| parser.add_argument( | |
| "--target-shards", | |
| type=int, | |
| default=None, | |
| help="Total train-shard count to reach (0..target-1). Mutually exclusive with --start/--end.", | |
| ) | |
| parser.add_argument("--start", type=int, default=None, help="Starting shard index (inclusive)") | |
| parser.add_argument("--end", type=int, default=None, help="Ending shard index (exclusive)") | |
| parser.add_argument("--workers", type=int, default=8, help="Parallel download workers") | |
| parser.add_argument( | |
| "--include-val", | |
| action="store_true", | |
| help="Also fetch the pinned validation shard (normally present already)", | |
| ) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="List what would be downloaded without fetching", | |
| ) | |
| args = parser.parse_args() | |
| # Resolve shard range. | |
| if args.target_shards is not None: | |
| if args.start is not None or args.end is not None: | |
| print("ERROR: --target-shards is exclusive with --start/--end") | |
| return 1 | |
| ids = list(range(min(args.target_shards, MAX_SHARD))) | |
| else: | |
| start = args.start or 0 | |
| end = args.end if args.end is not None else MAX_SHARD | |
| end = min(end, MAX_SHARD) | |
| ids = list(range(start, end)) | |
| if args.include_val and VAL_SHARD not in ids: | |
| ids.append(VAL_SHARD) | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| present = set() | |
| for p in Path(DATA_DIR).glob("shard_*.parquet"): | |
| try: | |
| idx = int(p.stem.split("_")[1]) | |
| present.add(idx) | |
| except (IndexError, ValueError): | |
| continue | |
| to_fetch = [i for i in ids if i not in present] | |
| if not to_fetch: | |
| print(f"All {len(ids)} shards already present at {DATA_DIR}") | |
| return 0 | |
| # Estimate space: shards are ~88MB; leave 10% headroom. | |
| avg_shard_bytes = 90 * (1 << 20) # 90MB | |
| required = avg_shard_bytes * len(to_fetch) | |
| ok, free = check_disk_space(required, DATA_DIR) | |
| print(f"Plan: fetch {len(to_fetch)} shards (~{human_bytes(required)}); " | |
| f"disk free: {human_bytes(free)}") | |
| if not ok: | |
| print("ERROR: insufficient disk space (need 1.1x required)") | |
| return 2 | |
| if args.dry_run: | |
| preview = to_fetch[:10] | |
| print( | |
| f"Dry-run — would fetch {len(to_fetch)} shards. First {len(preview)}: {preview}" | |
| ) | |
| return 0 | |
| print(f"Downloading {len(to_fetch)} shards with {args.workers} workers...") | |
| t_start = time.time() | |
| success = 0 | |
| failed = 0 | |
| total_bytes = 0 | |
| with ThreadPoolExecutor(max_workers=args.workers) as ex: | |
| futs = {ex.submit(download_one, i, DATA_DIR): i for i in to_fetch} | |
| for fut in as_completed(futs): | |
| idx, ok, nbytes, msg = fut.result() | |
| if ok: | |
| success += 1 | |
| total_bytes += nbytes | |
| if success % 10 == 0 or success == len(to_fetch): | |
| elapsed = time.time() - t_start | |
| rate = total_bytes / max(elapsed, 1) | |
| print( | |
| f" [{success}/{len(to_fetch)}] shard_{idx:05d} ok " | |
| f"({human_bytes(total_bytes)} @ {human_bytes(int(rate))}/s)" | |
| ) | |
| else: | |
| failed += 1 | |
| print(f" [FAIL] shard_{idx:05d}: {msg}") | |
| elapsed = time.time() - t_start | |
| print() | |
| print("=" * 60) | |
| print(f"Downloaded {success}/{len(to_fetch)} shards in {elapsed:.1f}s") | |
| print(f"Failed: {failed}") | |
| print(f"Total bytes: {human_bytes(total_bytes)}") | |
| print("=" * 60) | |
| return 0 if failed == 0 else 3 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |