emg-10m-conv_test / tokenizer_MorPiece.py
NeTS-lab's picture
Upload EMG model with MorPiece tokenizer
9e31d55 verified
import os
import json
from math import log
class MorPiece:
def __init__(self, vocab_size=30000, min_frequency=2, cutoff=8, bf=10, special_tokens=None):
self.tokenization_to_print = "TP left-right \t BF right-left \t TP right-left \t BP right-left\n" # for debugging only
if special_tokens is None:
special_tokens = ['<unk>', '<pad>', '<s>', '</s>']
self.special_tokens = special_tokens
self.reserved_keys = {'[RSX]', '##', 'IDX', '++'}
self.vocab_size = vocab_size
self.min_frequency = min_frequency
self.bf = bf
self.roots = {'[RSX]': {}, '++': {}}
self.roots_unoptimized = {}
self.infls = {}
self.types = {}
self.last_item_in_trie = {}
self.idx = 0
self.tokens = []
self.suffixes = []
self.tokens_bf = []
self.suffixes_bf = []
self.prefix = ""
self.n_prefix = 0
self.n_suffix = 0
self.tokenized_words = []
self.tokenized_word_longest = ""
self.tokenized_word_idx_longest = ""
self.cutoff = cutoff # ln(8) is > 2, so, non-branching paths will be ignored
self.num_tokens_in_corpus = 0
self.num_chars_in_corpus = 0
self.num_chars_in_trie = 0
self.num_chars_in_optimized_trie = 0
self.set_special_tokens(self.special_tokens)
def train(self, corpus: str): # create the vocabulary
words = corpus.split()
print("MorPiece tokenizer training: processing words...")
for word in words:
word_alpha = ''.join([char for char in word if char.isalpha() or char == "'"])
if not word_alpha:
word = ''.join([char for char in word])
else:
word = word_alpha
if word:
self.build_trie(word, self.roots_unoptimized) # create roots trie
self.build_trie(word[::-1], self.infls) # create inflections trie
if word not in self.types: # count tokens and chars in corpus
self.types[word] = 1
else:
self.types[word] += 1
self.num_tokens_in_corpus += 1
self.num_chars_in_corpus += len(word)
self.types = dict(sorted(self.types.items(), key=lambda item: item[1], reverse=True))
sort_trie_by_freq(self.roots_unoptimized)
sort_trie_by_freq(self.infls)
print("MorPiece tokenizer training: trie optimization...")
self.optimize(self.types)
print(f"Built final vocabulary with {self.get_vocab_size()} tokens")
print(f"Most common tokens: {list(self.types.items())[:20]}")
def build_trie(self, wordpiece, root): # build the trie and register # of traversals in '##'
if wordpiece[0] in root:
root[wordpiece[0]]['##'] += 1
self.num_chars_in_trie += 1
if len(wordpiece) > 1:
self.build_trie(wordpiece[1:], root[wordpiece[0]])
else:
if 'END' not in root[wordpiece[0]]:
root[wordpiece[0]]['END'] = None
else:
root[wordpiece[0]] = {}
root[wordpiece[0]]['##'] = 1
if len(wordpiece) > 1:
self.build_trie(wordpiece[1:], root[wordpiece[0]])
def set_special_tokens(self, list):
for item in list:
if item not in self.roots['[RSX]'].keys():
self.roots['[RSX]'][item] = {'IDX': None}
self.roots['[RSX]'][item]['IDX'] = self.idx
self.idx += 1
# assign idx based on word freq and add potential inflection links in the root trie, remove frequency at the end
def optimize(self, words):
for word, freq in words.items():
if freq >= self.min_frequency and self.idx <= self.vocab_size:
self.tokens = []
self.suffixes = []
self.tokens_bf = []
self.suffixes_bf = []
self.tokens.append(word[0])
self.suffixes.append(word[len(word) - 1])
self.split_prefix(word, self.roots_unoptimized)
if len(self.tokens) > 1:
self.split_suffix(word[::-1], self.infls)
self.suffixes = [word[::-1] for word in self.suffixes][::-1]
self.tokenization_to_print += str(self.tokens) + '\t' + str(self.tokens_bf) + '\t' + str(
self.suffixes) + '\t' + str(self.suffixes_bf) + '\n' # for debugging only
for i in range(0,
len(self.tokens)): # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
if i == 0:
self.last_item_in_trie = self.roots
self.add_items_to_trie(
self.tokens[0]) # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
else:
self.last_item_in_trie = self.roots['++']
self.add_items_to_trie(
self.tokens[i]) # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
if 'IDX' not in self.last_item_in_trie:
self.last_item_in_trie['IDX'] = self.idx
self.idx += 1
else:
self.last_item_in_trie = self.roots
self.add_items_to_trie(word)
if 'IDX' not in self.last_item_in_trie:
self.last_item_in_trie['IDX'] = self.idx
self.idx += 1
self.build_vocab_lookup()
def build_vocab_lookup(self):
self.vocab_to_id = {}
def traverse(trie, path):
for k, v in trie.items():
if k == 'IDX':
token = ''.join(path)
self.vocab_to_id[token] = v
elif isinstance(v, dict):
traverse(v, path + [k])
traverse(self.roots, [])
def encode(self, sentence: str):
self.tokenized_words = []
words = sentence.strip().split()
token_ids = []
for word in words:
if word in self.roots['[RSX]']:
token_ids.append(self.roots['[RSX]'][word]['IDX'])
else:
self.tokenized_word_longest = ""
self.tokenized_word_idx_longest = None
self.retrieve(word, self.roots)
if self.tokenized_word_idx_longest is not None:
token_ids.append(self.tokenized_word_idx_longest)
else:
token_ids.append(self.roots['[RSX]']['<unk>']['IDX'])
return token_ids
def decode(self, sentence_idxs):
tokens = []
for idx in sentence_idxs:
keys_path = find_idx_path(self.roots, idx)
if keys_path:
token = "".join(keys_path)
if token.startswith('[RSX]'):
token = token[5:]
tokens.append(token)
return tokens
def retrieve(self, word, trie):
self.longest_match_in_trie(word, trie)
if self.tokenized_word_longest:
self.tokenized_words.append([self.tokenized_word_longest, self.tokenized_word_idx_longest])
else:
self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']])
def longest_match_in_trie(self, string, trie):
if string[0] in trie:
self.tokenized_word_longest += string[0]
if 'IDX' in trie[string[0]]:
self.tokenized_word_idx_longest = trie[string[0]]['IDX']
if len(string) > 1:
self.longest_match_in_trie(string[1:], trie[string[0]])
else:
# print(string[0], self.tokenized_word_longest)
if string[0] in self.roots['++'] and self.tokenized_word_idx_longest:
self.tokenized_words.append([self.tokenized_word_longest + '++', self.tokenized_word_idx_longest])
self.tokenized_word_longest = '++'
self.tokenized_word_idx_longest = 0
self.longest_match_in_trie(string, self.roots['++'])
else:
self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']])
self.tokenized_word_longest = None
def split_prefix(self, word, trie):
l = len(word)
if l > 1:
self.get_pair_in_trie(word[0], word[1], trie)
if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf:
self.tokens.append(word[1])
self.tokens_bf.append(word[0] + str(self.get_bf(trie[word[0]])))
else:
self.tokens[len(self.tokens) - 1] = self.tokens[len(self.tokens) - 1] + word[1]
if l > 2:
self.split_prefix(word[1:], trie[word[0]])
def split_suffix(self, word, trie):
l = len(word)
if l > 1:
self.get_pair_in_trie(word[0], word[1], trie)
if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf: # verify if the
self.suffixes.append(word[1])
self.suffixes_bf.append(word[0] + str(self.get_bf(trie[word[0]])))
else:
self.suffixes[len(self.suffixes) - 1] = self.suffixes[len(self.suffixes) - 1] + word[1]
if l > 2:
if word[0] in trie.keys():
self.split_suffix(word[1:], trie[word[0]])
def get_pair_in_trie(self, prefix, suffix, trie):
self.n_prefix = 0
self.n_suffix = 0
if prefix in trie:
if suffix in trie[prefix]:
self.n_prefix = trie[prefix]["##"]
self.n_suffix = trie[prefix][suffix]["##"]
def check_tp(self, m, d): # verify if Tolerance Principle applies between m(other) and d(aughter) nodes
if not m > 1:
return False
else:
tp = m / log(m)
if self.cutoff <= m != d > tp:
return True
else:
return False
def get_bf(self, m): # return the branching factor of the mother node
keys = m.keys()
n_keys = len(keys)
for k in keys:
if k in self.special_tokens:
n_keys -= 1
return n_keys
def add_items_to_trie(self, items):
for item in items:
self.add_item_to_trie(item)
def add_item_to_trie(self, item):
if item not in self.last_item_in_trie:
self.last_item_in_trie[item] = {}
self.last_item_in_trie = self.last_item_in_trie[item]
def pad_sentence(sentence, l):
"""
Pads the given sentence with "[pad]" tokens at the beginning to reach the desired length.
Parameters:
- sentence (str): The original sentence to be padded.
- l (int): The desired total number of tokens in the sentence after padding.
Returns:
- str: The padded sentence.
"""
words = sentence.split()
n_pad = max(l - len(words), 0) # Ensure n_pad is not negative
pad_tokens = ["[pad]"] * n_pad
padded_sentence = ' '.join(pad_tokens + words)
return padded_sentence
def get_num_chars_in_trie(self):
return self.num_chars_in_trie
def get_num_chars_in_corpus(self):
return self.num_chars_in_corpus
def get_vocab_size(self) -> int:
return self.idx
def get_vocab(self):
return self.vocab_to_id.copy()
def get_num_tokens_in_corpus(self):
return self.num_tokens_in_corpus
def get_num_types_in_corpus(self):
return len(self.types)
def get_compression_ratio(self):
return round(self.num_chars_in_trie / self.num_chars_in_corpus, 3)
def get_ttr(self):
return round(len(self.types) / self.num_tokens_in_corpus, 3)
def save(self, save_file):
self.build_vocab_lookup()
with open(save_file, 'w') as f:
json.dump({
'roots': self.roots,
'vocab': self.vocab_to_id
}, f, indent=2)
def from_pretrained(self, load_file):
with open(load_file + '/tokenizer.json', 'r') as f:
data = json.load(f)
# Backward compatibility: if old format, data is just roots
if isinstance(data, dict) and 'roots' in data:
self.roots = data['roots']
self.vocab_to_id = data.get('vocab', {}) # fallback to empty dict if missing
else:
# Old format support (e.g., tokenizer.json only had roots)
self.roots = data
self.vocab_to_id = {}
# Ensure [RSX] exists
if '[RSX]' not in self.roots:
raise ValueError("Invalid tokenizer format: Missing [RSX] root node.")
def save_types(self, file):
with open(file, 'w') as f:
json.dump(self.types, f, indent=2)
def sort_trie_by_freq(d):
if not isinstance(d, dict):
return d
# Sort the dictionary items by the value of the nested key '##'
sorted_items = sorted(
d.items(),
key=lambda item: item[1].get('##', float('-inf')) if isinstance(item[1], dict) else float('-inf'),
reverse=True
)
# Clear the dictionary and update with sorted items
d.clear()
for k, v in sorted_items:
d[k] = sort_trie_by_freq(v)
return d
def find_idx_path(d, target_value, path=None):
if path is None:
path = []
for key, value in d.items():
if key == 'IDX' and value == target_value:
return path
elif isinstance(value, dict):
result = find_idx_path(value, target_value, path + [key])
if result is not None:
return result
return None