| |
| """Download the public TinyStories parquet shard used for corpus holdouts.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| from pathlib import Path |
| from urllib.request import urlretrieve |
|
|
|
|
| TINYSTORIES_TRAIN_URL = ( |
| "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/" |
| "f54c09fd23315a6f9c86f9dc80f725de7d8f9c64/" |
| "data/train-00000-of-00004-2d5a1467fff1081b.parquet?download=true" |
| ) |
| EXPECTED_BYTES = 248_731_111 |
| EXPECTED_SHA256 = "77cf780cebe52b6e83e3a2ac84bc56d8059363113e41d17a023f1d8b2ed0fc0b" |
|
|
|
|
| def sha256(path: Path) -> str: |
| digest = hashlib.sha256() |
| with path.open("rb") as handle: |
| for chunk in iter(lambda: handle.read(1024 * 1024), b""): |
| digest.update(chunk) |
| return digest.hexdigest() |
|
|
|
|
| def verify_file(path: Path) -> None: |
| size = path.stat().st_size |
| if size != EXPECTED_BYTES: |
| raise SystemExit(f"{path} has {size:,} bytes; expected {EXPECTED_BYTES:,}.") |
| actual = sha256(path) |
| if actual != EXPECTED_SHA256: |
| raise SystemExit(f"{path} has sha256 {actual}; expected {EXPECTED_SHA256}.") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Prepare the TinyStories public corpus holdout." |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path("data/tinystories"), |
| help="Directory where the parquet file should be stored.", |
| ) |
| parser.add_argument( |
| "--force", |
| action="store_true", |
| help="Download again even if the target file already exists.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| target = args.output_dir / "train-00000-of-00004.parquet" |
| if target.exists() and not args.force: |
| verify_file(target) |
| print(target) |
| return |
|
|
| print(f"Downloading TinyStories train parquet shard to {target}") |
| urlretrieve(TINYSTORIES_TRAIN_URL, target) |
| verify_file(target) |
| print(target) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|