Spaces:
Runtime error
Runtime error
| import os | |
| import pickle | |
| import numpy as np | |
| from tqdm import tqdm | |
| from prettytable import PrettyTable | |
| from pyarabic.araby import tokenize, strip_tashkeel | |
| import diac_utils as du | |
| class DatasetUtils: | |
| def __init__(self, config): | |
| self.base_path = config["paths"]["base"] | |
| self.special_tokens = ['<pad>', '<unk>', '<num>', '<punc>'] | |
| self.delimeters = config["sentence-break"]["delimeters"] | |
| self.load_constants(config["paths"]["constants"]) | |
| self.debug = config["debug"] | |
| self.stride = config["sentence-break"]["stride"] | |
| self.window = config["sentence-break"]["window"] | |
| self.val_stride = config["sentence-break"].get("val-stride", self.stride) | |
| self.test_stride = config["predictor"]["stride"] | |
| self.test_window = config["predictor"]["window"] | |
| self.max_word_len = config["train"]["max-word-len"] | |
| self.max_sent_len = config["train"]["max-sent-len"] | |
| self.max_token_count = config["train"]["max-token-count"] | |
| self.pad_target_val = -100 | |
| self.pad_char_id = du.LETTER_LIST.index('<pad>') | |
| self.markov_signal = config['train'].get('markov-signal', False) | |
| self.batch_first = config['train'].get('batch-first', True) | |
| self.gt_prob = config["predictor"]["gt-signal-prob"] | |
| if self.gt_prob > 0: | |
| self.s_idx = config["predictor"]["seed-idx"] | |
| subpath = f"test_gt_mask_{self.gt_prob}_{self.s_idx}.txt" | |
| mask_path = os.path.join(self.base_path, "test", subpath) | |
| with open(mask_path, 'r') as fin: | |
| self.gt_mask = fin.readlines() | |
| if "word-embs" in config["paths"] and config["paths"]["word-embs"].strip() != "": | |
| self.pad_val = self.special_tokens.index("<pad>") | |
| self.embeddings, self.vocab = self.load_embeddings(config["paths"]["word-embs"], config["loader"]["wembs-limit"]) | |
| self.embeddings = self.normalize(self.embeddings, ["unit", "centeremb", "unit"]) | |
| self.w2idx = {word: i for i, word in enumerate(self.vocab)} | |
| def load_file(self, path): | |
| with open(path, 'rb') as f: | |
| return list(pickle.load(f)) | |
| def normalize(self, matrix, actions, mean=None): | |
| def length_normalize(matrix): | |
| norms = np.sqrt(np.sum(matrix**2, axis=1)) | |
| norms[norms == 0] = 1 | |
| matrix = matrix / norms[:, np.newaxis] | |
| return matrix | |
| def mean_center(matrix): | |
| return matrix - mean | |
| def length_normalize_dimensionwise(matrix): | |
| norms = np.sqrt(np.sum(matrix**2, axis=0)) | |
| norms[norms == 0] = 1 | |
| matrix = matrix / norms | |
| return matrix | |
| def mean_center_embeddingwise(matrix): | |
| avg = np.mean(matrix, axis=1) | |
| matrix = matrix - avg[:, np.newaxis] | |
| return matrix | |
| for action in actions: | |
| if action == 'unit': | |
| matrix = length_normalize(matrix) | |
| elif action == 'center': | |
| matrix = mean_center(matrix) | |
| elif action == 'unitdim': | |
| matrix = length_normalize_dimensionwise(matrix) | |
| elif action == 'centeremb': | |
| matrix = mean_center_embeddingwise(matrix) | |
| return matrix | |
| def load_constants(self, path): | |
| # self.numbers = [c for c in "0123456789"] | |
| # self.letter_list = self.special_tokens + self.load_file(os.path.join(path, 'ARABIC_LETTERS_LIST.pickle')) | |
| # self.diacritic_list = [' '] + self.load_file(os.path.join(path, 'DIACRITICS_LIST.pickle')) | |
| self.numbers = du.NUMBERS | |
| self.letter_list = du.LETTER_LIST | |
| self.diacritic_list = du.DIACRITICS_SHORT | |
| def split_word_on_characters_with_diacritics(self, word: str): | |
| return du.split_word_on_characters_with_diacritics(word) | |
| def load_mapping_v3(self, dtype, file_ext=None): | |
| mapping = {} | |
| if file_ext is None: | |
| file_ext = f"-{self.test_stride}-{self.test_window}.map" | |
| f_name = os.path.join(self.base_path, dtype, dtype + file_ext) | |
| with open(f_name, 'r') as fin: | |
| for line in fin: | |
| sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) | |
| if sent_idx not in mapping: | |
| mapping[sent_idx] = {} | |
| if seg_idx not in mapping[sent_idx]: | |
| mapping[sent_idx][seg_idx] = {} | |
| if t_idx not in mapping[sent_idx][seg_idx]: | |
| mapping[sent_idx][seg_idx][t_idx] = [] | |
| mapping[sent_idx][seg_idx][t_idx] += [c_idx] | |
| return mapping | |
| def load_mapping_v3_from_list(self, mapping_list): | |
| mapping = {} | |
| for line in mapping_list: | |
| sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) | |
| if sent_idx not in mapping: | |
| mapping[sent_idx] = {} | |
| if seg_idx not in mapping[sent_idx]: | |
| mapping[sent_idx][seg_idx] = {} | |
| if t_idx not in mapping[sent_idx][seg_idx]: | |
| mapping[sent_idx][seg_idx][t_idx] = [] | |
| mapping[sent_idx][seg_idx][t_idx] += [c_idx] | |
| return mapping | |
| def load_embeddings(self, embs_path, limit=-1): | |
| if self.debug: | |
| return np.zeros((200+len(self.special_tokens),300)), self.special_tokens + ["c"] * 200 | |
| words = [self.special_tokens[0]] | |
| print(f"[INFO] Reading Embeddings from {embs_path}") | |
| with open(embs_path, encoding='utf-8', mode='r') as fin: | |
| n, d = map(int, fin.readline().split()) | |
| limit = n if limit <= 0 else limit | |
| embeddings = np.zeros((limit+1, d)) | |
| for i, line in tqdm(enumerate(fin), total=limit): | |
| if i >= limit: break | |
| tokens = line.rstrip().split() | |
| words += [tokens[0]] | |
| embeddings[i+1] = list(map(float, tokens[1:])) | |
| return embeddings, words | |
| def load_file_clean(self, dtype, strip=False): | |
| f_name = os.path.join(self.base_path, dtype, dtype + ".txt") | |
| with open(f_name, 'r', encoding="utf-8", newline='\n') as fin: | |
| if strip: | |
| original_lines = [strip_tashkeel(self.preprocess(line)) for line in fin.readlines()] | |
| else: | |
| original_lines = [self.preprocess(line) for line in fin.readlines()] | |
| return original_lines | |
| def preprocess(self, line): | |
| return ' '.join(tokenize(line)) | |
| def pad_and_truncate_sequence(self, tokens, max_len, pad=None): | |
| if pad is None: | |
| pad = self.special_tokens.index("<pad>") | |
| if len(tokens) < max_len: | |
| offset = max_len - len(tokens) | |
| return tokens + [pad] * offset | |
| else: | |
| return tokens[:max_len] | |
| def stats(self, freq, percentile=90, name="stats"): | |
| table = PrettyTable(["Dataset", "Mean", "Std", "Min", "Max", f"{percentile}th Percentile"]) | |
| freq = np.array(sorted(freq)) | |
| table.add_row([name, freq.mean(), freq.std(), freq.min(), freq.max(), np.percentile(freq, percentile)]) | |
| print(table) | |
| def create_gt_mask(self, lines, prob, idx, seed=1111): | |
| np.random.seed(seed) | |
| gt_masks = [] | |
| for line in lines: | |
| tokens = tokenize(line.strip()) | |
| gt_mask_token = "" | |
| for t_idx, token in enumerate(tokens): | |
| gt_mask_token += ''.join(map(str, np.random.binomial(1, prob, len(token)))) | |
| if t_idx+1 < len(tokens): | |
| gt_mask_token += " " | |
| gt_masks += [gt_mask_token] | |
| subpath = f"test_gt_mask_{prob}_{idx}.txt" | |
| mask_path = os.path.join(self.base_path, "test", subpath) | |
| with open(mask_path, 'w') as fout: | |
| fout.write('\n'.join(gt_masks)) | |
| def create_gt_labels(self, lines): | |
| gt_labels = [] | |
| for line in lines: | |
| gt_labels_line = [] | |
| tokens = tokenize(line.strip()) | |
| for w_idx, word in enumerate(tokens): | |
| split_word = self.split_word_on_characters_with_diacritics(word) | |
| _, cy_flat, _ = du.create_label_for_word(split_word) | |
| gt_labels_line.extend(cy_flat) | |
| if w_idx+1 < len(tokens): | |
| gt_labels_line += [0] | |
| gt_labels += [gt_labels_line] | |
| return gt_labels | |
| def get_ce(self, diac_word_y, e_idx=None, return_idx=False): | |
| #^ diac_word_y: [Tw 3] | |
| if e_idx is None: e_idx = len(diac_word_y) | |
| for c_idx in reversed(range(e_idx)): | |
| if diac_word_y[c_idx] != [0,0,0]: | |
| return diac_word_y[c_idx] if not return_idx else c_idx | |
| return diac_word_y[e_idx-1] if not return_idx else e_idx-1 | |
| def create_decoder_input(self, diac_code_y, prob=0): | |
| #^ diac_code_y: [Ts Tw 3] | |
| diac_code_x = np.zeros((*np.array(diac_code_y).shape[:-1], 8)) | |
| if not self.markov_signal: | |
| return list(diac_code_x) | |
| prev_ce = list(np.eye(6)[-1]) + [0,0] # bos tag | |
| for w_idx, word in enumerate(diac_code_y): | |
| diac_code_x[w_idx, 0, :] = prev_ce | |
| for c_idx, char in enumerate(word[:-1]): | |
| # if np.random.rand() < prob: | |
| # continue | |
| if char[0] == self.pad_target_val: | |
| break | |
| haraka = list(np.eye(6)[char[0]]) | |
| diac_code_x[w_idx, c_idx+1, :] = haraka + char[1:] | |
| ce = self.get_ce(diac_code_y[w_idx], c_idx) | |
| prev_ce = list(np.eye(6)[ce[0]]) + ce[1:] | |
| return list(diac_code_x) |