| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| from transformers import PreTrainedTokenizerFast |
|
|
|
|
| class LizzyTokenizerFast(PreTrainedTokenizerFast): |
| """Family-agnostic fast tokenizer wrapper for Lizzy checkpoints.""" |
|
|
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| preserved_keys = ( |
| "add_prefix_space", |
| "add_bos_token", |
| "add_eos_token", |
| "clean_up_tokenization_spaces", |
| "use_default_system_prompt", |
| "legacy", |
| "fix_mistral_regex", |
| ) |
| preserved_init_attrs = { |
| key: kwargs.get(key) |
| for key in preserved_keys |
| if key in kwargs |
| } |
| super().__init__(*args, **kwargs) |
| init_kwargs = getattr(self, "init_kwargs", {}) |
| local_payload: dict[str, Any] = {} |
| config_path = ( |
| Path(str(getattr(self, "name_or_path", ""))) / "tokenizer_config.json" |
| ) |
| if config_path.is_file(): |
| try: |
| local_payload = json.loads(config_path.read_text(encoding="utf-8")) |
| except Exception: |
| local_payload = {} |
| for key in preserved_keys: |
| value = preserved_init_attrs.get(key, init_kwargs.get(key)) |
| if value is None: |
| value = local_payload.get(key) |
| if value is not None: |
| setattr(self, key, value) |
|
|
| @property |
| def all_special_tokens_extended(self) -> list[str]: |
| """Compatibility shim for runtimes still expecting the pre-5.4 API.""" |
| return list(self.all_special_tokens) |
|
|