RON-110M / code /prepare_wikitext.py
endurasolution's picture
Upload Ron-110M: pretrain + summarizer + tokenizer + code
3b97420 verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer, Tokenizer
from tokenizers import decoders as _decoders
from tqdm import tqdm
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
def clean_lines(dataset):
for row in dataset:
text = row["text"].strip()
if text:
yield text
class _TokenizerAdapter:
"""
Small adapter so the rest of the script can call .encode(text).ids and
.get_vocab() / .get_vocab_size() regardless of whether the tokenizer was
freshly trained (ByteLevelBPETokenizer) or reloaded from JSON (Tokenizer).
"""
def __init__(self, tokenizer):
self._t = tokenizer
def encode(self, text: str):
return self._t.encode(text)
def get_vocab(self):
return self._t.get_vocab()
def get_vocab_size(self):
return self._t.get_vocab_size()
def load_or_train_tokenizer(tokenizer_path: Path, train_dataset, vocab_size: int, min_frequency: int):
if tokenizer_path.exists():
print(f"Using existing tokenizer at {tokenizer_path}")
# Reload via the generic Tokenizer class. ByteLevelBPETokenizer does NOT
# accept tokenizer_file= in current tokenizers releases.
t = Tokenizer.from_file(str(tokenizer_path))
# Make sure a ByteLevel decoder is attached so downstream decoding works.
try:
current_decoder = t.decoder
except Exception:
current_decoder = None
if current_decoder is None:
t.decoder = _decoders.ByteLevel()
return _TokenizerAdapter(t)
print("Training byte-level BPE tokenizer...")
t = ByteLevelBPETokenizer()
t.train_from_iterator(
clean_lines(train_dataset),
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=SPECIAL_TOKENS,
)
t.save(str(tokenizer_path))
# Reopen via generic Tokenizer so we attach a decoder consistently.
reopened = Tokenizer.from_file(str(tokenizer_path))
try:
current_decoder = reopened.decoder
except Exception:
current_decoder = None
if current_decoder is None:
reopened.decoder = _decoders.ByteLevel()
return _TokenizerAdapter(reopened)
def write_split(tokenizer, dataset, out_file: Path, dtype, bos_id: int, eos_id: int) -> int:
token_count = 0
with out_file.open("wb") as f:
for text in tqdm(clean_lines(dataset), desc=f"tokenizing {out_file.name}"):
ids = [bos_id] + tokenizer.encode(text).ids + [eos_id]
arr = np.asarray(ids, dtype=dtype)
arr.tofile(f)
token_count += len(ids)
return token_count
def main() -> None:
parser = argparse.ArgumentParser(
description="Download WikiText-103, train a tokenizer, and make binary token files."
)
parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103"))
parser.add_argument("--dataset", type=str, default="Salesforce/wikitext")
parser.add_argument("--config", type=str, default="wikitext-103-raw-v1")
parser.add_argument("--vocab_size", type=int, default=32000)
parser.add_argument("--min_frequency", type=int, default=2)
args = parser.parse_args()
args.data_dir.mkdir(parents=True, exist_ok=True)
tokenizer_path = args.data_dir / "tokenizer.json"
print("Loading WikiText-103...")
train = load_dataset(args.dataset, args.config, split="train")
val = load_dataset(args.dataset, args.config, split="validation")
test = load_dataset(args.dataset, args.config, split="test")
tokenizer = load_or_train_tokenizer(
tokenizer_path=tokenizer_path,
train_dataset=train,
vocab_size=args.vocab_size,
min_frequency=args.min_frequency,
)
vocab = tokenizer.get_vocab()
if "<bos>" not in vocab or "<eos>" not in vocab or "<pad>" not in vocab:
raise RuntimeError(
"Tokenizer is missing required special tokens (<pad>, <bos>, <eos>). "
"Delete data/wikitext103/tokenizer.json and re-run to retrain."
)
bos_id = vocab["<bos>"]
eos_id = vocab["<eos>"]
pad_id = vocab["<pad>"]
vocab_size = tokenizer.get_vocab_size()
dtype = np.uint16 if vocab_size <= np.iinfo(np.uint16).max else np.uint32
train_tokens = write_split(tokenizer, train, args.data_dir / "train.bin", dtype, bos_id, eos_id)
val_tokens = write_split(tokenizer, val, args.data_dir / "val.bin", dtype, bos_id, eos_id)
test_tokens = write_split(tokenizer, test, args.data_dir / "test.bin", dtype, bos_id, eos_id)
meta = {
"dataset": args.dataset,
"config": args.config,
"vocab_size": vocab_size,
"dtype": "uint16" if dtype == np.uint16 else "uint32",
"bos_id": bos_id,
"eos_id": eos_id,
"pad_id": pad_id,
"train_tokens": train_tokens,
"val_tokens": val_tokens,
"test_tokens": test_tokens,
}
(args.data_dir / "meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")
print(f"Done. Wrote tokenizer and token files to {args.data_dir}")
if __name__ == "__main__":
main()