from __future__ import annotations import json from pathlib import Path from typing import Any from transformers import PreTrainedTokenizerFast class LizzyTokenizerFast(PreTrainedTokenizerFast): """Family-agnostic fast tokenizer wrapper for Lizzy checkpoints.""" model_input_names = ["input_ids", "attention_mask"] def __init__(self, *args: Any, **kwargs: Any) -> None: preserved_keys = ( "add_prefix_space", "add_bos_token", "add_eos_token", "clean_up_tokenization_spaces", "use_default_system_prompt", "legacy", "fix_mistral_regex", ) preserved_init_attrs = { key: kwargs.get(key) for key in preserved_keys if key in kwargs } super().__init__(*args, **kwargs) init_kwargs = getattr(self, "init_kwargs", {}) local_payload: dict[str, Any] = {} config_path = ( Path(str(getattr(self, "name_or_path", ""))) / "tokenizer_config.json" ) if config_path.is_file(): try: local_payload = json.loads(config_path.read_text(encoding="utf-8")) except Exception: local_payload = {} for key in preserved_keys: value = preserved_init_attrs.get(key, init_kwargs.get(key)) if value is None: value = local_payload.get(key) if value is not None: setattr(self, key, value) @property def all_special_tokens_extended(self) -> list[str]: """Compatibility shim for runtimes still expecting the pre-5.4 API.""" return list(self.all_special_tokens)