File size: 1,469 Bytes
b127f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Minimal byte tokenizer: token id == UTF-8 byte value, everything in [0, 256).

Mirrors the UTF8Tokenizer design principle (no out-of-range ids; special roles
ride on C0 control bytes) without an external dependency. Pad = NUL (byte 0x00).
"""
from transformers import PreTrainedTokenizer


class ByteTokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(self, pad_token="\x00", **kwargs):
        # Map pad to an existing byte id (0) BEFORE super().__init__, so it is
        # NOT allocated a fresh id at 256. id == byte stays true for everything.
        from transformers import AddedToken
        self._added_tokens_decoder = {
            0: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
        }
        super().__init__(pad_token=pad_token, **kwargs)

    @property
    def vocab_size(self):
        return 256

    def get_vocab(self):
        vocab = {chr(i): i for i in range(256)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _tokenize(self, text):
        return [chr(b) for b in text.encode("utf-8")]

    def _convert_token_to_id(self, token):
        return ord(token) if len(token) == 1 and ord(token) < 256 else self.unk_token_id

    def _convert_id_to_token(self, index):
        return chr(index)

    def convert_tokens_to_string(self, tokens):
        return bytes(ord(t) for t in tokens).decode("utf-8", errors="replace")