File size: 2,115 Bytes
20c4ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
"""Download the public TinyStories parquet shard used for corpus holdouts."""

from __future__ import annotations

import argparse
import hashlib
from pathlib import Path
from urllib.request import urlretrieve


TINYSTORIES_TRAIN_URL = (
    "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/"
    "f54c09fd23315a6f9c86f9dc80f725de7d8f9c64/"
    "data/train-00000-of-00004-2d5a1467fff1081b.parquet?download=true"
)
EXPECTED_BYTES = 248_731_111
EXPECTED_SHA256 = "77cf780cebe52b6e83e3a2ac84bc56d8059363113e41d17a023f1d8b2ed0fc0b"


def sha256(path: Path) -> str:
    digest = hashlib.sha256()
    with path.open("rb") as handle:
        for chunk in iter(lambda: handle.read(1024 * 1024), b""):
            digest.update(chunk)
    return digest.hexdigest()


def verify_file(path: Path) -> None:
    size = path.stat().st_size
    if size != EXPECTED_BYTES:
        raise SystemExit(f"{path} has {size:,} bytes; expected {EXPECTED_BYTES:,}.")
    actual = sha256(path)
    if actual != EXPECTED_SHA256:
        raise SystemExit(f"{path} has sha256 {actual}; expected {EXPECTED_SHA256}.")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Prepare the TinyStories public corpus holdout."
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=Path("data/tinystories"),
        help="Directory where the parquet file should be stored.",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Download again even if the target file already exists.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    args.output_dir.mkdir(parents=True, exist_ok=True)
    target = args.output_dir / "train-00000-of-00004.parquet"
    if target.exists() and not args.force:
        verify_file(target)
        print(target)
        return

    print(f"Downloading TinyStories train parquet shard to {target}")
    urlretrieve(TINYSTORIES_TRAIN_URL, target)
    verify_file(target)
    print(target)


if __name__ == "__main__":
    main()