"""Hugging Face compatible tokenizer combining text encoder and numeric decoder.""" from __future__ import annotations import json import math import re from pathlib import Path from typing import Dict, Iterable, List, Sequence import numpy as np from transformers import AutoTokenizer, PreTrainedTokenizer DELIMITERS = ("<", ">") def _to_token(value: str | int) -> str: left, right = DELIMITERS return f"{left}{value}{right}" def _from_token(token: str) -> str: left, right = DELIMITERS match = re.fullmatch(rf"{re.escape(left)}(.*?){re.escape(right)}", token) if not match: raise ValueError(f"Cannot deserialize token: {token}") return match.group(1) class _NumericTokenizerBase(PreTrainedTokenizer): """Shared utilities for numeric decoder tokenizers.""" vocab_files_names: Dict[str, str] = {} model_input_names = ["input_ids"] vocab_filename = "numeric_vocab.json" def __init__( self, *, encoder_tokenizer_dir: str | None = None, encoder_tokenizer_name: str | None = None, encoder_tokenizer: PreTrainedTokenizer | None = None, bos_token: str = "", eos_token: str | None = None, pad_token: str | None = None, unk_token: str | None = None, **kwargs, ) -> None: eos_token = eos_token or bos_token pad_token = pad_token or bos_token self.encoder_tokenizer_dir = encoder_tokenizer_dir self.encoder_tokenizer_name = encoder_tokenizer_name self.encoder_tokenizer = encoder_tokenizer base_tokens = self._build_base_tokens() base_tokens = sorted(base_tokens) # ensure lexicographic order of strings like "<10>" vs "<2>" tokens: List[str] = [pad_token] + base_tokens self._tokens = tokens self._token_to_id = {token: idx for idx, token in enumerate(tokens)} self._id_to_token = {idx: token for token, idx in self._token_to_id.items()} init_kwargs = dict(kwargs) init_kwargs.update(self._extra_init_kwargs()) super().__init__( bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, encoder_tokenizer_dir=encoder_tokenizer_dir, encoder_tokenizer_name=encoder_tokenizer_name, **init_kwargs, ) self._load_encoder_tokenizer() # ------------------------------------------------------------------ # Hooks implemented by subclasses # ------------------------------------------------------------------ def _build_base_tokens(self) -> List[str]: raise NotImplementedError def _extra_init_kwargs(self) -> Dict[str, object]: return {} def float_to_tokens(self, value: float) -> List[str]: raise NotImplementedError def tokens_to_float(self, tokens: Sequence[str]) -> float: raise NotImplementedError def _possible_next_tokens(self, prev_tokens: Sequence[str]) -> List[str]: raise NotImplementedError # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _load_encoder_tokenizer(self) -> None: if self.encoder_tokenizer is not None: return if not self.encoder_tokenizer_dir and not self.encoder_tokenizer_name: return base_dir = None if getattr(self, "name_or_path", None): base_dir = Path(self.name_or_path) if self.encoder_tokenizer_dir and base_dir is not None: candidate = base_dir / self.encoder_tokenizer_dir if candidate.exists(): self.encoder_tokenizer = AutoTokenizer.from_pretrained(str(candidate)) return if self.encoder_tokenizer_name: self.encoder_tokenizer = AutoTokenizer.from_pretrained(self.encoder_tokenizer_name) def _ensure_encoder_tokenizer(self) -> None: if self.encoder_tokenizer is None: raise NotImplementedError( "Text tokenization requires encoder tokenizer assets. " "Ensure `encoder_tokenizer_dir` or `encoder_tokenizer_name` are provided." ) # ------------------------------------------------------------------ # Text tokenizer passthrough # ------------------------------------------------------------------ def __call__(self, *args, **kwargs): # type: ignore[override] self._ensure_encoder_tokenizer() return self.encoder_tokenizer(*args, **kwargs) def encode(self, *args, **kwargs): # type: ignore[override] self._ensure_encoder_tokenizer() return self.encoder_tokenizer.encode(*args, **kwargs) def encode_plus(self, *args, **kwargs): # type: ignore[override] self._ensure_encoder_tokenizer() return self.encoder_tokenizer.encode_plus(*args, **kwargs) def batch_encode_plus(self, *args, **kwargs): # type: ignore[override] self._ensure_encoder_tokenizer() return self.encoder_tokenizer.batch_encode_plus(*args, **kwargs) def tokenize(self, *args, **kwargs): # type: ignore[override] self._ensure_encoder_tokenizer() return self.encoder_tokenizer.tokenize(*args, **kwargs) def _tokenize(self, text: str) -> List[str]: # pragma: no cover - unused but required by base class. raise NotImplementedError("Numeric tokenizers operate directly on floats, not text.") def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: List[int] | None = None ) -> List[int]: if token_ids_1: raise ValueError("Numeric decoder tokenizer does not support pair inputs.") return token_ids_0 # ------------------------------------------------------------------ # Vocabulary helpers # ------------------------------------------------------------------ def get_vocab(self) -> Dict[str, int]: vocab = dict(self._token_to_id) if self.encoder_tokenizer is not None: vocab.update(self.encoder_tokenizer.get_vocab()) return vocab @property def vocab_size(self) -> int: # type: ignore[override] if self.encoder_tokenizer is not None and getattr(self.encoder_tokenizer, "vocab_size", None): return int(self.encoder_tokenizer.vocab_size) return len(self._tokens) @property def decoder_vocab_size(self) -> int: return len(self._tokens) def _convert_token_to_id(self, token: str) -> int: if token not in self._token_to_id: if self.encoder_tokenizer is None: raise KeyError(f"Unknown token: {token}") return self.encoder_tokenizer.convert_tokens_to_ids(token) return self._token_to_id[token] def _convert_id_to_token(self, index: int) -> str: if index not in self._id_to_token: if self.encoder_tokenizer is None: raise KeyError(f"Unknown token id: {index}") return self.encoder_tokenizer.convert_ids_to_tokens(index) return self._id_to_token[index] def save_vocabulary(self, save_directory: str | Path, filename_prefix: str | None = None) -> tuple[str]: save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) name = self.vocab_filename if filename_prefix is None else f"{filename_prefix}-{self.vocab_filename}" path = save_directory / name with path.open("w", encoding="utf-8") as f: json.dump({token: idx for idx, token in enumerate(self._tokens)}, f, indent=2) return (str(path),) def save_pretrained(self, save_directory: str | Path, filename_prefix: str | None = None): # type: ignore[override] paths = super().save_pretrained(save_directory, filename_prefix=filename_prefix) if self.encoder_tokenizer is not None and self.encoder_tokenizer_dir: encoder_dir = Path(save_directory) / self.encoder_tokenizer_dir encoder_dir.mkdir(parents=True, exist_ok=True) self.encoder_tokenizer.save_pretrained(encoder_dir) return paths # ------------------------------------------------------------------ # Numeric helpers used by the logits processor / callers. # ------------------------------------------------------------------ def float_to_token_ids(self, value: float) -> List[int]: tokens = self.float_to_tokens(value) return [self._convert_token_to_id(token) for token in tokens] def token_ids_to_floats(self, token_ids: Sequence[int]) -> List[float]: cleaned = list(token_ids[1:]) if token_ids else [] if not cleaned: return [] if len(cleaned) % self.num_tokens_per_obj != 0: raise ValueError("Token ids length is not a multiple of tokens per object.") floats: List[float] = [] for start in range(0, len(cleaned), self.num_tokens_per_obj): chunk = cleaned[start : start + self.num_tokens_per_obj] tokens = [] for idx in chunk: token = self._convert_id_to_token(idx) if token not in self._token_to_id: raise ValueError( "Token id is not part of the numeric decoder vocabulary: " f"{idx}" ) tokens.append(token) floats.append(self.tokens_to_float(tokens)) return floats # ------------------------------------------------------------------ # Generation helpers # ------------------------------------------------------------------ def possible_next_token_ids(self, prev_token_ids: Sequence[int]) -> List[int]: prev_core = list(prev_token_ids[1:]) if prev_token_ids else [] if not prev_core: local_context: List[int] = [] else: remainder = len(prev_core) % self.num_tokens_per_obj local_context = prev_core[-remainder:] if remainder else [] local_tokens = [self._convert_id_to_token(idx) for idx in local_context] allowed_tokens = self._possible_next_tokens(local_tokens) return [self._convert_token_to_id(token) for token in allowed_tokens] # ------------------------------------------------------------------ # Required hooks for the base class – kept minimal. # ------------------------------------------------------------------ def convert_tokens_to_string(self, tokens: List[str]) -> str: # pragma: no cover - not used for numeric decoding. if tokens and all(token in self._token_to_id for token in tokens): return " ".join(tokens) if self.encoder_tokenizer is not None: return self.encoder_tokenizer.convert_tokens_to_string(tokens) return " ".join(tokens) def decode(self, token_ids: Sequence[int], **kwargs) -> str: # pragma: no cover - rely on floats helper instead. token_list = list(token_ids) if token_list and all(0 <= idx < len(self._tokens) for idx in token_list): floats = self.token_ids_to_floats(token_list) return " ".join(f"{value:.6g}" for value in floats) if self.encoder_tokenizer is not None: return self.encoder_tokenizer.decode(token_ids, **kwargs) raise ValueError("Cannot decode token ids without encoder tokenizer assets.") class P10Tokenizer(_NumericTokenizerBase): """Tokenizer that mirrors :class:`regress_lm.tokenizers.P10Tokenizer`.""" vocab_filename = "p10_vocab.json" def __init__( self, num_digits: int = 6, exponent_range: int = 10, **kwargs, ) -> None: self.num_digits = int(num_digits) self.exponent_range = int(exponent_range) if self.num_digits < 1: raise ValueError("num_digits must be >= 1") if self.exponent_range < 0: raise ValueError("exponent_range must be >= 0") super().__init__(**kwargs) self.num_tokens_per_obj = 2 + self.num_digits self.decoder_tokenizer = "P10" def _extra_init_kwargs(self) -> Dict[str, object]: return { "num_digits": self.num_digits, "exponent_range": self.exponent_range, "decoder_tokenizer": "P10", "auto_map": {"AutoTokenizer": ["tokenization_p10.P10Tokenizer", None]}, "tokenizer_class": "P10Tokenizer", } def _build_base_tokens(self) -> List[str]: tokens: List[str] = [] tokens.extend(_to_token(sign) for sign in ["+", "-"]) tokens.extend(_to_token(digit) for digit in range(10)) exponents = [f"E{value}" for value in range(-self.exponent_range, self.exponent_range + 1)] tokens.extend(_to_token(exp) for exp in exponents) return tokens def _round_float(self, value: float) -> float: abs_value = abs(value) max_abs = float("9" * self.num_digits) * (10.0**self.exponent_range) min_abs = float("1" + "0" * (self.num_digits - 1)) * (10.0 ** (-self.exponent_range)) abs_value = min(abs_value, max_abs) if abs_value < min_abs: zero_or_min = round(abs_value / min_abs) abs_value = min_abs * zero_or_min return abs_value if value >= 0 else -abs_value def float_to_tokens(self, value: float) -> List[str]: rounded = self._round_float(value) sci = np.format_float_scientific( rounded, precision=self.num_digits - 1, min_digits=self.num_digits - 1, sign=True, ) match = re.fullmatch(r"([+-])([0-9.]*)e(.*)", sci) if not match: raise RuntimeError(f"Unexpected scientific notation from numpy: {sci}") sign = match.group(1) digits = list(match.group(2).replace(".", "")) exponent = int(match.group(3)) - len(digits) + 1 if rounded else 0 tokens = [sign] + digits + [f"E{exponent}"] return [_to_token(token) for token in tokens] def tokens_to_float(self, tokens: Sequence[str]) -> float: primitives = [_from_token(token) for token in tokens] sign = -1 if primitives[0] == "-" else 1 mantissa = int("".join(map(str, primitives[1:-1]))) exponent = int(primitives[-1].lstrip("E")) return float(sign * mantissa * (10 ** exponent)) def _possible_next_tokens(self, prev_tokens: Sequence[str]) -> List[str]: index = len(prev_tokens) if index < 0 or index >= self.num_tokens_per_obj: raise ValueError( f"Index {index} out of bounds for tokens per object {self.num_tokens_per_obj}." ) if index == 0: candidates: Iterable[str | int] = ["+", "-"] elif index == self.num_tokens_per_obj - 1: candidates = [ f"E{value}" for value in range(-self.exponent_range, self.exponent_range + 1) ] else: candidates = range(10) return [_to_token(candidate) for candidate in candidates] class IEEEFloatTokenizer(_NumericTokenizerBase): """Tokenizer that mirrors :class:`regress_lm.tokenizers.IEEEFloatTokenizer`.""" vocab_filename = "ieee_vocab.json" def __init__( self, *, base: int = 10, num_exponent_digits: int = 1, num_mantissa_digits: int = 4, **kwargs, ) -> None: if base < 2: raise ValueError("base must be >= 2") if num_exponent_digits < 1: raise ValueError("num_exponent_digits must be >= 1") if num_mantissa_digits < 1: raise ValueError("num_mantissa_digits must be >= 1") self.base = int(base) self.num_exponent_digits = int(num_exponent_digits) self.num_mantissa_digits = int(num_mantissa_digits) super().__init__(**kwargs) self.num_tokens_per_obj = 2 + self.num_exponent_digits + self.num_mantissa_digits self.decoder_tokenizer = f"IEEE_{self.num_mantissa_digits}_{self.num_exponent_digits}" def _extra_init_kwargs(self) -> Dict[str, object]: return { "base": self.base, "num_exponent_digits": self.num_exponent_digits, "num_mantissa_digits": self.num_mantissa_digits, "auto_map": {"AutoTokenizer": ["tokenization_p10.IEEEFloatTokenizer", None]}, "tokenizer_class": "IEEEFloatTokenizer", } def _build_base_tokens(self) -> List[str]: tokens = ["+", "-"] + list(range(self.base)) return [_to_token(token) for token in tokens] def float_to_tokens(self, value: float) -> List[str]: sign = "+" if value >= 0 else "-" abs_value = abs(value) exponent = ( math.floor(np.log(abs_value) / np.log(self.base)) if abs_value > 0 else 0 ) exponent_sign = "+" if exponent >= 0 else "-" abs_exponent = abs(exponent) exponent_repr = np.base_repr(abs_exponent, base=self.base) if len(exponent_repr) > self.num_exponent_digits and exponent_sign == "+": raise ValueError(f"Overflow: Exponent {abs_exponent} too large.") if len(exponent_repr) > self.num_exponent_digits and exponent_sign == "-": all_zeros = ["0"] * (self.num_exponent_digits + self.num_mantissa_digits) out = [sign, "-"] + all_zeros return [_to_token(s) for s in out] exponent_repr = exponent_repr.zfill(self.num_exponent_digits) mantissa = np.base_repr( abs_value * self.base ** (self.num_mantissa_digits - 1 - exponent), base=self.base, ) if len(mantissa) > self.num_mantissa_digits: mantissa = mantissa[: self.num_mantissa_digits] if len(mantissa) < self.num_mantissa_digits: mantissa += "0" * (self.num_mantissa_digits - len(mantissa)) raw_str = sign + exponent_sign + exponent_repr + mantissa return [_to_token(s) for s in raw_str] def tokens_to_float(self, tokens: Sequence[str]) -> float: primitives = [_from_token(token) for token in tokens] sign = -1 if primitives[0] == "-" else 1 exponent_sign = -1 if primitives[1] == "-" else 1 abs_exponent_str = "".join( map(str, primitives[2 : 2 + self.num_exponent_digits]) ) abs_exponent = int(abs_exponent_str, base=self.base) exponent = exponent_sign * abs_exponent mantissa_str = "".join(map(str, primitives[2 + self.num_exponent_digits :])) mantissa_unscaled = int(mantissa_str, base=self.base) mantissa = mantissa_unscaled / self.base ** (self.num_mantissa_digits - 1) return sign * (self.base**exponent) * mantissa def _possible_next_tokens(self, prev_tokens: Sequence[str]) -> List[str]: index = len(prev_tokens) if index < 0 or index >= self.num_tokens_per_obj: raise ValueError( f"Index {index} out of bounds for tokens per object {self.num_tokens_per_obj}." ) if index in (0, 1): candidates: Iterable[str | int] = ["+", "-"] else: candidates = range(self.base) return [_to_token(candidate) for candidate in candidates]