| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| from tokenizers import Tokenizer |
| from tokenizers import decoders as _decoders |
|
|
|
|
| class TextTokenizer: |
| """ |
| Wrapper around tokenizers.Tokenizer that guarantees a ByteLevel decoder |
| is attached. ByteLevelBPETokenizer saves a JSON without a decoder block, |
| so reloading via Tokenizer.from_file() yields a tokenizer whose .decode() |
| returns raw byte-level tokens (Ġ, ä) and replacement chars (�, �) |
| instead of proper UTF-8 text. We attach the decoder here so decode is |
| always correct. |
| """ |
|
|
| def __init__(self, path: str | Path): |
| self.path = Path(path) |
| self.tokenizer = Tokenizer.from_file(str(self.path)) |
|
|
| |
| try: |
| current_decoder = self.tokenizer.decoder |
| except Exception: |
| current_decoder = None |
| if current_decoder is None: |
| self.tokenizer.decoder = _decoders.ByteLevel() |
|
|
| vocab = self.tokenizer.get_vocab() |
| self.pad_id = vocab.get("<pad>", 0) |
| self.bos_id = vocab.get("<bos>", 1) |
| self.eos_id = vocab.get("<eos>", 2) |
| self.unk_id = vocab.get("<unk>", 3) |
| self.vocab_size = self.tokenizer.get_vocab_size() |
|
|
| def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> list[int]: |
| ids = self.tokenizer.encode(text).ids |
| if add_bos: |
| ids = [self.bos_id] + ids |
| if add_eos: |
| ids = ids + [self.eos_id] |
| return ids |
|
|
| def decode(self, ids: list[int], skip_special_tokens: bool = True) -> str: |
| if skip_special_tokens: |
| specials = {self.pad_id, self.bos_id, self.eos_id, self.unk_id} |
| ids = [int(i) for i in ids if int(i) not in specials] |
| else: |
| ids = [int(i) for i in ids] |
| return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens) |