| from typing import List |
| import json |
| from tokenizers import NormalizedString, PreTokenizedString |
| import re |
| from transformers import BertTokenizerFast |
| from .splinter_json import splinter_data |
| from tokenizers.pre_tokenizers import PreTokenizer, Sequence as PreTokenizerSequence |
| from tokenizers.decoders import Decoder |
|
|
| final_letters_map = { |
| 'ך': 'כ', |
| 'ם': 'מ', |
| 'ץ': 'צ', |
| 'ף': 'פ', |
| 'ן': 'נ', |
| 'כ': 'ך', |
| 'מ': 'ם', |
| 'צ': 'ץ', |
| 'פ': 'ף', |
| 'נ': 'ן' |
| } |
|
|
| def get_permutation(word, position, word_length): |
| if position < 0: |
| permutation = word[:word_length + position] + word[(word_length + position + 1):] |
| else: |
| permutation = word[:position] + word[(position + 1):] |
| return permutation |
|
|
| def replace_final_letters(text): |
| if text == '': return text |
| if text[-1] in final_letters_map: |
| return replace_last_letter(text, final_letters_map[text[-1]]) |
| return text |
| |
| def replace_last_letter(text, replacement): |
| return text[:-1] + replacement |
|
|
| def is_hebrew_letter(char): |
| return '\u05D0' <= char <= '\u05EA' |
|
|
| def is_word_contains_non_hebrew_letters(word) -> str: |
| return re.search(r'[^\u05D0-\u05EA]', word) is not None |
|
|
| class Splinter: |
| def __init__(self, path, use_cache=True): |
| if type(path) == str: |
| with open(path, 'r', encoding='utf-8-sig') as r: |
| parsed = json.loads(r.read()) |
| else: |
| parsed = path |
|
|
| self.reductions_map = {int(key): value for key, value in parsed['reductions_map'].items()} |
| self.new_unicode_chars_map = parsed['new_unicode_chars'] |
| self.new_unicode_chars_inverted_map = {v:k for k,v in self.new_unicode_chars_map.items()} |
| self.word_reductions_cache = dict() |
| self.use_cache = use_cache |
|
|
| def splinter_word(self, word: str): |
| if self.use_cache: |
| ret = self.word_reductions_cache.get(word, None) |
| if ret: return self.word_reductions_cache[word] |
| |
| clean_word = replace_final_letters(word) |
| |
| if len(clean_word) > 15 or is_word_contains_non_hebrew_letters(clean_word): |
| encoded_word = self.get_word_with_non_hebrew_chars_reduction(clean_word) |
| else: |
| word_reductions = self.get_word_reductions(clean_word) |
| encoded_word = ''.join([self.new_unicode_chars_map[reduction] for reduction in word_reductions]) |
| if self.use_cache: |
| self.word_reductions_cache[word] = encoded_word |
| return encoded_word |
|
|
| def unsplinter_word(self, word: str): |
| decoded_word = self.decode_word(word) |
| return self.rebuild_reduced_word(decoded_word) |
|
|
| def decode_word(self, word: str): |
| return [self.new_unicode_chars_inverted_map.get(char, char) for char in word] |
|
|
|
|
| def rebuild_reduced_word(self, decoded_word): |
| original_word = "" |
| for reduction in decoded_word: |
| if ':' in reduction and len(reduction) > 1: |
| position, letter = reduction.split(':') |
| position = int(position) |
| if position < 0: |
| position = len(original_word) + position + 1 |
| if len(original_word) == position - 1: |
| original_word += reduction |
| else: |
| original_word = original_word[:position] + letter + original_word[position:] |
| else: |
| original_word += reduction |
|
|
| original_word = replace_final_letters(original_word) |
| return original_word |
|
|
|
|
| def get_word_reductions(self, word): |
| reduced_word = word |
| reductions = [] |
| while len(reduced_word) > 3: |
| |
| if len(reduced_word) not in self.reductions_map: |
| reductions.extend(self.get_single_chars_reductions(reduced_word)) |
| break |
| reduction = self.get_reduction(reduced_word, 3, 3) |
| if reduction is not None: |
| position = int(reduction.split(':')[0]) |
| reductions.append(reduction) |
| reduced_word = get_permutation(reduced_word, position, len(reduced_word)) |
| |
| else: |
| reductions.extend(self.get_single_chars_reductions(reduced_word)) |
| break |
|
|
| |
| if len(reduced_word) < 4: |
| reductions.extend(self.get_single_chars_reductions(reduced_word)) |
|
|
| reductions.reverse() |
| return reductions |
|
|
| def get_reduction(self, word, depth, width): |
| curr_step_reductions = [{"word": word, "reduction": None, "root_reduction": None, "score": 1}] |
| word_length = len(word) |
| i = 0 |
| while i < depth and len(curr_step_reductions) > 0 and word_length > 3: |
| next_step_reductions = list() |
| for reduction in curr_step_reductions: |
| possible_reductions = self.get_most_frequent_reduction_keys( |
| reduction["word"], |
| reduction["root_reduction"], |
| reduction["score"], |
| width, |
| word_length |
| ) |
| next_step_reductions += possible_reductions |
| curr_step_reductions = list(next_step_reductions) |
| i += 1 |
| word_length -= 1 |
|
|
| max_score_reduction = None |
| if len(curr_step_reductions) > 0: |
| max_score_reduction = max(curr_step_reductions, key=lambda x: x["score"])["root_reduction"] |
| return max_score_reduction |
|
|
| def get_most_frequent_reduction_keys(self, word, root_reduction, parent_score, number_of_reductions, word_length): |
| possible_reductions = list() |
| for reduction, score in self.reductions_map[len(word)].items(): |
| position, letter = reduction.split(':') |
| position = int(position) |
| if word[position] == letter: |
| permutation = get_permutation(word, position, word_length) |
| possible_reductions.append({ |
| "word": permutation, |
| "reduction": reduction, |
| "root_reduction": root_reduction if root_reduction is not None else reduction, |
| "score": parent_score * score |
| }) |
| if len(possible_reductions) >= number_of_reductions: |
| break |
| return possible_reductions |
|
|
| def get_word_with_non_hebrew_chars_reduction(self, word): |
| return ''.join(self.new_unicode_chars_map[char] if is_hebrew_letter(char) else char for char in word) |
|
|
| @staticmethod |
| def get_single_chars_reductions(reduced_word): |
| reductions = [] |
| for char in reduced_word[::-1]: |
| reductions.append(char) |
| return reductions |
|
|
| class SplinterPreTokenizer: |
| def __init__(self, splinter: Splinter): |
| super().__init__() |
| self.splinter = splinter |
| |
| def splinter_split(self, i: int, str: NormalizedString): |
| |
| splintered_word = iter(self.splinter.splinter_word(str.normalized)) |
| str.map(lambda _: next(splintered_word, ' ')) |
| str.strip() |
| return [str] |
| |
| def pre_tokenize(self, pretok: PreTokenizedString): |
| pretok.split(self.splinter_split) |
| |
| class SplinterDecoder: |
| def __init__(self, splinter: Splinter): |
| self.splinter = splinter |
|
|
| def decode_chain(self, tokens: List[str]) -> List[str]: |
| |
| combined_tokens = [] |
| for token in tokens: |
| if token.startswith('##') and combined_tokens: |
| combined_tokens[-1] += token[2:] |
| else: combined_tokens.append(token) |
|
|
| return [f' {t}' for t in map(self.splinter.unsplinter_word, combined_tokens)] |
| |
|
|
| class SplinterBertTokenizerFast(BertTokenizerFast): |
| def __init__(self, *args, use_cache=False, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.splinter = Splinter(splinter_data, use_cache=use_cache) |
| self._tokenizer.pre_tokenizer = PreTokenizerSequence([ |
| self._tokenizer.pre_tokenizer, |
| PreTokenizer.custom(SplinterPreTokenizer(self.splinter)) |
| ]) |
| self._tokenizer.decoder = Decoder.custom(SplinterDecoder(self.splinter)) |
|
|
| def save_pretrained(self, *args, **kwargs): |
| self._save_pretrained(*args, **kwargs) |
|
|
| def _save_pretrained(self, *args, **kwargs): |
| print('Cannot save SplinterBertTokenizerFast, please copy the files directly from the repository') |
|
|