Mandeep Sidhu commited on
Commit
e39c73c
·
1 Parent(s): 555cf14

Add reproducible WikiText corpus prep

Browse files
.gitignore CHANGED
@@ -2,6 +2,7 @@
2
  __pycache__/
3
  *.py[cod]
4
  .cache/
 
5
  *.npy
6
  *.pdf
7
  .venv/
 
2
  __pycache__/
3
  *.py[cod]
4
  .cache/
5
+ data/
6
  *.npy
7
  *.pdf
8
  .venv/
scripts/prepare_wikitext103.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Download the public WikiText-103 raw parquet used for corpus holdouts."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import hashlib
8
+ from pathlib import Path
9
+ from urllib.request import urlretrieve
10
+
11
+
12
+ WIKITEXT103_RAW_TRAIN_URL = (
13
+ "https://huggingface.co/datasets/Salesforce/wikitext/resolve/"
14
+ "6231e49f19a707241d6b84d9cff60a3a86b85a85/"
15
+ "wikitext-103-raw-v1/train-00001-of-00002.parquet?download=true"
16
+ )
17
+ EXPECTED_BYTES = 156_700_942
18
+ EXPECTED_SHA256 = "75aa65dee9de2a7c10ba1808efd2408c3f4eb008104c3ccac47f8ed19300ebdd"
19
+
20
+
21
+ def sha256(path: Path) -> str:
22
+ digest = hashlib.sha256()
23
+ with path.open("rb") as handle:
24
+ for chunk in iter(lambda: handle.read(1024 * 1024), b""):
25
+ digest.update(chunk)
26
+ return digest.hexdigest()
27
+
28
+
29
+ def verify_file(path: Path) -> None:
30
+ size = path.stat().st_size
31
+ if size != EXPECTED_BYTES:
32
+ raise SystemExit(
33
+ f"{path} has {size:,} bytes; expected {EXPECTED_BYTES:,}."
34
+ )
35
+ actual = sha256(path)
36
+ if actual != EXPECTED_SHA256:
37
+ raise SystemExit(
38
+ f"{path} has sha256 {actual}; expected {EXPECTED_SHA256}."
39
+ )
40
+
41
+
42
+ def parse_args() -> argparse.Namespace:
43
+ parser = argparse.ArgumentParser(
44
+ description="Prepare the WikiText-103 raw parquet corpus holdout."
45
+ )
46
+ parser.add_argument(
47
+ "--output-dir",
48
+ type=Path,
49
+ default=Path("data/wikitext103_raw"),
50
+ help="Directory where the parquet file should be stored.",
51
+ )
52
+ parser.add_argument(
53
+ "--force",
54
+ action="store_true",
55
+ help="Download again even if the target file already exists.",
56
+ )
57
+ return parser.parse_args()
58
+
59
+
60
+ def main() -> None:
61
+ args = parse_args()
62
+ args.output_dir.mkdir(parents=True, exist_ok=True)
63
+ target = args.output_dir / "train-00001-of-00002.parquet"
64
+ if target.exists() and not args.force:
65
+ verify_file(target)
66
+ print(target)
67
+ return
68
+
69
+ print(f"Downloading WikiText-103 raw train parquet to {target}")
70
+ urlretrieve(WIKITEXT103_RAW_TRAIN_URL, target)
71
+ verify_file(target)
72
+ print(target)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
src/dropout_decay/data.py CHANGED
@@ -41,6 +41,10 @@ class CachedTokenizer:
41
  vocab_size: int
42
 
43
 
 
 
 
 
44
  def resolve_paths(corpus: str | None, corpus_glob: str | None) -> list[Path]:
45
  paths: list[Path] = []
46
  if corpus:
@@ -83,7 +87,7 @@ def load_cached_splits(
83
 
84
  tokenizer = CachedTokenizer(vocab_size=vocab_size)
85
  tokens = np.load(encoded_path, mmap_mode="r")
86
- need_total = max_required_train_tokens + val_tokens
87
  if len(tokens) < need_total and not allow_short_corpus:
88
  raise ValueError(
89
  f"cached token file has {len(tokens):,} tokens, but {need_total:,} "
@@ -173,7 +177,7 @@ def encode_corpus(
173
  dtype = np.uint16 if tokenizer.vocab_size <= np.iinfo(np.uint16).max else np.uint32
174
  encoded_path = output_dir / f"tokens-v{tokenizer.vocab_size}-{dtype.__name__}.npy"
175
  tokenizer_path = output_dir / f"tokenizer-v{tokenizer.vocab_size}.json"
176
- need_total = max_required_train_tokens + val_tokens
177
  encode_needed = force_reencode or not encoded_path.exists()
178
  if not encode_needed:
179
  cached_tokens = np.load(encoded_path, mmap_mode="r")
 
41
  vocab_size: int
42
 
43
 
44
+ def required_token_count(max_required_train_tokens: int, val_tokens: int) -> int:
45
+ return max(max_required_train_tokens + val_tokens, val_tokens * 10)
46
+
47
+
48
  def resolve_paths(corpus: str | None, corpus_glob: str | None) -> list[Path]:
49
  paths: list[Path] = []
50
  if corpus:
 
87
 
88
  tokenizer = CachedTokenizer(vocab_size=vocab_size)
89
  tokens = np.load(encoded_path, mmap_mode="r")
90
+ need_total = required_token_count(max_required_train_tokens, val_tokens)
91
  if len(tokens) < need_total and not allow_short_corpus:
92
  raise ValueError(
93
  f"cached token file has {len(tokens):,} tokens, but {need_total:,} "
 
177
  dtype = np.uint16 if tokenizer.vocab_size <= np.iinfo(np.uint16).max else np.uint32
178
  encoded_path = output_dir / f"tokens-v{tokenizer.vocab_size}-{dtype.__name__}.npy"
179
  tokenizer_path = output_dir / f"tokenizer-v{tokenizer.vocab_size}.json"
180
+ need_total = required_token_count(max_required_train_tokens, val_tokens)
181
  encode_needed = force_reencode or not encoded_path.exists()
182
  if not encode_needed:
183
  cached_tokens = np.load(encoded_path, mmap_mode="r")