| |
| """Download the public WikiText-103 raw parquet used for corpus holdouts.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| from pathlib import Path |
| from urllib.request import urlretrieve |
|
|
|
|
| WIKITEXT103_RAW_TRAIN_URL = ( |
| "https://huggingface.co/datasets/Salesforce/wikitext/resolve/" |
| "6231e49f19a707241d6b84d9cff60a3a86b85a85/" |
| "wikitext-103-raw-v1/train-00001-of-00002.parquet?download=true" |
| ) |
| EXPECTED_BYTES = 156_700_942 |
| EXPECTED_SHA256 = "75aa65dee9de2a7c10ba1808efd2408c3f4eb008104c3ccac47f8ed19300ebdd" |
|
|
|
|
| def sha256(path: Path) -> str: |
| digest = hashlib.sha256() |
| with path.open("rb") as handle: |
| for chunk in iter(lambda: handle.read(1024 * 1024), b""): |
| digest.update(chunk) |
| return digest.hexdigest() |
|
|
|
|
| def verify_file(path: Path) -> None: |
| size = path.stat().st_size |
| if size != EXPECTED_BYTES: |
| raise SystemExit( |
| f"{path} has {size:,} bytes; expected {EXPECTED_BYTES:,}." |
| ) |
| actual = sha256(path) |
| if actual != EXPECTED_SHA256: |
| raise SystemExit( |
| f"{path} has sha256 {actual}; expected {EXPECTED_SHA256}." |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Prepare the WikiText-103 raw parquet corpus holdout." |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path("data/wikitext103_raw"), |
| help="Directory where the parquet file should be stored.", |
| ) |
| parser.add_argument( |
| "--force", |
| action="store_true", |
| help="Download again even if the target file already exists.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| target = args.output_dir / "train-00001-of-00002.parquet" |
| if target.exists() and not args.force: |
| verify_file(target) |
| print(target) |
| return |
|
|
| print(f"Downloading WikiText-103 raw train parquet to {target}") |
| urlretrieve(WIKITEXT103_RAW_TRAIN_URL, target) |
| verify_file(target) |
| print(target) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|