scripts
Browse files- README.md +19 -0
- __init__.py +3 -0
- base.py +300 -0
- helper.py +94 -0
- mana_tokenizer.py +70 -0
README.md
CHANGED
|
@@ -10,6 +10,25 @@ language:
|
|
| 10 |
|
| 11 |
The Mana Tokenizer is a custom-trained BPE tokenizer designed for Persian text. It is trained on a combination of huge Persian corpus. The tokenizer is built using the BPE with high character coverage to handle diverse Persian text.
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
## Special Tokens
|
| 14 |
|
| 15 |
- **user Token:** `<|user|>`
|
|
|
|
| 10 |
|
| 11 |
The Mana Tokenizer is a custom-trained BPE tokenizer designed for Persian text. It is trained on a combination of huge Persian corpus. The tokenizer is built using the BPE with high character coverage to handle diverse Persian text.
|
| 12 |
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
from mana_tokenizer import ManaTokenizer
|
| 17 |
+
tokenizer = ManaTokenizer()
|
| 18 |
+
text = "سلام من یک متن تست برای تست این تست هستم."
|
| 19 |
+
print(tokenizer.encode(text))
|
| 20 |
+
print(tokenizer.decode(tokenizer.encode(text)))
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
You can also add special tokens
|
| 24 |
+
```python
|
| 25 |
+
tokenizer.register_special_tokens({"</s>": 100269})
|
| 26 |
+
```
|
| 27 |
+
Batch encode:
|
| 28 |
+
```python
|
| 29 |
+
tokenizer.batch_encode(["یک متن طولانی"])
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
## Special Tokens
|
| 33 |
|
| 34 |
- **user Token:** `<|user|>`
|
__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import Tokenizer
|
| 2 |
+
from .mana_tokenizer import ManaTokenizer
|
| 3 |
+
import helper
|
base.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
import requests
|
| 4 |
+
from datasets import IterableDataset, Dataset
|
| 5 |
+
from pyarrow import ChunkedArray
|
| 6 |
+
from joblib import Parallel, delayed, cpu_count
|
| 7 |
+
import time
|
| 8 |
+
import os
|
| 9 |
+
import regex as re
|
| 10 |
+
import csv
|
| 11 |
+
import time
|
| 12 |
+
import helper
|
| 13 |
+
|
| 14 |
+
class Tokenizer:
|
| 15 |
+
"""Base class for Tokenizers"""
|
| 16 |
+
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
|
| 17 |
+
# default: vocab size of 256 (all bytes), no merges, no patterns
|
| 18 |
+
MANA_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re|می|نمی|به|بی|در|باز|بر|فرا|هم|ور|وا|ف|ک|چ|ن|پ|ا|از|ای|ی|ها|ترین|تر|ات|ان|ت|ٔ|یی|ا)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
| 19 |
+
self.merges = {} # (int, int) -> int
|
| 20 |
+
self.pattern = "" # str
|
| 21 |
+
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
|
| 22 |
+
self.vocab = self._build_vocab() # int -> bytes
|
| 23 |
+
self.pattern = MANA_SPLIT_PATTERN if pattern is None else pattern
|
| 24 |
+
self.compiled_pattern = re.compile(self.pattern)
|
| 25 |
+
self.multiprocess = multiprocess
|
| 26 |
+
if multiprocess:
|
| 27 |
+
self._cpus = cpu_count()
|
| 28 |
+
else:
|
| 29 |
+
self._cpus = 1
|
| 30 |
+
self.store_dict = store_dict
|
| 31 |
+
self.stop_list_size = stop_list_size
|
| 32 |
+
self.stop_words = {}
|
| 33 |
+
self.freq_cutoff = freq_cutoff
|
| 34 |
+
|
| 35 |
+
def _id_dict_to_list(self, ids):
|
| 36 |
+
if self.stop_list_size:
|
| 37 |
+
# get twice as many to be sure to be able to get X chunks of length > 1
|
| 38 |
+
top2X = ids.most_common(2*self.stop_list_size)
|
| 39 |
+
index = len(self.vocab)
|
| 40 |
+
stop_index = index + self.stop_list_size
|
| 41 |
+
stop_words = {}
|
| 42 |
+
for key, val in top2X:
|
| 43 |
+
if len(key) > 1: # and re.match(r'^ [A-Za-z\'’`]+$[A-Za-z]*', key):
|
| 44 |
+
stop_words[key] = index
|
| 45 |
+
self.vocab[index] = key.encode('utf-8')
|
| 46 |
+
index += 1
|
| 47 |
+
if index == stop_index:
|
| 48 |
+
break
|
| 49 |
+
self.stop_words = stop_words
|
| 50 |
+
if self.freq_cutoff > 1:
|
| 51 |
+
return [([*key.encode('utf-8')], val) for key, val in ids.items()
|
| 52 |
+
if (val >= self.freq_cutoff and key not in self.stop_words)]
|
| 53 |
+
else:
|
| 54 |
+
return [([*key.encode('utf-8')], val) for key, val in ids.items()
|
| 55 |
+
if key not in self.stop_words]
|
| 56 |
+
else: # self.stop_list_size == 0
|
| 57 |
+
if self.freq_cutoff > 1:
|
| 58 |
+
return [([*key.encode('utf-8')], val) for key, val in ids.items()
|
| 59 |
+
if val >= self.freq_cutoff]
|
| 60 |
+
else:
|
| 61 |
+
return [([*key.encode('utf-8')], val) for key, val in ids.items()]
|
| 62 |
+
|
| 63 |
+
def _import_data(self, data):
|
| 64 |
+
# determine if `data` is a text as a string, a path to a file, a url to
|
| 65 |
+
# a text document, a dictionary of datasets kwargs, or a list of any of
|
| 66 |
+
# the above. Return a list of 2-tuples of bytes objects and their counts.
|
| 67 |
+
ids = Counter()
|
| 68 |
+
if not isinstance(data, (list, tuple)):
|
| 69 |
+
data = (data,)
|
| 70 |
+
for item in data:
|
| 71 |
+
# convert to ChunkedArray, dict, or str of text to parse
|
| 72 |
+
if isinstance(item, Dataset):
|
| 73 |
+
item = item.data['text']
|
| 74 |
+
elif isinstance(item, str) and item.endswith('.csv'): # csv file from previous data load
|
| 75 |
+
with open(item, 'r') as f:
|
| 76 |
+
reader = csv.reader(f)
|
| 77 |
+
next(reader)
|
| 78 |
+
item = {k: int(v) for k, v in reader}
|
| 79 |
+
elif isinstance(item, str):
|
| 80 |
+
if item.startswith('https://') or item.startswith('http://'):
|
| 81 |
+
item = requests.get(item).text # if it's a url, assume it's to a text file
|
| 82 |
+
elif os.path.isfile(item) and item.endswith('.txt'):
|
| 83 |
+
with open(item, 'r', encoding='utf-8') as f:
|
| 84 |
+
item = f.read()
|
| 85 |
+
# process data
|
| 86 |
+
if isinstance(item, dict):
|
| 87 |
+
last_item = item.popitem()
|
| 88 |
+
if last_item[1] != 0:
|
| 89 |
+
print(f'Warning: the csv file or dictionary passed does not seem to have been made by this tokenizer.')
|
| 90 |
+
item[last_item[0]] = last_item[1]
|
| 91 |
+
elif last_item[0] != self.pattern:
|
| 92 |
+
print(f'Warning: the dictionary or csv file passed did not use the same split pattern.')
|
| 93 |
+
ids.update(item)
|
| 94 |
+
elif isinstance(item, str): # assume the string is the text itself
|
| 95 |
+
ids.update(re.findall(self.compiled_pattern, item))
|
| 96 |
+
elif isinstance(item, ChunkedArray):
|
| 97 |
+
batch_size = len(item) // (self._cpus*2) or 1
|
| 98 |
+
batches = [item[i:i + batch_size] for i in range(0, len(item), batch_size)]
|
| 99 |
+
print(f'Processing {len(batches)} batches of size {batch_size}')
|
| 100 |
+
results = Parallel(n_jobs=self._cpus)(delayed(helper._process_string_scalar)(batch, self.compiled_pattern) for batch in batches)
|
| 101 |
+
for result in results: # Aggregate results into one Counter
|
| 102 |
+
ids.update(result)
|
| 103 |
+
elif isinstance(item, IterableDataset):
|
| 104 |
+
print('Serially processing IterableDataset...')
|
| 105 |
+
for _dict in item:
|
| 106 |
+
ids.update(re.findall(self.compiled_pattern, _dict['text']))
|
| 107 |
+
|
| 108 |
+
if self.store_dict: # store dict compression of dataset to a csv file if requested
|
| 109 |
+
ids[self.pattern] = 0 # store the pattern used to split the text as the last key
|
| 110 |
+
formatted_time = time.strftime('%Y-%m-%d-%H_%M', time.localtime())
|
| 111 |
+
filename = f'{formatted_time}-dataset-dict.csv'
|
| 112 |
+
try:
|
| 113 |
+
with open(filename, 'w', newline='') as f:
|
| 114 |
+
writer = csv.writer(f)
|
| 115 |
+
writer.writerow(['text_chunk', 'count'])
|
| 116 |
+
for key, value in ids.items():
|
| 117 |
+
writer.writerow([key, value])
|
| 118 |
+
print(f"Stored dictionary of {len(ids)} keys to {filename}")
|
| 119 |
+
except:
|
| 120 |
+
print('Failed to store dictionary of dataset.')
|
| 121 |
+
del ids[self.pattern] # remove the pattern key from the ids dict
|
| 122 |
+
|
| 123 |
+
ids = self._id_dict_to_list(ids)
|
| 124 |
+
return ids
|
| 125 |
+
|
| 126 |
+
def train(self, text, vocab_size, verbose=False):
|
| 127 |
+
# Tokenizer can train a vocabulary of size vocab_size from text
|
| 128 |
+
raise NotImplementedError
|
| 129 |
+
|
| 130 |
+
def _build_vocab(self):
|
| 131 |
+
# vocab is simply and deterministically derived from merges
|
| 132 |
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
| 133 |
+
for (p0, p1), idx in self.merges.items():
|
| 134 |
+
vocab[idx] = vocab[p0] + vocab[p1]
|
| 135 |
+
for special, idx in self.special_tokens.items():
|
| 136 |
+
vocab[idx] = special.encode("utf-8")
|
| 137 |
+
return vocab
|
| 138 |
+
|
| 139 |
+
def register_special_tokens(self, special_tokens):
|
| 140 |
+
# special_tokens is a dictionary of str -> int
|
| 141 |
+
# example: {"<|endoftext|>": 100257}
|
| 142 |
+
self.special_tokens = special_tokens
|
| 143 |
+
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
|
| 144 |
+
|
| 145 |
+
def save(self, file_prefix):
|
| 146 |
+
"""
|
| 147 |
+
Saves two files: file_prefix.vocab and file_prefix.model
|
| 148 |
+
This is inspired (but not equivalent to!) sentencepiece's model saving:
|
| 149 |
+
- model file is the critical one, intended for load() later
|
| 150 |
+
- vocab file is just a pretty printed version for human inspection only
|
| 151 |
+
"""
|
| 152 |
+
# write the model: to be used in load() later
|
| 153 |
+
model_file = file_prefix + ".model"
|
| 154 |
+
with open(model_file, 'w', encoding='utf-8') as f: # Added encoding='utf-8'
|
| 155 |
+
# write the version, pattern and merges, that's all that's needed
|
| 156 |
+
f.write("mana v1\n")
|
| 157 |
+
f.write(f"{self.pattern}\n")
|
| 158 |
+
# write the special tokens, first the number of them, then each one
|
| 159 |
+
f.write(f"{len(self.special_tokens)}\n")
|
| 160 |
+
for special, idx in self.special_tokens.items():
|
| 161 |
+
f.write(f"{special} {idx}\n")
|
| 162 |
+
# the merges dict
|
| 163 |
+
for key in self.merges:
|
| 164 |
+
if isinstance(key, tuple):
|
| 165 |
+
f.write(f"{key[0]} {key[1]}\n")
|
| 166 |
+
else:
|
| 167 |
+
f.write(f"{key}\n")
|
| 168 |
+
|
| 169 |
+
# write the vocab: for the human to look at
|
| 170 |
+
vocab_file = file_prefix + ".vocab"
|
| 171 |
+
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
|
| 172 |
+
with open(vocab_file, "w", encoding="utf-8") as f: # Ensure this is also utf-8
|
| 173 |
+
for idx, token in self.vocab.items():
|
| 174 |
+
s = helper.render_token(token)
|
| 175 |
+
# find the children of this token, if any
|
| 176 |
+
if idx in inverted_merges:
|
| 177 |
+
idx0, idx1 = inverted_merges[idx]
|
| 178 |
+
s0 = helper.render_token(self.vocab[idx0])
|
| 179 |
+
s1 = helper.render_token(self.vocab[idx1])
|
| 180 |
+
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
|
| 181 |
+
else:
|
| 182 |
+
f.write(f"[{s}] {idx}\n")
|
| 183 |
+
|
| 184 |
+
def load(self, model_file):
|
| 185 |
+
"""Inverse of save() but only for the model file"""
|
| 186 |
+
assert model_file.endswith(".model")
|
| 187 |
+
# read the model file
|
| 188 |
+
merges = {}
|
| 189 |
+
special_tokens = {}
|
| 190 |
+
idx = 256
|
| 191 |
+
with open(model_file, 'r', encoding="utf-8") as f:
|
| 192 |
+
# read the version
|
| 193 |
+
version = f.readline().strip()
|
| 194 |
+
assert version == "mana v1"
|
| 195 |
+
# read the pattern
|
| 196 |
+
self.pattern = f.readline().strip()
|
| 197 |
+
# read the special tokens
|
| 198 |
+
num_special = int(f.readline().strip())
|
| 199 |
+
for _ in range(num_special):
|
| 200 |
+
special, special_idx = f.readline().strip().split()
|
| 201 |
+
special_tokens[special] = int(special_idx)
|
| 202 |
+
# read the merges
|
| 203 |
+
for line in f:
|
| 204 |
+
idx1, idx2 = map(int, line.split())
|
| 205 |
+
merges[(idx1, idx2)] = idx
|
| 206 |
+
idx += 1
|
| 207 |
+
self.merges = merges
|
| 208 |
+
self.special_tokens = special_tokens
|
| 209 |
+
self.vocab = self._build_vocab()
|
| 210 |
+
|
| 211 |
+
def decode(self, ids):
|
| 212 |
+
# given ids (list of integers), return Python string
|
| 213 |
+
part_bytes = [self.vocab[idx] if idx in self.vocab
|
| 214 |
+
else self.inverse_special_tokens[idx].encode("utf-8")
|
| 215 |
+
for idx in ids] # raises KeyError if any idx is not a valid token
|
| 216 |
+
text_bytes = b"".join(part_bytes)
|
| 217 |
+
text = text_bytes.decode("utf-8", errors="replace")
|
| 218 |
+
return text
|
| 219 |
+
|
| 220 |
+
@lru_cache(maxsize=131072)
|
| 221 |
+
def _encode_chunk(self, chunk):
|
| 222 |
+
if chunk in self.stop_words: # TODO: revisit this if statement
|
| 223 |
+
return [self.stop_words[chunk]]
|
| 224 |
+
# return the token chunk as a list of ints, similar to a bytes object
|
| 225 |
+
chunk = [*chunk.encode("utf-8")]
|
| 226 |
+
len_chunk = len(chunk)
|
| 227 |
+
while len_chunk >= 2:
|
| 228 |
+
# find the pair with the lowest merge index
|
| 229 |
+
low = 987654321
|
| 230 |
+
for i in range(len_chunk - 1):
|
| 231 |
+
current_pair = (chunk[i], chunk[i+1])
|
| 232 |
+
new_val = self.merges.get(current_pair, 987654321)
|
| 233 |
+
if new_val < low:
|
| 234 |
+
pair = current_pair
|
| 235 |
+
low = new_val
|
| 236 |
+
if low == 987654321: # no merges were found
|
| 237 |
+
break # nothing else can be merged
|
| 238 |
+
# otherwise let's merge the best pair (lowest merge index)
|
| 239 |
+
idx = self.merges[pair]
|
| 240 |
+
len_chunk = helper.merge(chunk, pair, idx, len_chunk)
|
| 241 |
+
return chunk # list of ints
|
| 242 |
+
|
| 243 |
+
def encode_ordinary(self, text):
|
| 244 |
+
"""Encoding that ignores any special tokens."""
|
| 245 |
+
ids = []
|
| 246 |
+
for chunk in re.findall(self.compiled_pattern, text):
|
| 247 |
+
ids.extend(self._encode_chunk(chunk))
|
| 248 |
+
return ids
|
| 249 |
+
|
| 250 |
+
def encode(self, text, allowed_special="none_raise"):
|
| 251 |
+
"""
|
| 252 |
+
Unlike encode_ordinary, this function handles special tokens.
|
| 253 |
+
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
|
| 254 |
+
if none_raise, then an error is raised if any special token is encountered in text
|
| 255 |
+
this is the default tiktoken behavior right now as well
|
| 256 |
+
any other behavior is either annoying, or a major footgun
|
| 257 |
+
"""
|
| 258 |
+
# decode the user desire w.r.t. handling of special tokens
|
| 259 |
+
special = None
|
| 260 |
+
if allowed_special == "all":
|
| 261 |
+
special = self.special_tokens
|
| 262 |
+
elif allowed_special == "none":
|
| 263 |
+
special = {}
|
| 264 |
+
elif allowed_special == "none_raise":
|
| 265 |
+
special = {}
|
| 266 |
+
assert all(token not in text for token in self.special_tokens)
|
| 267 |
+
elif isinstance(allowed_special, set):
|
| 268 |
+
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
|
| 269 |
+
else:
|
| 270 |
+
raise ValueError(f"allowed_special={allowed_special} not understood")
|
| 271 |
+
if not special: # shortcut: if no special tokens, just use the ordinary encoding
|
| 272 |
+
return self.encode_ordinary(text)
|
| 273 |
+
# split on special tokens. Note that surrounding the pattern with ()
|
| 274 |
+
# makes it into a capturing group, so the special tokens will be included
|
| 275 |
+
special_pattern = f"({'|'.join([re.escape(k) for k in special])})"
|
| 276 |
+
special_chunks = re.split(special_pattern, text)
|
| 277 |
+
# now all the special characters are separated from the rest of the text
|
| 278 |
+
# all chunks of text are encoded separately, then results are joined
|
| 279 |
+
ids = []
|
| 280 |
+
for part in special_chunks:
|
| 281 |
+
special_token = special.get(part)
|
| 282 |
+
if special_token is None: # this is an ordinary sequence, encode it normally
|
| 283 |
+
ids.extend(self.encode_ordinary(part))
|
| 284 |
+
else: # this is a special token, encode it separately as a special case
|
| 285 |
+
ids.append(special_token)
|
| 286 |
+
return ids
|
| 287 |
+
|
| 288 |
+
def batch_encode(self, texts, allowed_special="none_raise"):
|
| 289 |
+
"""
|
| 290 |
+
Encode a list of texts in batch mode.
|
| 291 |
+
Each text will be encoded according to the handling of special tokens specified in allowed_special.
|
| 292 |
+
|
| 293 |
+
Parameters:
|
| 294 |
+
texts (list of str): List of texts to encode.
|
| 295 |
+
allowed_special (str|set): Special token handling mode.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
list of list of int: A list where each element is the encoded form of a text in `texts`.
|
| 299 |
+
"""
|
| 300 |
+
return [self.encode(text, allowed_special=allowed_special) for text in texts]
|
helper.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from collections import Counter, defaultdict
|
| 3 |
+
import unicodedata
|
| 4 |
+
|
| 5 |
+
def get_stats(ids):
|
| 6 |
+
"""
|
| 7 |
+
Given `ids`, a list of 2-tuples of iterables of ints and int values,
|
| 8 |
+
returns a defaultdict with the counts of occurrences of all the consecutive
|
| 9 |
+
pairs of integers within each bytes object, multiplied by the integer value
|
| 10 |
+
associated with each key. This function does not count pairs between the last
|
| 11 |
+
element of one key the first element of the next key. The integer value
|
| 12 |
+
associated with each key serves as a multiplier for the count of each pair
|
| 13 |
+
within that object. Consecutive identical pairs within the same bytes object
|
| 14 |
+
are counted only once to avoid overcounting repeat characters.
|
| 15 |
+
|
| 16 |
+
Example:
|
| 17 |
+
get_stats({b'abc': 2, b'bcd': 1, b'eee': 1})
|
| 18 |
+
-> defaultdict(<class 'int'>, {(97, 98): 1, (98, 99): 2, (99, 100): 1, (101, 101): 1})
|
| 19 |
+
"""
|
| 20 |
+
counts = defaultdict(int)
|
| 21 |
+
for chunk, num in ids:
|
| 22 |
+
last_index = len(chunk) - 1
|
| 23 |
+
i = 0
|
| 24 |
+
while i < last_index:
|
| 25 |
+
j = i + 1
|
| 26 |
+
counts[(chunk[i], chunk[j])] += num
|
| 27 |
+
i = j
|
| 28 |
+
return counts
|
| 29 |
+
|
| 30 |
+
def merge_batch_get_stats(ids, pairs):
|
| 31 |
+
counts = defaultdict(int)
|
| 32 |
+
for chunk, num in ids:
|
| 33 |
+
last_index = len(chunk) - 1
|
| 34 |
+
i = 0
|
| 35 |
+
while i < last_index:
|
| 36 |
+
j = i + 1
|
| 37 |
+
token = pairs.get((chunk[i], chunk[j]))
|
| 38 |
+
if token is not None:
|
| 39 |
+
chunk[i] = token
|
| 40 |
+
del chunk[j]
|
| 41 |
+
last_index -= 1
|
| 42 |
+
if i:
|
| 43 |
+
counts[(chunk[i-1], chunk[i])] += num
|
| 44 |
+
i = j
|
| 45 |
+
if i and i == last_index:
|
| 46 |
+
counts[(chunk[-2], chunk[i])] += num
|
| 47 |
+
return counts
|
| 48 |
+
|
| 49 |
+
def merge(ids, pair, idx, len_ids):
|
| 50 |
+
"""
|
| 51 |
+
In the list of integers (ids), replace all consecutive occurrences
|
| 52 |
+
of pair with the new integer token idx
|
| 53 |
+
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
| 54 |
+
"""
|
| 55 |
+
i = 0
|
| 56 |
+
while i + 1 < len_ids:
|
| 57 |
+
j = i + 1
|
| 58 |
+
if ids[i] == pair[0] and ids[j] == pair[1]:
|
| 59 |
+
ids[i] = idx
|
| 60 |
+
del ids[j]
|
| 61 |
+
len_ids -= 1
|
| 62 |
+
i = j
|
| 63 |
+
return len_ids
|
| 64 |
+
|
| 65 |
+
def replace_control_characters(s: str) -> str:
|
| 66 |
+
# we don't want to print control characters
|
| 67 |
+
# which distort the output (e.g. \n or much worse)
|
| 68 |
+
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
|
| 69 |
+
# http://www.unicode.org/reports/tr44/#GC_Values_Table
|
| 70 |
+
chars = []
|
| 71 |
+
for ch in s:
|
| 72 |
+
if unicodedata.category(ch)[0] != "C":
|
| 73 |
+
chars.append(ch) # this character is ok
|
| 74 |
+
else:
|
| 75 |
+
chars.append(f"\\u{ord(ch):04x}") # escape
|
| 76 |
+
return "".join(chars)
|
| 77 |
+
|
| 78 |
+
def render_token(t: bytes) -> str:
|
| 79 |
+
# pretty print a token, escaping control characters
|
| 80 |
+
s = t.decode('utf-8', errors='replace')
|
| 81 |
+
s = replace_control_characters(s)
|
| 82 |
+
return s
|
| 83 |
+
|
| 84 |
+
def _process_dicts(batch, compiled_pattern): # for raw datasets.Dataset
|
| 85 |
+
counter = Counter()
|
| 86 |
+
for item in batch:
|
| 87 |
+
counter.update(re.findall(compiled_pattern, item))
|
| 88 |
+
return counter
|
| 89 |
+
|
| 90 |
+
def _process_string_scalar(batch, compiled_pattern):
|
| 91 |
+
counter = Counter()
|
| 92 |
+
for item in batch:
|
| 93 |
+
counter.update(re.findall(compiled_pattern, item.as_py()))
|
| 94 |
+
return counter
|
mana_tokenizer.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import Tokenizer, get_stats, merge_batch_get_stats
|
| 2 |
+
from heapq import nlargest
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
MANA_SPECIAL_TOKENS = {
|
| 6 |
+
'<|end|>': 100257,
|
| 7 |
+
'<|user|>': 100258,
|
| 8 |
+
'<|assistant|>': 100259,
|
| 9 |
+
'<|system|>': 100260
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
class ManaTokenizer(Tokenizer):
|
| 13 |
+
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
|
| 14 |
+
"""
|
| 15 |
+
- pattern: optional string to override the default (GPT-4 split pattern)
|
| 16 |
+
- special_tokens: str -> int dictionary of special tokens
|
| 17 |
+
example: {'<|endoftext|>': 100257}
|
| 18 |
+
"""
|
| 19 |
+
self.register_special_tokens(MANA_SPECIAL_TOKENS)
|
| 20 |
+
super().__init__(pattern, multiprocess, store_dict, stop_list_size, freq_cutoff)
|
| 21 |
+
|
| 22 |
+
def train(self, data, vocab_size, cap_divisor=2, max_batch_size=0, verbose=False):
|
| 23 |
+
t0 = time.time()
|
| 24 |
+
ids = self._import_data(data) # [(bytes, int)] -> text chunks and their counts
|
| 25 |
+
t1 = time.time()
|
| 26 |
+
print(f'Time spent loading data: {t1-t0:.2f}')
|
| 27 |
+
|
| 28 |
+
merges = self.merges # {(int, int): int} -> token pair to new token
|
| 29 |
+
vocab = self.vocab # {int: bytes} -> token to its bytes representation
|
| 30 |
+
batch_count = 0
|
| 31 |
+
curr_vocab_size = len(vocab)
|
| 32 |
+
num_merges = vocab_size - curr_vocab_size
|
| 33 |
+
merges_remaining = num_merges
|
| 34 |
+
if max_batch_size < 1:
|
| 35 |
+
max_batch_size = num_merges
|
| 36 |
+
stats = get_stats(ids) # stats are later updated by merge_batch_get_stats
|
| 37 |
+
start_time = time.time()
|
| 38 |
+
while merges_remaining > 0:
|
| 39 |
+
seen_first = set() # tokens seen in the first position in pairs
|
| 40 |
+
seen_last = set() # tokens seen in the last position in pairs
|
| 41 |
+
pairs_to_merge = {}
|
| 42 |
+
num_pairs_to_search = min(merges_remaining//cap_divisor, len(vocab), max_batch_size) or 1
|
| 43 |
+
top_pairs = nlargest(num_pairs_to_search, stats, key=stats.get)
|
| 44 |
+
for first, last in top_pairs: # pairs are (first, last) tuples
|
| 45 |
+
if first in seen_last or last in seen_first: # unsafe merge
|
| 46 |
+
seen_first.add(first)
|
| 47 |
+
seen_last.add(last)
|
| 48 |
+
continue # skip this pair but keep looking for safe merges in top_pairs
|
| 49 |
+
seen_first.add(first)
|
| 50 |
+
seen_last.add(last)
|
| 51 |
+
pairs_to_merge[(first, last)] = curr_vocab_size
|
| 52 |
+
vocab[curr_vocab_size] = vocab[first] + vocab[last]
|
| 53 |
+
curr_vocab_size += 1
|
| 54 |
+
merges_remaining -= len(pairs_to_merge)
|
| 55 |
+
merges.update(pairs_to_merge) # save the merges
|
| 56 |
+
batch_count += 1
|
| 57 |
+
if merges_remaining: # no need to merge last batch
|
| 58 |
+
stats = merge_batch_get_stats(ids, pairs_to_merge) # replace pairs_to_merge keys in ids with their values
|
| 59 |
+
if verbose:
|
| 60 |
+
t2 = time.time()
|
| 61 |
+
time_taken = t2 - start_time
|
| 62 |
+
avg_time_per_batch = time_taken / batch_count
|
| 63 |
+
estimated_remaining_time = avg_time_per_batch * (num_merges - merges_remaining)
|
| 64 |
+
estimated_end_time = time.strftime("%H:%M:%S", time.localtime(time.time() + estimated_remaining_time))
|
| 65 |
+
print(f"Batch {batch_count} merged {len(pairs_to_merge)} pairs in {t2-t1:.2f} sec. "
|
| 66 |
+
f"Merges remaining: {merges_remaining}. Estimated end time: {estimated_end_time}")
|
| 67 |
+
t1 = t2
|
| 68 |
+
|
| 69 |
+
self.merges = merges # used in encode()
|
| 70 |
+
self.vocab = vocab # used in decode()
|