File size: 16,562 Bytes
473c3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Any, cast

from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

from distiller.model2vec.tokenizer.datamodels import Token
from distiller.model2vec.tokenizer.model import process_tokenizer
from distiller.model2vec.tokenizer.normalizer import replace_normalizer
from distiller.model2vec.tokenizer.pretokenizer import replace_pretokenizer

if TYPE_CHECKING:
    import re

    from tokenizers.normalizers import Normalizer
    from tokenizers.pre_tokenizers import (
        PreTokenizer,
    )

logger = logging.getLogger(__name__)


_DEFAULT_POST_PROCESSOR_TEMPLATE = {
    "type": "TemplateProcessing",
    "single": [{"Sequence": {"id": "A", "type_id": 0}}],
    "pair": [{"Sequence": {"id": "A", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 0}}],
    "special_tokens": {},
}


def _remap_added_tokens(
    special_tokens: list[dict[str, Any]],
    vocabulary: list[str],
) -> list[dict[str, Any]]:
    """
    Remap special tokens in the tokenizer.

    This function updates the special tokens in the tokenizer based on a mapping provided.
    It also ensures that the special tokens are present in the vocabulary.

    :param special_tokens: The special tokens to remap.
    :param vocabulary: The vocabulary as a list of tokens.
    :return: The updated special tokens.
    """
    # Deepcopy
    special_tokens = [{**x} for x in special_tokens]
    for token in special_tokens:
        token["id"] = vocabulary.index(token["content"])

    return special_tokens


def replace_vocabulary(
    tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
) -> Tokenizer:
    """Replace the vocabulary of a tokenizer with a new one."""
    tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str())
    added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]

    pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary]

    # We need to remove the added tokens but keep [UNK] and [PAD] tokens.
    added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens)
    added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens)

    # Remove old added tokens from added tokens
    tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
    tokenizer_json = process_tokenizer(
        tokenizer_json, pre_tokenized_tokens, "[UNK]" if "[UNK]" in pre_tokenized_tokens else None
    )

    # Remap special tokens
    tokenizer_json["added_tokens"] = _remap_added_tokens(
        special_tokens=tokenizer_json["added_tokens"],
        vocabulary=pre_tokenized_tokens,
    )
    tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE

    return Tokenizer.from_str(json.dumps(tokenizer_json))


def _rename_added_token(
    form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str]
) -> list[dict[str, Any]]:
    """Rename added tokens in the tokenizer."""
    if form is None:
        return added_tokens

    idx = vocabulary.index(form)
    added_token = [x for x in added_tokens if x["content"] == form]
    if added_token:
        added_token[0]["id"] = idx
        added_token[0]["content"] = new_form
        vocabulary[idx] = new_form

    return added_tokens


