File size: 2,170 Bytes
e39c73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
"""Download the public WikiText-103 raw parquet used for corpus holdouts."""

from __future__ import annotations

import argparse
import hashlib
from pathlib import Path
from urllib.request import urlretrieve


WIKITEXT103_RAW_TRAIN_URL = (
    "https://huggingface.co/datasets/Salesforce/wikitext/resolve/"
    "6231e49f19a707241d6b84d9cff60a3a86b85a85/"
    "wikitext-103-raw-v1/train-00001-of-00002.parquet?download=true"
)
EXPECTED_BYTES = 156_700_942
EXPECTED_SHA256 = "75aa65dee9de2a7c10ba1808efd2408c3f4eb008104c3ccac47f8ed19300ebdd"


def sha256(path: Path) -> str:
    digest = hashlib.sha256()
    with path.open("rb") as handle:
        for chunk in iter(lambda: handle.read(1024 * 1024), b""):
            digest.update(chunk)
    return digest.hexdigest()


def verify_file(path: Path) -> None:
    size = path.stat().st_size
    if size != EXPECTED_BYTES:
        raise SystemExit(
            f"{path} has {size:,} bytes; expected {EXPECTED_BYTES:,}."
        )
    actual = sha256(path)
    if actual != EXPECTED_SHA256:
        raise SystemExit(
            f"{path} has sha256 {actual}; expected {EXPECTED_SHA256}."
        )


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Prepare the WikiText-103 raw parquet corpus holdout."
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=Path("data/wikitext103_raw"),
        help="Directory where the parquet file should be stored.",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Download again even if the target file already exists.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    args.output_dir.mkdir(parents=True, exist_ok=True)
    target = args.output_dir / "train-00001-of-00002.parquet"
    if target.exists() and not args.force:
        verify_file(target)
        print(target)
        return

    print(f"Downloading WikiText-103 raw train parquet to {target}")
    urlretrieve(WIKITEXT103_RAW_TRAIN_URL, target)
    verify_file(target)
    print(target)


if __name__ == "__main__":
    main()