icarus112's picture
Update Feather a10g-large training runtime image
e5cf7c3 verified
from __future__ import annotations
"""
Fetch additional training shards from karpathy/climbmix-400b-shuffle.
The repo already has ~500 shards (~31B tokens). This script is a
resumable, parallel downloader for cases where more shards are needed
(e.g., multi-day training, experiments requiring fresh-unseen data,
or when we want to split the corpus across processes).
Usage:
# Fetch shards up to index 600 (total cap)
python scripts/fetch_corpus.py --target-shards 600
# Fetch a specific range
python scripts/fetch_corpus.py --start 500 --end 800
# Dry-run (list what would be downloaded)
python scripts/fetch_corpus.py --target-shards 600 --dry-run
Notes:
- Safe to run while training is active; only writes files not touched
by the training process.
- Resumable: skips shards already on disk.
- Downloads to the same DATA_DIR used by prepare.py so they're picked
up on next training launch.
"""
import argparse
import os
import shutil
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import requests
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from prepare import BASE_URL, DATA_DIR, MAX_SHARD, VAL_SHARD # noqa: E402
def human_bytes(n: int) -> str:
for unit in ("B", "KB", "MB", "GB", "TB"):
if n < 1024:
return f"{n:.1f}{unit}"
n /= 1024
return f"{n:.1f}PB"
def download_one(
index: int, data_dir: str, timeout: int = 30, max_attempts: int = 5
) -> tuple[int, bool, int, str]:
"""
Download a single parquet shard. Resumable + retry with exponential backoff.
Returns (index, success, bytes_written, message).
"""
filename = f"shard_{index:05d}.parquet"
filepath = os.path.join(data_dir, filename)
tmp_path = filepath + ".tmp"
if os.path.exists(filepath):
return index, True, 0, "already-present"
url = f"{BASE_URL}/{filename}"
for attempt in range(1, max_attempts + 1):
try:
with requests.get(url, stream=True, timeout=timeout) as r:
r.raise_for_status()
bytes_written = 0
with open(tmp_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1 << 20):
if chunk:
f.write(chunk)
bytes_written += len(chunk)
os.rename(tmp_path, filepath)
return index, True, bytes_written, f"ok (attempt {attempt})"
except (requests.RequestException, OSError) as e:
# Clean up partial file.
for p in (tmp_path, filepath):
if os.path.exists(p):
try:
os.remove(p)
except OSError:
pass
if attempt < max_attempts:
wait = 2 ** attempt
time.sleep(wait)
continue
return index, False, 0, f"failed after {max_attempts} attempts: {e}"
return index, False, 0, "unknown failure"
def check_disk_space(required_bytes: int, data_dir: str) -> tuple[bool, int]:
"""Ensure we have at least required_bytes + 10% headroom free."""
os.makedirs(data_dir, exist_ok=True)
stats = shutil.disk_usage(data_dir)
headroom = int(required_bytes * 1.1)
return stats.free >= headroom, stats.free
def main() -> int:
parser = argparse.ArgumentParser(
description="Fetch additional climbmix-400b-shuffle shards"
)
parser.add_argument(
"--target-shards",
type=int,
default=None,
help="Total train-shard count to reach (0..target-1). Mutually exclusive with --start/--end.",
)
parser.add_argument("--start", type=int, default=None, help="Starting shard index (inclusive)")
parser.add_argument("--end", type=int, default=None, help="Ending shard index (exclusive)")
parser.add_argument("--workers", type=int, default=8, help="Parallel download workers")
parser.add_argument(
"--include-val",
action="store_true",
help="Also fetch the pinned validation shard (normally present already)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="List what would be downloaded without fetching",
)
args = parser.parse_args()
# Resolve shard range.
if args.target_shards is not None:
if args.start is not None or args.end is not None:
print("ERROR: --target-shards is exclusive with --start/--end")
return 1
ids = list(range(min(args.target_shards, MAX_SHARD)))
else:
start = args.start or 0
end = args.end if args.end is not None else MAX_SHARD
end = min(end, MAX_SHARD)
ids = list(range(start, end))
if args.include_val and VAL_SHARD not in ids:
ids.append(VAL_SHARD)
os.makedirs(DATA_DIR, exist_ok=True)
present = set()
for p in Path(DATA_DIR).glob("shard_*.parquet"):
try:
idx = int(p.stem.split("_")[1])
present.add(idx)
except (IndexError, ValueError):
continue
to_fetch = [i for i in ids if i not in present]
if not to_fetch:
print(f"All {len(ids)} shards already present at {DATA_DIR}")
return 0
# Estimate space: shards are ~88MB; leave 10% headroom.
avg_shard_bytes = 90 * (1 << 20) # 90MB
required = avg_shard_bytes * len(to_fetch)
ok, free = check_disk_space(required, DATA_DIR)
print(f"Plan: fetch {len(to_fetch)} shards (~{human_bytes(required)}); "
f"disk free: {human_bytes(free)}")
if not ok:
print("ERROR: insufficient disk space (need 1.1x required)")
return 2
if args.dry_run:
preview = to_fetch[:10]
print(
f"Dry-run — would fetch {len(to_fetch)} shards. First {len(preview)}: {preview}"
)
return 0
print(f"Downloading {len(to_fetch)} shards with {args.workers} workers...")
t_start = time.time()
success = 0
failed = 0
total_bytes = 0
with ThreadPoolExecutor(max_workers=args.workers) as ex:
futs = {ex.submit(download_one, i, DATA_DIR): i for i in to_fetch}
for fut in as_completed(futs):
idx, ok, nbytes, msg = fut.result()
if ok:
success += 1
total_bytes += nbytes
if success % 10 == 0 or success == len(to_fetch):
elapsed = time.time() - t_start
rate = total_bytes / max(elapsed, 1)
print(
f" [{success}/{len(to_fetch)}] shard_{idx:05d} ok "
f"({human_bytes(total_bytes)} @ {human_bytes(int(rate))}/s)"
)
else:
failed += 1
print(f" [FAIL] shard_{idx:05d}: {msg}")
elapsed = time.time() - t_start
print()
print("=" * 60)
print(f"Downloaded {success}/{len(to_fetch)} shards in {elapsed:.1f}s")
print(f"Failed: {failed}")
print(f"Total bytes: {human_bytes(total_bytes)}")
print("=" * 60)
return 0 if failed == 0 else 3
if __name__ == "__main__":
raise SystemExit(main())