def clean_and_create_vocabulary(
    tokenizer: PreTrainedTokenizerFast,
    vocabulary: list[str],
    token_remove_regex: re.Pattern | None,
) -> tuple[list[Token], Tokenizer]:
    """Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
    seen_tokens = set()
    post_normalize_seen_tokens = set()
    n_empty = 0
    n_duplicates = 0

    backend_tokenizer = tokenizer.backend_tokenizer

    # Make a base list of tokens.
    internal_vocab: dict[str, int] = tokenizer.get_vocab()
    internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])]

    cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
    # Copy the backend tokenizer to avoid modifying the original.
    backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str())
    backend_tokenizer = replace_normalizer(backend_tokenizer)

    internal_tokens_set = {token.form for token in cleaned_vocabulary}

    normalizer: Normalizer | None = backend_tokenizer.normalizer
    for token in vocabulary:
        if normalizer is not None:
            token = cast("str", normalizer.normalize_str(token))

        if not token:
            n_empty += 1
            continue

        pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer
        normalized_token = token
        if pre_tokenizer is not None:
            normalized_token = _normalize_vocabulary_token(
                token=token,
                pre_tokenizer=pre_tokenizer,
            )

        # We need to check whether the pretokenized token is in the vocabulary.
        # But we need to return the original token, because that will be tokenized
        # again by the tokenizer during featurization.
        if normalized_token in seen_tokens or normalized_token in internal_tokens_set:
            n_duplicates += 1
            continue

        # Add the possibly pretokenized token to seen
        seen_tokens.add(normalized_token)

        # After checking the token exists, we need to normalize it into the token
        # it will become. For byte tokens, this means we don't do anything. For
        # other types of tokens, we will insert a metaspace.
        # In the case of multiword tokens, we replace any spaces with the metaspace
        # or byte prefix token.
        if not normalized_token.startswith(("▁", "Ġ")):
            normalized_token = normalized_token.replace(" ", "▁")
            normalized_token = f"▁{normalized_token}"
        else:
            normalized_token = normalized_token.replace(" ", normalized_token[0])

        if normalized_token in post_normalize_seen_tokens:
            n_duplicates += 1
            continue

        post_normalize_seen_tokens.add(normalized_token)
        # Add the original string to the vocabulary.
        cleaned_vocabulary.append(
            Token(form=token, normalized_form=normalized_token, is_subword=False, is_internal=False)
        )

    if n_duplicates:
        logger.warning(f"Removed {n_duplicates} duplicate tokens.")
    if n_empty:
        logger.warning(f"Removed {n_empty} empty tokens.")

    return cleaned_vocabulary, replace_pretokenizer(backend_tokenizer)


def _process_internal_tokens(
    tokenizer: PreTrainedTokenizerFast,
    backend_tokenizer: Tokenizer,
    internal_tokens: list[str],
    token_remove_regex: re.Pattern | None,
) -> list[Token]:
    """Clean internal tokens."""
    # Get the pad and unk token from the tokenizer.
    pad_token: str | None = tokenizer.special_tokens_map.get("pad_token")  # type: ignore[assignment]
    unk_token: str | None = tokenizer.special_tokens_map.get("unk_token")  # type: ignore[assignment]
    # Empty set if no pad or unk token is set.
    added_tokens_to_keep: set[str] = {x for x in (pad_token, unk_token) if x is not None}
    added_tokens_to_remove = set(tokenizer.added_tokens_encoder) - added_tokens_to_keep
    cleaned_internal_tokens: list[Token] = []

    # Figure out whether token is a subword or not.
    encoded = backend_tokenizer.encode(f" {'a' * 25}", add_special_tokens=False)
    first_token, second_token, *_ = encoded.tokens
    # Isolate the prefix. We can't do first_token[0] because we don't know
    # how long the prefix is.
    # e.g., "Ġaaaa" -> "Ġ"
    a_index = None if "a" not in first_token else first_token.index("a")
    word_prefix = first_token[:a_index]
    is_byte_prefix = word_prefix == "Ġ"
    second_token = encoded.tokens[1]
    # The second token is the first subword token.
    # If a tokenizer uses subwords, this token will have been prefixed.
    # We don't know how long the prefix is.
    a_index = None if "a" not in second_token else second_token.index("a")
    subword_prefix = second_token[:a_index]

    pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer

    for token in internal_tokens:
        # Create the token objects. If this returns None, it was unsucessful for some reason.
        if token_object := _create_single_internal_token(
            token=token,
            subword_prefix=subword_prefix,
            word_prefix=word_prefix,
            pre_tokenizer=pre_tokenizer,
            is_byte_prefix=is_byte_prefix,
            token_remove_regex=token_remove_regex,
            added_tokens_to_keep=added_tokens_to_keep,
            added_tokens_to_remove=added_tokens_to_remove,
        ):
            cleaned_internal_tokens.append(token_object)

    if len(cleaned_internal_tokens) != len(internal_tokens):
        logger.info(
            f"Removed {len(internal_tokens) - len(cleaned_internal_tokens)} internal tokens from the vocabulary."
        )

    return cleaned_internal_tokens


def _create_single_internal_token(
    token: str,
    subword_prefix: str,
    word_prefix: str,
    pre_tokenizer: PreTokenizer | None,
    is_byte_prefix: bool,
    token_remove_regex: re.Pattern | None,
    added_tokens_to_keep: set[str],
    added_tokens_to_remove: set[str],
) -> Token | None:
    """Create a token object from a string."""
    if token in added_tokens_to_remove:
        # We remove any tokens that are added tokens that aren't [UNK] or [PAD].
        return None
    if token in added_tokens_to_keep:
        # Don't put added tokens through the regular motions.
        return Token(form=token, normalized_form=token, is_subword=False, is_internal=True)
    if token_remove_regex and token_remove_regex.match(token):
        # If the regex matches, remove the token.
        return None

    # A token is a subword if there is a subword prefix and the word
    # starts with a subword prefix, or if there is a WORD prefix, and the word
    # does not start with this prefix. For metaspace tokenizers, for example:
    # "doghouse" -> ["_dog", "house"]
    # So we can only tell that "house" is a subword by knowing that it is not prefixed
    # and word-initial tokens are.
    is_subword = False
    if subword_prefix:
        is_subword = bool(token.startswith(subword_prefix))
    if word_prefix:
        is_subword = not bool(token.startswith(word_prefix))

    # Byte prefixed tokenizers don't need to be checked.
    if pre_tokenizer is not None and not is_byte_prefix:
        # We need to check the thing without prefixes. If we have a word prefix,
        # we need to check tokens that have are subwords. Other way around for subword
        # prefixes.
        if (subword_prefix and not is_subword) or (word_prefix and is_subword):
            # If this is True, the token is unreachable, even though it is a subword token.
            if len(pre_tokenizer.pre_tokenize_str(token)) > 1:
                return None

    # Turn a token into a normalized form for later processing.
    normalized_form = _create_normalized_form(token, subword_prefix, word_prefix, is_byte_prefix, is_subword)

    return Token(form=token, normalized_form=normalized_form, is_subword=is_subword, is_internal=True)


def _create_normalized_form(
    token: str, subword_prefix: str, word_prefix: str, is_byte_prefix: bool, is_subword: bool
) -> str:
    """Turn an internal token string into a normalized form."""
    # We don't need to check byte prefixed strings.
    if is_byte_prefix:
        return token
    # We need to check if the token is a subword or not and remove the prefix.
    if is_subword:
        return token.removeprefix(subword_prefix)
    # If the token is not a subword, we need to remove the word prefix, and add metaspace.
    return f"▁{token.removeprefix(word_prefix)}"


def turn_tokens_into_ids(
    tokens: list[Token], tokenizer: PreTrainedTokenizerFast, unk_token: str | None
) -> list[list[int]]:
    """
    Convert a list of Token objects to their corresponding token ID sequences.

    :param tokens: List of Token objects to convert
    :param tokenizer: The tokenizer to use for converting tokens to IDs
    :param unk_token: The string form of the unk token.
    :return: List of token IDs corresponding to the input tokens
    """
    unk_id = None if unk_token is None else tokenizer.convert_tokens_to_ids(unk_token)
    prefix, suffix = find_eos_bos(tokenizer)

    token_ids: list[list[int]] = []
    for token in tokens:
        if token.is_internal:
            # Careful. Any incorrect tokens will just get `[UNK]``, so this could go horribly wrong
            # Cast because return type is wrong.
            token_id: int = cast("int", tokenizer.convert_tokens_to_ids(token.form)) or 0
            # Explicitly check and warn if `unk_id` appears, but don't crash.
            if unk_id is not None and token_id == unk_id and token.form != unk_token:
                logger.warning(f"Token {token.form} was set to unk. This is wrong.")
            token_ids.append([*prefix, token_id, *suffix])
        else:
            token_ids.append(tokenizer.encode(token.form))

    return token_ids


def find_eos_bos(tokenizer: PreTrainedTokenizerFast) -> tuple[list[int], list[int]]:
    """Finds the eos and bos tokens for a tokenizer."""
    # Little bit complicated, because not all tokenizers have eos and bos tokens.
    encoding = tokenizer.encode("a", add_special_tokens=True)
    if len(encoding) != 3:
        a_encoded = tokenizer.encode("a", add_special_tokens=False)
        if len(a_encoded) != 1:
            msg = f"Error while encoding, couldn't determine eos and bos tokens. The model tokenizes 'a' to '{a_encoded}'"
            raise ValueError(
                msg
            )
        a_idx = encoding.index(a_encoded[0])
        prefix, suffix = encoding[:a_idx], encoding[a_idx + 1 :]
    else:
        prefix, suffix = encoding[:1], encoding[2:]
    return prefix, suffix


def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str:
    """Normalize a token that is not in the initial token vocabulary."""
    # Add prefix space for byte tokenizers.
    prefixed_token = f" {token}"
    pretokenized_tokens: tuple[str, ...]
    pretokenized_tokens, offsets = zip(*pre_tokenizer.pre_tokenize_str(prefixed_token), strict=False)
    # The first item is always the start of the token.
    new_token = [pretokenized_tokens[0]]
    # Loop over the subtokens and offsets.
    for t, (s, _) in zip(pretokenized_tokens[1:], offsets[1:], strict=False):
        # Do not prefix the token with a space if it starts with a metaspace.
        if t.startswith("▁"):
            new_token.append(t)
        # If the character before the subtoken is a space, we have a
        # multiword token. e.g., "room for the moon", which is split into
        # ["room", "for", "the", "moon"].
        # If it doesn't have a space, it is part of a complex multiword token,
        # e.g., "chat-gpt", which is split into ["chat", "-", "gpt"].
        elif prefixed_token[s - 1] == " ":
            new_token.append(f" {t}")
        else:
            new_token.append(t)
    return "".join(new_token)



def create_tokenizer(
    tokenizer: PreTrainedTokenizerFast,
    vocabulary: list[str],
    token_remove_regex: re.Pattern | None = None,
) -> PreTrainedTokenizerFast:
    """
    Create a tokenizer by adding tokens to the vocabulary.

    This function turns any tokenizer into a supertoken tokenizer. It does the following:
    1. Turns the tokenizer model into a unigram model.
    2. Adds a new pretokenizer, splitting on punctuation.
    3. Adds all tokens in vocabulary to the model.
    4. Removes any internal tokens that conform to the regex.

    :param tokenizer: The tokenizer to use.
    :param vocabulary: The vocabulary to use.
    :param token_remove_regex: The regex to use to remove tokens from the vocabulary.
    :return: The created tokenizer.
    """
    unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token"))
    pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token"))
    cleaned_vocabulary, backend_tokenizer = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex)
    new_tokenizer = replace_vocabulary(backend_tokenizer, cleaned_vocabulary, unk_token, pad_token)

    return PreTrainedTokenizerFast(tokenizer_object=new_tokenizer)