| | """ |
| | Train a Byte-Level BPE tokenizer on raw text files. |
| | |
| | The tokenizer is saved in two formats: |
| | 1. Native HuggingFace ``tokenizers`` format (vocab.json + merges.txt) inside |
| | the output directory — for fast loading with ByteLevelBPETokenizer. |
| | 2. A ``tokenizer.json`` file (PreTrainedTokenizerFast) in the output directory |
| | — for easy loading with transformers.AutoTokenizer. |
| | |
| | Usage: |
| | python tokenizer/train_tokenizer.py \ |
| | --input "data/raw/*.txt" \ |
| | --output tokenizer/ \ |
| | --vocab_size 32000 \ |
| | --min_frequency 2 |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import glob |
| | import os |
| | import sys |
| | from pathlib import Path |
| |
|
| | from tokenizers import AddedToken |
| | from tokenizers.implementations import ByteLevelBPETokenizer |
| | from transformers import PreTrainedTokenizerFast |
| |
|
| |
|
| | |
| | |
| | |
| | SPECIAL_TOKENS: list[str] = ["<pad>", "<s>", "</s>", "<unk>"] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def find_input_files(pattern: str) -> list[str]: |
| | """Resolve a glob pattern or a plain file path to a sorted list of paths.""" |
| | if any(c in pattern for c in ("*", "?", "[")): |
| | files = sorted(glob.glob(pattern, recursive=True)) |
| | else: |
| | files = [pattern] if Path(pattern).exists() else [] |
| | if not files: |
| | raise FileNotFoundError(f"No files matched pattern: {pattern!r}") |
| | return files |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description="Train a Byte-Level BPE tokenizer and save to disk." |
| | ) |
| | parser.add_argument( |
| | "--input", |
| | required=True, |
| | help='Glob pattern for training text files, e.g. "data/raw/*.txt"', |
| | ) |
| | parser.add_argument( |
| | "--output", |
| | default="tokenizer/", |
| | help="Output directory for the trained tokenizer (default: tokenizer/)", |
| | ) |
| | parser.add_argument( |
| | "--vocab_size", |
| | type=int, |
| | default=32000, |
| | help="Target vocabulary size (default: 32000)", |
| | ) |
| | parser.add_argument( |
| | "--min_frequency", |
| | type=int, |
| | default=2, |
| | help="Minimum frequency for a pair to be merged (default: 2)", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| |
|
| | |
| | input_files = find_input_files(args.input) |
| | print(f"Found {len(input_files)} training file(s).") |
| |
|
| | |
| | output_dir = Path(args.output) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | tokenizer = ByteLevelBPETokenizer() |
| |
|
| | |
| | print( |
| | f"\nTraining BPE tokenizer | vocab_size={args.vocab_size} " |
| | f"| min_frequency={args.min_frequency} ..." |
| | ) |
| | tokenizer.train( |
| | files=input_files, |
| | vocab_size=args.vocab_size, |
| | min_frequency=args.min_frequency, |
| | special_tokens=SPECIAL_TOKENS, |
| | show_progress=True, |
| | ) |
| |
|
| | |
| | tokenizer.add_special_tokens(SPECIAL_TOKENS) |
| |
|
| | |
| | tokenizer.save_model(str(output_dir)) |
| | print(f"\nSaved vocab.json + merges.txt to: {output_dir}") |
| |
|
| | |
| | fast_tokenizer = PreTrainedTokenizerFast( |
| | tokenizer_object=tokenizer._tokenizer, |
| | bos_token="<s>", |
| | eos_token="</s>", |
| | unk_token="<unk>", |
| | pad_token="<pad>", |
| | ) |
| | tokenizer_json_path = output_dir / "tokenizer.json" |
| | fast_tokenizer.save_pretrained(str(output_dir)) |
| | print(f"Saved PreTrainedTokenizerFast to: {output_dir}") |
| | print(f" -> tokenizer.json: {tokenizer_json_path}") |
| |
|
| | |
| | actual_vocab_size = tokenizer.get_vocab_size() |
| | print("\n" + "=" * 50) |
| | print("Tokenizer training statistics") |
| | print("=" * 50) |
| | print(f" Training files : {len(input_files):>10,}") |
| | print(f" Target vocab : {args.vocab_size:>10,}") |
| | print(f" Actual vocab : {actual_vocab_size:>10,}") |
| | print(f" Min frequency : {args.min_frequency:>10,}") |
| | print(f" Special tokens : {SPECIAL_TOKENS}") |
| | print(f" Output dir : {output_dir.resolve()}") |
| | print("=" * 50) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|