| from typing import List |
|
|
| import torch as T |
| import numpy as np |
|
|
| from pyarabic.araby import ( |
| tokenize, |
| strip_tashkeel, |
| strip_tatweel, |
| DIACRITICS |
| ) |
|
|
| SEPARATE_DIACRITICS = { |
| "FATHA": 1, |
| "KASRA": 2, |
| "DAMMA": 3, |
| "SUKUN": 4 |
| } |
|
|
| HARAKAT_MAP = [ |
| |
| (0,0,0), |
| (1,0,0), |
| (1,1,0), |
| (2,0,0), |
| (2,1,0), |
| (3,0,0), |
| (3,1,0), |
| (4,0,0), |
| (0,0,1), |
| (1,0,1), |
| (1,1,1), |
| (2,0,1), |
| (2,1,1), |
| (3,0,1), |
| (3,1,1), |
| (0,0,0), |
| ] |
|
|
| DIAC_PAD_IDX = -1 |
|
|
| SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>'] |
| LETTER_LIST = SPECIAL_TOKENS + list("ุกุขุฃุคุฅุฆุงุจุฉุชุซุฌุญุฎุฏุฐุฑุฒุณุดุตุถุทุธุนุบููููู
ููููู") |
| CLASSES_LIST = [' ', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ูู', 'ูู', 'ูู', 'ูู', 'ูู', 'ูู'] |
| DIACRITICS_SHORT = [' ', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู'] |
| NUMBERS = list("0123456789") |
| DELIMITERS = ["ุ","ุ",",",";","ยซ","ยป","{","}","(",")","[","]",".","*","-",":","?","!","ุ"] |
|
|
| UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT))) |
|
|
| def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str: |
| returned_text = "" |
| if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]: |
| returned_text += "\u0651" |
|
|
| if diac == SEPARATE_DIACRITICS["FATHA"]: |
| returned_text += "\u064E" if not tanween else "\u064B" |
| elif diac == SEPARATE_DIACRITICS["KASRA"]: |
| returned_text += "\u0650" if not tanween else "\u064D" |
| elif diac == SEPARATE_DIACRITICS["DAMMA"]: |
| returned_text += "\u064F" if not tanween else "\u064C" |
| elif diac == SEPARATE_DIACRITICS["SUKUN"]: |
| returned_text += "\u0652" |
|
|
| return returned_text |
|
|
| def diac_ids_of_line(line: str): |
| diacs = [] |
| words = tokenize(line) |
| for word in words: |
| word_chars = split_word_on_characters_with_diacritics(word) |
| _cx, cy, _cy_3head = create_label_for_word(word_chars) |
| diacs.extend(cy) |
| diacs.append(DIAC_PAD_IDX) |
| return np.array(diacs[:-1]) |
|
|
| def strip_unknown_tashkeel(word: str): |
| |
| return word |
| return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS) |
|
|
| def create_gt_labels(lines): |
| gt_labels = [] |
| for line in lines: |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| gt_labels_line = diac_ids_of_line(line) |
| gt_labels.append(gt_labels_line) |
| return gt_labels |
|
|
| def split_word_on_characters_with_diacritics(word: str): |
| ''' |
| TODO! Make faster without deque and looping |
| Returns: List[List[char: "letter or diacritic"]] |
| ''' |
| chars_w_diac = [] |
| i_start = 0 |
| for i_c, c in enumerate(word): |
| |
| |
| |
| if c not in DIACRITICS_SHORT: |
| sub = list(word[i_start:i_c]) |
| chars_w_diac.append(sub) |
| i_start = i_c |
| sub = list(word[i_start:]) |
| if sub: |
| chars_w_diac.append(sub) |
| if not chars_w_diac[0]: |
| chars_w_diac = chars_w_diac[1:] |
| return chars_w_diac |
|
|
|
|
| def load_lines(path: str, *, strip: bool): |
| with open(path, 'r', encoding="utf-8", newline='\n') as fin: |
| if strip: |
| original_lines = [strip_tashkeel(normalize_spaces(line)) for line in fin.readlines()] |
| else: |
| original_lines = [normalize_spaces(line) for line in fin.readlines()] |
| return original_lines |
|
|
| def normalize_spaces(line: str): |
| return ' '.join(tokenize(line.strip())) |
|
|
|
|
| def char_type(char: str): |
| if char in LETTER_LIST: |
| return LETTER_LIST.index(char) |
| elif char in NUMBERS: |
| return LETTER_LIST.index('<num>') |
| elif char in DELIMITERS: |
| return LETTER_LIST.index('<punc>') |
| else: |
| return LETTER_LIST.index('<unk>') |
|
|
| def create_labels(char_w_diac: str): |
| remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4} |
| char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:])) |
| if len(char_w_diac) > 3: |
| char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3] |
|
|
| char_idx = None |
| diacritic_index = None |
| head_3 = None |
|
|
| char_idx = char_type(char_w_diac[0]) |
| diacs = set(char_w_diac[1:]) |
| diac_h3 = [0, 0, 0] |
| for diac in diacs: |
| if diac in DIACRITICS_SHORT: |
| diac_idx = DIACRITICS_SHORT.index(diac) |
| if diac_idx in [2, 4, 6]: |
| diac_h3[0] = remap_dict[diac_idx - 1] |
| diac_h3[1] = 1 |
| elif diac_idx == 8: |
| diac_h3[2] = 1 |
| else: |
| diac_h3[0] = remap_dict[diac_idx] |
| assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2])) |
| diacritic_index = HARAKAT_MAP.index(tuple(diac_h3)) |
| return char_idx, diacritic_index, diac_h3 |
| if len(char_w_diac) == 1: |
| return char_idx, 0, [remap_dict[0], 0, 0] |
| elif len(char_w_diac) == 2: |
| diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) |
| if diacritic_index in [2, 4, 6]: |
| head_3 = [remap_dict[diacritic_index - 1], 1, 0] |
| elif diacritic_index == 8: |
| head_3 = [0, 0, 1] |
| else: |
| head_3 = [remap_dict[diacritic_index], 0, 0] |
| elif len(char_w_diac) == 3: |
| if DIACRITICS_SHORT[8] == char_w_diac[1]: |
| diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2]) |
| else: |
| diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) |
|
|
| if diacritic_index in [2, 4, 6]: |
| head_3 = [remap_dict[diacritic_index - 1], 1, 1] |
| else: |
| head_3 = [remap_dict[diacritic_index], 0, 1] |
| diacritic_index = diacritic_index+8 |
|
|
| return char_idx, diacritic_index, head_3 |
|
|
| def create_label_for_word(split_word: List[List[str]]): |
| word_char_indices = [] |
| word_diac_indices = [] |
| word_diac_indices_h3 = [] |
| for char_w_diac in split_word: |
| char_idx, diac_idx, diac_h3 = create_labels(char_w_diac) |
| if char_idx == None: |
| print(split_word) |
| raise ValueError(char_idx) |
| word_char_indices.append(char_idx) |
| word_diac_indices.append(diac_idx) |
| word_diac_indices_h3.append(diac_h3) |
| return word_char_indices, word_diac_indices, word_diac_indices_h3 |
|
|
|
|
| def flat_2_3head(output: T.Tensor): |
| ''' |
| output: [b tw tc] |
| ''' |
| haraka, tanween, shadda = [], [], [] |
|
|
| |
| |
|
|
| b, ts, tw = output.shape |
|
|
| for b_idx in range(b): |
| h_s, t_s, s_s = [], [], [] |
| for w_idx in range(ts): |
| h_w, t_w, s_w = [], [], [] |
| for c_idx in range(tw): |
| c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])] |
| h_w += [c[0]] |
| t_w += [c[1]] |
| s_w += [c[2]] |
| h_s += [h_w] |
| t_s += [t_w] |
| s_s += [s_w] |
|
|
| haraka += [h_s] |
| tanween += [t_s] |
| shadda += [s_s] |
|
|
|
|
| return haraka, tanween, shadda |
|
|
| def flat2_3head(diac_idx): |
| ''' |
| diac_idx: [tw] |
| ''' |
| haraka, tanween, shadda = [], [], [] |
| |
| |
|
|
| for diac in diac_idx: |
| c_out = HARAKAT_MAP[diac] |
| haraka += [c_out[0]] |
| tanween += [c_out[1]] |
| shadda += [c_out[2]] |
|
|
| return np.array(haraka), np.array(tanween), np.array(shadda) |