from __future__ import annotations import json import os import re from typing import Dict, List, Optional, Tuple from transformers import PreTrainedTokenizer class BinaryLLMTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] TOKEN_RE = re.compile(r"^$") def __init__( self, bos_token: str = "", eos_token: str = "", unk_token: str = "", pad_token: Optional[str] = None, **kwargs, ): # base ids 0..65535 reserved for radix tokens (strict) self._base_vocab_size = 65536 # reserve ids self._bos_id = 65536 self._eos_id = 65537 # UNK is an alias to EOS to preserve radix purity (no new base id) self._unk_id = self._eos_id # special token strings self._bos_str = bos_token self._eos_str = eos_token self._unk_str = unk_token self._pad_str = pad_token super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, **kwargs, ) @property def vocab_size(self) -> int: return 65538 def get_vocab(self) -> Dict[str, int]: # IMPORTANT: never call self.unk_token_id here (it triggers recursion) v = { self._bos_str: self._bos_id, self._eos_str: self._eos_id, self._unk_str: self._unk_id, } if self.pad_token is not None: # if pad_token == "", it will map to eos id via _convert_token_to_id() v[self.pad_token] = self._convert_token_to_id(self.pad_token) return v def _tokenize(self, text: str) -> List[str]: ids = self._encode_to_uint16(text) return [self._id_to_token_base(i) for i in ids] def _convert_token_to_id(self, token: str) -> int: if token == self._bos_str: return self._bos_id if token == self._eos_str: return self._eos_id if token == self._unk_str: return self._unk_id if self.pad_token is not None and token == self.pad_token: # common case: pad_token is "" if self.pad_token == self._eos_str: return self._eos_id # otherwise: no dedicated PAD id in this vocab, alias to EOS return self._eos_id m = self.TOKEN_RE.match(token) if m: return int(m.group(1), 16) return self._unk_id def _convert_id_to_token(self, index: int) -> str: if index == self._bos_id: return self._bos_str if index == self._eos_id: return self._eos_str if index == self._unk_id: return self._unk_str if self.pad_token is not None and index == self.pad_token_id: return self.pad_token if 0 <= index < self._base_vocab_size: return self._id_to_token_base(index) return self._unk_str def convert_tokens_to_string(self, tokens: List[str]) -> str: ids: List[int] = [] for t in tokens: if t in (self._bos_str, self._eos_str, self._unk_str): continue if self.pad_token is not None and t == self.pad_token: continue m = self.TOKEN_RE.match(t) if m: ids.append(int(m.group(1), 16)) return self._decode_from_uint16(ids) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: if token_ids_1 is None: return [self._bos_id] + token_ids_0 + [self._eos_id] return [self._bos_id] + token_ids_0 + [self._eos_id] + token_ids_1 + [self._eos_id] def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: pad_id = self.pad_token_id if self.pad_token is not None else -1 if already_has_special_tokens: return [ 1 if t in (self._bos_id, self._eos_id, self._unk_id, pad_id) else 0 for t in token_ids_0 ] if token_ids_1 is None: return [1] + [0] * len(token_ids_0) + [1] return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: if token_ids_1 is None: return [0] * (len(token_ids_0) + 2) return [0] * (len(token_ids_0) + len(token_ids_1) + 3) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) name = (filename_prefix + "-" if filename_prefix else "") + "binaryllm_vocab.json" path = os.path.join(save_directory, name) data = { "base_vocab_size": 65536, "vocab_size": 65538, "bos_token": self._bos_str, "bos_token_id": self._bos_id, "eos_token": self._eos_str, "eos_token_id": self._eos_id, "unk_token": self._unk_str, "unk_token_id": self._unk_id, "pad_token": self.pad_token, "pad_token_id": self.pad_token_id, } with open(path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) return (path,) def _id_to_token_base(self, i: int) -> str: return f"" def _encode_to_uint16(self, text: str) -> List[int]: b = text.encode("utf-8", errors="strict") if len(b) % 2 == 1: b += b"\x00" out: List[int] = [] for k in range(0, len(b), 2): out.append(b[k] | (b[k + 1] << 8)) return out def _decode_from_uint16(self, ids: List[int]) -> str: bb = bytearray() for x in ids: x &= 0xFFFF bb.append(x & 0xFF) bb.append((x >> 8) & 0xFF) if len(bb) and bb[-1] == 0: bb = bb[:-1] return bb.decode("utf-8", errors="replace")