Safetensors
regresslm
custom_code
RLM-GemmaS-Code-v0 / tokenization_p10.py
akhauriyash's picture
Initial upload: RegressLM export for RLM-GemmaS-Code-v0
ee2b58e verified
"""Hugging Face compatible tokenizer combining text encoder and numeric decoder."""
from __future__ import annotations
import json
import math
import re
from pathlib import Path
from typing import Dict, Iterable, List, Sequence
import numpy as np
from transformers import AutoTokenizer, PreTrainedTokenizer
DELIMITERS = ("<", ">")
def _to_token(value: str | int) -> str:
left, right = DELIMITERS
return f"{left}{value}{right}"
def _from_token(token: str) -> str:
left, right = DELIMITERS
match = re.fullmatch(rf"{re.escape(left)}(.*?){re.escape(right)}", token)
if not match:
raise ValueError(f"Cannot deserialize token: {token}")
return match.group(1)
class _NumericTokenizerBase(PreTrainedTokenizer):
"""Shared utilities for numeric decoder tokenizers."""
vocab_files_names: Dict[str, str] = {}
model_input_names = ["input_ids"]
vocab_filename = "numeric_vocab.json"
def __init__(
self,
*,
encoder_tokenizer_dir: str | None = None,
encoder_tokenizer_name: str | None = None,
encoder_tokenizer: PreTrainedTokenizer | None = None,
bos_token: str = "<pad>",
eos_token: str | None = None,
pad_token: str | None = None,
unk_token: str | None = None,
**kwargs,
) -> None:
eos_token = eos_token or bos_token
pad_token = pad_token or bos_token
self.encoder_tokenizer_dir = encoder_tokenizer_dir
self.encoder_tokenizer_name = encoder_tokenizer_name
self.encoder_tokenizer = encoder_tokenizer
base_tokens = self._build_base_tokens()
base_tokens = sorted(base_tokens) # ensure lexicographic order of strings like "<10>" vs "<2>"
tokens: List[str] = [pad_token] + base_tokens
self._tokens = tokens
self._token_to_id = {token: idx for idx, token in enumerate(tokens)}
self._id_to_token = {idx: token for token, idx in self._token_to_id.items()}
init_kwargs = dict(kwargs)
init_kwargs.update(self._extra_init_kwargs())
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
encoder_tokenizer_dir=encoder_tokenizer_dir,
encoder_tokenizer_name=encoder_tokenizer_name,
**init_kwargs,
)
self._load_encoder_tokenizer()
# ------------------------------------------------------------------
# Hooks implemented by subclasses
# ------------------------------------------------------------------
def _build_base_tokens(self) -> List[str]:
raise NotImplementedError
def _extra_init_kwargs(self) -> Dict[str, object]:
return {}
def float_to_tokens(self, value: float) -> List[str]:
raise NotImplementedError
def tokens_to_float(self, tokens: Sequence[str]) -> float:
raise NotImplementedError
def _possible_next_tokens(self, prev_tokens: Sequence[str]) -> List[str]:
raise NotImplementedError
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _load_encoder_tokenizer(self) -> None:
if self.encoder_tokenizer is not None:
return
if not self.encoder_tokenizer_dir and not self.encoder_tokenizer_name:
return
base_dir = None
if getattr(self, "name_or_path", None):
base_dir = Path(self.name_or_path)
if self.encoder_tokenizer_dir and base_dir is not None:
candidate = base_dir / self.encoder_tokenizer_dir
if candidate.exists():
self.encoder_tokenizer = AutoTokenizer.from_pretrained(str(candidate))
return
if self.encoder_tokenizer_name:
self.encoder_tokenizer = AutoTokenizer.from_pretrained(self.encoder_tokenizer_name)
def _ensure_encoder_tokenizer(self) -> None:
if self.encoder_tokenizer is None:
raise NotImplementedError(
"Text tokenization requires encoder tokenizer assets. "
"Ensure `encoder_tokenizer_dir` or `encoder_tokenizer_name` are provided."
)
# ------------------------------------------------------------------
# Text tokenizer passthrough
# ------------------------------------------------------------------
def __call__(self, *args, **kwargs): # type: ignore[override]
self._ensure_encoder_tokenizer()
return self.encoder_tokenizer(*args, **kwargs)
def encode(self, *args, **kwargs): # type: ignore[override]
self._ensure_encoder_tokenizer()
return self.encoder_tokenizer.encode(*args, **kwargs)
def encode_plus(self, *args, **kwargs): # type: ignore[override]
self._ensure_encoder_tokenizer()
return self.encoder_tokenizer.encode_plus(*args, **kwargs)
def batch_encode_plus(self, *args, **kwargs): # type: ignore[override]
self._ensure_encoder_tokenizer()
return self.encoder_tokenizer.batch_encode_plus(*args, **kwargs)
def tokenize(self, *args, **kwargs): # type: ignore[override]
self._ensure_encoder_tokenizer()
return self.encoder_tokenizer.tokenize(*args, **kwargs)
def _tokenize(self, text: str) -> List[str]: # pragma: no cover - unused but required by base class.
raise NotImplementedError("Numeric tokenizers operate directly on floats, not text.")
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: List[int] | None = None
) -> List[int]:
if token_ids_1:
raise ValueError("Numeric decoder tokenizer does not support pair inputs.")
return token_ids_0
# ------------------------------------------------------------------
# Vocabulary helpers
# ------------------------------------------------------------------
def get_vocab(self) -> Dict[str, int]:
vocab = dict(self._token_to_id)
if self.encoder_tokenizer is not None:
vocab.update(self.encoder_tokenizer.get_vocab())
return vocab
@property
def vocab_size(self) -> int: # type: ignore[override]
if self.encoder_tokenizer is not None and getattr(self.encoder_tokenizer, "vocab_size", None):
return int(self.encoder_tokenizer.vocab_size)
return len(self._tokens)
@property
def decoder_vocab_size(self) -> int:
return len(self._tokens)
def _convert_token_to_id(self, token: str) -> int:
if token not in self._token_to_id:
if self.encoder_tokenizer is None:
raise KeyError(f"Unknown token: {token}")
return self.encoder_tokenizer.convert_tokens_to_ids(token)
return self._token_to_id[token]
def _convert_id_to_token(self, index: int) -> str:
if index not in self._id_to_token:
if self.encoder_tokenizer is None:
raise KeyError(f"Unknown token id: {index}")
return self.encoder_tokenizer.convert_ids_to_tokens(index)
return self._id_to_token[index]
def save_vocabulary(self, save_directory: str | Path, filename_prefix: str | None = None) -> tuple[str]:
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
name = self.vocab_filename if filename_prefix is None else f"{filename_prefix}-{self.vocab_filename}"
path = save_directory / name
with path.open("w", encoding="utf-8") as f:
json.dump({token: idx for idx, token in enumerate(self._tokens)}, f, indent=2)
return (str(path),)
def save_pretrained(self, save_directory: str | Path, filename_prefix: str | None = None): # type: ignore[override]
paths = super().save_pretrained(save_directory, filename_prefix=filename_prefix)
if self.encoder_tokenizer is not None and self.encoder_tokenizer_dir:
encoder_dir = Path(save_directory) / self.encoder_tokenizer_dir
encoder_dir.mkdir(parents=True, exist_ok=True)
self.encoder_tokenizer.save_pretrained(encoder_dir)
return paths
# ------------------------------------------------------------------
# Numeric helpers used by the logits processor / callers.
# ------------------------------------------------------------------
def float_to_token_ids(self, value: float) -> List[int]:
tokens = self.float_to_tokens(value)
return [self._convert_token_to_id(token) for token in tokens]
def token_ids_to_floats(self, token_ids: Sequence[int]) -> List[float]:
cleaned = list(token_ids[1:]) if token_ids else []
if not cleaned:
return []
if len(cleaned) % self.num_tokens_per_obj != 0:
raise ValueError("Token ids length is not a multiple of tokens per object.")
floats: List[float] = []
for start in range(0, len(cleaned), self.num_tokens_per_obj):
chunk = cleaned[start : start + self.num_tokens_per_obj]
tokens = []
for idx in chunk:
token = self._convert_id_to_token(idx)
if token not in self._token_to_id:
raise ValueError(
"Token id is not part of the numeric decoder vocabulary: "
f"{idx}"
)
tokens.append(token)
floats.append(self.tokens_to_float(tokens))
return floats
# ------------------------------------------------------------------
# Generation helpers
# ------------------------------------------------------------------
def possible_next_token_ids(self, prev_token_ids: Sequence[int]) -> List[int]:
prev_core = list(prev_token_ids[1:]) if prev_token_ids else []
if not prev_core:
local_context: List[int] = []
else:
remainder = len(prev_core) % self.num_tokens_per_obj
local_context = prev_core[-remainder:] if remainder else []
local_tokens = [self._convert_id_to_token(idx) for idx in local_context]
allowed_tokens = self._possible_next_tokens(local_tokens)
return [self._convert_token_to_id(token) for token in allowed_tokens]
# ------------------------------------------------------------------
# Required hooks for the base class – kept minimal.
# ------------------------------------------------------------------
def convert_tokens_to_string(self, tokens: List[str]) -> str: # pragma: no cover - not used for numeric decoding.
if tokens and all(token in self._token_to_id for token in tokens):
return " ".join(tokens)
if self.encoder_tokenizer is not None:
return self.encoder_tokenizer.convert_tokens_to_string(tokens)
return " ".join(tokens)
def decode(self, token_ids: Sequence[int], **kwargs) -> str: # pragma: no cover - rely on floats helper instead.
token_list = list(token_ids)
if token_list and all(0 <= idx < len(self._tokens) for idx in token_list):
floats = self.token_ids_to_floats(token_list)
return " ".join(f"{value:.6g}" for value in floats)
if self.encoder_tokenizer is not None:
return self.encoder_tokenizer.decode(token_ids, **kwargs)
raise ValueError("Cannot decode token ids without encoder tokenizer assets.")
class P10Tokenizer(_NumericTokenizerBase):
"""Tokenizer that mirrors :class:`regress_lm.tokenizers.P10Tokenizer`."""
vocab_filename = "p10_vocab.json"
def __init__(
self,
num_digits: int = 6,
exponent_range: int = 10,
**kwargs,
) -> None:
self.num_digits = int(num_digits)
self.exponent_range = int(exponent_range)
if self.num_digits < 1:
raise ValueError("num_digits must be >= 1")
if self.exponent_range < 0:
raise ValueError("exponent_range must be >= 0")
super().__init__(**kwargs)
self.num_tokens_per_obj = 2 + self.num_digits
self.decoder_tokenizer = "P10"
def _extra_init_kwargs(self) -> Dict[str, object]:
return {
"num_digits": self.num_digits,
"exponent_range": self.exponent_range,
"decoder_tokenizer": "P10",
"auto_map": {"AutoTokenizer": ["tokenization_p10.P10Tokenizer", None]},
"tokenizer_class": "P10Tokenizer",
}
def _build_base_tokens(self) -> List[str]:
tokens: List[str] = []
tokens.extend(_to_token(sign) for sign in ["+", "-"])
tokens.extend(_to_token(digit) for digit in range(10))
exponents = [f"E{value}" for value in range(-self.exponent_range, self.exponent_range + 1)]
tokens.extend(_to_token(exp) for exp in exponents)
return tokens
def _round_float(self, value: float) -> float:
abs_value = abs(value)
max_abs = float("9" * self.num_digits) * (10.0**self.exponent_range)
min_abs = float("1" + "0" * (self.num_digits - 1)) * (10.0 ** (-self.exponent_range))
abs_value = min(abs_value, max_abs)
if abs_value < min_abs:
zero_or_min = round(abs_value / min_abs)
abs_value = min_abs * zero_or_min
return abs_value if value >= 0 else -abs_value
def float_to_tokens(self, value: float) -> List[str]:
rounded = self._round_float(value)
sci = np.format_float_scientific(
rounded,
precision=self.num_digits - 1,
min_digits=self.num_digits - 1,
sign=True,
)
match = re.fullmatch(r"([+-])([0-9.]*)e(.*)", sci)
if not match:
raise RuntimeError(f"Unexpected scientific notation from numpy: {sci}")
sign = match.group(1)
digits = list(match.group(2).replace(".", ""))
exponent = int(match.group(3)) - len(digits) + 1 if rounded else 0
tokens = [sign] + digits + [f"E{exponent}"]
return [_to_token(token) for token in tokens]
def tokens_to_float(self, tokens: Sequence[str]) -> float:
primitives = [_from_token(token) for token in tokens]
sign = -1 if primitives[0] == "-" else 1
mantissa = int("".join(map(str, primitives[1:-1])))
exponent = int(primitives[-1].lstrip("E"))
return float(sign * mantissa * (10 ** exponent))
def _possible_next_tokens(self, prev_tokens: Sequence[str]) -> List[str]:
index = len(prev_tokens)
if index < 0 or index >= self.num_tokens_per_obj:
raise ValueError(
f"Index {index} out of bounds for tokens per object {self.num_tokens_per_obj}."
)
if index == 0:
candidates: Iterable[str | int] = ["+", "-"]
elif index == self.num_tokens_per_obj - 1:
candidates = [
f"E{value}" for value in range(-self.exponent_range, self.exponent_range + 1)
]
else:
candidates = range(10)
return [_to_token(candidate) for candidate in candidates]
class IEEEFloatTokenizer(_NumericTokenizerBase):
"""Tokenizer that mirrors :class:`regress_lm.tokenizers.IEEEFloatTokenizer`."""
vocab_filename = "ieee_vocab.json"
def __init__(
self,
*,
base: int = 10,
num_exponent_digits: int = 1,
num_mantissa_digits: int = 4,
**kwargs,
) -> None:
if base < 2:
raise ValueError("base must be >= 2")
if num_exponent_digits < 1:
raise ValueError("num_exponent_digits must be >= 1")
if num_mantissa_digits < 1:
raise ValueError("num_mantissa_digits must be >= 1")
self.base = int(base)
self.num_exponent_digits = int(num_exponent_digits)
self.num_mantissa_digits = int(num_mantissa_digits)
super().__init__(**kwargs)
self.num_tokens_per_obj = 2 + self.num_exponent_digits + self.num_mantissa_digits
self.decoder_tokenizer = f"IEEE_{self.num_mantissa_digits}_{self.num_exponent_digits}"
def _extra_init_kwargs(self) -> Dict[str, object]:
return {
"base": self.base,
"num_exponent_digits": self.num_exponent_digits,
"num_mantissa_digits": self.num_mantissa_digits,
"auto_map": {"AutoTokenizer": ["tokenization_p10.IEEEFloatTokenizer", None]},
"tokenizer_class": "IEEEFloatTokenizer",
}
def _build_base_tokens(self) -> List[str]:
tokens = ["+", "-"] + list(range(self.base))
return [_to_token(token) for token in tokens]
def float_to_tokens(self, value: float) -> List[str]:
sign = "+" if value >= 0 else "-"
abs_value = abs(value)
exponent = (
math.floor(np.log(abs_value) / np.log(self.base)) if abs_value > 0 else 0
)
exponent_sign = "+" if exponent >= 0 else "-"
abs_exponent = abs(exponent)
exponent_repr = np.base_repr(abs_exponent, base=self.base)
if len(exponent_repr) > self.num_exponent_digits and exponent_sign == "+":
raise ValueError(f"Overflow: Exponent {abs_exponent} too large.")
if len(exponent_repr) > self.num_exponent_digits and exponent_sign == "-":
all_zeros = ["0"] * (self.num_exponent_digits + self.num_mantissa_digits)
out = [sign, "-"] + all_zeros
return [_to_token(s) for s in out]
exponent_repr = exponent_repr.zfill(self.num_exponent_digits)
mantissa = np.base_repr(
abs_value * self.base ** (self.num_mantissa_digits - 1 - exponent),
base=self.base,
)
if len(mantissa) > self.num_mantissa_digits:
mantissa = mantissa[: self.num_mantissa_digits]
if len(mantissa) < self.num_mantissa_digits:
mantissa += "0" * (self.num_mantissa_digits - len(mantissa))
raw_str = sign + exponent_sign + exponent_repr + mantissa
return [_to_token(s) for s in raw_str]
def tokens_to_float(self, tokens: Sequence[str]) -> float:
primitives = [_from_token(token) for token in tokens]
sign = -1 if primitives[0] == "-" else 1
exponent_sign = -1 if primitives[1] == "-" else 1
abs_exponent_str = "".join(
map(str, primitives[2 : 2 + self.num_exponent_digits])
)
abs_exponent = int(abs_exponent_str, base=self.base)
exponent = exponent_sign * abs_exponent
mantissa_str = "".join(map(str, primitives[2 + self.num_exponent_digits :]))
mantissa_unscaled = int(mantissa_str, base=self.base)
mantissa = mantissa_unscaled / self.base ** (self.num_mantissa_digits - 1)
return sign * (self.base**exponent) * mantissa
def _possible_next_tokens(self, prev_tokens: Sequence[str]) -> List[str]:
index = len(prev_tokens)
if index < 0 or index >= self.num_tokens_per_obj:
raise ValueError(
f"Index {index} out of bounds for tokens per object {self.num_tokens_per_obj}."
)
if index in (0, 1):
candidates: Iterable[str | int] = ["+", "-"]
else:
candidates = range(self.base)
return [_to_token(candidate) for candidate in candidates]