HusseinEid's picture
Super-squash branch 'main' using huggingface_hub
68a4c53
"""HuggingFace-compatible tokenizer wrapper.
Subclasses `PreTrainedTokenizerFast` so users get a drop-in replacement for
`AutoTokenizer` in any HF training/inference loop.
Usage:
from cute_tokenizer import CUTETokenizerFast
tok = CUTETokenizerFast.from_pretrained("./output")
ids = tok("def hello(): return 42").input_ids
text = tok.decode(ids)
The two-line wrapper UX the user asked for. Pre-tokenization runs in Python
(Aho-Corasick over the PUA mapping); the trained byte-level BPE handles
the residual stream in Rust.
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
from transformers import PreTrainedTokenizerFast
from ._accel_loader import USE_RUST, accel, prepare_mapping
from .decode import reverse_pua_substitute
from .pretokenizer import pretokenize_to_string
from .pua import PUAMapping
from .trainer import load_mapping
if TYPE_CHECKING:
pass
_MAPPING_FILENAME = "cute_mapping.json"
class CUTETokenizerFast(PreTrainedTokenizerFast):
"""CUTE wrapper. Performs Python-side PUA substitution before delegating
encoding to the underlying byte-level BPE tokenizer.
Inherits everything else from `PreTrainedTokenizerFast` — padding,
truncation, batch encoding, special tokens, save/load semantics.
"""
# Static map of constructor kwarg name → on-disk filename. Required by
# PreTrainedTokenizerFast machinery; values are immutable strings, so
# the class-level dict is safe despite RUF012's general advice.
vocab_files_names = { # type: ignore[assignment] # noqa: RUF012
"tokenizer_file": "tokenizer.json",
"cute_mapping_file": _MAPPING_FILENAME,
}
def __init__(
self,
tokenizer_file: str | Path | None = None,
cute_mapping_file: str | Path | None = None,
**kwargs: Any,
) -> None:
if tokenizer_file is None:
raise ValueError("CUTETokenizerFast requires `tokenizer_file`")
if cute_mapping_file is None:
raise ValueError("CUTETokenizerFast requires `cute_mapping_file`")
# PreTrainedTokenizerFast loads the underlying tokenizer.json itself.
super().__init__(tokenizer_file=str(tokenizer_file), **kwargs)
self._cute_mapping: PUAMapping = load_mapping(Path(cute_mapping_file))
self._cute_mapping_file = str(cute_mapping_file)
# 1.1.0: build the purpose-built `cute-bpe` encoder once. The
# `fast_encode` / `fast_decode` hot paths use it instead of HF's
# BPE for ~2x lower encode latency at byte-equal output. Falls
# back to the HF-backed path if the Rust extension didn't load,
# or when `CUTE_USE_HF_BACKEND=1` is set by the user.
import os as _os
force_hf = _os.environ.get("CUTE_USE_HF_BACKEND", "") not in (
"",
"0",
"false",
"False",
)
self._cute_bpe: Any | None = None
if USE_RUST and not force_hf and hasattr(accel, "BPEEncoder"):
try:
self._cute_bpe = accel.BPEEncoder(str(tokenizer_file), str(cute_mapping_file))
except Exception:
self._cute_bpe = None
# ---------------------------------------------------------------------
# Persistence
# ---------------------------------------------------------------------
@property
def cute_mapping(self) -> PUAMapping:
return self._cute_mapping
def _save_pretrained( # type: ignore[override]
self,
save_directory: Any,
file_names: tuple[str, ...],
legacy_format: bool | None = None,
filename_prefix: str | None = None,
) -> tuple[str, ...]:
"""Save the BPE tokenizer.json + cute_mapping.json + tokenizer_config.json.
We hook here (rather than `save_vocabulary`) because Fast tokenizers
bypass `save_vocabulary` entirely — `_save_pretrained` is the real
extension point.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
# Let the parent write tokenizer.json + tokenizer_config.json + special_tokens_map.json
out = super()._save_pretrained(
str(save_directory),
file_names=file_names,
legacy_format=legacy_format,
filename_prefix=filename_prefix,
)
prefix = f"{filename_prefix}-" if filename_prefix else ""
mapping_path = save_directory / f"{prefix}{_MAPPING_FILENAME}"
from .trainer import save_mapping
save_mapping(self._cute_mapping, mapping_path)
return (*out, str(mapping_path))
# ---------------------------------------------------------------------
# Encode path (override at the lowest convenient level)
# ---------------------------------------------------------------------
def _cute_pretokenize(self, text: str) -> str:
"""Run PUA substitution + identifier splitting on a single string.
Hot path: skip the wrapper-level `pretokenize_to_string` redirection
and call `_accel.pretokenize_to_string` directly with the cached
prepared mapping. Saves one Python frame per encode.
"""
if USE_RUST:
return accel.pretokenize_to_string(text, self._prepared_mapping)
return pretokenize_to_string(text, self._cute_mapping)
def _batch_encode_plus( # type: ignore[override]
self,
batch_text_or_text_pairs: Any,
**kwargs: Any,
) -> Any:
return super()._batch_encode_plus(
self._preprocess_batch(batch_text_or_text_pairs), **kwargs
)
def _encode_plus( # type: ignore[override]
self,
text: Any,
text_pair: Any | None = None,
**kwargs: Any,
) -> Any:
text = self._preprocess_one(text)
if text_pair is not None:
text_pair = self._preprocess_one(text_pair)
return super()._encode_plus(text, text_pair=text_pair, **kwargs)
def _preprocess_one(self, x: Any) -> Any:
if isinstance(x, str):
return self._cute_pretokenize(x)
# Pre-tokenized input (list of strings) — substitute each piece.
if isinstance(x, list) and all(isinstance(p, str) for p in x):
# Batch FFI hop: one allow_threads + Rayon over all pieces.
if USE_RUST and hasattr(accel, "pretokenize_batch"):
return list(accel.pretokenize_batch(x, self._prepared_mapping))
return [self._cute_pretokenize(p) for p in x]
return x
def _preprocess_batch(self, batch: Any) -> Any:
if not isinstance(batch, list):
return batch
# Fast path: a homogeneous list of plain strings can go through the
# batched Rust API in one FFI hop (Rayon-parallel inside).
if (
USE_RUST
and hasattr(accel, "pretokenize_batch")
and all(isinstance(b, str) for b in batch)
):
return list(accel.pretokenize_batch(batch, self._prepared_mapping))
return [self._preprocess_pair_or_text(b) for b in batch]
def _preprocess_pair_or_text(self, item: Any) -> Any:
if isinstance(item, tuple) and len(item) == 2:
a, b = item
return (self._preprocess_one(a), self._preprocess_one(b))
return self._preprocess_one(item)
@property
def _prepared_mapping(self) -> Any:
"""Lazily-built, instance-cached `_accel.PreparedMapping`. Avoids
the global id-keyed cache lookup on every encode call.
"""
cached = getattr(self, "_prepared_mapping_cache", None)
if cached is not None:
return cached
prepared = prepare_mapping(self._cute_mapping)
# Bypass dataclass-frozen-style attribute checks via __dict__.
self.__dict__["_prepared_mapping_cache"] = prepared
return prepared
# ---------------------------------------------------------------------
# Decode path
# ---------------------------------------------------------------------
def _decode( # type: ignore[override]
self,
token_ids: Any,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool | None = None,
**kwargs: Any,
) -> str:
text = super()._decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
if USE_RUST:
return accel.reverse_pua_substitute(text, self._prepared_mapping)
return reverse_pua_substitute(text, self._cute_mapping)
def convert_tokens_to_string(self, tokens: list[str]) -> str: # type: ignore[override]
text = super().convert_tokens_to_string(tokens)
if USE_RUST:
return accel.reverse_pua_substitute(text, self._prepared_mapping)
return reverse_pua_substitute(text, self._cute_mapping)
# ------------------------------------------------------------------
# Fast paths — skip PreTrainedTokenizerFast machinery
# ------------------------------------------------------------------
#
# `__call__` / `_encode_plus` build a `BatchEncoding` (input_ids +
# attention_mask + special-token insertion + optional padding /
# truncation). That machinery costs ~400 µs per call on top of the
# actual BPE encode. When all you want is `list[int]` of token ids,
# use `fast_encode` / `fast_decode` to skip the wrapper and call the
# raw HF Tokenizer directly. Trade-off: no attention_mask, no
# special-token insertion, no padding/truncation logic. For raw
# inference / training pipelines that's usually fine.
def fast_encode(self, text: str) -> list[int]:
"""Return token ids for ``text``. Skips the ``BatchEncoding`` wrapper.
Uses the 1.1.0 purpose-built ``cute-bpe`` Rust encoder when
available (~2x lower latency than the HF-backed fallback,
byte-equal output). Set ``CUTE_USE_HF_BACKEND=1`` to force the
previous HF-tokenizers path.
"""
if self._cute_bpe is not None:
return list(self._cute_bpe.encode(text))
if USE_RUST:
pre = accel.pretokenize_to_string(text, self._prepared_mapping)
else:
pre = pretokenize_to_string(text, self._cute_mapping)
return self._tokenizer.encode(pre, add_special_tokens=False).ids
def fast_encode_batch(self, texts: list[str]) -> list[list[int]]:
"""Batched ``fast_encode``. Single Rayon-parallel FFI hop when
the ``cute-bpe`` backend is loaded.
"""
if self._cute_bpe is not None:
return [list(ids) for ids in self._cute_bpe.encode_batch(texts)]
if USE_RUST and hasattr(accel, "pretokenize_batch"):
pres = list(accel.pretokenize_batch(texts, self._prepared_mapping))
else:
pres = [pretokenize_to_string(t, self._cute_mapping) for t in texts]
return [enc.ids for enc in self._tokenizer.encode_batch(pres, add_special_tokens=False)]
def fast_decode(self, ids: list[int]) -> str:
"""Decode token ids to text. Skips the wrapper.
Uses the 1.1.0 ``cute-bpe`` decoder (table-lookup + byte-level
inverse + reverse-PUA in one Rust call) when available.
"""
if self._cute_bpe is not None:
return self._cute_bpe.decode(list(ids))
text = self._tokenizer.decode(ids, skip_special_tokens=True)
if USE_RUST:
return accel.reverse_pua_substitute(text, self._prepared_mapping)
return reverse_pua_substitute(text, self._cute_mapping)
def fast_decode_batch(self, ids_list: list[list[int]]) -> list[str]:
"""Batched decode. One Rayon-parallel FFI hop when ``cute-bpe``
is loaded.
"""
if self._cute_bpe is not None:
return list(self._cute_bpe.decode_batch([list(ids) for ids in ids_list]))
texts = self._tokenizer.decode_batch(ids_list, skip_special_tokens=True)
if USE_RUST and hasattr(accel, "reverse_pua_batch"):
return list(accel.reverse_pua_batch(texts, self._prepared_mapping))
return [reverse_pua_substitute(t, self._cute_mapping) for t in texts]
__all__ = ["CUTETokenizerFast"]