""" Fetch additional training shards from karpathy/climbmix-400b-shuffle. The repo already has ~500 shards (~31B tokens). This script is a resumable, parallel downloader for cases where more shards are needed (e.g., multi-day training, experiments requiring fresh-unseen data, or when we want to split the corpus across processes). Usage: # Fetch shards up to index 600 (total cap) python scripts/fetch_corpus.py --target-shards 600 # Fetch a specific range python scripts/fetch_corpus.py --start 500 --end 800 # Dry-run (list what would be downloaded) python scripts/fetch_corpus.py --target-shards 600 --dry-run Notes: - Safe to run while training is active; only writes files not touched by the training process. - Resumable: skips shards already on disk. - Downloads to the same DATA_DIR used by prepare.py so they're picked up on next training launch. """ from __future__ import annotations import argparse import os import shutil import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import requests REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT)) from prepare import BASE_URL, DATA_DIR, MAX_SHARD, VAL_SHARD # noqa: E402 def human_bytes(n: int) -> str: for unit in ("B", "KB", "MB", "GB", "TB"): if n < 1024: return f"{n:.1f}{unit}" n /= 1024 return f"{n:.1f}PB" def download_one( index: int, data_dir: str, timeout: int = 30, max_attempts: int = 5 ) -> tuple[int, bool, int, str]: """ Download a single parquet shard. Resumable + retry with exponential backoff. Returns (index, success, bytes_written, message). """ filename = f"shard_{index:05d}.parquet" filepath = os.path.join(data_dir, filename) tmp_path = filepath + ".tmp" if os.path.exists(filepath): return index, True, 0, "already-present" url = f"{BASE_URL}/{filename}" for attempt in range(1, max_attempts + 1): try: with requests.get(url, stream=True, timeout=timeout) as r: r.raise_for_status() bytes_written = 0 with open(tmp_path, "wb") as f: for chunk in r.iter_content(chunk_size=1 << 20): if chunk: f.write(chunk) bytes_written += len(chunk) os.rename(tmp_path, filepath) return index, True, bytes_written, f"ok (attempt {attempt})" except (requests.RequestException, OSError) as e: # Clean up partial file. for p in (tmp_path, filepath): if os.path.exists(p): try: os.remove(p) except OSError: pass if attempt < max_attempts: wait = 2 ** attempt time.sleep(wait) continue return index, False, 0, f"failed after {max_attempts} attempts: {e}" return index, False, 0, "unknown failure" def check_disk_space(required_bytes: int, data_dir: str) -> tuple[bool, int]: """Ensure we have at least required_bytes + 10% headroom free.""" os.makedirs(data_dir, exist_ok=True) stats = shutil.disk_usage(data_dir) headroom = int(required_bytes * 1.1) return stats.free >= headroom, stats.free def main() -> int: parser = argparse.ArgumentParser( description="Fetch additional climbmix-400b-shuffle shards" ) parser.add_argument( "--target-shards", type=int, default=None, help="Total train-shard count to reach (0..target-1). Mutually exclusive with --start/--end.", ) parser.add_argument("--start", type=int, default=None, help="Starting shard index (inclusive)") parser.add_argument("--end", type=int, default=None, help="Ending shard index (exclusive)") parser.add_argument("--workers", type=int, default=8, help="Parallel download workers") parser.add_argument( "--include-val", action="store_true", help="Also fetch the pinned validation shard (normally present already)", ) parser.add_argument( "--dry-run", action="store_true", help="List what would be downloaded without fetching", ) args = parser.parse_args() # Resolve shard range. if args.target_shards is not None: if args.start is not None or args.end is not None: print("ERROR: --target-shards is exclusive with --start/--end") return 1 ids = list(range(min(args.target_shards, MAX_SHARD))) else: start = args.start or 0 end = args.end if args.end is not None else MAX_SHARD end = min(end, MAX_SHARD) ids = list(range(start, end)) if args.include_val and VAL_SHARD not in ids: ids.append(VAL_SHARD) os.makedirs(DATA_DIR, exist_ok=True) present = set() for p in Path(DATA_DIR).glob("shard_*.parquet"): try: idx = int(p.stem.split("_")[1]) present.add(idx) except (IndexError, ValueError): continue to_fetch = [i for i in ids if i not in present] if not to_fetch: print(f"All {len(ids)} shards already present at {DATA_DIR}") return 0 # Estimate space: shards are ~88MB; leave 10% headroom. avg_shard_bytes = 90 * (1 << 20) # 90MB required = avg_shard_bytes * len(to_fetch) ok, free = check_disk_space(required, DATA_DIR) print(f"Plan: fetch {len(to_fetch)} shards (~{human_bytes(required)}); " f"disk free: {human_bytes(free)}") if not ok: print("ERROR: insufficient disk space (need 1.1x required)") return 2 if args.dry_run: preview = to_fetch[:10] print( f"Dry-run — would fetch {len(to_fetch)} shards. First {len(preview)}: {preview}" ) return 0 print(f"Downloading {len(to_fetch)} shards with {args.workers} workers...") t_start = time.time() success = 0 failed = 0 total_bytes = 0 with ThreadPoolExecutor(max_workers=args.workers) as ex: futs = {ex.submit(download_one, i, DATA_DIR): i for i in to_fetch} for fut in as_completed(futs): idx, ok, nbytes, msg = fut.result() if ok: success += 1 total_bytes += nbytes if success % 10 == 0 or success == len(to_fetch): elapsed = time.time() - t_start rate = total_bytes / max(elapsed, 1) print( f" [{success}/{len(to_fetch)}] shard_{idx:05d} ok " f"({human_bytes(total_bytes)} @ {human_bytes(int(rate))}/s)" ) else: failed += 1 print(f" [FAIL] shard_{idx:05d}: {msg}") elapsed = time.time() - t_start print() print("=" * 60) print(f"Downloaded {success}/{len(to_fetch)} shards in {elapsed:.1f}s") print(f"Failed: {failed}") print(f"Total bytes: {human_bytes(total_bytes)}") print("=" * 60) return 0 if failed == 0 else 3 if __name__ == "__main__": raise SystemExit(main())