File size: 1,945 Bytes
3b97420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import annotations

from pathlib import Path

from tokenizers import Tokenizer
from tokenizers import decoders as _decoders


class TextTokenizer:
    """
    Wrapper around tokenizers.Tokenizer that guarantees a ByteLevel decoder
    is attached. ByteLevelBPETokenizer saves a JSON without a decoder block,
    so reloading via Tokenizer.from_file() yields a tokenizer whose .decode()
    returns raw byte-level tokens (Ġ, ä) and replacement chars (�, �)
    instead of proper UTF-8 text. We attach the decoder here so decode is
    always correct.
    """

    def __init__(self, path: str | Path):
        self.path = Path(path)
        self.tokenizer = Tokenizer.from_file(str(self.path))

        # Force a ByteLevel decoder if one is not attached.
        try:
            current_decoder = self.tokenizer.decoder
        except Exception:
            current_decoder = None
        if current_decoder is None:
            self.tokenizer.decoder = _decoders.ByteLevel()

        vocab = self.tokenizer.get_vocab()
        self.pad_id = vocab.get("<pad>", 0)
        self.bos_id = vocab.get("<bos>", 1)
        self.eos_id = vocab.get("<eos>", 2)
        self.unk_id = vocab.get("<unk>", 3)
        self.vocab_size = self.tokenizer.get_vocab_size()

    def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> list[int]:
        ids = self.tokenizer.encode(text).ids
        if add_bos:
            ids = [self.bos_id] + ids
        if add_eos:
            ids = ids + [self.eos_id]
        return ids

    def decode(self, ids: list[int], skip_special_tokens: bool = True) -> str:
        if skip_special_tokens:
            specials = {self.pad_id, self.bos_id, self.eos_id, self.unk_id}
            ids = [int(i) for i in ids if int(i) not in specials]
        else:
            ids = [int(i) for i in ids]
        return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)