File size: 5,258 Bytes
3b97420 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer, Tokenizer
from tokenizers import decoders as _decoders
from tqdm import tqdm
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
def clean_lines(dataset):
for row in dataset:
text = row["text"].strip()
if text:
yield text
class _TokenizerAdapter:
"""
Small adapter so the rest of the script can call .encode(text).ids and
.get_vocab() / .get_vocab_size() regardless of whether the tokenizer was
freshly trained (ByteLevelBPETokenizer) or reloaded from JSON (Tokenizer).
"""
def __init__(self, tokenizer):
self._t = tokenizer
def encode(self, text: str):
return self._t.encode(text)
def get_vocab(self):
return self._t.get_vocab()
def get_vocab_size(self):
return self._t.get_vocab_size()
def load_or_train_tokenizer(tokenizer_path: Path, train_dataset, vocab_size: int, min_frequency: int):
if tokenizer_path.exists():
print(f"Using existing tokenizer at {tokenizer_path}")
# Reload via the generic Tokenizer class. ByteLevelBPETokenizer does NOT
# accept tokenizer_file= in current tokenizers releases.
t = Tokenizer.from_file(str(tokenizer_path))
# Make sure a ByteLevel decoder is attached so downstream decoding works.
try:
current_decoder = t.decoder
except Exception:
current_decoder = None
if current_decoder is None:
t.decoder = _decoders.ByteLevel()
return _TokenizerAdapter(t)
print("Training byte-level BPE tokenizer...")
t = ByteLevelBPETokenizer()
t.train_from_iterator(
clean_lines(train_dataset),
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=SPECIAL_TOKENS,
)
t.save(str(tokenizer_path))
# Reopen via generic Tokenizer so we attach a decoder consistently.
reopened = Tokenizer.from_file(str(tokenizer_path))
try:
current_decoder = reopened.decoder
except Exception:
current_decoder = None
if current_decoder is None:
reopened.decoder = _decoders.ByteLevel()
return _TokenizerAdapter(reopened)
def write_split(tokenizer, dataset, out_file: Path, dtype, bos_id: int, eos_id: int) -> int:
token_count = 0
with out_file.open("wb") as f:
for text in tqdm(clean_lines(dataset), desc=f"tokenizing {out_file.name}"):
ids = [bos_id] + tokenizer.encode(text).ids + [eos_id]
arr = np.asarray(ids, dtype=dtype)
arr.tofile(f)
token_count += len(ids)
return token_count
def main() -> None:
parser = argparse.ArgumentParser(
description="Download WikiText-103, train a tokenizer, and make binary token files."
)
parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103"))
parser.add_argument("--dataset", type=str, default="Salesforce/wikitext")
parser.add_argument("--config", type=str, default="wikitext-103-raw-v1")
parser.add_argument("--vocab_size", type=int, default=32000)
parser.add_argument("--min_frequency", type=int, default=2)
args = parser.parse_args()
args.data_dir.mkdir(parents=True, exist_ok=True)
tokenizer_path = args.data_dir / "tokenizer.json"
print("Loading WikiText-103...")
train = load_dataset(args.dataset, args.config, split="train")
val = load_dataset(args.dataset, args.config, split="validation")
test = load_dataset(args.dataset, args.config, split="test")
tokenizer = load_or_train_tokenizer(
tokenizer_path=tokenizer_path,
train_dataset=train,
vocab_size=args.vocab_size,
min_frequency=args.min_frequency,
)
vocab = tokenizer.get_vocab()
if "<bos>" not in vocab or "<eos>" not in vocab or "<pad>" not in vocab:
raise RuntimeError(
"Tokenizer is missing required special tokens (<pad>, <bos>, <eos>). "
"Delete data/wikitext103/tokenizer.json and re-run to retrain."
)
bos_id = vocab["<bos>"]
eos_id = vocab["<eos>"]
pad_id = vocab["<pad>"]
vocab_size = tokenizer.get_vocab_size()
dtype = np.uint16 if vocab_size <= np.iinfo(np.uint16).max else np.uint32
train_tokens = write_split(tokenizer, train, args.data_dir / "train.bin", dtype, bos_id, eos_id)
val_tokens = write_split(tokenizer, val, args.data_dir / "val.bin", dtype, bos_id, eos_id)
test_tokens = write_split(tokenizer, test, args.data_dir / "test.bin", dtype, bos_id, eos_id)
meta = {
"dataset": args.dataset,
"config": args.config,
"vocab_size": vocab_size,
"dtype": "uint16" if dtype == np.uint16 else "uint32",
"bos_id": bos_id,
"eos_id": eos_id,
"pad_id": pad_id,
"train_tokens": train_tokens,
"val_tokens": val_tokens,
"test_tokens": test_tokens,
}
(args.data_dir / "meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")
print(f"Done. Wrote tokenizer and token files to {args.data_dir}")
if __name__ == "__main__":
main() |