| """ |
| Derived from Andrej Karpathy's nanochat project. |
| |
| MIT License |
| |
| Copyright (c) 2025 Andrej Karpathy |
| |
| Permission is hereby granted, free of charge, to any person obtaining a copy |
| of this software and associated documentation files (the "Software"), to deal |
| in the Software without restriction, including without limitation the rights |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| copies of the Software, and to permit persons to whom the Software is |
| furnished to do so, subject to the following conditions: |
| |
| The above copyright notice and this permission notice shall be included in all |
| copies or substantial portions of the Software. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Iterable |
|
|
| from tokenizers import Regex |
| from tokenizers import Tokenizer as HFTokenizer |
| from tokenizers import decoders, pre_tokenizers |
| from tokenizers.models import BPE |
| from tokenizers.trainers import BpeTrainer |
|
|
| SPECIAL_TOKENS = [ |
| "<|bos|>", |
| "<|user_start|>", |
| "<|user_end|>", |
| "<|assistant_start|>", |
| "<|assistant_end|>", |
| "<|python_start|>", |
| "<|python_end|>", |
| "<|output_start|>", |
| "<|output_end|>", |
| ] |
|
|
| SPLIT_PATTERN = ( |
| r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}|""" |
| r""" ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" |
| ) |
|
|
|
|
| class BpeTokenizer: |
| """Minimal HuggingFace BPE wrapper following nanochat's GPT-4-style splitter.""" |
|
|
| def __init__(self, tokenizer: HFTokenizer): |
| self.tokenizer = tokenizer |
|
|
| @classmethod |
| def train_from_iterator( |
| cls, text_iterator: Iterable[str], vocab_size: int |
| ) -> "BpeTokenizer": |
| tokenizer = HFTokenizer(BPE(byte_fallback=True, unk_token=None, fuse_unk=False)) |
| tokenizer.normalizer = None |
| tokenizer.pre_tokenizer = pre_tokenizers.Sequence( |
| [ |
| pre_tokenizers.Split( |
| pattern=Regex(SPLIT_PATTERN), |
| behavior="isolated", |
| invert=False, |
| ), |
| pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), |
| ] |
| ) |
| tokenizer.decoder = decoders.ByteLevel() |
| trainer = BpeTrainer( |
| vocab_size=vocab_size, |
| show_progress=True, |
| min_frequency=0, |
| initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), |
| special_tokens=SPECIAL_TOKENS, |
| ) |
| tokenizer.train_from_iterator(text_iterator, trainer) |
| return cls(tokenizer) |
|
|
| @classmethod |
| def from_file(cls, path: str | Path) -> "BpeTokenizer": |
| return cls(HFTokenizer.from_file(str(path))) |
|
|
| def save(self, path: str | Path) -> None: |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| self.tokenizer.save(str(path)) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self.tokenizer.get_vocab_size() |
|
|
| @property |
| def bos_id(self) -> int: |
| bos = self.tokenizer.token_to_id("<|bos|>") |
| if bos is None: |
| raise ValueError("tokenizer is missing <|bos|>") |
| return bos |
|
|
| def encode(self, text: str, prepend_bos: bool = False) -> list[int]: |
| ids = self.tokenizer.encode(text, add_special_tokens=False).ids |
| if prepend_bos: |
| ids.insert(0, self.bos_id) |
| return ids |
|
|