aframson commited on
Commit
1fcedf8
·
1 Parent(s): 578f96a
Files changed (1) hide show
  1. tokenizeConfig.py +217 -125
tokenizeConfig.py CHANGED
@@ -1,130 +1,222 @@
1
- import os
2
- from shutil import copyfile
3
- from typing import Any, Dict, List, Optional, Tuple
4
-
5
- import tokenizers
6
- from tokenizers import models, pre_tokenizers, decoders, trainers
7
-
8
- from transformers.tokenization_utils import PreTrainedTokenizer
9
- from transformers.utils import logging
10
-
11
-
12
- logger = logging.get_logger(__name__)
13
-
14
- VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.json"}
15
-
16
- PRETRAINED_VOCAB_FILES_MAP = {}
17
-
18
-
19
- class OBITokenizer(PreTrainedTokenizer):
20
- """
21
- Construct a InternLM tokenizer. Based on byte-level Byte-Pair-Encoding.
22
- Args:
23
- vocab_file (`str`):
24
- Path to the vocabulary file.
25
- """
26
-
27
- vocab_files_names = VOCAB_FILES_NAMES
28
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
29
- model_input_names = ["input_ids", "attention_mask"]
30
- _auto_class = "AutoTokenizer"
31
-
32
- def __init__(
33
- self,
34
- vocab_file,
35
- unk_token="<unk>",
36
- bos_token="<s>",
37
- eos_token="</s>",
38
- pad_token="</s>",
39
- add_bos_token=True,
40
- add_eos_token=False,
41
- clean_up_tokenization_spaces=False,
42
- **kwargs,
43
- ):
44
- super().__init__(
45
- bos_token=bos_token,
46
- eos_token=eos_token,
47
- unk_token=unk_token,
48
- pad_token=pad_token,
49
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
50
- **kwargs,
51
- )
52
- self.vocab_file = vocab_file
53
- self.add_bos_token = add_bos_token
54
- self.add_eos_token = add_eos_token
55
- self.tokenizer = tokenizers.Tokenizer(models.BPE())
56
- self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
57
- self.tokenizer.decoder = decoders.ByteLevel()
58
- self.tokenizer.post_processor = tokenizers.processors.ByteLevel()
59
- self.tokenizer.enable_truncation(max_length=512) # Adjust max_length as needed
60
- self.tokenizer.enable_padding(max_length=512, pad_token="[PAD]") # Adjust max_length and pad_token as needed
61
-
62
- self._no_prefix_space_tokens = None
63
-
64
- @property
65
- def no_prefix_space_tokens(self):
66
- if self._no_prefix_space_tokens is None:
67
- vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
68
- self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
69
- return self._no_prefix_space_tokens
70
-
71
- @property
72
- def vocab_size(self):
73
- """Returns vocab size"""
74
- return len(self.tokenizer.get_vocab())
75
-
76
- @property
77
- def bos_token_id(self) -> Optional[int]:
78
- return self.tokenizer.token_to_id("<s>")
79
-
80
- @property
81
- def eos_token_id(self) -> Optional[int]:
82
- return self.tokenizer.token_to_id("</s>")
83
-
84
- def get_vocab(self):
85
- """Returns vocab as a dict"""
86
- vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
87
- vocab.update(self.added_tokens_encoder)
88
- return vocab
89
-
90
- def _tokenize(self, text):
91
- """Returns a tokenized string."""
92
- encoding = self.tokenizer.encode(text)
93
- return encoding.ids
94
-
95
- def _convert_token_to_id(self, token):
96
- """Converts a token (str) in an id using the vocab."""
97
- return self.tokenizer.token_to_id(token)
98
-
99
- def _convert_id_to_token(self, index):
100
- """Converts an index (integer) in a token (str) using the vocab."""
101
- return self.tokenizer.id_to_token(index)
102
-
103
- def convert_tokens_to_string(self, tokens):
104
- """Converts a sequence of tokens (string) into a single string."""
105
- return self.tokenizer.decode(tokens)
106
-
107
- def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
108
  """
109
- Save the vocabulary and special tokens file to a directory.
110
- Args:
111
- save_directory (`str`):
112
- The directory in which to save the vocabulary.
113
- Returns:
114
- `Tuple(str)`: Paths to the files saved.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  """
116
- if not os.path.isdir(save_directory):
117
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
118
- return
119
- out_vocab_file = os.path.join(
120
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
 
 
 
 
 
 
 
 
 
121
  )
