byte-gpt2-2layer / byte_tokenizer.py
gvadhul's picture
Upload tokenizer
b127f2a verified
Raw
History Blame Contribute Delete
1.47 kB
"""Minimal byte tokenizer: token id == UTF-8 byte value, everything in [0, 256).
Mirrors the UTF8Tokenizer design principle (no out-of-range ids; special roles
ride on C0 control bytes) without an external dependency. Pad = NUL (byte 0x00).
"""
from transformers import PreTrainedTokenizer
class ByteTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, pad_token="\x00", **kwargs):
# Map pad to an existing byte id (0) BEFORE super().__init__, so it is
# NOT allocated a fresh id at 256. id == byte stays true for everything.
from transformers import AddedToken
self._added_tokens_decoder = {
0: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
}
super().__init__(pad_token=pad_token, **kwargs)
@property
def vocab_size(self):
return 256
def get_vocab(self):
vocab = {chr(i): i for i in range(256)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
return [chr(b) for b in text.encode("utf-8")]
def _convert_token_to_id(self, token):
return ord(token) if len(token) == 1 and ord(token) < 256 else self.unk_token_id
def _convert_id_to_token(self, index):
return chr(index)
def convert_tokens_to_string(self, tokens):
return bytes(ord(t) for t in tokens).decode("utf-8", errors="replace")