File size: 13,930 Bytes
9e31d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import os
import json
from math import log


class MorPiece:
    def __init__(self, vocab_size=30000, min_frequency=2, cutoff=8, bf=10, special_tokens=None):
        self.tokenization_to_print = "TP left-right \t BF right-left \t TP right-left \t BP right-left\n"  # for debugging only
        if special_tokens is None:
            special_tokens = ['<unk>', '<pad>', '<s>', '</s>']
        self.special_tokens = special_tokens
        self.reserved_keys = {'[RSX]', '##', 'IDX', '++'}
        self.vocab_size = vocab_size
        self.min_frequency = min_frequency
        self.bf = bf
        self.roots = {'[RSX]': {}, '++': {}}
        self.roots_unoptimized = {}
        self.infls = {}
        self.types = {}
        self.last_item_in_trie = {}
        self.idx = 0
        self.tokens = []
        self.suffixes = []
        self.tokens_bf = []
        self.suffixes_bf = []
        self.prefix = ""
        self.n_prefix = 0
        self.n_suffix = 0
        self.tokenized_words = []
        self.tokenized_word_longest = ""
        self.tokenized_word_idx_longest = ""
        self.cutoff = cutoff  # ln(8) is > 2, so, non-branching paths will be ignored
        self.num_tokens_in_corpus = 0
        self.num_chars_in_corpus = 0
        self.num_chars_in_trie = 0
        self.num_chars_in_optimized_trie = 0
        self.set_special_tokens(self.special_tokens)

    def train(self, corpus: str):  # create the vocabulary
        words = corpus.split()
        print("MorPiece tokenizer training: processing words...")
        for word in words:
            word_alpha = ''.join([char for char in word if char.isalpha() or char == "'"])
            if not word_alpha:
                word = ''.join([char for char in word])
            else:
                word = word_alpha
            if word:
                self.build_trie(word, self.roots_unoptimized)  # create roots trie
                self.build_trie(word[::-1], self.infls)  # create inflections trie
                if word not in self.types:  # count tokens and chars in corpus
                    self.types[word] = 1
                else:
                    self.types[word] += 1
                self.num_tokens_in_corpus += 1
                self.num_chars_in_corpus += len(word)
        self.types = dict(sorted(self.types.items(), key=lambda item: item[1], reverse=True))
        sort_trie_by_freq(self.roots_unoptimized)
        sort_trie_by_freq(self.infls)

        print("MorPiece tokenizer training: trie optimization...")
        self.optimize(self.types)

        print(f"Built final vocabulary with {self.get_vocab_size()} tokens")
        print(f"Most common tokens: {list(self.types.items())[:20]}")

    def build_trie(self, wordpiece, root):  # build the trie and register # of traversals in '##'
        if wordpiece[0] in root:
            root[wordpiece[0]]['##'] += 1
            self.num_chars_in_trie += 1
            if len(wordpiece) > 1:
                self.build_trie(wordpiece[1:], root[wordpiece[0]])
            else:
                if 'END' not in root[wordpiece[0]]:
                    root[wordpiece[0]]['END'] = None
        else:
            root[wordpiece[0]] = {}
            root[wordpiece[0]]['##'] = 1
            if len(wordpiece) > 1:
                self.build_trie(wordpiece[1:], root[wordpiece[0]])

    def set_special_tokens(self, list):
        for item in list:
            if item not in self.roots['[RSX]'].keys():
                self.roots['[RSX]'][item] = {'IDX': None}
                self.roots['[RSX]'][item]['IDX'] = self.idx
                self.idx += 1

    # assign idx based on word freq and add potential inflection links in the root trie, remove frequency at the end
    def optimize(self, words):
        for word, freq in words.items():
            if freq >= self.min_frequency and self.idx <= self.vocab_size:
                self.tokens = []
                self.suffixes = []
                self.tokens_bf = []
                self.suffixes_bf = []
                self.tokens.append(word[0])
                self.suffixes.append(word[len(word) - 1])
                self.split_prefix(word, self.roots_unoptimized)
                if len(self.tokens) > 1:
                    self.split_suffix(word[::-1], self.infls)
                    self.suffixes = [word[::-1] for word in self.suffixes][::-1]
                    self.tokenization_to_print += str(self.tokens) + '\t' + str(self.tokens_bf) + '\t' + str(
                        self.suffixes) + '\t' + str(self.suffixes_bf) + '\n'  # for debugging only
                    for i in range(0,
                                   len(self.tokens)):  # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
                        if i == 0:
                            self.last_item_in_trie = self.roots
                            self.add_items_to_trie(
                                self.tokens[0])  # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
                        else:
                            self.last_item_in_trie = self.roots['++']
                            self.add_items_to_trie(
                                self.tokens[i])  # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
                        if 'IDX' not in self.last_item_in_trie:
                            self.last_item_in_trie['IDX'] = self.idx
                            self.idx += 1
                else:
                    self.last_item_in_trie = self.roots
                    self.add_items_to_trie(word)
                    if 'IDX' not in self.last_item_in_trie:
                        self.last_item_in_trie['IDX'] = self.idx
                        self.idx += 1

        self.build_vocab_lookup()

    def build_vocab_lookup(self):
        self.vocab_to_id = {}

        def traverse(trie, path):
            for k, v in trie.items():
                if k == 'IDX':
                    token = ''.join(path)
                    self.vocab_to_id[token] = v
                elif isinstance(v, dict):
                    traverse(v, path + [k])

        traverse(self.roots, [])

    def encode(self, sentence: str):
        self.tokenized_words = []
        words = sentence.strip().split()
        token_ids = []
        for word in words:
            if word in self.roots['[RSX]']:
                token_ids.append(self.roots['[RSX]'][word]['IDX'])
            else:
                self.tokenized_word_longest = ""
                self.tokenized_word_idx_longest = None
                self.retrieve(word, self.roots)
                if self.tokenized_word_idx_longest is not None:
                    token_ids.append(self.tokenized_word_idx_longest)
                else:
                    token_ids.append(self.roots['[RSX]']['<unk>']['IDX'])
        return token_ids

    def decode(self, sentence_idxs):
        tokens = []
        for idx in sentence_idxs:
            keys_path = find_idx_path(self.roots, idx)
            if keys_path:
                token = "".join(keys_path)
                if token.startswith('[RSX]'):
                    token = token[5:]
                tokens.append(token)
        return tokens

    def retrieve(self, word, trie):
        self.longest_match_in_trie(word, trie)
        if self.tokenized_word_longest:
            self.tokenized_words.append([self.tokenized_word_longest, self.tokenized_word_idx_longest])
        else:
            self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']])

    def longest_match_in_trie(self, string, trie):
        if string[0] in trie:
            self.tokenized_word_longest += string[0]
            if 'IDX' in trie[string[0]]:
                self.tokenized_word_idx_longest = trie[string[0]]['IDX']
            if len(string) > 1:
                self.longest_match_in_trie(string[1:], trie[string[0]])
        else:
            # print(string[0], self.tokenized_word_longest)
            if string[0] in self.roots['++'] and self.tokenized_word_idx_longest:
                self.tokenized_words.append([self.tokenized_word_longest + '++', self.tokenized_word_idx_longest])
                self.tokenized_word_longest = '++'
                self.tokenized_word_idx_longest = 0
                self.longest_match_in_trie(string, self.roots['++'])
            else:
                self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']])
                self.tokenized_word_longest = None

    def split_prefix(self, word, trie):
        l = len(word)
        if l > 1:
            self.get_pair_in_trie(word[0], word[1], trie)
            if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf:
                self.tokens.append(word[1])
                self.tokens_bf.append(word[0] + str(self.get_bf(trie[word[0]])))
            else:
                self.tokens[len(self.tokens) - 1] = self.tokens[len(self.tokens) - 1] + word[1]
        if l > 2:
            self.split_prefix(word[1:], trie[word[0]])

    def split_suffix(self, word, trie):
        l = len(word)
        if l > 1:
            self.get_pair_in_trie(word[0], word[1], trie)
            if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf:  # verify if the
                self.suffixes.append(word[1])
                self.suffixes_bf.append(word[0] + str(self.get_bf(trie[word[0]])))
            else:
                self.suffixes[len(self.suffixes) - 1] = self.suffixes[len(self.suffixes) - 1] + word[1]
        if l > 2:
            if word[0] in trie.keys():
                self.split_suffix(word[1:], trie[word[0]])

    def get_pair_in_trie(self, prefix, suffix, trie):
        self.n_prefix = 0
        self.n_suffix = 0
        if prefix in trie:
            if suffix in trie[prefix]:
                self.n_prefix = trie[prefix]["##"]
                self.n_suffix = trie[prefix][suffix]["##"]

    def check_tp(self, m, d):  # verify if Tolerance Principle applies between m(other) and d(aughter) nodes
        if not m > 1:
            return False
        else:
            tp = m / log(m)
        if self.cutoff <= m != d > tp:
            return True
        else:
            return False

    def get_bf(self, m):  # return the branching factor of the mother node
        keys = m.keys()
        n_keys = len(keys)
        for k in keys:
            if k in self.special_tokens:
                n_keys -= 1
        return n_keys

    def add_items_to_trie(self, items):
        for item in items:
            self.add_item_to_trie(item)

    def add_item_to_trie(self, item):
        if item not in self.last_item_in_trie:
            self.last_item_in_trie[item] = {}
        self.last_item_in_trie = self.last_item_in_trie[item]

    def pad_sentence(sentence, l):
        """
        Pads the given sentence with "[pad]" tokens at the beginning to reach the desired length.

        Parameters:
        - sentence (str): The original sentence to be padded.
        - l (int): The desired total number of tokens in the sentence after padding.

        Returns:
        - str: The padded sentence.
        """
        words = sentence.split()
        n_pad = max(l - len(words), 0)  # Ensure n_pad is not negative
        pad_tokens = ["[pad]"] * n_pad
        padded_sentence = ' '.join(pad_tokens + words)
        return padded_sentence

    def get_num_chars_in_trie(self):
        return self.num_chars_in_trie

    def get_num_chars_in_corpus(self):
        return self.num_chars_in_corpus

    def get_vocab_size(self) -> int:
        return self.idx

    def get_vocab(self):
        return self.vocab_to_id.copy()

    def get_num_tokens_in_corpus(self):
        return self.num_tokens_in_corpus

    def get_num_types_in_corpus(self):
        return len(self.types)

    def get_compression_ratio(self):
        return round(self.num_chars_in_trie / self.num_chars_in_corpus, 3)

    def get_ttr(self):
        return round(len(self.types) / self.num_tokens_in_corpus, 3)

    def save(self, save_file):
        self.build_vocab_lookup()
        with open(save_file, 'w') as f:
            json.dump({
                'roots': self.roots,
                'vocab': self.vocab_to_id
            }, f, indent=2)

    def from_pretrained(self, load_file):
        with open(load_file + '/tokenizer.json', 'r') as f:
            data = json.load(f)

        # Backward compatibility: if old format, data is just roots
        if isinstance(data, dict) and 'roots' in data:
            self.roots = data['roots']
            self.vocab_to_id = data.get('vocab', {})  # fallback to empty dict if missing
        else:
            # Old format support (e.g., tokenizer.json only had roots)
            self.roots = data
            self.vocab_to_id = {}

        # Ensure [RSX] exists
        if '[RSX]' not in self.roots:
            raise ValueError("Invalid tokenizer format: Missing [RSX] root node.")

    def save_types(self, file):
        with open(file, 'w') as f:
            json.dump(self.types, f, indent=2)


def sort_trie_by_freq(d):
    if not isinstance(d, dict):
        return d
    # Sort the dictionary items by the value of the nested key '##'
    sorted_items = sorted(
        d.items(),
        key=lambda item: item[1].get('##', float('-inf')) if isinstance(item[1], dict) else float('-inf'),
        reverse=True
    )
    # Clear the dictionary and update with sorted items
    d.clear()
    for k, v in sorted_items:
        d[k] = sort_trie_by_freq(v)
    return d


def find_idx_path(d, target_value, path=None):
    if path is None:
        path = []
    for key, value in d.items():
        if key == 'IDX' and value == target_value:
            return path
        elif isinstance(value, dict):
            result = find_idx_path(value, target_value, path + [key])
            if result is not None:
                return result
    return None