File size: 11,638 Bytes
4689c2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import json
import os
import pickle
import sys


_PATCH_ALLOWED_PATHS = None
_ORIG_FAST_INIT = None
_DISABLE_FULL_TOKENIZER_PICKLE_CACHE = True  # Temporary: disable Python pickle tokenizer cache load/dump.


def _normalize_path(path):
    if not path:
        return None
    try:
        return os.path.normcase(os.path.abspath(path))
    except Exception:
        return None


def _path_allowed(path):
    if not _PATCH_ALLOWED_PATHS:
        return False
    norm = _normalize_path(path)
    if norm is None:
        return False
    for allowed in _PATCH_ALLOWED_PATHS:
        if allowed is None:
            continue
        try:
            if os.path.commonpath([norm, allowed]) == allowed:
                return True
        except Exception:
            if norm.startswith(allowed):
                return True
    return False


def _load_cached_tokenizer(tokenizer_file, TokenizerFast):
    if not tokenizer_file:
        return None
    return TokenizerFast.from_file(tokenizer_file)


def patch_pretrained_tokenizer_fast(allow_paths=None):
    global _PATCH_ALLOWED_PATHS
    global _ORIG_FAST_INIT
    if allow_paths is not None:
        _PATCH_ALLOWED_PATHS = [_normalize_path(p) for p in allow_paths if p]

    try:
        import transformers.tokenization_utils_fast as tuf
    except Exception:
        return

    cls = tuf.PreTrainedTokenizerFast
    if getattr(cls, "_wan2gp_fast_init_patched", False):
        return

    if _ORIG_FAST_INIT is None:
        _ORIG_FAST_INIT = cls.__init__

    def _patched_init(self, *args, **kwargs):
        fast_tokenizer_file = kwargs.get("tokenizer_file")
        from_slow = kwargs.get("from_slow", False)
        if not fast_tokenizer_file or from_slow or not _path_allowed(fast_tokenizer_file):
            return _ORIG_FAST_INIT(self, *args, **kwargs)

        try:
            fast_tokenizer = _load_cached_tokenizer(fast_tokenizer_file, tuf.TokenizerFast)
            if fast_tokenizer is None:
                return _ORIG_FAST_INIT(self, *args, **kwargs)
            kwargs["tokenizer_object"] = fast_tokenizer
        except Exception:
            return _ORIG_FAST_INIT(self, *args, **kwargs)

        tokenizer_object = kwargs.pop("tokenizer_object", None)
        slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
        fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
        from_slow = kwargs.pop("from_slow", False)
        added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
        self.add_prefix_space = kwargs.get("add_prefix_space", False)

        if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
            raise ValueError(
                "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
                "have sentencepiece installed."
            )

        if tokenizer_object is not None:
            fast_tokenizer = tokenizer_object
        else:
            fast_tokenizer = tuf.TokenizerFast.from_file(fast_tokenizer_file)

        self._tokenizer = fast_tokenizer

        if slow_tokenizer is not None:
            kwargs.update(slow_tokenizer.init_kwargs)

        self._decode_use_source_tokenizer = False

        _truncation = self._tokenizer.truncation

        if _truncation is not None:
            self._tokenizer.enable_truncation(**_truncation)
            kwargs.setdefault("max_length", _truncation["max_length"])
            kwargs.setdefault("truncation_side", _truncation["direction"])
            kwargs.setdefault("stride", _truncation["stride"])
            kwargs.setdefault("truncation_strategy", _truncation["strategy"])
        else:
            self._tokenizer.no_truncation()

        _padding = self._tokenizer.padding
        if _padding is not None:
            self._tokenizer.enable_padding(**_padding)
            kwargs.setdefault("pad_token", _padding["pad_token"])
            kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
            kwargs.setdefault("padding_side", _padding["direction"])
            kwargs.setdefault("max_length", _padding["length"])
            kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])

        tuf.PreTrainedTokenizerBase.__init__(self, **kwargs)
        self._tokenizer.encode_special_tokens = self.split_special_tokens

        added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
        tokens_to_add = [
            token
            for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])
            if hash(repr(token)) not in added_tokens_decoder_hash
        ]
        encoder_set = set(self.added_tokens_encoder.keys())
        for token in tokens_to_add:
            if isinstance(token, tuf.AddedToken):
                encoder_set.add(token.content)
            else:
                encoder_set.add(str(token))
        tokens_to_add_set = set(tokens_to_add)
        tokens_to_add += [
            token
            for token in self.all_special_tokens_extended
            if token not in encoder_set and token not in tokens_to_add_set
        ]

        if len(tokens_to_add) > 0:
            special_tokens = set(self.all_special_tokens)
            tokens = []
            append = tokens.append
            for token in tokens_to_add:
                if isinstance(token, tuf.AddedToken):
                    content = token.content
                    if (not token.special) and (content in special_tokens):
                        token.special = True
                    append(token)
                else:
                    append(tuf.AddedToken(token, special=(token in special_tokens)))
            if tokens:
                self.add_tokens(tokens)

        try:
            pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
            if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space:
                pre_tok_class = getattr(tuf.pre_tokenizers_fast, pre_tok_state.pop("type"))
                pre_tok_state["add_prefix_space"] = self.add_prefix_space
                self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
        except Exception:
            pass

    cls.__init__ = _patched_init
    cls._wan2gp_fast_init_patched = True