122
 
123
- # Save the BPE vocab
124
- # Training: Fit the tokenizer on your text data
125
- trainer = trainers.BpeTrainer(special_tokens=["<unk>", "<s>", "</s>","[PAD]"])
126
- self.tokenizer.train(trainer=trainer, files=[out_vocab_file])
127
- # Save the trained tokenizer to a file
128
- self.tokenizer.save(out_vocab_file)
129
 
130
- return (out_vocab_file,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is an educational implementation of the byte pair encoding algorithm."""
2
+ import collections
3
+ from typing import Optional
4
+
5
+ import regex
6
+
7
+ import tiktoken
8
+
9
+
10
+ class OBITokenizer:
11
+ def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
12
+ """Creates an Encoding object."""
13
+ # A regex pattern string that is used to split the input text
14
+ self.pat_str = pat_str
15
+ # A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority
16
+ self.mergeable_ranks = mergeable_ranks
17
+
18
+ self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()}
19
+ self._pat = regex.compile(pat_str)
20
+
21
+ def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]:
22
+ """Encodes a string into tokens.
23
+
24
+ >>> enc.encode("hello world")
25
+ [388, 372]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
+ # Use the regex to split the text into (approximately) words
28
+ words = self._pat.findall(text)
29
+ tokens = []
30
+ for word in words:
31
+ # Turn each word into tokens, using the byte pair encoding algorithm
32
+ word_bytes = word.encode("utf-8")
33
+ word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise)
34
+ tokens.extend(word_tokens)
35
+ return tokens
36
+
37
+ def decode_bytes(self, tokens: list[int]) -> bytes:
38
+ """Decodes a list of tokens into bytes.
39
+
40
+ >>> enc.decode_bytes([388, 372])
41
+ b'hello world'
42
+ """
43
+ return b"".join(self._decoder[token] for token in tokens)
44
+
45
+ def decode(self, tokens: list[int]) -> str:
46
+ """Decodes a list of tokens into a string.
47
+
48
+ Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
49
+ the invalid bytes with the replacement character "�".
50
+
51
+ >>> enc.decode([388, 372])
52
+ 'hello world'
53
+ """
54
+ return self.decode_bytes(tokens).decode("utf-8", errors="replace")
55
+
56
+ def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
57
+ """Decodes a list of tokens into a list of bytes.
58
+
59
+ Useful for visualising how a string is tokenised.
60
+
61
+ >>> enc.decode_tokens_bytes([388, 372])
62
+ [b'hello', b' world']
63
  """
64
+ return [self._decoder[token] for token in tokens]
65
+
66
+ @staticmethod
67
+ def train(training_data: str, vocab_size: int, pat_str: str):
68
+ """Train a BPE tokeniser on some data!"""
69
+ mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str)
70
+ return OBITokenizer(pat_str=pat_str, mergeable_ranks=mergeable_ranks)
71
+
72
+ @staticmethod
73
+ def from_tiktoken(encoding):
74
+ if isinstance(encoding, str):
75
+ encoding = tiktoken.get_encoding(encoding)
76
+ return OBITokenizer(
77
+ pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks
78
  )
79
 
 
 
 
 
 
 
80
 
