from __future__ import annotations import argparse import json from pathlib import Path import numpy as np from datasets import load_dataset from tokenizers import ByteLevelBPETokenizer, Tokenizer from tokenizers import decoders as _decoders from tqdm import tqdm SPECIAL_TOKENS = ["", "", "", ""] def clean_lines(dataset): for row in dataset: text = row["text"].strip() if text: yield text class _TokenizerAdapter: """ Small adapter so the rest of the script can call .encode(text).ids and .get_vocab() / .get_vocab_size() regardless of whether the tokenizer was freshly trained (ByteLevelBPETokenizer) or reloaded from JSON (Tokenizer). """ def __init__(self, tokenizer): self._t = tokenizer def encode(self, text: str): return self._t.encode(text) def get_vocab(self): return self._t.get_vocab() def get_vocab_size(self): return self._t.get_vocab_size() def load_or_train_tokenizer(tokenizer_path: Path, train_dataset, vocab_size: int, min_frequency: int): if tokenizer_path.exists(): print(f"Using existing tokenizer at {tokenizer_path}") # Reload via the generic Tokenizer class. ByteLevelBPETokenizer does NOT # accept tokenizer_file= in current tokenizers releases. t = Tokenizer.from_file(str(tokenizer_path)) # Make sure a ByteLevel decoder is attached so downstream decoding works. try: current_decoder = t.decoder except Exception: current_decoder = None if current_decoder is None: t.decoder = _decoders.ByteLevel() return _TokenizerAdapter(t) print("Training byte-level BPE tokenizer...") t = ByteLevelBPETokenizer() t.train_from_iterator( clean_lines(train_dataset), vocab_size=vocab_size, min_frequency=min_frequency, special_tokens=SPECIAL_TOKENS, ) t.save(str(tokenizer_path)) # Reopen via generic Tokenizer so we attach a decoder consistently. reopened = Tokenizer.from_file(str(tokenizer_path)) try: current_decoder = reopened.decoder except Exception: current_decoder = None if current_decoder is None: reopened.decoder = _decoders.ByteLevel() return _TokenizerAdapter(reopened) def write_split(tokenizer, dataset, out_file: Path, dtype, bos_id: int, eos_id: int) -> int: token_count = 0 with out_file.open("wb") as f: for text in tqdm(clean_lines(dataset), desc=f"tokenizing {out_file.name}"): ids = [bos_id] + tokenizer.encode(text).ids + [eos_id] arr = np.asarray(ids, dtype=dtype) arr.tofile(f) token_count += len(ids) return token_count def main() -> None: parser = argparse.ArgumentParser( description="Download WikiText-103, train a tokenizer, and make binary token files." ) parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103")) parser.add_argument("--dataset", type=str, default="Salesforce/wikitext") parser.add_argument("--config", type=str, default="wikitext-103-raw-v1") parser.add_argument("--vocab_size", type=int, default=32000) parser.add_argument("--min_frequency", type=int, default=2) args = parser.parse_args() args.data_dir.mkdir(parents=True, exist_ok=True) tokenizer_path = args.data_dir / "tokenizer.json" print("Loading WikiText-103...") train = load_dataset(args.dataset, args.config, split="train") val = load_dataset(args.dataset, args.config, split="validation") test = load_dataset(args.dataset, args.config, split="test") tokenizer = load_or_train_tokenizer( tokenizer_path=tokenizer_path, train_dataset=train, vocab_size=args.vocab_size, min_frequency=args.min_frequency, ) vocab = tokenizer.get_vocab() if "" not in vocab or "" not in vocab or "" not in vocab: raise RuntimeError( "Tokenizer is missing required special tokens (, , ). " "Delete data/wikitext103/tokenizer.json and re-run to retrain." ) bos_id = vocab[""] eos_id = vocab[""] pad_id = vocab[""] vocab_size = tokenizer.get_vocab_size() dtype = np.uint16 if vocab_size <= np.iinfo(np.uint16).max else np.uint32 train_tokens = write_split(tokenizer, train, args.data_dir / "train.bin", dtype, bos_id, eos_id) val_tokens = write_split(tokenizer, val, args.data_dir / "val.bin", dtype, bos_id, eos_id) test_tokens = write_split(tokenizer, test, args.data_dir / "test.bin", dtype, bos_id, eos_id) meta = { "dataset": args.dataset, "config": args.config, "vocab_size": vocab_size, "dtype": "uint16" if dtype == np.uint16 else "uint32", "bos_id": bos_id, "eos_id": eos_id, "pad_id": pad_id, "train_tokens": train_tokens, "val_tokens": val_tokens, "test_tokens": test_tokens, } (args.data_dir / "meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8") print(f"Done. Wrote tokenizer and token files to {args.data_dir}") if __name__ == "__main__": main()