File size: 14,571 Bytes
4128ba5 d786ff1 4128ba5 d786ff1 4128ba5 d786ff1 4128ba5 d786ff1 4128ba5 d786ff1 4128ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
from collections import Counter
from functools import lru_cache
import requests
from datasets import IterableDataset, Dataset
from pyarrow import ChunkedArray
from joblib import Parallel, delayed, cpu_count
import time
import os
import regex as re
import csv
import time
from mana_tokenizer.helper import _process_string_scalar, render_token, merge
class Tokenizer:
"""Base class for Tokenizers"""
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
# default: vocab size of 256 (all bytes), no merges, no patterns
MANA_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re|می|نمی|به|بی|در|باز|بر|فرا|هم|ور|وا|ف|ک|چ|ن|پ|ا|از|ای|ی|ها|ترین|تر|ات|ان|ت|ٔ|یی|ا)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
self.merges = {} # (int, int) -> int
self.pattern = "" # str
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes
self.pattern = MANA_SPLIT_PATTERN if pattern is None else pattern
self.compiled_pattern = re.compile(self.pattern)
self.multiprocess = multiprocess
if multiprocess:
self._cpus = cpu_count()
else:
self._cpus = 1
self.store_dict = store_dict
self.stop_list_size = stop_list_size
self.stop_words = {}
self.freq_cutoff = freq_cutoff
def _id_dict_to_list(self, ids):
if self.stop_list_size:
# get twice as many to be sure to be able to get X chunks of length > 1
top2X = ids.most_common(2*self.stop_list_size)
index = len(self.vocab)
stop_index = index + self.stop_list_size
stop_words = {}
for key, val in top2X:
if len(key) > 1: # and re.match(r'^ [A-Za-z\'’`]+$[A-Za-z]*', key):
stop_words[key] = index
self.vocab[index] = key.encode('utf-8')
index += 1
if index == stop_index:
break
self.stop_words = stop_words
if self.freq_cutoff > 1:
return [([*key.encode('utf-8')], val) for key, val in ids.items()
if (val >= self.freq_cutoff and key not in self.stop_words)]
else:
return [([*key.encode('utf-8')], val) for key, val in ids.items()
if key not in self.stop_words]
else: # self.stop_list_size == 0
if self.freq_cutoff > 1:
return [([*key.encode('utf-8')], val) for key, val in ids.items()
if val >= self.freq_cutoff]
else:
return [([*key.encode('utf-8')], val) for key, val in ids.items()]
def _import_data(self, data):
# determine if `data` is a text as a string, a path to a file, a url to
# a text document, a dictionary of datasets kwargs, or a list of any of
# the above. Return a list of 2-tuples of bytes objects and their counts.
ids = Counter()
if not isinstance(data, (list, tuple)):
data = (data,)
for item in data:
# convert to ChunkedArray, dict, or str of text to parse
if isinstance(item, Dataset):
item = item.data['text']
elif isinstance(item, str) and item.endswith('.csv'): # csv file from previous data load
with open(item, 'r') as f:
reader = csv.reader(f)
next(reader)
item = {k: int(v) for k, v in reader}
elif isinstance(item, str):
if item.startswith('https://') or item.startswith('http://'):
item = requests.get(item).text # if it's a url, assume it's to a text file
elif os.path.isfile(item) and item.endswith('.txt'):
with open(item, 'r', encoding='utf-8') as f:
item = f.read()
# process data
if isinstance(item, dict):
last_item = item.popitem()
if last_item[1] != 0:
print(f'Warning: the csv file or dictionary passed does not seem to have been made by this tokenizer.')
item[last_item[0]] = last_item[1]
elif last_item[0] != self.pattern:
print(f'Warning: the dictionary or csv file passed did not use the same split pattern.')
ids.update(item)
elif isinstance(item, str): # assume the string is the text itself
ids.update(re.findall(self.compiled_pattern, item))
elif isinstance(item, ChunkedArray):
batch_size = len(item) // (self._cpus*2) or 1
batches = [item[i:i + batch_size] for i in range(0, len(item), batch_size)]
print(f'Processing {len(batches)} batches of size {batch_size}')
results = Parallel(n_jobs=self._cpus)(delayed(_process_string_scalar)(batch, self.compiled_pattern) for batch in batches)
for result in results: # Aggregate results into one Counter
ids.update(result)
elif isinstance(item, IterableDataset):
print('Serially processing IterableDataset...')
for _dict in item:
ids.update(re.findall(self.compiled_pattern, _dict['text']))
if self.store_dict: # store dict compression of dataset to a csv file if requested
ids[self.pattern] = 0 # store the pattern used to split the text as the last key
formatted_time = time.strftime('%Y-%m-%d-%H_%M', time.localtime())
filename = f'{formatted_time}-dataset-dict.csv'
try:
with open(filename, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['text_chunk', 'count'])
for key, value in ids.items():
writer.writerow([key, value])
print(f"Stored dictionary of {len(ids)} keys to {filename}")
except:
print('Failed to store dictionary of dataset.')
del ids[self.pattern] # remove the pattern key from the ids dict
ids = self._id_dict_to_list(ids)
return ids
def train(self, text, vocab_size, verbose=False):
# Tokenizer can train a vocabulary of size vocab_size from text
raise NotImplementedError
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def save(self, file_prefix):
"""
Saves two files: file_prefix.vocab and file_prefix.model
This is inspired (but not equivalent to!) sentencepiece's model saving:
- model file is the critical one, intended for load() later
- vocab file is just a pretty printed version for human inspection only
"""
# write the model: to be used in load() later
model_file = file_prefix + ".model"
with open(model_file, 'w', encoding='utf-8') as f: # Added encoding='utf-8'
# write the version, pattern and merges, that's all that's needed
f.write("mana v1\n")
f.write(f"{self.pattern}\n")
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for key in self.merges:
if isinstance(key, tuple):
f.write(f"{key[0]} {key[1]}\n")
else:
f.write(f"{key}\n")
# write the vocab: for the human to look at
vocab_file = file_prefix + ".vocab"
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, "w", encoding="utf-8") as f: # Ensure this is also utf-8
for idx, token in self.vocab.items():
s = render_token(token)
# find the children of this token, if any
if idx in inverted_merges:
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
else:
f.write(f"[{s}] {idx}\n")
def load(self, model_file):
"""Inverse of save() but only for the model file"""
assert model_file.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "mana v1"
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = [self.vocab[idx] if idx in self.vocab
else self.inverse_special_tokens[idx].encode("utf-8")
for idx in ids] # raises KeyError if any idx is not a valid token
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
@lru_cache(maxsize=131072)
def _encode_chunk(self, chunk):
if chunk in self.stop_words: # TODO: revisit this if statement
return [self.stop_words[chunk]]
# return the token chunk as a list of ints, similar to a bytes object
chunk = [*chunk.encode("utf-8")]
len_chunk = len(chunk)
while len_chunk >= 2:
# find the pair with the lowest merge index
low = 987654321
for i in range(len_chunk - 1):
current_pair = (chunk[i], chunk[i+1])
new_val = self.merges.get(current_pair, 987654321)
if new_val < low:
pair = current_pair
low = new_val
if low == 987654321: # no merges were found
break # nothing else can be merged
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
len_chunk = merge(chunk, pair, idx, len_chunk)
return chunk # list of ints
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
ids = []
for chunk in re.findall(self.compiled_pattern, text):
ids.extend(self._encode_chunk(chunk))
return ids
def encode(self, text, allowed_special="none_raise"):
"""
Unlike encode_ordinary, this function handles special tokens.
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
if none_raise, then an error is raised if any special token is encountered in text
this is the default tiktoken behavior right now as well
any other behavior is either annoying, or a major footgun
"""
# decode the user desire w.r.t. handling of special tokens
special = None
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif allowed_special == "none_raise":
special = {}
assert all(token not in text for token in self.special_tokens)
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special: # shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ordinary(text)
# split on special tokens. Note that surrounding the pattern with ()
# makes it into a capturing group, so the special tokens will be included
special_pattern = f"({'|'.join([re.escape(k) for k in special])})"
special_chunks = re.split(special_pattern, text)
# now all the special characters are separated from the rest of the text
# all chunks of text are encoded separately, then results are joined
ids = []
for part in special_chunks:
special_token = special.get(part)
if special_token is None: # this is an ordinary sequence, encode it normally
ids.extend(self.encode_ordinary(part))
else: # this is a special token, encode it separately as a special case
ids.append(special_token)
return ids
def batch_encode(self, texts, allowed_special="none_raise"):
"""
Encode a list of texts in batch mode.
Each text will be encoded according to the handling of special tokens specified in allowed_special.
Parameters:
texts (list of str): List of texts to encode.
allowed_special (str|set): Special token handling mode.
Returns:
list of list of int: A list where each element is the encoded form of a text in `texts`.
"""
return [self.encode(text, allowed_special=allowed_special) for text in texts]
|