Mandeep Sidhu commited on
Commit ·
e39c73c
1
Parent(s): 555cf14
Add reproducible WikiText corpus prep
Browse files- .gitignore +1 -0
- scripts/prepare_wikitext103.py +76 -0
- src/dropout_decay/data.py +6 -2
.gitignore
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
__pycache__/
|
| 3 |
*.py[cod]
|
| 4 |
.cache/
|
|
|
|
| 5 |
*.npy
|
| 6 |
*.pdf
|
| 7 |
.venv/
|
|
|
|
| 2 |
__pycache__/
|
| 3 |
*.py[cod]
|
| 4 |
.cache/
|
| 5 |
+
data/
|
| 6 |
*.npy
|
| 7 |
*.pdf
|
| 8 |
.venv/
|
scripts/prepare_wikitext103.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Download the public WikiText-103 raw parquet used for corpus holdouts."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import hashlib
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from urllib.request import urlretrieve
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
WIKITEXT103_RAW_TRAIN_URL = (
|
| 13 |
+
"https://huggingface.co/datasets/Salesforce/wikitext/resolve/"
|
| 14 |
+
"6231e49f19a707241d6b84d9cff60a3a86b85a85/"
|
| 15 |
+
"wikitext-103-raw-v1/train-00001-of-00002.parquet?download=true"
|
| 16 |
+
)
|
| 17 |
+
EXPECTED_BYTES = 156_700_942
|
| 18 |
+
EXPECTED_SHA256 = "75aa65dee9de2a7c10ba1808efd2408c3f4eb008104c3ccac47f8ed19300ebdd"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def sha256(path: Path) -> str:
|
| 22 |
+
digest = hashlib.sha256()
|
| 23 |
+
with path.open("rb") as handle:
|
| 24 |
+
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
|
| 25 |
+
digest.update(chunk)
|
| 26 |
+
return digest.hexdigest()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def verify_file(path: Path) -> None:
|
| 30 |
+
size = path.stat().st_size
|
| 31 |
+
if size != EXPECTED_BYTES:
|
| 32 |
+
raise SystemExit(
|
| 33 |
+
f"{path} has {size:,} bytes; expected {EXPECTED_BYTES:,}."
|
| 34 |
+
)
|
| 35 |
+
actual = sha256(path)
|
| 36 |
+
if actual != EXPECTED_SHA256:
|
| 37 |
+
raise SystemExit(
|
| 38 |
+
f"{path} has sha256 {actual}; expected {EXPECTED_SHA256}."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def parse_args() -> argparse.Namespace:
|
| 43 |
+
parser = argparse.ArgumentParser(
|
| 44 |
+
description="Prepare the WikiText-103 raw parquet corpus holdout."
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--output-dir",
|
| 48 |
+
type=Path,
|
| 49 |
+
default=Path("data/wikitext103_raw"),
|
| 50 |
+
help="Directory where the parquet file should be stored.",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--force",
|
| 54 |
+
action="store_true",
|
| 55 |
+
help="Download again even if the target file already exists.",
|
| 56 |
+
)
|
| 57 |
+
return parser.parse_args()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main() -> None:
|
| 61 |
+
args = parse_args()
|
| 62 |
+
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
target = args.output_dir / "train-00001-of-00002.parquet"
|
| 64 |
+
if target.exists() and not args.force:
|
| 65 |
+
verify_file(target)
|
| 66 |
+
print(target)
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
print(f"Downloading WikiText-103 raw train parquet to {target}")
|
| 70 |
+
urlretrieve(WIKITEXT103_RAW_TRAIN_URL, target)
|
| 71 |
+
verify_file(target)
|
| 72 |
+
print(target)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
src/dropout_decay/data.py
CHANGED
|
@@ -41,6 +41,10 @@ class CachedTokenizer:
|
|
| 41 |
vocab_size: int
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def resolve_paths(corpus: str | None, corpus_glob: str | None) -> list[Path]:
|
| 45 |
paths: list[Path] = []
|
| 46 |
if corpus:
|
|
@@ -83,7 +87,7 @@ def load_cached_splits(
|
|
| 83 |
|
| 84 |
tokenizer = CachedTokenizer(vocab_size=vocab_size)
|
| 85 |
tokens = np.load(encoded_path, mmap_mode="r")
|
| 86 |
-
need_total = max_required_train_tokens
|
| 87 |
if len(tokens) < need_total and not allow_short_corpus:
|
| 88 |
raise ValueError(
|
| 89 |
f"cached token file has {len(tokens):,} tokens, but {need_total:,} "
|
|
@@ -173,7 +177,7 @@ def encode_corpus(
|
|
| 173 |
dtype = np.uint16 if tokenizer.vocab_size <= np.iinfo(np.uint16).max else np.uint32
|
| 174 |
encoded_path = output_dir / f"tokens-v{tokenizer.vocab_size}-{dtype.__name__}.npy"
|
| 175 |
tokenizer_path = output_dir / f"tokenizer-v{tokenizer.vocab_size}.json"
|
| 176 |
-
need_total = max_required_train_tokens
|
| 177 |
encode_needed = force_reencode or not encoded_path.exists()
|
| 178 |
if not encode_needed:
|
| 179 |
cached_tokens = np.load(encoded_path, mmap_mode="r")
|
|
|
|
| 41 |
vocab_size: int
|
| 42 |
|
| 43 |
|
| 44 |
+
def required_token_count(max_required_train_tokens: int, val_tokens: int) -> int:
|
| 45 |
+
return max(max_required_train_tokens + val_tokens, val_tokens * 10)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def resolve_paths(corpus: str | None, corpus_glob: str | None) -> list[Path]:
|
| 49 |
paths: list[Path] = []
|
| 50 |
if corpus:
|
|
|
|
| 87 |
|
| 88 |
tokenizer = CachedTokenizer(vocab_size=vocab_size)
|
| 89 |
tokens = np.load(encoded_path, mmap_mode="r")
|
| 90 |
+
need_total = required_token_count(max_required_train_tokens, val_tokens)
|
| 91 |
if len(tokens) < need_total and not allow_short_corpus:
|
| 92 |
raise ValueError(
|
| 93 |
f"cached token file has {len(tokens):,} tokens, but {need_total:,} "
|
|
|
|
| 177 |
dtype = np.uint16 if tokenizer.vocab_size <= np.iinfo(np.uint16).max else np.uint32
|
| 178 |
encoded_path = output_dir / f"tokens-v{tokenizer.vocab_size}-{dtype.__name__}.npy"
|
| 179 |
tokenizer_path = output_dir / f"tokenizer-v{tokenizer.vocab_size}.json"
|
| 180 |
+
need_total = required_token_count(max_required_train_tokens, val_tokens)
|
| 181 |
encode_needed = force_reencode or not encoded_path.exists()
|
| 182 |
if not encode_needed:
|
| 183 |
cached_tokens = np.load(encoded_path, mmap_mode="r")
|