File size: 1,710 Bytes
edbbb7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from transformers import PreTrainedTokenizerFast


class LizzyTokenizerFast(PreTrainedTokenizerFast):
    """Family-agnostic fast tokenizer wrapper for Lizzy checkpoints."""

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        preserved_keys = (
            "add_prefix_space",
            "add_bos_token",
            "add_eos_token",
            "clean_up_tokenization_spaces",
            "use_default_system_prompt",
            "legacy",
            "fix_mistral_regex",
        )
        preserved_init_attrs = {
            key: kwargs.get(key)
            for key in preserved_keys
            if key in kwargs
        }
        super().__init__(*args, **kwargs)
        init_kwargs = getattr(self, "init_kwargs", {})
        local_payload: dict[str, Any] = {}
        config_path = (
            Path(str(getattr(self, "name_or_path", ""))) / "tokenizer_config.json"
        )
        if config_path.is_file():
            try:
                local_payload = json.loads(config_path.read_text(encoding="utf-8"))
            except Exception:
                local_payload = {}
        for key in preserved_keys:
            value = preserved_init_attrs.get(key, init_kwargs.get(key))
            if value is None:
                value = local_payload.get(key)
            if value is not None:
                setattr(self, key, value)

    @property
    def all_special_tokens_extended(self) -> list[str]:
        """Compatibility shim for runtimes still expecting the pre-5.4 API."""
        return list(self.all_special_tokens)