Spaces:
Runtime error
Runtime error
| from typing import Iterable, Union, Tuple | |
| from collections import Counter | |
| import argparse | |
| import os | |
| import yaml | |
| from pyarabic.araby import tokenize, strip_tatweel | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch as T | |
| from torch.utils.data import DataLoader | |
| from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line | |
| from model_partial import PartialDD | |
| from model_dd import DiacritizerD2 | |
| from data_utils import DatasetUtils | |
| from dataloader import DataRetriever | |
| from segment import segment | |
| class Predictor: | |
| def __init__(self, config, text): | |
| self.data_utils = DatasetUtils(config) | |
| vocab_size = len(self.data_utils.letter_list) | |
| word_embeddings = self.data_utils.embeddings | |
| stride = config["segment"]["stride"] | |
| window = config["segment"]["window"] | |
| min_window = config["segment"]["min-window"] | |
| segments, mapping = segment([text], stride, window, min_window) | |
| mapping_lines = [] | |
| for sent_idx, seg_idx, word_idx, char_idx in mapping: | |
| mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"] | |
| self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines) | |
| self.original_lines = [text] | |
| self.segments = segments | |
| self.device = T.device( | |
| config['predictor'].get('device', 'cuda:0') | |
| if T.cuda.is_available() else 'cpu' | |
| ) | |
| self.model = DiacritizerD2(config) | |
| self.model.build(word_embeddings, vocab_size) | |
| state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict'] | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.data_loader = DataLoader( | |
| DataRetriever(self.data_utils, segments), | |
| batch_size=config["predictor"].get("batch-size", 32), | |
| shuffle=False, | |
| num_workers=config['loader'].get('num-workers', 0), | |
| ) | |
| class PredictTri(Predictor): | |
| def __init__(self, config, text): | |
| super().__init__(config, text) | |
| self.diacritics = { | |
| "FATHA": 1, | |
| "KASRA": 2, | |
| "DAMMA": 3, | |
| "SUKUN": 4 | |
| } | |
| self.votes: Union[Counter[int], Counter[bool]] = Counter() | |
| def count_votes( | |
| self, | |
| things: Union[Iterable[int], Iterable[bool]] | |
| ): | |
| self.votes.clear() | |
| self.votes.update(things) | |
| return self.votes.most_common(1)[0][0] | |
| def predict_majority_vote(self): | |
| y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader) | |
| diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda) | |
| return diacritized_lines | |
| def predict_majority_vote_context_contrastive(self, overwrite_cache=False): | |
| assert isinstance(self.model, PartialDD) | |
| if not os.path.exists("dataset/cache/y_gen_diac.npy") or overwrite_cache: | |
| if not os.path.exists("dataset/cache"): | |
| os.mkdir("dataset/cache") | |
| # segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True) | |
| segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='ctxt') | |
| T.save(segment_outputs, "dataset/cache/cache.pt") | |
| else: | |
| segment_outputs = T.load("dataset/cache/cache.pt") | |
| y_gen_diac, y_gen_tanween, y_gen_shadda = segment_outputs['diacritics'] | |
| diacritized_lines, extra_for_lines = self.coalesce_votes_by_majority( | |
| y_gen_diac, y_gen_tanween, y_gen_shadda, | |
| ) | |
| extra_out = { | |
| 'line_data': { | |
| **extra_for_lines, | |
| }, | |
| 'segment_data': { | |
| **segment_outputs, | |
| # 'logits': segment_outputs['logits'], | |
| } | |
| } | |
| return diacritized_lines, extra_out | |
| def coalesce_votes_by_majority( | |
| self, | |
| y_gen_diac: np.ndarray, | |
| y_gen_tanween: np.ndarray, | |
| y_gen_shadda: np.ndarray, | |
| ): | |
| prepped_lines_og = [' '.join(tokenize(strip_tatweel(line))) for line in self.original_lines] | |
| max_line_chars = max(len(line) for line in prepped_lines_og) | |
| diacritics_pred = np.full((len(self.original_lines), max_line_chars), fill_value=-1, dtype=int) | |
| count_processed_sents = 0 | |
| do_break = False | |
| diacritized_lines = [] | |
| for sent_idx, line in enumerate(tqdm(prepped_lines_og)): | |
| count_processed_sents = sent_idx + 1 | |
| line = line.strip() | |
| diacritized_line = "" | |
| for char_idx, char in enumerate(line): | |
| diacritized_line += char | |
| char_vote_diacritic = [] | |
| # ? This is the voting part | |
| if sent_idx not in self.mapping: | |
| continue | |
| mapping_s_i = self.mapping[sent_idx] | |
| for seg_idx in mapping_s_i: | |
| if self.data_utils.debug and seg_idx >= 256: | |
| do_break = True | |
| break | |
| mapping_g_i = mapping_s_i[seg_idx] | |
| for t_idx in mapping_g_i: | |
| mapping_t_i = mapping_g_i[t_idx] | |
| if char_idx in mapping_t_i: | |
| c_idx = mapping_t_i.index(char_idx) | |
| output_idx = np.s_[seg_idx, t_idx, c_idx] | |
| diac_h3 = (y_gen_diac[output_idx], y_gen_tanween[output_idx], y_gen_shadda[output_idx]) | |
| diac_char_i = HARAKAT_MAP.index(diac_h3) | |
| if c_idx < 13 and diac_char_i != 0: | |
| char_vote_diacritic.append(diac_char_i) | |
| if do_break: | |
| break | |
| if len(char_vote_diacritic) > 0: | |
| char_mv_diac = self.count_votes(char_vote_diacritic) | |
| diacritized_line += shakkel_char(*HARAKAT_MAP[char_mv_diac]) | |
| diacritics_pred[sent_idx, char_idx] = char_mv_diac | |
| else: | |
| diacritics_pred[sent_idx, char_idx] = 0 | |
| if do_break: | |
| break | |
| diacritized_lines += [diacritized_line.strip()] | |
| print(f'[INFO] Cutting stats from {len(diacritics_pred)} to {count_processed_sents}') | |
| extra = { | |
| 'diac_pred': diacritics_pred[:count_processed_sents], | |
| } | |
| return diacritized_lines, extra |