Update tokenization_steerling.py

#5
by AyaGL - opened
Files changed (1) hide show
  1. tokenization_steerling.py +145 -0
tokenization_steerling.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any
3
+ import tiktoken
4
+ from transformers import PreTrainedTokenizer
5
+
6
+ import tiktoken
7
+
8
+ class _SteerlingTokenizer:
9
+ """
10
+ Tokenizer for Steerling models.
11
+
12
+ Uses tiktoken cl100k_base with 4 additional special tokens as mentioned above.
13
+ """
14
+ ENCODING_NAME = 'cl100k_base'
15
+
16
+ def __init__(self):
17
+ base_enc = tiktoken.get_encoding(self.ENCODING_NAME)
18
+ base_vocab = base_enc.n_vocab
19
+ self._pad_token_id = base_vocab
20
+ self._bos_token_id = base_vocab + 1
21
+ self._endofchunk_token_id = base_vocab + 2
22
+ self._mask_token_id = base_vocab + 3
23
+ self._eos_token_id = base_enc._special_tokens['<|endoftext|>']
24
+ self._vocab_size = base_vocab + 4
25
+ self._tokenizer = tiktoken.Encoding(name=f'{self.ENCODING_NAME}_steerling', pat_str=base_enc._pat_str, mergeable_ranks=base_enc._mergeable_ranks, special_tokens={**base_enc._special_tokens, '<|pad|>': self._pad_token_id, '<|bos|>': self._bos_token_id, '<|endofchunk|>': self._endofchunk_token_id, '<|mask|>': self._mask_token_id})
26
+ self._special_token_ids = {self._pad_token_id, self._bos_token_id, self._eos_token_id, self._endofchunk_token_id, self._mask_token_id}
27
+
28
+ def encode(self, text: str, add_special_tokens: bool=True) -> list[int]:
29
+ """
30
+ Encode text to token IDs.
31
+
32
+ Args:
33
+ text: Input text
34
+ add_special_tokens: If True, prepend BOS and append EOS
35
+
36
+ Returns:
37
+ List of token IDs
38
+ """
39
+ tokens = self._tokenizer.encode(text, disallowed_special=())
40
+ if add_special_tokens:
41
+ tokens = [self._bos_token_id] + tokens + [self._eos_token_id]
42
+ return tokens
43
+
44
+ def decode(self, tokens: list[int], skip_special_tokens: bool=True) -> str:
45
+ """
46
+ Decode token IDs to text.
47
+
48
+ Args:
49
+ tokens: Token IDs (list, numpy array, or torch tensor)
50
+ skip_special_tokens: If True, filter out special tokens before decoding
51
+
52
+ Returns:
53
+ Decoded text
54
+ """
55
+ if skip_special_tokens:
56
+ tokens = [int(t) for t in tokens if int(t) not in self._special_token_ids]
57
+ else:
58
+ tokens = [int(t) for t in tokens]
59
+ return self._tokenizer.decode(tokens)
60
+
61
+ @property
62
+ def vocab_size(self) -> int:
63
+ return self._vocab_size
64
+
65
+ @property
66
+ def pad_token_id(self) -> int:
67
+ return self._pad_token_id
68
+
69
+ @property
70
+ def bos_token_id(self) -> int:
71
+ return self._bos_token_id
72
+
73
+ @property
74
+ def eos_token_id(self) -> int:
75
+ return self._eos_token_id
76
+
77
+ @property
78
+ def endofchunk_token_id(self) -> int:
79
+ return self._endofchunk_token_id
80
+
81
+ @property
82
+ def mask_token_id(self) -> int:
83
+ return self._mask_token_id
84
+
85
+ class SteerlingTokenizer(PreTrainedTokenizer):
86
+ vocab_files_names: dict[str, str] = {}
87
+ model_input_names = ["input_ids", "attention_mask"]
88
+
89
+ def __init__(self, encoding_name="cl100k_base", pad_token_id=100277,
90
+ bos_token_id=100278, eos_token_id=100257,
91
+ endofchunk_token_id=100279, mask_token_id=100280, **kwargs):
92
+ self._core = _SteerlingTokenizer()
93
+ self._endofchunk_token_id = endofchunk_token_id
94
+ self._mask_token_id = mask_token_id
95
+ for k in ("pad_token", "bos_token", "eos_token", "additional_special_tokens"):
96
+ kwargs.pop(k, None)
97
+ super().__init__(pad_token="<|pad|>", bos_token="<|bos|>", eos_token="<|endoftext|>",
98
+ additional_special_tokens=["<|endofchunk|>", "<|mask|>"], **kwargs)
99
+
100
+ @property
101
+ def vocab_size(self): return self._core.vocab_size
102
+ @property
103
+ def endofchunk_token_id(self): return self._core.endofchunk_token_id
104
+ @property
105
+ def mask_token_id(self): return self._core.mask_token_id
106
+
107
+ def get_vocab(self): return dict(self._core._tokenizer._special_tokens)
108
+
109
+ def _tokenize(self, text, **kwargs):
110
+ return [str(i) for i in self._core._tokenizer.encode(text, disallowed_special=())]
111
+
112
+ def _convert_token_to_id(self, token):
113
+ special = self._core._tokenizer._special_tokens
114
+ if token in special: return special[token]
115
+ try: return int(token)
116
+ except ValueError:
117
+ ids = self._core._tokenizer.encode(token, disallowed_special=())
118
+ return ids[0] if ids else self._core.pad_token_id
119
+
120
+ def _convert_id_to_token(self, index):
121
+ for name, idx in self._core._tokenizer._special_tokens.items():
122
+ if idx == index: return name
123
+ try: return self._core._tokenizer.decode([index])
124
+ except Exception: return f"<|token_{index}|>"
125
+
126
+ def convert_tokens_to_string(self, tokens):
127
+ ids, special = [], self._core._tokenizer._special_tokens
128
+ for t in tokens:
129
+ if t in special: continue
130
+ try:
131
+ tid = int(t)
132
+ if tid not in self._core._special_token_ids: ids.append(tid)
133
+ except ValueError:
134
+ ids.extend(self._core._tokenizer.encode(t, disallowed_special=()))
135
+ return self._core._tokenizer.decode(ids)
136
+
137
+ def _decode(self, token_ids, skip_special_tokens=False, **kwargs):
138
+ return self._core.decode(list(token_ids) if not isinstance(token_ids, list) else token_ids,
139
+ skip_special_tokens=skip_special_tokens)
140
+
141
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
142
+ return token_ids_0
143
+
144
+ def save_vocabulary(self, save_directory, filename_prefix=None):
145
+ return ()