| | from typing import List, Tuple, Any |
| |
|
| | import os |
| | from functools import lru_cache |
| |
|
| | from pyarabic.araby import tokenize, strip_tashkeel |
| |
|
| | import numpy as np |
| | import torch as T |
| | from torch.utils.data import Dataset |
| |
|
| | try: |
| | from transformers import PreTrainedTokenizer |
| | except: |
| | from typing import Any as PreTrainedTokenizer |
| |
|
| | from data_utils import DatasetUtils |
| | import diac_utils as du |
| |
|
| | class DataRetriever(Dataset): |
| | def __init__( |
| | self, |
| | lines, |
| | data_utils: DatasetUtils, |
| | is_test: bool = False, |
| | *, |
| | tokenizer: PreTrainedTokenizer, |
| | lines_mode: bool = False, |
| | **kwargs, |
| | ): |
| | super(DataRetriever).__init__() |
| |
|
| | self.data_utils = data_utils |
| | self.is_test = is_test |
| | self.tokenizer = tokenizer |
| | |
| | self.stride = data_utils.test_stride |
| | |
| | self.data_points = lines |
| |
|
| | self.bos_token_id = int(self.tokenizer.bos_token_id or self.tokenizer.cls_token_id) |
| | self.eos_token_id = int(self.tokenizer.eos_token_id or self.tokenizer.sep_token_id) |
| |
|
| | self.max_tokens = self.data_utils.max_token_count |
| | self.max_slen = self.data_utils.max_sent_len |
| | self.max_wlen = self.data_utils.max_word_len |
| | |
| | self.p_val = self.tokenizer.pad_token_id |
| | self.pc_val = self.data_utils.pad_char_id |
| | self.pt_val = self.data_utils.pad_target_val |
| | |
| | self.char_x_padding = [self.pc_val] * self.max_wlen |
| | self.diac_x_padding = [[self.pc_val]*8] * self.max_wlen |
| | self.diac_y_padding = [self.pt_val] * self.max_wlen |
| |
|
| | def preprocess(self, data, dtype=T.long): |
| | return [T.tensor(np.array(x), dtype=dtype) for x in data] |
| |
|
| | def __len__(self): |
| | return len(self.data_points) |
| |
|
| | @lru_cache(maxsize=1024 * 2) |
| | def __getitem__(self, idx: int) -> Tuple[List[T.Tensor], T.Tensor, T.Tensor]: |
| | word_x, char_x, diac_x, diac_y, subword_lengths = self.create_sentence(idx) |
| | return ( |
| | self.preprocess([word_x, char_x, diac_x]), |
| | T.tensor(diac_y, dtype=T.long), |
| | T.tensor(subword_lengths, dtype=T.long) |
| | ) |
| |
|
| | def create_sentence(self, idx): |
| | line = self.data_points[idx] |
| | |
| | words: List[str] = tokenize(line.strip()) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | subwords_x = [self.bos_token_id] |
| | subword_lengths = [] |
| |
|
| | char_x = [] |
| | diac_x = [] |
| | diac_y = [] |
| | diac_y_tmp = [] |
| |
|
| | for i_word, word in enumerate(words): |
| | word = du.strip_unknown_tashkeel(word) |
| | word_chars = du.split_word_on_characters_with_diacritics(word) |
| | cx, cy, cy_3head = du.create_label_for_word(word_chars) |
| |
|
| | word_strip = strip_tashkeel(word) |
| | |
| | |
| | word_sub_ids = self.tokenizer(word_strip)['input_ids'][1:-1] |
| | subword_lengths += [len(word_sub_ids)] |
| |
|
| | subwords_x += word_sub_ids |
| | |
| |
|
| | char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.max_wlen)] |
| |
|
| | diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.max_wlen, pad=self.data_utils.pad_target_val)] |
| | diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.max_wlen, pad=[self.data_utils.pad_target_val]*3)] |
| |
|
| | assert len(char_x) == len(subword_lengths), f"{char_x=}; {subword_lengths=} ;;" |
| | assert len(char_x) == len(words) |
| |
|
| | diac_x = self.data_utils.create_decoder_input(diac_y_tmp) |
| |
|
| | subwords_x += [self.eos_token_id] |
| | |
| | assert len(subword_lengths) == len(words) |
| | subwords_x = self.data_utils.pad_and_truncate_sequence(subwords_x, self.max_tokens, pad=self.p_val) |
| | subword_lengths = self.data_utils.pad_and_truncate_sequence(subword_lengths, self.max_slen, pad=0) |
| |
|
| | char_x = self.data_utils.pad_and_truncate_sequence(char_x, self.max_slen, pad=self.char_x_padding) |
| | diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, self.max_slen, pad=self.diac_x_padding) |
| | diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, self.max_slen, pad=self.diac_y_padding) |
| |
|
| | return subwords_x, char_x, diac_x, diac_y, subword_lengths |