File size: 5,584 Bytes
efaecc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from __future__ import annotations
from typing import Any
import tiktoken
from transformers import PreTrainedTokenizer

import tiktoken

class _SteerlingTokenizer:
    """
    Tokenizer for Steerling models.

    Uses tiktoken cl100k_base with 4 additional special tokens as mentioned above.
    """
    ENCODING_NAME = 'cl100k_base'

    def __init__(self):
        base_enc = tiktoken.get_encoding(self.ENCODING_NAME)
        base_vocab = base_enc.n_vocab
        self._pad_token_id = base_vocab
        self._bos_token_id = base_vocab + 1
        self._endofchunk_token_id = base_vocab + 2
        self._mask_token_id = base_vocab + 3
        self._eos_token_id = base_enc._special_tokens['<|endoftext|>']
        self._vocab_size = base_vocab + 4
        self._tokenizer = tiktoken.Encoding(name=f'{self.ENCODING_NAME}_steerling', pat_str=base_enc._pat_str, mergeable_ranks=base_enc._mergeable_ranks, special_tokens={**base_enc._special_tokens, '<|pad|>': self._pad_token_id, '<|bos|>': self._bos_token_id, '<|endofchunk|>': self._endofchunk_token_id, '<|mask|>': self._mask_token_id})
        self._special_token_ids = {self._pad_token_id, self._bos_token_id, self._eos_token_id, self._endofchunk_token_id, self._mask_token_id}

    def encode(self, text: str, add_special_tokens: bool=True) -> list[int]:
        """
        Encode text to token IDs.

        Args:
            text: Input text
            add_special_tokens: If True, prepend BOS and append EOS

        Returns:
            List of token IDs
        """
        tokens = self._tokenizer.encode(text, disallowed_special=())
        if add_special_tokens:
            tokens = [self._bos_token_id] + tokens + [self._eos_token_id]
        return tokens

    def decode(self, tokens: list[int], skip_special_tokens: bool=True) -> str:
        """
        Decode token IDs to text.

        Args:
            tokens: Token IDs (list, numpy array, or torch tensor)
            skip_special_tokens: If True, filter out special tokens before decoding

        Returns:
            Decoded text
        """
        if skip_special_tokens:
            tokens = [int(t) for t in tokens if int(t) not in self._special_token_ids]
        else:
            tokens = [int(t) for t in tokens]
        return self._tokenizer.decode(tokens)

    @property
    def vocab_size(self) -> int:
        return self._vocab_size

    @property
    def pad_token_id(self) -> int:
        return self._pad_token_id

    @property
    def bos_token_id(self) -> int:
        return self._bos_token_id

    @property
    def eos_token_id(self) -> int:
        return self._eos_token_id

    @property
    def endofchunk_token_id(self) -> int:
        return self._endofchunk_token_id

    @property
    def mask_token_id(self) -> int:
        return self._mask_token_id

class SteerlingTokenizer(PreTrainedTokenizer):
    vocab_files_names: dict[str, str] = {}
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(self, encoding_name="cl100k_base", pad_token_id=100277,
                 bos_token_id=100278, eos_token_id=100257,
                 endofchunk_token_id=100279, mask_token_id=100280, **kwargs):
        self._core = _SteerlingTokenizer()
        self._endofchunk_token_id = endofchunk_token_id
        self._mask_token_id = mask_token_id
        for k in ("pad_token", "bos_token", "eos_token", "additional_special_tokens"):
            kwargs.pop(k, None)
        super().__init__(pad_token="<|pad|>", bos_token="<|bos|>", eos_token="<|endoftext|>",
                         additional_special_tokens=["<|endofchunk|>", "<|mask|>"], **kwargs)

    @property
    def vocab_size(self): return self._core.vocab_size
    @property
    def endofchunk_token_id(self): return self._core.endofchunk_token_id
    @property
    def mask_token_id(self): return self._core.mask_token_id

    def get_vocab(self): return dict(self._core._tokenizer._special_tokens)

    def _tokenize(self, text, **kwargs):
        return [str(i) for i in self._core._tokenizer.encode(text, disallowed_special=())]

    def _convert_token_to_id(self, token):
        special = self._core._tokenizer._special_tokens
        if token in special: return special[token]
        try: return int(token)
        except ValueError:
            ids = self._core._tokenizer.encode(token, disallowed_special=())
            return ids[0] if ids else self._core.pad_token_id

    def _convert_id_to_token(self, index):
        for name, idx in self._core._tokenizer._special_tokens.items():
            if idx == index: return name
        try: return self._core._tokenizer.decode([index])
        except Exception: return f"<|token_{index}|>"

    def convert_tokens_to_string(self, tokens):
        ids, special = [], self._core._tokenizer._special_tokens
        for t in tokens:
            if t in special: continue
            try:
                tid = int(t)
                if tid not in self._core._special_token_ids: ids.append(tid)
            except ValueError:
                ids.extend(self._core._tokenizer.encode(t, disallowed_special=()))
        return self._core._tokenizer.decode(ids)

    def _decode(self, token_ids, skip_special_tokens=False, **kwargs):
        return self._core.decode(list(token_ids) if not isinstance(token_ids, list) else token_ids,
                                 skip_special_tokens=skip_special_tokens)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        return token_ids_0

    def save_vocabulary(self, save_directory, filename_prefix=None):
        return ()