|
|
import os |
|
|
import json |
|
|
from math import log |
|
|
|
|
|
|
|
|
class MorPiece: |
|
|
def __init__(self, vocab_size=30000, min_frequency=2, cutoff=8, bf=10, special_tokens=None): |
|
|
self.tokenization_to_print = "TP left-right \t BF right-left \t TP right-left \t BP right-left\n" |
|
|
if special_tokens is None: |
|
|
special_tokens = ['<unk>', '<pad>', '<s>', '</s>'] |
|
|
self.special_tokens = special_tokens |
|
|
self.reserved_keys = {'[RSX]', '##', 'IDX', '++'} |
|
|
self.vocab_size = vocab_size |
|
|
self.min_frequency = min_frequency |
|
|
self.bf = bf |
|
|
self.roots = {'[RSX]': {}, '++': {}} |
|
|
self.roots_unoptimized = {} |
|
|
self.infls = {} |
|
|
self.types = {} |
|
|
self.last_item_in_trie = {} |
|
|
self.idx = 0 |
|
|
self.tokens = [] |
|
|
self.suffixes = [] |
|
|
self.tokens_bf = [] |
|
|
self.suffixes_bf = [] |
|
|
self.prefix = "" |
|
|
self.n_prefix = 0 |
|
|
self.n_suffix = 0 |
|
|
self.tokenized_words = [] |
|
|
self.tokenized_word_longest = "" |
|
|
self.tokenized_word_idx_longest = "" |
|
|
self.cutoff = cutoff |
|
|
self.num_tokens_in_corpus = 0 |
|
|
self.num_chars_in_corpus = 0 |
|
|
self.num_chars_in_trie = 0 |
|
|
self.num_chars_in_optimized_trie = 0 |
|
|
self.set_special_tokens(self.special_tokens) |
|
|
|
|
|
def train(self, corpus: str): |
|
|
words = corpus.split() |
|
|
print("MorPiece tokenizer training: processing words...") |
|
|
for word in words: |
|
|
word_alpha = ''.join([char for char in word if char.isalpha() or char == "'"]) |
|
|
if not word_alpha: |
|
|
word = ''.join([char for char in word]) |
|
|
else: |
|
|
word = word_alpha |
|
|
if word: |
|
|
self.build_trie(word, self.roots_unoptimized) |
|
|
self.build_trie(word[::-1], self.infls) |
|
|
if word not in self.types: |
|
|
self.types[word] = 1 |
|
|
else: |
|
|
self.types[word] += 1 |
|
|
self.num_tokens_in_corpus += 1 |
|
|
self.num_chars_in_corpus += len(word) |
|
|
self.types = dict(sorted(self.types.items(), key=lambda item: item[1], reverse=True)) |
|
|
sort_trie_by_freq(self.roots_unoptimized) |
|
|
sort_trie_by_freq(self.infls) |
|
|
|
|
|
print("MorPiece tokenizer training: trie optimization...") |
|
|
self.optimize(self.types) |
|
|
|
|
|
print(f"Built final vocabulary with {self.get_vocab_size()} tokens") |
|
|
print(f"Most common tokens: {list(self.types.items())[:20]}") |
|
|
|
|
|
def build_trie(self, wordpiece, root): |
|
|
if wordpiece[0] in root: |
|
|
root[wordpiece[0]]['##'] += 1 |
|
|
self.num_chars_in_trie += 1 |
|
|
if len(wordpiece) > 1: |
|
|
self.build_trie(wordpiece[1:], root[wordpiece[0]]) |
|
|
else: |
|
|
if 'END' not in root[wordpiece[0]]: |
|
|
root[wordpiece[0]]['END'] = None |
|
|
else: |
|
|
root[wordpiece[0]] = {} |
|
|
root[wordpiece[0]]['##'] = 1 |
|
|
if len(wordpiece) > 1: |
|
|
self.build_trie(wordpiece[1:], root[wordpiece[0]]) |
|
|
|
|
|
def set_special_tokens(self, list): |
|
|
for item in list: |
|
|
if item not in self.roots['[RSX]'].keys(): |
|
|
self.roots['[RSX]'][item] = {'IDX': None} |
|
|
self.roots['[RSX]'][item]['IDX'] = self.idx |
|
|
self.idx += 1 |
|
|
|
|
|
|
|
|
def optimize(self, words): |
|
|
for word, freq in words.items(): |
|
|
if freq >= self.min_frequency and self.idx <= self.vocab_size: |
|
|
self.tokens = [] |
|
|
self.suffixes = [] |
|
|
self.tokens_bf = [] |
|
|
self.suffixes_bf = [] |
|
|
self.tokens.append(word[0]) |
|
|
self.suffixes.append(word[len(word) - 1]) |
|
|
self.split_prefix(word, self.roots_unoptimized) |
|
|
if len(self.tokens) > 1: |
|
|
self.split_suffix(word[::-1], self.infls) |
|
|
self.suffixes = [word[::-1] for word in self.suffixes][::-1] |
|
|
self.tokenization_to_print += str(self.tokens) + '\t' + str(self.tokens_bf) + '\t' + str( |
|
|
self.suffixes) + '\t' + str(self.suffixes_bf) + '\n' |
|
|
for i in range(0, |
|
|
len(self.tokens)): |
|
|
if i == 0: |
|
|
self.last_item_in_trie = self.roots |
|
|
self.add_items_to_trie( |
|
|
self.tokens[0]) |
|
|
else: |
|
|
self.last_item_in_trie = self.roots['++'] |
|
|
self.add_items_to_trie( |
|
|
self.tokens[i]) |
|
|
if 'IDX' not in self.last_item_in_trie: |
|
|
self.last_item_in_trie['IDX'] = self.idx |
|
|
self.idx += 1 |
|
|
else: |
|
|
self.last_item_in_trie = self.roots |
|
|
self.add_items_to_trie(word) |
|
|
if 'IDX' not in self.last_item_in_trie: |
|
|
self.last_item_in_trie['IDX'] = self.idx |
|
|
self.idx += 1 |
|
|
|
|
|
self.build_vocab_lookup() |
|
|
|
|
|
def build_vocab_lookup(self): |
|
|
self.vocab_to_id = {} |
|
|
|
|
|
def traverse(trie, path): |
|
|
for k, v in trie.items(): |
|
|
if k == 'IDX': |
|
|
token = ''.join(path) |
|
|
self.vocab_to_id[token] = v |
|
|
elif isinstance(v, dict): |
|
|
traverse(v, path + [k]) |
|
|
|
|
|
traverse(self.roots, []) |
|
|
|
|
|
def encode(self, sentence: str): |
|
|
self.tokenized_words = [] |
|
|
words = sentence.strip().split() |
|
|
token_ids = [] |
|
|
for word in words: |
|
|
if word in self.roots['[RSX]']: |
|
|
token_ids.append(self.roots['[RSX]'][word]['IDX']) |
|
|
else: |
|
|
self.tokenized_word_longest = "" |
|
|
self.tokenized_word_idx_longest = None |
|
|
self.retrieve(word, self.roots) |
|
|
if self.tokenized_word_idx_longest is not None: |
|
|
token_ids.append(self.tokenized_word_idx_longest) |
|
|
else: |
|
|
token_ids.append(self.roots['[RSX]']['<unk>']['IDX']) |
|
|
return token_ids |
|
|
|
|
|
def decode(self, sentence_idxs): |
|
|
tokens = [] |
|
|
for idx in sentence_idxs: |
|
|
keys_path = find_idx_path(self.roots, idx) |
|
|
if keys_path: |
|
|
token = "".join(keys_path) |
|
|
if token.startswith('[RSX]'): |
|
|
token = token[5:] |
|
|
tokens.append(token) |
|
|
return tokens |
|
|
|
|
|
def retrieve(self, word, trie): |
|
|
self.longest_match_in_trie(word, trie) |
|
|
if self.tokenized_word_longest: |
|
|
self.tokenized_words.append([self.tokenized_word_longest, self.tokenized_word_idx_longest]) |
|
|
else: |
|
|
self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']]) |
|
|
|
|
|
def longest_match_in_trie(self, string, trie): |
|
|
if string[0] in trie: |
|
|
self.tokenized_word_longest += string[0] |
|
|
if 'IDX' in trie[string[0]]: |
|
|
self.tokenized_word_idx_longest = trie[string[0]]['IDX'] |
|
|
if len(string) > 1: |
|
|
self.longest_match_in_trie(string[1:], trie[string[0]]) |
|
|
else: |
|
|
|
|
|
if string[0] in self.roots['++'] and self.tokenized_word_idx_longest: |
|
|
self.tokenized_words.append([self.tokenized_word_longest + '++', self.tokenized_word_idx_longest]) |
|
|
self.tokenized_word_longest = '++' |
|
|
self.tokenized_word_idx_longest = 0 |
|
|
self.longest_match_in_trie(string, self.roots['++']) |
|
|
else: |
|
|
self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']]) |
|
|
self.tokenized_word_longest = None |
|
|
|
|
|
def split_prefix(self, word, trie): |
|
|
l = len(word) |
|
|
if l > 1: |
|
|
self.get_pair_in_trie(word[0], word[1], trie) |
|
|
if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf: |
|
|
self.tokens.append(word[1]) |
|
|
self.tokens_bf.append(word[0] + str(self.get_bf(trie[word[0]]))) |
|
|
else: |
|
|
self.tokens[len(self.tokens) - 1] = self.tokens[len(self.tokens) - 1] + word[1] |
|
|
if l > 2: |
|
|
self.split_prefix(word[1:], trie[word[0]]) |
|
|
|
|
|
def split_suffix(self, word, trie): |
|
|
l = len(word) |
|
|
if l > 1: |
|
|
self.get_pair_in_trie(word[0], word[1], trie) |
|
|
if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf: |
|
|
self.suffixes.append(word[1]) |
|
|
self.suffixes_bf.append(word[0] + str(self.get_bf(trie[word[0]]))) |
|
|
else: |
|
|
self.suffixes[len(self.suffixes) - 1] = self.suffixes[len(self.suffixes) - 1] + word[1] |
|
|
if l > 2: |
|
|
if word[0] in trie.keys(): |
|
|
self.split_suffix(word[1:], trie[word[0]]) |
|
|
|
|
|
def get_pair_in_trie(self, prefix, suffix, trie): |
|
|
self.n_prefix = 0 |
|
|
self.n_suffix = 0 |
|
|
if prefix in trie: |
|
|
if suffix in trie[prefix]: |
|
|
self.n_prefix = trie[prefix]["##"] |
|
|
self.n_suffix = trie[prefix][suffix]["##"] |
|
|
|
|
|
def check_tp(self, m, d): |
|
|
if not m > 1: |
|
|
return False |
|
|
else: |
|
|
tp = m / log(m) |
|
|
if self.cutoff <= m != d > tp: |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
|
|
|
def get_bf(self, m): |
|
|
keys = m.keys() |
|
|
n_keys = len(keys) |
|
|
for k in keys: |
|
|
if k in self.special_tokens: |
|
|
n_keys -= 1 |
|
|
return n_keys |
|
|
|
|
|
def add_items_to_trie(self, items): |
|
|
for item in items: |
|
|
self.add_item_to_trie(item) |
|
|
|
|
|
def add_item_to_trie(self, item): |
|
|
if item not in self.last_item_in_trie: |
|
|
self.last_item_in_trie[item] = {} |
|
|
self.last_item_in_trie = self.last_item_in_trie[item] |
|
|
|
|
|
def pad_sentence(sentence, l): |
|
|
""" |
|
|
Pads the given sentence with "[pad]" tokens at the beginning to reach the desired length. |
|
|
|
|
|
Parameters: |
|
|
- sentence (str): The original sentence to be padded. |
|
|
- l (int): The desired total number of tokens in the sentence after padding. |
|
|
|
|
|
Returns: |
|
|
- str: The padded sentence. |
|
|
""" |
|
|
words = sentence.split() |
|
|
n_pad = max(l - len(words), 0) |
|
|
pad_tokens = ["[pad]"] * n_pad |
|
|
padded_sentence = ' '.join(pad_tokens + words) |
|
|
return padded_sentence |
|
|
|
|
|
def get_num_chars_in_trie(self): |
|
|
return self.num_chars_in_trie |
|
|
|
|
|
def get_num_chars_in_corpus(self): |
|
|
return self.num_chars_in_corpus |
|
|
|
|
|
def get_vocab_size(self) -> int: |
|
|
return self.idx |
|
|
|
|
|
def get_vocab(self): |
|
|
return self.vocab_to_id.copy() |
|
|
|
|
|
def get_num_tokens_in_corpus(self): |
|
|
return self.num_tokens_in_corpus |
|
|
|
|
|
def get_num_types_in_corpus(self): |
|
|
return len(self.types) |
|
|
|
|
|
def get_compression_ratio(self): |
|
|
return round(self.num_chars_in_trie / self.num_chars_in_corpus, 3) |
|
|
|
|
|
def get_ttr(self): |
|
|
return round(len(self.types) / self.num_tokens_in_corpus, 3) |
|
|
|
|
|
def save(self, save_file): |
|
|
self.build_vocab_lookup() |
|
|
with open(save_file, 'w') as f: |
|
|
json.dump({ |
|
|
'roots': self.roots, |
|
|
'vocab': self.vocab_to_id |
|
|
}, f, indent=2) |
|
|
|
|
|
def from_pretrained(self, load_file): |
|
|
with open(load_file + '/tokenizer.json', 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
if isinstance(data, dict) and 'roots' in data: |
|
|
self.roots = data['roots'] |
|
|
self.vocab_to_id = data.get('vocab', {}) |
|
|
else: |
|
|
|
|
|
self.roots = data |
|
|
self.vocab_to_id = {} |
|
|
|
|
|
|
|
|
if '[RSX]' not in self.roots: |
|
|
raise ValueError("Invalid tokenizer format: Missing [RSX] root node.") |
|
|
|
|
|
def save_types(self, file): |
|
|
with open(file, 'w') as f: |
|
|
json.dump(self.types, f, indent=2) |
|
|
|
|
|
|
|
|
def sort_trie_by_freq(d): |
|
|
if not isinstance(d, dict): |
|
|
return d |
|
|
|
|
|
sorted_items = sorted( |
|
|
d.items(), |
|
|
key=lambda item: item[1].get('##', float('-inf')) if isinstance(item[1], dict) else float('-inf'), |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
d.clear() |
|
|
for k, v in sorted_items: |
|
|
d[k] = sort_trie_by_freq(v) |
|
|
return d |
|
|
|
|
|
|
|
|
def find_idx_path(d, target_value, path=None): |
|
|
if path is None: |
|
|
path = [] |
|
|
for key, value in d.items(): |
|
|
if key == 'IDX' and value == target_value: |
|
|
return path |
|
|
elif isinstance(value, dict): |
|
|
result = find_idx_path(value, target_value, path + [key]) |
|
|
if result is not None: |
|
|
return result |
|
|
return None |
|
|
|