"""Shared construction and loading helpers for the project's tokenizer.""" from __future__ import annotations from dataclasses import dataclass, field import json from pathlib import Path import re from typing import Any, Iterable SPECIAL_TOKENS = [ "<|pad|>", "<|bos|>", "<|eos|>", "<|unk|>", "<|endoftext|>", ] EOT_ID = SPECIAL_TOKENS.index("<|endoftext|>") ARITHMETIC_TOKENS = ("+", "-", "*", "/", "=", "(", ")") MAX_PLACE_ID = 64 PLACE_OVERFLOW_ID = MAX_PLACE_ID + 1 PLACE_VOCAB_SIZE = PLACE_OVERFLOW_ID + 1 RESULT_ROLE_ID = 10 SPACE_ROLE_ID = 11 ROLE_VOCAB_SIZE = SPACE_ROLE_ID + 1 MAX_OPERAND_ROLES = 9 @dataclass(frozen=True) class FusionEncoding: ids: list[int] place_ids: list[int] role_ids: list[int] tokens: list[str] = field(default_factory=list) @property def input_ids(self) -> list[int]: return self.ids def __len__(self) -> int: return len(self.ids) def __iter__(self): return iter(self.ids) def __post_init__(self) -> None: if not (len(self.ids) == len(self.place_ids) == len(self.role_ids)): raise ValueError("Fusion tokenizer streams must have equal length") def build_tokenizer() -> Any: """Build a byte-level BPE tokenizer with explicit lossless boundaries.""" from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers tokenizer = Tokenizer(models.BPE(unk_token="<|unk|>")) tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [ pre_tokenizers.Split( Regex(r"\s+|\d|[+\-*/=()]|[^\s\d+\-*/=()]+"), behavior="isolated", ), pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), ] ) tokenizer.decoder = decoders.ByteLevel() return tokenizer class FusionTokenizer: """Runtime wrapper adding LSD-first digit streams to a trained BPE tokenizer.""" _digit_span_re = re.compile(r"\d+") def __init__(self, tokenizer: Any): self.tokenizer = tokenizer self._digit_token_ids = frozenset( token_id for digit in "0123456789" if (token_id := self.tokenizer.token_to_id(digit)) is not None ) self._digit_id_to_text = { int(self.tokenizer.token_to_id(digit)): digit for digit in "0123456789" if self.tokenizer.token_to_id(digit) is not None } self._equals_id = self.tokenizer.token_to_id("=") self._special_token_ids = frozenset( token_id for token in SPECIAL_TOKENS if (token_id := self.tokenizer.token_to_id(token)) is not None ) if len(self._digit_token_ids) != 10: raise ValueError("Tokenizer vocabulary must contain atomic digit tokens 0-9") if self._equals_id is None: raise ValueError("Tokenizer vocabulary must contain an atomic '=' token") def __getattr__(self, name: str) -> Any: return getattr(self.tokenizer, name) @property def digit_token_ids(self) -> frozenset[int]: return self._digit_token_ids @property def special_token_ids(self) -> frozenset[int]: return self._special_token_ids def get_vocab_size(self, with_added_tokens: bool = True) -> int: return int(self.tokenizer.get_vocab_size(with_added_tokens=with_added_tokens)) def get_vocab(self, with_added_tokens: bool = True) -> dict[str, int]: return self.tokenizer.get_vocab(with_added_tokens=with_added_tokens) def token_to_id(self, token: str) -> int | None: return self.tokenizer.token_to_id(token) def id_to_token(self, token_id: int) -> str | None: return self.tokenizer.id_to_token(int(token_id)) @classmethod def _reverse_digit_spans(cls, text: str) -> str: return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text) def _decode_token_piece(self, token_id: int) -> str: return self.tokenizer.decode([int(token_id)], skip_special_tokens=False) @staticmethod def _is_equation_whitespace(piece: str) -> bool: return bool(piece) and piece.isspace() and "\n" not in piece and "\r" not in piece def _is_equation_piece(self, token_id: int, piece: str) -> bool: if token_id in self._special_token_ids: return False if token_id in self._digit_token_ids: return True if self._is_equation_whitespace(piece): return True return len(piece) == 1 and piece in set(ARITHMETIC_TOKENS) def _annotate_equation_span( self, ids: list[int], pieces: list[str], start: int, end: int, role_ids: list[int], ) -> None: equals_positions = [ index for index in range(start, end) if ids[index] == self._equals_id ] if len(equals_positions) != 1: return equals_position = equals_positions[0] digit_runs: list[tuple[int, int]] = [] index = start while index < end: if ids[index] not in self._digit_token_ids: index += 1 continue run_start = index while index < end and ids[index] in self._digit_token_ids: index += 1 digit_runs.append((run_start, index)) operand_runs = [(a, b) for a, b in digit_runs if b <= equals_position] result_runs = [(a, b) for a, b in digit_runs if a > equals_position] if not operand_runs or not result_runs or len(operand_runs) > MAX_OPERAND_ROLES: return for index in range(start, end): if self._is_equation_whitespace(pieces[index]): role_ids[index] = SPACE_ROLE_ID for role, (run_start, run_end) in enumerate(operand_runs, start=1): for index in range(run_start, run_end): role_ids[index] = role for run_start, run_end in result_runs: for index in range(run_start, run_end): role_ids[index] = RESULT_ROLE_ID def annotate_ids(self, ids: Iterable[int]) -> tuple[list[int], list[int]]: input_ids = [int(token_id) for token_id in ids] place_ids = [0] * len(input_ids) role_ids = [0] * len(input_ids) pieces = [self._decode_token_piece(token_id) for token_id in input_ids] index = 0 while index < len(input_ids): if input_ids[index] not in self._digit_token_ids: index += 1 continue run_start = index while index < len(input_ids) and input_ids[index] in self._digit_token_ids: offset = index - run_start + 1 place_ids[index] = min(offset, PLACE_OVERFLOW_ID) index += 1 span_start: int | None = None for index, (token_id, piece) in enumerate(zip(input_ids, pieces, strict=True)): if self._is_equation_piece(token_id, piece): if span_start is None: span_start = index continue if span_start is not None: self._annotate_equation_span(input_ids, pieces, span_start, index, role_ids) span_start = None if span_start is not None: self._annotate_equation_span(input_ids, pieces, span_start, len(input_ids), role_ids) return place_ids, role_ids def encode(self, text: str, *args, **kwargs) -> FusionEncoding: transformed = self._reverse_digit_spans(text) encoding = self.tokenizer.encode(transformed, *args, **kwargs) ids = [int(token_id) for token_id in encoding.ids] place_ids, role_ids = self.annotate_ids(ids) return FusionEncoding( ids=ids, place_ids=place_ids, role_ids=role_ids, tokens=list(getattr(encoding, "tokens", [])), ) def encode_batch(self, texts: list[str], *args, **kwargs) -> list[FusionEncoding]: return [self.encode(text, *args, **kwargs) for text in texts] def decode( self, token_ids: Iterable[int], skip_special_tokens: bool = True, ) -> str: pieces: list[str] = [] text_ids: list[int] = [] digit_buffer: list[str] = [] def flush_text() -> None: if text_ids: pieces.append( self.tokenizer.decode( text_ids, skip_special_tokens=skip_special_tokens, ) ) text_ids.clear() def flush_digits() -> None: if digit_buffer: pieces.extend(reversed(digit_buffer)) digit_buffer.clear() for raw_id in token_ids: token_id = int(raw_id) if token_id in self._digit_token_ids: flush_text() digit_buffer.append(self._digit_id_to_text[token_id]) continue flush_digits() text_ids.append(token_id) flush_text() flush_digits() return "".join(pieces) def build_trainer(vocab_size: int, min_frequency: int) -> Any: from tokenizers import pre_tokenizers, trainers return trainers.BpeTrainer( vocab_size=vocab_size, min_frequency=min_frequency, special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), ) def tokenizer_files(tokenizer_dir: Path) -> tuple[Path, Path, Path]: return ( tokenizer_dir / "tokenizer.json", tokenizer_dir / "vocab.json", tokenizer_dir / "merges.txt", ) def validate_tokenizer(tokenizer_dir: Path) -> None: tokenizer_json, vocab_path, merges_path = tokenizer_files(tokenizer_dir) if not tokenizer_json.exists(): raise FileNotFoundError( f"Missing {tokenizer_json}. Retrain with train_tokenizer.py so the " "whitespace and digit boundary rules are preserved." ) if vocab_path.exists(): with vocab_path.open("r", encoding="utf-8") as f: vocab = json.load(f) else: with tokenizer_json.open("r", encoding="utf-8") as f: tokenizer_data = json.load(f) vocab = tokenizer_data.get("model", {}).get("vocab") if not isinstance(vocab, dict): raise FileNotFoundError(f"Missing vocab.json and no embedded vocab in {tokenizer_json}") max_id = max(vocab.values()) if max_id > 65_535: raise ValueError(f"Tokenizer max id {max_id} does not fit in uint16") if vocab.get("<|endoftext|>") != EOT_ID: raise ValueError( f"Expected <|endoftext|> id {EOT_ID}, " f"got {vocab.get('<|endoftext|>')}" ) missing = [ token for token in (*[str(value) for value in range(10)], *ARITHMETIC_TOKENS) if token not in vocab ] if missing: raise ValueError(f"Tokenizer missing required atomic tokens: {missing}") def load_tokenizer(tokenizer_dir: Path) -> Any: from tokenizers import Tokenizer validate_tokenizer(tokenizer_dir) tokenizer_json, _, _ = tokenizer_files(tokenizer_dir) return FusionTokenizer(Tokenizer.from_file(str(tokenizer_json)))