dropout-decay / scripts /prepare_wikitext103.py
Mandeep Sidhu
Add reproducible WikiText corpus prep
e39c73c
#!/usr/bin/env python3
"""Download the public WikiText-103 raw parquet used for corpus holdouts."""
from __future__ import annotations
import argparse
import hashlib
from pathlib import Path
from urllib.request import urlretrieve
WIKITEXT103_RAW_TRAIN_URL = (
"https://huggingface.co/datasets/Salesforce/wikitext/resolve/"
"6231e49f19a707241d6b84d9cff60a3a86b85a85/"
"wikitext-103-raw-v1/train-00001-of-00002.parquet?download=true"
)
EXPECTED_BYTES = 156_700_942
EXPECTED_SHA256 = "75aa65dee9de2a7c10ba1808efd2408c3f4eb008104c3ccac47f8ed19300ebdd"
def sha256(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def verify_file(path: Path) -> None:
size = path.stat().st_size
if size != EXPECTED_BYTES:
raise SystemExit(
f"{path} has {size:,} bytes; expected {EXPECTED_BYTES:,}."
)
actual = sha256(path)
if actual != EXPECTED_SHA256:
raise SystemExit(
f"{path} has sha256 {actual}; expected {EXPECTED_SHA256}."
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Prepare the WikiText-103 raw parquet corpus holdout."
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/wikitext103_raw"),
help="Directory where the parquet file should be stored.",
)
parser.add_argument(
"--force",
action="store_true",
help="Download again even if the target file already exists.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
target = args.output_dir / "train-00001-of-00002.parquet"
if target.exists() and not args.force:
verify_file(target)
print(target)
return
print(f"Downloading WikiText-103 raw train parquet to {target}")
urlretrieve(WIKITEXT103_RAW_TRAIN_URL, target)
verify_file(target)
print(target)
if __name__ == "__main__":
main()