def unpatch_pretrained_tokenizer_fast():
    global _ORIG_FAST_INIT
    if _ORIG_FAST_INIT is None:
        return
    try:
        import transformers.tokenization_utils_fast as tuf
    except Exception:
        return
    cls = tuf.PreTrainedTokenizerFast
    if not getattr(cls, "_wan2gp_fast_init_patched", False):
        return
    cls.__init__ = _ORIG_FAST_INIT
    cls._wan2gp_fast_init_patched = False


def _get_transformers_version():
    try:
        import transformers as _transformers
        return getattr(_transformers, "__version__", None)
    except Exception:
        return None


def _get_tokenizers_version():
    try:
        import tokenizers as _tokenizers
        return getattr(_tokenizers, "__version__", None)
    except Exception:
        return None


def _collect_tokenizer_files(tokenizer_dir):
    candidates = [
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "added_tokens.json",
        "vocab.json",
        "merges.txt",
        "config.json",
        "sentencepiece.bpe.model",
        "tokenizer.model",
    ]
    files = []
    for name in candidates:
        path = os.path.join(tokenizer_dir, name)
        if os.path.isfile(path):
            try:
                stat = os.stat(path)
                files.append({"path": name, "mtime": stat.st_mtime, "size": stat.st_size})
            except OSError:
                files.append({"path": name, "mtime": None, "size": None})
    return files


def _sanitize_cache_tag(tag):
    if not tag:
        return ""
    safe = "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in str(tag))
    return safe.strip("._-")


def _cache_paths(tokenizer_dir, cache_tag=None):
    suffix = _sanitize_cache_tag(cache_tag)
    if suffix:
        cache_file = os.path.join(tokenizer_dir, f"tokenizer.wgp.full.{suffix}.pkl")
        meta_file = os.path.join(tokenizer_dir, f"tokenizer.wgp.full.{suffix}.meta.json")
    else:
        cache_file = os.path.join(tokenizer_dir, "tokenizer.wgp.full.pkl")
        meta_file = os.path.join(tokenizer_dir, "tokenizer.wgp.full.meta.json")
    return cache_file, meta_file


def _read_cache_meta(meta_file):
    try:
        with open(meta_file, "r", encoding="utf-8") as handle:
            return json.load(handle)
    except Exception:
        return None


def _meta_matches(meta, tokenizer_dir):
    if not meta:
        return False
    if tuple(meta.get("py_version", [])) != tuple(sys.version_info[:3]):
        return False
    if meta.get("transformers_version") != _get_transformers_version():
        return False
    if meta.get("tokenizers_version") != _get_tokenizers_version():
        return False
    expected_files = meta.get("files", [])
    current_files = _collect_tokenizer_files(tokenizer_dir)
    if len(expected_files) != len(current_files):
        return False
    current_map = {f.get("path"): f for f in current_files}
    for entry in expected_files:
        cur = current_map.get(entry.get("path"))
        if cur is None:
            return False
        if entry.get("mtime") != cur.get("mtime") or entry.get("size") != cur.get("size"):
            return False
    return True


def _load_full_tokenizer_cache(tokenizer_dir, cache_tag=None):
    if _DISABLE_FULL_TOKENIZER_PICKLE_CACHE:
        return None
    cache_file, meta_file = _cache_paths(tokenizer_dir, cache_tag=cache_tag)
    if not os.path.isfile(cache_file) or not os.path.isfile(meta_file):
        return None
    meta = _read_cache_meta(meta_file)
    if not _meta_matches(meta, tokenizer_dir):
        return None
    try:
        with open(cache_file, "rb") as handle:
            return pickle.load(handle)
    except Exception:
        return None


def _save_full_tokenizer_cache(tokenizer_dir, tokenizer, cache_tag=None):
    if _DISABLE_FULL_TOKENIZER_PICKLE_CACHE:
        return
    cache_file, meta_file = _cache_paths(tokenizer_dir, cache_tag=cache_tag)
    meta = {
        "py_version": list(sys.version_info[:3]),
        "transformers_version": _get_transformers_version(),
        "tokenizers_version": _get_tokenizers_version(),
        "files": _collect_tokenizer_files(tokenizer_dir),
    }
    try:
        with open(cache_file, "wb") as handle:
            pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
        with open(meta_file, "w", encoding="utf-8") as handle:
            json.dump(meta, handle)
    except Exception:
        pass


def load_cached_lm_tokenizer(tokenizer_dir, loader_fn, cache_tag=None):
    if not tokenizer_dir:
        return loader_fn()
    cached = _load_full_tokenizer_cache(tokenizer_dir, cache_tag=cache_tag)
    if cached is not None:
        return cached
    patch_pretrained_tokenizer_fast(allow_paths=[tokenizer_dir])
    try:
        tokenizer = loader_fn()
    finally:
        unpatch_pretrained_tokenizer_fast()
    _save_full_tokenizer_cache(tokenizer_dir, tokenizer, cache_tag=cache_tag)
    return tokenizer