Patenty-0.1 / tokenization_binaryllm.py
PhysiQuanty's picture
Update tokenization_binaryllm.py
7de52e4 verified
#!/usr/bin/env python3
# tokenization_binaryllm.py
# ============================================================
# BinaryLLMTokenizer (AutoTokenizer compatible) EXACTEMENT comme
# llmTalk (mode base=65536) + infer_tagged12.py:
#
# - Base vocab: 0..65535 (radix)
# - BOS id = 65536
# - EOS id = 65537
# - UNK alias = EOS (pas de nouvel id)
# - Encodage MANUEL: UTF-8 bytes -> digits base65536 BIG-ENDIAN (chunks 2 bytes)
# (si byte impair: dernier chunk = 1 byte => id 0..255)
# - Décodage: digits -> bytes BIG-ENDIAN -> UTF-8 (errors=replace)
# - build_inputs_with_special_tokens:
# single: [BOS] + ids + [EOS]
# pair : [BOS] + ids0 + [EOS] + ids1 + [EOS]
#
# IMPORTANT:
# - Ce tokenizer NE génère PAS ton pattern "...[EOS][BOS]" tout seul,
# parce que HuggingFace standard = BOS ... EOS.
# Pour llmTalk / infer_tagged*, c’est ton script qui ajoute le BOS final.
#
# Usage:
# - Mets ce fichier dans le repo HF (avec __init__.py si tu veux) et
# ajoute dans tokenizer_config.json:
# {"tokenizer_class": "BinaryLLMTokenizer", "auto_map": {...}} si besoin.
# - Puis:
# tok = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
# ============================================================
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional, Tuple, Any
from transformers import PreTrainedTokenizer
class BinaryLLMTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
TOKEN_RE = re.compile(r"^<U([0-9A-Fa-f]{4})>$")
def __init__(
self,
bos_token: str = "<BOS>",
eos_token: str = "<EOS>",
unk_token: str = "<UNK>",
pad_token: Optional[str] = None,
**kwargs: Any,
):
# base ids 0..65535 reserved for radix tokens (strict)
self._base_vocab_size = 65536
# reserve ids
self._bos_id = 65536
self._eos_id = 65537
# UNK is an alias to EOS to preserve radix purity (no new base id)
self._unk_id = self._eos_id
# special token strings
self._bos_str = bos_token
self._eos_str = eos_token
self._unk_str = unk_token
self._pad_str = pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
**kwargs,
)
@property
def vocab_size(self) -> int:
return 65538
def get_vocab(self) -> Dict[str, int]:
# IMPORTANT: never call self.unk_token_id here (it triggers recursion)
v = {
self._bos_str: self._bos_id,
self._eos_str: self._eos_id,
self._unk_str: self._unk_id,
}
if self.pad_token is not None:
# if pad_token == "<EOS>", it will map to eos id via _convert_token_to_id()
v[self.pad_token] = self._convert_token_to_id(self.pad_token)
return v
# -----------------------------
# Core manual base65536 codec
# -----------------------------
def _encode_base65536_be(self, text: str) -> List[int]:
b = bytearray(text.encode("utf-8", errors="strict"))
if len(b) == 0:
return [0]
out: List[int] = []
i = 0
n = len(b)
# chunks of 2 bytes, BIG-ENDIAN
while i + 1 < n:
out.append((b[i] << 8) | b[i + 1])
i += 2
# last odd byte => 0..255
if i < n:
out.append(int(b[i]))
return out
def _decode_base65536_be(self, ids: List[int]) -> str:
bb = bytearray()
for x in ids:
xi = int(x) & 0xFFFFFFFF
if 0 <= xi <= 255:
bb.append(xi)
else:
bb.append((xi >> 8) & 0xFF)
bb.append(xi & 0xFF)
return bytes(bb).decode("utf-8", errors="replace")
# -----------------------------
# HF required overrides
# -----------------------------
def _tokenize(self, text: str) -> List[str]:
ids = self._encode_base65536_be(text)
return [self._id_to_token_base(i) for i in ids]
def _convert_token_to_id(self, token: str) -> int:
if token == self._bos_str:
return self._bos_id
if token == self._eos_str:
return self._eos_id
if token == self._unk_str:
return self._unk_id
if self.pad_token is not None and token == self.pad_token:
# common case: pad_token is "<EOS>"
if self.pad_token == self._eos_str:
return self._eos_id
# otherwise: no dedicated PAD id in this vocab, alias to EOS
return self._eos_id
m = self.TOKEN_RE.match(token)
if m:
return int(m.group(1), 16)
return self._unk_id
def _convert_id_to_token(self, index: int) -> str:
if index == self._bos_id:
return self._bos_str
if index == self._eos_id:
return self._eos_str
if index == self._unk_id:
return self._unk_str
if self.pad_token is not None and index == self.pad_token_id:
return self.pad_token
if 0 <= index < self._base_vocab_size:
return self._id_to_token_base(index)
return self._unk_str
def convert_tokens_to_string(self, tokens: List[str]) -> str:
ids: List[int] = []
for t in tokens:
if t in (self._bos_str, self._eos_str, self._unk_str):
continue
if self.pad_token is not None and t == self.pad_token:
continue
m = self.TOKEN_RE.match(t)
if m:
ids.append(int(m.group(1), 16))
return self._decode_base65536_be(ids)
def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
if token_ids_1 is None:
return [self._bos_id] + token_ids_0 + [self._eos_id]
return [self._bos_id] + token_ids_0 + [self._eos_id] + token_ids_1 + [self._eos_id]
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
pad_id = self.pad_token_id if self.pad_token is not None else -1
if already_has_special_tokens:
return [
1 if t in (self._bos_id, self._eos_id, self._unk_id, pad_id) else 0
for t in token_ids_0
]
if token_ids_1 is None:
return [1] + [0] * len(token_ids_0) + [1]
return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
if token_ids_1 is None:
return [0] * (len(token_ids_0) + 2)
return [0] * (len(token_ids_0) + len(token_ids_1) + 3)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
name = (filename_prefix + "-" if filename_prefix else "") + "binaryllm_vocab.json"
path = os.path.join(save_directory, name)
data = {
"base_vocab_size": 65536,
"vocab_size": 65538,
"bos_token": self._bos_str,
"bos_token_id": self._bos_id,
"eos_token": self._eos_str,
"eos_token_id": self._eos_id,
"unk_token": self._unk_str,
"unk_token_id": self._unk_id,
"pad_token": self.pad_token,
"pad_token_id": self.pad_token_id,
"encoding": "utf-8",
"packing": "base65536_big_endian_2bytes",
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return (path,)
# -----------------------------
# Utilities
# -----------------------------
def _id_to_token_base(self, i: int) -> str:
return f"<U{i:04X}>"
# -----------------------------
# Make HF "fast path" consistent
# -----------------------------
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = False,
**kwargs: Any,
) -> str:
ids: List[int] = []
for t in token_ids:
ti = int(t)
if skip_special_tokens and ti in (self._bos_id, self._eos_id, self._unk_id):
continue
if skip_special_tokens and self.pad_token is not None and ti == self.pad_token_id:
continue
if 0 <= ti < self._base_vocab_size:
ids.append(ti)
else:
if not skip_special_tokens:
# ignore out-of-range unknowns in decode body; specials are handled above
pass
return self._decode_base65536_be(ids)
def __call__(
self,
text: str,
add_special_tokens: bool = True,
return_attention_mask: bool = True,
return_tensors: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
ids = self._encode_base65536_be(text)
if add_special_tokens:
ids = self.build_inputs_with_special_tokens(ids)
attn = [1] * len(ids) if return_attention_mask else None
if return_tensors is None:
out: Dict[str, Any] = {"input_ids": ids}
if return_attention_mask:
out["attention_mask"] = attn
return out
rt = str(return_tensors).lower().strip()
if rt != "pt":
raise ValueError("Only return_tensors='pt' is supported in this tokenizer.")
input_ids_t = torch.tensor([ids], dtype=torch.long)
out_t: Dict[str, Any] = {"input_ids": input_ids_t}
if return_attention_mask:
out_t["attention_mask"] = torch.tensor([attn], dtype=torch.long)
return out_t