tspersian commited on
Commit
4128ba5
·
1 Parent(s): 8eb50fc
Files changed (5) hide show
  1. README.md +19 -0
  2. __init__.py +3 -0
  3. base.py +300 -0
  4. helper.py +94 -0
  5. mana_tokenizer.py +70 -0
README.md CHANGED
@@ -10,6 +10,25 @@ language:
10
 
11
  The Mana Tokenizer is a custom-trained BPE tokenizer designed for Persian text. It is trained on a combination of huge Persian corpus. The tokenizer is built using the BPE with high character coverage to handle diverse Persian text.
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ## Special Tokens
14
 
15
  - **user Token:** `<|user|>`
 
10
 
11
  The Mana Tokenizer is a custom-trained BPE tokenizer designed for Persian text. It is trained on a combination of huge Persian corpus. The tokenizer is built using the BPE with high character coverage to handle diverse Persian text.
12
 
13
+ ## Quick Start
14
+
15
+ ```python
16
+ from mana_tokenizer import ManaTokenizer
17
+ tokenizer = ManaTokenizer()
18
+ text = "سلام من یک متن تست برای تست این تست هستم."
19
+ print(tokenizer.encode(text))
20
+ print(tokenizer.decode(tokenizer.encode(text)))
21
+ ```
22
+
23
+ You can also add special tokens
24
+ ```python
25
+ tokenizer.register_special_tokens({"</s>": 100269})
26
+ ```
27
+ Batch encode:
28
+ ```python
29
+ tokenizer.batch_encode(["یک متن طولانی"])
30
+ ```
31
+
32
  ## Special Tokens
33
 
