File size: 14,571 Bytes
4128ba5
 
 
 
 
 
 
 
 
 
 
d786ff1
4128ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d786ff1
4128ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d786ff1
4128ba5
 
 
d786ff1
 
4128ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d786ff1
4128ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
from collections import Counter
from functools import lru_cache
import requests
from datasets import IterableDataset, Dataset
from pyarrow import ChunkedArray
from joblib import Parallel, delayed, cpu_count
import time
import os
import regex as re
import csv
import time
from mana_tokenizer.helper import _process_string_scalar, render_token, merge

class Tokenizer:
    """Base class for Tokenizers"""
    def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
        # default: vocab size of 256 (all bytes), no merges, no patterns
        MANA_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re|می|نمی|به|بی|در|باز|بر|فرا|هم|ور|وا|ف|ک|چ|ن|پ|ا|از|ای|ی|ها|ترین|تر|ات|ان|ت|ٔ|یی|‌ا)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = self._build_vocab() # int -> bytes
        self.pattern = MANA_SPLIT_PATTERN if pattern is None else pattern
        self.compiled_pattern = re.compile(self.pattern)
        self.multiprocess = multiprocess
        if multiprocess:
            self._cpus = cpu_count()
        else:
            self._cpus = 1
        self.store_dict = store_dict
        self.stop_list_size = stop_list_size
        self.stop_words = {}
        self.freq_cutoff = freq_cutoff

    def _id_dict_to_list(self, ids):
        if self.stop_list_size:
            # get twice as many to be sure to be able to get X chunks of length > 1
            top2X = ids.most_common(2*self.stop_list_size)
            index = len(self.vocab)
            stop_index = index + self.stop_list_size
            stop_words = {}
            for key, val in top2X:
                if len(key) > 1: # and re.match(r'^ [A-Za-z\'’`]+$[A-Za-z]*', key):
                    stop_words[key] = index
                    self.vocab[index] = key.encode('utf-8')
                    index += 1
                if index == stop_index:
                    break
            self.stop_words = stop_words
            if self.freq_cutoff > 1:
                return [([*key.encode('utf-8')], val) for key, val in ids.items()
                        if (val >= self.freq_cutoff and key not in self.stop_words)]
            else:
                return [([*key.encode('utf-8')], val) for key, val in ids.items()
                        if key not in self.stop_words]
        else:   # self.stop_list_size == 0
            if self.freq_cutoff > 1:
                return [([*key.encode('utf-8')], val) for key, val in ids.items()
                        if val >= self.freq_cutoff]
            else:
                return [([*key.encode('utf-8')], val) for key, val in ids.items()]

    def _import_data(self, data):
        # determine if `data` is a text as a string, a path to a file, a url to
        # a text document, a dictionary of datasets kwargs, or a list of any of
        # the above. Return a list of 2-tuples of bytes objects and their counts.
        ids = Counter()
        if not isinstance(data, (list, tuple)):
            data = (data,)
        for item in data:
            # convert to ChunkedArray, dict, or str of text to parse
            if isinstance(item, Dataset):
                item = item.data['text']
            elif isinstance(item, str) and item.endswith('.csv'):   # csv file from previous data load
                with open(item, 'r') as f:
                    reader = csv.reader(f)
                    next(reader)
                    item = {k: int(v) for k, v in reader}
            elif isinstance(item, str):
                if item.startswith('https://') or item.startswith('http://'):
                    item = requests.get(item).text    # if it's a url, assume it's to a text file
                elif os.path.isfile(item) and item.endswith('.txt'):
                    with open(item, 'r', encoding='utf-8') as f:
                        item = f.read()
            # process data
            if isinstance(item, dict):
                last_item = item.popitem()
                if last_item[1] != 0:
                    print(f'Warning: the csv file or dictionary passed does not seem to have been made by this tokenizer.')
                    item[last_item[0]] = last_item[1]
                elif last_item[0] != self.pattern:
                    print(f'Warning: the dictionary or csv file passed did not use the same split pattern.')
                ids.update(item)
            elif isinstance(item, str):   # assume the string is the text itself
                ids.update(re.findall(self.compiled_pattern, item))
            elif isinstance(item, ChunkedArray):
                batch_size = len(item) // (self._cpus*2) or 1
                batches = [item[i:i + batch_size] for i in range(0, len(item), batch_size)]
                print(f'Processing {len(batches)} batches of size {batch_size}')
                results = Parallel(n_jobs=self._cpus)(delayed(_process_string_scalar)(batch, self.compiled_pattern) for batch in batches)
                for result in results:  # Aggregate results into one Counter
                    ids.update(result)
            elif isinstance(item, IterableDataset):
                print('Serially processing IterableDataset...')
                for _dict in item:
                    ids.update(re.findall(self.compiled_pattern, _dict['text']))

        if self.store_dict:   # store dict compression of dataset to a csv file if requested
            ids[self.pattern] = 0   # store the pattern used to split the text as the last key
            formatted_time = time.strftime('%Y-%m-%d-%H_%M', time.localtime())
            filename = f'{formatted_time}-dataset-dict.csv'
            try:
                with open(filename, 'w', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(['text_chunk', 'count'])
                    for key, value in ids.items():
                        writer.writerow([key, value])
                print(f"Stored dictionary of {len(ids)} keys to {filename}")
            except:
                print('Failed to store dictionary of dataset.')
            del ids[self.pattern]   # remove the pattern key from the ids dict

        ids = self._id_dict_to_list(ids)
        return ids

    def train(self, text, vocab_size, verbose=False):
        # Tokenizer can train a vocabulary of size vocab_size from text
        raise NotImplementedError

    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def register_special_tokens(self, special_tokens):
        # special_tokens is a dictionary of str -> int
        # example: {"<|endoftext|>": 100257}
        self.special_tokens = special_tokens
        self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}

    def save(self, file_prefix):
        """
        Saves two files: file_prefix.vocab and file_prefix.model
        This is inspired (but not equivalent to!) sentencepiece's model saving:
        - model file is the critical one, intended for load() later
        - vocab file is just a pretty printed version for human inspection only
        """
        # write the model: to be used in load() later
        model_file = file_prefix + ".model"
        with open(model_file, 'w', encoding='utf-8') as f:  # Added encoding='utf-8'
            # write the version, pattern and merges, that's all that's needed
            f.write("mana v1\n")
            f.write(f"{self.pattern}\n")
            # write the special tokens, first the number of them, then each one
            f.write(f"{len(self.special_tokens)}\n")
            for special, idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")
            # the merges dict
            for key in self.merges:
                if isinstance(key, tuple):
                    f.write(f"{key[0]} {key[1]}\n")
                else:
                    f.write(f"{key}\n")
        
        # write the vocab: for the human to look at
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:  # Ensure this is also utf-8
            for idx, token in self.vocab.items():
                s = render_token(token)
                # find the children of this token, if any
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    f.write(f"[{s}] {idx}\n")

    def load(self, model_file):
        """Inverse of save() but only for the model file"""
        assert model_file.endswith(".model")
        # read the model file
        merges = {}
        special_tokens = {}
        idx = 256
        with open(model_file, 'r', encoding="utf-8") as f:
            # read the version
            version = f.readline().strip()
            assert version == "mana v1"
            # read the pattern
            self.pattern = f.readline().strip()
            # read the special tokens
            num_special = int(f.readline().strip())
            for _ in range(num_special):
                special, special_idx = f.readline().strip().split()
                special_tokens[special] = int(special_idx)
            # read the merges
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab = self._build_vocab()

    def decode(self, ids):
        # given ids (list of integers), return Python string
        part_bytes = [self.vocab[idx] if idx in self.vocab
            else self.inverse_special_tokens[idx].encode("utf-8")
            for idx in ids] # raises KeyError if any idx is not a valid token
        text_bytes = b"".join(part_bytes)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    @lru_cache(maxsize=131072)
    def _encode_chunk(self, chunk):
        if chunk in self.stop_words:   # TODO: revisit this if statement
            return [self.stop_words[chunk]]
        # return the token chunk as a list of ints, similar to a bytes object
        chunk = [*chunk.encode("utf-8")]
        len_chunk = len(chunk)
        while len_chunk >= 2:
            # find the pair with the lowest merge index
            low = 987654321
            for i in range(len_chunk - 1):
                current_pair = (chunk[i], chunk[i+1])
                new_val = self.merges.get(current_pair, 987654321)
                if new_val < low:
                    pair = current_pair
                    low = new_val
            if low == 987654321:   # no merges were found
                break   # nothing else can be merged
            # otherwise let's merge the best pair (lowest merge index)
            idx = self.merges[pair]
            len_chunk = merge(chunk, pair, idx, len_chunk)
        return chunk   # list of ints

    def encode_ordinary(self, text):
        """Encoding that ignores any special tokens."""
        ids = []
        for chunk in re.findall(self.compiled_pattern, text):
            ids.extend(self._encode_chunk(chunk))
        return ids

    def encode(self, text, allowed_special="none_raise"):
        """
        Unlike encode_ordinary, this function handles special tokens.
        allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
        if none_raise, then an error is raised if any special token is encountered in text
        this is the default tiktoken behavior right now as well
        any other behavior is either annoying, or a major footgun
        """
        # decode the user desire w.r.t. handling of special tokens
        special = None
        if allowed_special == "all":
            special = self.special_tokens
        elif allowed_special == "none":
            special = {}
        elif allowed_special == "none_raise":
            special = {}
            assert all(token not in text for token in self.special_tokens)
        elif isinstance(allowed_special, set):
            special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
        else:
            raise ValueError(f"allowed_special={allowed_special} not understood")
        if not special:   # shortcut: if no special tokens, just use the ordinary encoding
            return self.encode_ordinary(text)
        # split on special tokens. Note that surrounding the pattern with ()
        # makes it into a capturing group, so the special tokens will be included
        special_pattern = f"({'|'.join([re.escape(k) for k in special])})"
        special_chunks = re.split(special_pattern, text)
        # now all the special characters are separated from the rest of the text
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for part in special_chunks:
            special_token = special.get(part)
            if special_token is None:   # this is an ordinary sequence, encode it normally
                ids.extend(self.encode_ordinary(part))
            else:   # this is a special token, encode it separately as a special case
                ids.append(special_token)
        return ids
        
    def batch_encode(self, texts, allowed_special="none_raise"):
        """
        Encode a list of texts in batch mode.
        Each text will be encoded according to the handling of special tokens specified in allowed_special.
        
        Parameters:
            texts (list of str): List of texts to encode.
            allowed_special (str|set): Special token handling mode.

        Returns:
            list of list of int: A list where each element is the encoded form of a text in `texts`.
        """
        return [self.encode(text, allowed_special=allowed_special) for text in texts]