81
+ def bpe_encode(
82
+ mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour"
83
+ ) -> list[int]:
84
+ parts = [bytes([b]) for b in input]
85
+ while True:
86
+ # See the intermediate merges play out!
87
+ if visualise:
88
+ if visualise in ["colour", "color"]:
89
+ visualise_tokens(parts)
90
+ elif visualise == "simple":
91
+ print(parts)
92
+
93
+ # Iterate over all pairs and find the pair we want to merge the most
94
+ min_idx = None
95
+ min_rank = None
96
+ for i, pair in enumerate(zip(parts[:-1], parts[1:])):
97
+ rank = mergeable_ranks.get(pair[0] + pair[1])
98
+
99
+ if rank is not None and (min_rank is None or rank < min_rank):
100
+ min_idx = i
101
+ min_rank = rank
102
+
103
+ # If there were no pairs we could merge, we're done!
104
+ if min_rank is None:
105
+ break
106
+ assert min_idx is not None
107
+
108
+ # Otherwise, merge that pair and leave the rest unchanged. Then repeat.
109
+ parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
110
+
111
+ if visualise:
112
+ print()
113
+
114
+ tokens = [mergeable_ranks[part] for part in parts]
115
+ return tokens
116
+
117
+
118
+ def bpe_train(
119
+ data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour"
120
+ ) -> dict[bytes, int]:
121
+ # First, add tokens for each individual byte value
122
+ if vocab_size < 2**8:
123
+ raise ValueError("vocab_size must be at least 256, so we can encode all bytes")
124
+ ranks = {}
125
+ for i in range(2**8):
126
+ ranks[bytes([i])] = i
127
+
128
+ # Splinter up our data into lists of bytes
129
+ # data = "Hello world"
130
+ # words = [
131
+ # [b'H', b'e', b'l', b'l', b'o'],
132
+ # [b' ', b'w', b'o', b'r', b'l', b'd']
133
+ # ]
134
+ words: list[list[bytes]] = [
135
+ [bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
136
+ ]
137
+
138
+ # Now, use our data to figure out which merges we should make
139
+ while len(ranks) < vocab_size:
140
+ # Find the most common pair. This will become our next token
141
+ stats = collections.Counter()
142
+ for piece in words:
143
+ for pair in zip(piece[:-1], piece[1:]):
144
+ stats[pair] += 1
145
+
146
+ most_common_pair = max(stats, key=lambda x: stats[x])
147
+ token_bytes = most_common_pair[0] + most_common_pair[1]
148
+ token = len(ranks)
149
+ # Add the new token!
150
+ ranks[token_bytes] = token
151
+
152
+ # Now merge that most common pair in all the words. That is, update our training data
153
+ # to reflect our decision to make that pair into a new token.
154
+ new_words = []
155
+ for word in words:
156
+ new_word = []
157
+ i = 0
158
+ while i < len(word) - 1:
159
+ if (word[i], word[i + 1]) == most_common_pair:
160
+ # We found our pair! Merge it
161
+ new_word.append(token_bytes)
162
+ i += 2
163
+ else:
164
+ new_word.append(word[i])
165
+ i += 1
166
+ if i == len(word) - 1:
167
+ new_word.append(word[i])
168
+ new_words.append(new_word)
169
+ words = new_words
170
+
171
+ # See the intermediate merges play out!
172
+ if visualise:
173
+ print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}")
174
+ print(f"So we made {token_bytes} our {len(ranks)}th token")
175
+ if visualise in ["colour", "color"]:
176
+ print("Now the first fifty words in our training data look like:")
177
+ visualise_tokens([token for word in words[:50] for token in word])
178
+ elif visualise == "simple":
179
+ print("Now the first twenty words in our training data look like:")
180
+ for word in words[:20]:
181
+ print(word)
182
+ print("\n")
183
+
184
+ return ranks
185
+
186
+
187
+ def visualise_tokens(token_values: list[bytes]) -> None:
188
+ background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]]
189
+ # If token boundaries do not occur at unicode character boundaries, it's unclear how best to
190
+ # visualise the token. Here, we'll just use the unicode replacement character to represent some
191
+ # fraction of a character.
192
+ unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values]
193
+
194
+ running_length = 0
195
+ last_color = None
196
+ for token in unicode_token_values:
197
+ color = background[running_length % len(background)]
198
+ if color == last_color:
199
+ color = background[(running_length + 1) % len(background)]
200
+ assert color != last_color
201
+ last_color = color
202
+ running_length += len(token)
203
+ print(color + token, end="")
204
+ print("\u001b[0m")
205
+
206
+
207
+ def train_simple_encoding():
208
+ gpt2_pattern = (
209
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
210
+ )
211
+ with open(__file__, "r") as f:
212
+ data = f.read()
213
+
214
+ enc = OBITokenizer.train(data, vocab_size=600, pat_str=gpt2_pattern)
215
+
216
+ print("This is the sequence of merges performed in order to encode 'hello world':")
217
+ tokens = enc.encode("hello world")
218
+ assert enc.decode(tokens) == "hello world"
219
+ assert enc.decode_bytes(tokens) == b"hello world"
220
+ assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"]
221
+
222
+ return enc