34
  - **user Token:** `<|user|>`
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import Tokenizer
2
+ from .mana_tokenizer import ManaTokenizer
3
+ import helper
base.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from functools import lru_cache
3
+ import requests
4
+ from datasets import IterableDataset, Dataset
5
+ from pyarrow import ChunkedArray
6
+ from joblib import Parallel, delayed, cpu_count
7
+ import time
8
+ import os
9
+ import regex as re
10
+ import csv
11
+ import time
12
+ import helper
13
+
14
+ class Tokenizer:
15
+ """Base class for Tokenizers"""
16
+ def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
17
+ # default: vocab size of 256 (all bytes), no merges, no patterns
18
+ MANA_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re|می|نمی|به|بی|در|باز|بر|فرا|هم|ور|وا|ف|ک|چ|ن|پ|ا|از|ای|ی|ها|ترین|تر|ات|ان|ت|ٔ|یی|‌ا)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
19
+ self.merges = {} # (int, int) -> int
20
+ self.pattern = "" # str
21
+ self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
22
+ self.vocab = self._build_vocab() # int -> bytes
23
+ self.pattern = MANA_SPLIT_PATTERN if pattern is None else pattern
24
+ self.compiled_pattern = re.compile(self.pattern)
25
+ self.multiprocess = multiprocess
26
+ if multiprocess:
27
+ self._cpus = cpu_count()
28
+ else:
29
+ self._cpus = 1
30
+ self.store_dict = store_dict
31
+ self.stop_list_size = stop_list_size
32
+ self.stop_words = {}
33
+ self.freq_cutoff = freq_cutoff
34
+
35
+ def _id_dict_to_list(self, ids):
36
+ if self.stop_list_size:
37
+ # get twice as many to be sure to be able to get X chunks of length > 1
38
+ top2X = ids.most_common(2*self.stop_list_size)
39
+ index = len(self.vocab)
40
+ stop_index = index + self.stop_list_size
41
+ stop_words = {}
42
+ for key, val in top2X:
43
+ if len(key) > 1: # and re.match(r'^ [A-Za-z\'’`]+$[A-Za-z]*', key):
44
+ stop_words[key] = index
45
+ self.vocab[index] = key.encode('utf-8')
46
+ index += 1
47
+ if index == stop_index:
48
+ break
49
+ self.stop_words = stop_words
50
+ if self.freq_cutoff > 1:
51
+ return [([*key.encode('utf-8')], val) for key, val in ids.items()
52
+ if (val >= self.freq_cutoff and key not in self.stop_words)]
53
+ else:
54
+ return [([*key.encode('utf-8')], val) for key, val in ids.items()
55
+ if key not in self.stop_words]
56
+ else: # self.stop_list_size == 0
57
+ if self.freq_cutoff > 1:
58
+ return [([*key.encode('utf-8')], val) for key, val in ids.items()
59
+ if val >= self.freq_cutoff]
60
+ else:
61
+ return [([*key.encode('utf-8')], val) for key, val in ids.items()]
62
+
63
+ def _import_data(self, data):
64
+ # determine if `data` is a text as a string, a path to a file, a url to
65
+ # a text document, a dictionary of datasets kwargs, or a list of any of
66
+ # the above. Return a list of 2-tuples of bytes objects and their counts.
67
+ ids = Counter()
68
+ if not isinstance(data, (list, tuple)):
69
+ data = (data,)
70
+ for item in data:
71
+ # convert to ChunkedArray, dict, or str of text to parse
72
+ if isinstance(item, Dataset):
73
+ item = item.data['text']
74
+ elif isinstance(item, str) and item.endswith('.csv'): # csv file from previous data load
75
+ with open(item, 'r') as f:
76
+ reader = csv.reader(f)
77
+ next(reader)
78
+ item = {k: int(v) for k, v in reader}
79
+ elif isinstance(item, str):
80
+ if item.startswith('https://') or item.startswith('http://'):
81
+ item = requests.get(item).text # if it's a url, assume it's to a text file
82
+ elif os.path.isfile(item) and item.endswith('.txt'):
83
+ with open(item, 'r', encoding='utf-8') as f:
84
+ item = f.read()
85
+ # process data
86
+ if isinstance(item, dict):
87
+ last_item = item.popitem()
88
+ if last_item[1] != 0:
89
+ print(f'Warning: the csv file or dictionary passed does not seem to have been made by this tokenizer.')
90
+ item[last_item[0]] = last_item[1]
91
+ elif last_item[0] != self.pattern:
92
+ print(f'Warning: the dictionary or csv file passed did not use the same split pattern.')
93
+ ids.update(item)
94
+ elif isinstance(item, str): # assume the string is the text itself
95
+ ids.update(re.findall(self.compiled_pattern, item))
96
+ elif isinstance(item, ChunkedArray):
97
+ batch_size = len(item) // (self._cpus*2) or 1
98
+ batches = [item[i:i + batch_size] for i in range(0, len(item), batch_size)]
99
+ print(f'Processing {len(batches)} batches of size {batch_size}')
100
+ results = Parallel(n_jobs=self._cpus)(delayed(helper._process_string_scalar)(batch, self.compiled_pattern) for batch in batches)
101
+ for result in results: # Aggregate results into one Counter
102
+ ids.update(result)
103
+ elif isinstance(item, IterableDataset):
104
+ print('Serially processing IterableDataset...')
105
+ for _dict in item:
106
+ ids.update(re.findall(self.compiled_pattern, _dict['text']))
107
+
108
+ if self.store_dict: # store dict compression of dataset to a csv file if requested
109
+ ids[self.pattern] = 0 # store the pattern used to split the text as the last key
110
+ formatted_time = time.strftime('%Y-%m-%d-%H_%M', time.localtime())
111
+ filename = f'{formatted_time}-dataset-dict.csv'
112
+ try:
113
+ with open(filename, 'w', newline='') as f:
114
+ writer = csv.writer(f)
115
+ writer.writerow(['text_chunk', 'count'])
116
+ for key, value in ids.items():
117
+ writer.writerow([key, value])
118
+ print(f"Stored dictionary of {len(ids)} keys to {filename}")
119
+ except:
120
+ print('Failed to store dictionary of dataset.')
121
+ del ids[self.pattern] # remove the pattern key from the ids dict
122
+
123
+ ids = self._id_dict_to_list(ids)
124
+ return ids
125
+
126
+ def train(self, text, vocab_size, verbose=False):
127
+ # Tokenizer can train a vocabulary of size vocab_size from text
128
+ raise NotImplementedError
129
+
130
+ def _build_vocab(self):
131
+ # vocab is simply and deterministically derived from merges
132
+ vocab = {idx: bytes([idx]) for idx in range(256)}
133
+ for (p0, p1), idx in self.merges.items():
134
+ vocab[idx] = vocab[p0] + vocab[p1]
135
+ for special, idx in self.special_tokens.items():
136
+ vocab[idx] = special.encode("utf-8")
137
+ return vocab
138
+
139
+ def register_special_tokens(self, special_tokens):
140
+ # special_tokens is a dictionary of str -> int
141
+ # example: {"<|endoftext|>": 100257}
142
+ self.special_tokens = special_tokens
143
+ self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
144
+
145
+ def save(self, file_prefix):
146
+ """
147
+ Saves two files: file_prefix.vocab and file_prefix.model
148
+ This is inspired (but not equivalent to!) sentencepiece's model saving:
149
+ - model file is the critical one, intended for load() later
150
+ - vocab file is just a pretty printed version for human inspection only
151
+ """
152
+ # write the model: to be used in load() later
153
+ model_file = file_prefix + ".model"
154
+ with open(model_file, 'w', encoding='utf-8') as f: # Added encoding='utf-8'
155
+ # write the version, pattern and merges, that's all that's needed
156
+ f.write("mana v1\n")
157
+ f.write(f"{self.pattern}\n")
158
+ # write the special tokens, first the number of them, then each one
159
+ f.write(f"{len(self.special_tokens)}\n")
160
+ for special, idx in self.special_tokens.items():
161
+ f.write(f"{special} {idx}\n")
162
+ # the merges dict
163
+ for key in self.merges:
164
+ if isinstance(key, tuple):
165
+ f.write(f"{key[0]} {key[1]}\n")
166
+ else:
167
+ f.write(f"{key}\n")
168
+
169
+ # write the vocab: for the human to look at
170
+ vocab_file = file_prefix + ".vocab"
171
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
172
+ with open(vocab_file, "w", encoding="utf-8") as f: # Ensure this is also utf-8
173
+ for idx, token in self.vocab.items():
174
+ s = helper.render_token(token)
175
+ # find the children of this token, if any
176
+ if idx in inverted_merges:
177
+ idx0, idx1 = inverted_merges[idx]
178
+ s0 = helper.render_token(self.vocab[idx0])
179
+ s1 = helper.render_token(self.vocab[idx1])
180
+ f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
181
+ else:
182
+ f.write(f"[{s}] {idx}\n")
183
+
184
+ def load(self, model_file):
185
+ """Inverse of save() but only for the model file"""
186
+ assert model_file.endswith(".model")
187
+ # read the model file
188
+ merges = {}
189
+ special_tokens = {}
190
+ idx = 256
191
+ with open(model_file, 'r', encoding="utf-8") as f:
192
+ # read the version
193
+ version = f.readline().strip()
194
+ assert version == "mana v1"
195
+ # read the pattern
196
+ self.pattern = f.readline().strip()
197
+ # read the special tokens
198
+ num_special = int(f.readline().strip())
199
+ for _ in range(num_special):
200
+ special, special_idx = f.readline().strip().split()
201
+ special_tokens[special] = int(special_idx)
202
+ # read the merges
203
+ for line in f:
204
+ idx1, idx2 = map(int, line.split())
205
+ merges[(idx1, idx2)] = idx
206
+ idx += 1
207
+ self.merges = merges
208
+ self.special_tokens = special_tokens
209
+ self.vocab = self._build_vocab()
210
+
211
+ def decode(self, ids):
212
+ # given ids (list of integers), return Python string
213
+ part_bytes = [self.vocab[idx] if idx in self.vocab
214
+ else self.inverse_special_tokens[idx].encode("utf-8")
215
+ for idx in ids] # raises KeyError if any idx is not a valid token
216
+ text_bytes = b"".join(part_bytes)
217
+ text = text_bytes.decode("utf-8", errors="replace")
218
+ return text
219
+
220
+ @lru_cache(maxsize=131072)
221
+ def _encode_chunk(self, chunk):
222
+ if chunk in self.stop_words: # TODO: revisit this if statement
223
+ return [self.stop_words[chunk]]
224
+ # return the token chunk as a list of ints, similar to a bytes object
225
+ chunk = [*chunk.encode("utf-8")]
226
+ len_chunk = len(chunk)
227
+ while len_chunk >= 2:
228
+ # find the pair with the lowest merge index
229
+ low = 987654321
230
+ for i in range(len_chunk - 1):
231
+ current_pair = (chunk[i], chunk[i+1])
232
+ new_val = self.merges.get(current_pair, 987654321)
233
+ if new_val < low:
234
+ pair = current_pair
235
+ low = new_val
236
+ if low == 987654321: # no merges were found
237
+ break # nothing else can be merged
238
+ # otherwise let's merge the best pair (lowest merge index)
239
+ idx = self.merges[pair]
240
+ len_chunk = helper.merge(chunk, pair, idx, len_chunk)
241
+ return chunk # list of ints
242
+
243
+ def encode_ordinary(self, text):
244
+ """Encoding that ignores any special tokens."""
245
+ ids = []
246
+ for chunk in re.findall(self.compiled_pattern, text):
247
+ ids.extend(self._encode_chunk(chunk))
248
+ return ids
249
+
250
+ def encode(self, text, allowed_special="none_raise"):
251
+ """
252
+ Unlike encode_ordinary, this function handles special tokens.
253
+ allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
254
+ if none_raise, then an error is raised if any special token is encountered in text
255
+ this is the default tiktoken behavior right now as well
256
+ any other behavior is either annoying, or a major footgun
257
+ """
258
+ # decode the user desire w.r.t. handling of special tokens
259
+ special = None
260
+ if allowed_special == "all":
261
+ special = self.special_tokens
262
+ elif allowed_special == "none":
263
+ special = {}
264
+ elif allowed_special == "none_raise":
265
+ special = {}
266
+ assert all(token not in text for token in self.special_tokens)
267
+ elif isinstance(allowed_special, set):
268
+ special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
269
+ else:
270
+ raise ValueError(f"allowed_special={allowed_special} not understood")
271
+ if not special: # shortcut: if no special tokens, just use the ordinary encoding
272
+ return self.encode_ordinary(text)
273
+ # split on special tokens. Note that surrounding the pattern with ()
274
+ # makes it into a capturing group, so the special tokens will be included
275
+ special_pattern = f"({'|'.join([re.escape(k) for k in special])})"
276
+ special_chunks = re.split(special_pattern, text)
277
+ # now all the special characters are separated from the rest of the text
278
+ # all chunks of text are encoded separately, then results are joined
279
+ ids = []
280
+ for part in special_chunks:
281
+ special_token = special.get(part)
282
+ if special_token is None: # this is an ordinary sequence, encode it normally
283
+ ids.extend(self.encode_ordinary(part))
284
+ else: # this is a special token, encode it separately as a special case
285
+ ids.append(special_token)
286
+ return ids
287
+
288
+ def batch_encode(self, texts, allowed_special="none_raise"):
289
+ """
290
+ Encode a list of texts in batch mode.
291
+ Each text will be encoded according to the handling of special tokens specified in allowed_special.
292
+
293
+ Parameters:
294
+ texts (list of str): List of texts to encode.
295
+ allowed_special (str|set): Special token handling mode.
296
+
297
+ Returns:
298
+ list of list of int: A list where each element is the encoded form of a text in `texts`.
299
+ """
300
+ return [self.encode(text, allowed_special=allowed_special) for text in texts]
helper.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from collections import Counter, defaultdict
3
+ import unicodedata
4
+
5
+ def get_stats(ids):
6
+ """
7
+ Given `ids`, a list of 2-tuples of iterables of ints and int values,
8
+ returns a defaultdict with the counts of occurrences of all the consecutive
9
+ pairs of integers within each bytes object, multiplied by the integer value
10
+ associated with each key. This function does not count pairs between the last
11
+ element of one key the first element of the next key. The integer value
12
+ associated with each key serves as a multiplier for the count of each pair
13
+ within that object. Consecutive identical pairs within the same bytes object
14
+ are counted only once to avoid overcounting repeat characters.
15
+
16
+ Example:
17
+ get_stats({b'abc': 2, b'bcd': 1, b'eee': 1})
18
+ -> defaultdict(<class 'int'>, {(97, 98): 1, (98, 99): 2, (99, 100): 1, (101, 101): 1})
19
+ """
20
+ counts = defaultdict(int)
21
+ for chunk, num in ids:
22
+ last_index = len(chunk) - 1
23
+ i = 0
24
+ while i < last_index:
25
+ j = i + 1
26
+ counts[(chunk[i], chunk[j])] += num
27
+ i = j
28
+ return counts
29
+
30
+ def merge_batch_get_stats(ids, pairs):
31
+ counts = defaultdict(int)
32
+ for chunk, num in ids:
33
+ last_index = len(chunk) - 1
34
+ i = 0
35
+ while i < last_index:
36
+ j = i + 1
37
+ token = pairs.get((chunk[i], chunk[j]))
38
+ if token is not None:
39
+ chunk[i] = token
40
+ del chunk[j]
41
+ last_index -= 1
42
+ if i:
43
+ counts[(chunk[i-1], chunk[i])] += num
44
+ i = j
45
+ if i and i == last_index:
46
+ counts[(chunk[-2], chunk[i])] += num
47
+ return counts
48
+
49
+ def merge(ids, pair, idx, len_ids):
50
+ """
51
+ In the list of integers (ids), replace all consecutive occurrences
52
+ of pair with the new integer token idx
53
+ Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
54
+ """
55
+ i = 0
56
+ while i + 1 < len_ids:
57
+ j = i + 1
58
+ if ids[i] == pair[0] and ids[j] == pair[1]:
59
+ ids[i] = idx
60
+ del ids[j]
61
+ len_ids -= 1
62
+ i = j
63
+ return len_ids
64
+
65
+ def replace_control_characters(s: str) -> str:
66
+ # we don't want to print control characters
67
+ # which distort the output (e.g. \n or much worse)
68
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
69
+ # http://www.unicode.org/reports/tr44/#GC_Values_Table
70
+ chars = []
71
+ for ch in s:
72
+ if unicodedata.category(ch)[0] != "C":
73
+ chars.append(ch) # this character is ok
74
+ else:
75
+ chars.append(f"\\u{ord(ch):04x}") # escape
76
+ return "".join(chars)
77
+
78
+ def render_token(t: bytes) -> str:
79
+ # pretty print a token, escaping control characters
80
+ s = t.decode('utf-8', errors='replace')
81
+ s = replace_control_characters(s)
82
+ return s
83
+
84
+ def _process_dicts(batch, compiled_pattern): # for raw datasets.Dataset
85
+ counter = Counter()
86
+ for item in batch:
87
+ counter.update(re.findall(compiled_pattern, item))
88
+ return counter
89
+
90
+ def _process_string_scalar(batch, compiled_pattern):
91
+ counter = Counter()
92
+ for item in batch:
93
+ counter.update(re.findall(compiled_pattern, item.as_py()))
94
+ return counter
mana_tokenizer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import Tokenizer, get_stats, merge_batch_get_stats
2
+ from heapq import nlargest
3
+ import time
4
+
5
+ MANA_SPECIAL_TOKENS = {
6
+ '<|end|>': 100257,
7
+ '<|user|>': 100258,
8
+ '<|assistant|>': 100259,
9
+ '<|system|>': 100260
10
+ }
11
+
12
+ class ManaTokenizer(Tokenizer):
13
+ def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
14
+ """
15
+ - pattern: optional string to override the default (GPT-4 split pattern)
16
+ - special_tokens: str -> int dictionary of special tokens
17
+ example: {'<|endoftext|>': 100257}
18
+ """
19
+ self.register_special_tokens(MANA_SPECIAL_TOKENS)
20
+ super().__init__(pattern, multiprocess, store_dict, stop_list_size, freq_cutoff)
21
+
22
+ def train(self, data, vocab_size, cap_divisor=2, max_batch_size=0, verbose=False):
23
+ t0 = time.time()
24
+ ids = self._import_data(data) # [(bytes, int)] -> text chunks and their counts
25
+ t1 = time.time()
26
+ print(f'Time spent loading data: {t1-t0:.2f}')
27
+
28
+ merges = self.merges # {(int, int): int} -> token pair to new token
29
+ vocab = self.vocab # {int: bytes} -> token to its bytes representation
30
+ batch_count = 0
31
+ curr_vocab_size = len(vocab)
32
+ num_merges = vocab_size - curr_vocab_size
33
+ merges_remaining = num_merges
34
+ if max_batch_size < 1:
35
+ max_batch_size = num_merges
36
+ stats = get_stats(ids) # stats are later updated by merge_batch_get_stats
37
+ start_time = time.time()
38
+ while merges_remaining > 0:
39
+ seen_first = set() # tokens seen in the first position in pairs
40
+ seen_last = set() # tokens seen in the last position in pairs
41
+ pairs_to_merge = {}
42
+ num_pairs_to_search = min(merges_remaining//cap_divisor, len(vocab), max_batch_size) or 1
43
+ top_pairs = nlargest(num_pairs_to_search, stats, key=stats.get)
44
+ for first, last in top_pairs: # pairs are (first, last) tuples
45
+ if first in seen_last or last in seen_first: # unsafe merge
46
+ seen_first.add(first)
47
+ seen_last.add(last)
48
+ continue # skip this pair but keep looking for safe merges in top_pairs
49
+ seen_first.add(first)
50
+ seen_last.add(last)
51
+ pairs_to_merge[(first, last)] = curr_vocab_size
52
+ vocab[curr_vocab_size] = vocab[first] + vocab[last]
53
+ curr_vocab_size += 1
54
+ merges_remaining -= len(pairs_to_merge)
55
+ merges.update(pairs_to_merge) # save the merges
56
+ batch_count += 1
57
+ if merges_remaining: # no need to merge last batch
58
+ stats = merge_batch_get_stats(ids, pairs_to_merge) # replace pairs_to_merge keys in ids with their values
59
+ if verbose:
60
+ t2 = time.time()
61
+ time_taken = t2 - start_time
62
+ avg_time_per_batch = time_taken / batch_count
63
+ estimated_remaining_time = avg_time_per_batch * (num_merges - merges_remaining)
64
+ estimated_end_time = time.strftime("%H:%M:%S", time.localtime(time.time() + estimated_remaining_time))
65
+ print(f"Batch {batch_count} merged {len(pairs_to_merge)} pairs in {t2-t1:.2f} sec. "
66
+ f"Merges remaining: {merges_remaining}. Estimated end time: {estimated_end_time}")
67
+ t1 = t2
68
+
69
+ self.merges = merges # used in encode()
70
+ self.vocab = vocab # used in decode()