| from typing import List, Tuple |
| from transformers import PreTrainedTokenizerFast |
| import re |
| import fast_disambig |
|
|
| _TATWEEL_RE = re.compile(r"\u0640") |
| _ALIF_RE = re.compile(r"[آأإٱ]") |
| _ALIF_MAK_RE = re.compile(r"ى") |
| _TEH_MARB_RE = re.compile(r"ة") |
| _ZERO_WIDTH_RE = re.compile(r"[\u200B-\u200D\u200E\u200F\uFEFF]") |
| ARABIC_DIACRITICS = { |
| "ً", "ٌ", "ٍ", |
| "َ", "ُ", "ِ", |
| "ّ", "ْ", |
| "ٗ", "٘", "ٙ", "ٚ", "ٛ", "ٜ", "ٝ", "ٞ", "ٟ", |
| "ؐ", "ؑ", "ؒ", "ؓ", "ؔ", "ؕ", "ؖ", "ؗ", "ؘ", "ؙ", "ؚ", |
| "ۖ", "ۗ", "ۘ", "ۙ", "ۚ", "ۛ", "ۜ", "۟", "۠", "ۡ", "ۢ", "ۣ", "ۤ", "ۧ", "ۨ", |
| "۪", "۫", "۬", "ۭ", |
| } |
|
|
| def separate_diacritics(text): |
| tokens = re.split(r'(\s+|\[\+\])', text) |
| processed_tokens = [] |
|
|
| for token in tokens: |
| if not token: |
| continue |
| if token.isspace() or token == '[+]': |
| processed_tokens.append(token) |
| continue |
|
|
| if not any(c in ARABIC_DIACRITICS for c in token): |
| processed_tokens.append(token) |
| continue |
|
|
| base_chars = [] |
| diac_groups = [] |
|
|
| for char in token: |
| if char in ARABIC_DIACRITICS: |
| if not diac_groups: |
| base_chars.append(" ") |
| diac_groups.append([]) |
| diac_groups[-1].append(char) |
| else: |
| base_chars.append(char) |
| diac_groups.append([]) |
|
|
| base_word = "".join(base_chars) |
| diac_string = [] |
| for group in diac_groups: |
| if group: |
| diac_string.append("".join(group)) |
| else: |
| diac_string.append("◌") |
|
|
| processed_tokens.append(base_word + " " + "".join(diac_string)) |
| return "".join(processed_tokens) |
|
|
| def normalize_arabic(text): |
| text = _TATWEEL_RE.sub("", text) |
| text = _ZERO_WIDTH_RE.sub("", text) |
| text = _ALIF_RE.sub("ا", text) |
| text = _ALIF_MAK_RE.sub("ي", text) |
| text = _TEH_MARB_RE.sub("ه", text) |
| return text |
|
|
| class ArabicMorphTokenizer(PreTrainedTokenizerFast): |
| slow_tokenizer_class = None |
|
|
| def __init__(self, tokenizer_file=None, apply_stemming=True, **kwargs): |
| super().__init__(tokenizer_file=tokenizer_file, **kwargs) |
| self.apply_stemming = apply_stemming |
| if self.apply_stemming: |
| self.stemmer = fast_disambig.camel.Stemmer() |
|
|
|
|
| def _preprocess_one(self, s, do_stem): |
| if isinstance(s, (list, tuple)): |
| return [self._preprocess_one(x, do_stem) for x in s] |
| if do_stem: |
| s = self.stemmer.stem(s, preserve_diacritics=True) |
| s = normalize_arabic(s) |
| s = separate_diacritics(s) |
| return s |
|
|
| def _preprocess_pair(self, text, text_pair, do_stem): |
| def maybe(s): |
| return self._preprocess_one(s, do_stem) if isinstance(s, str) else s |
| if isinstance(text, (list, tuple)): |
| text = [maybe(x) for x in text] |
| else: |
| text = maybe(text) |
| if isinstance(text_pair, (list, tuple)): |
| text_pair = [maybe(x) for x in text_pair] |
| else: |
| text_pair = maybe(text_pair) |
| return text, text_pair |
|
|
| def _pop_flag(self, kwargs): |
| v = kwargs.pop("apply_stemming", None) |
| return self.apply_stemming if v is None else bool(v) |
|
|
| def __call__(self, text=None, text_pair=None, *args, **kwargs): |
| flag = self._pop_flag(kwargs) |
| if not getattr(self, "_processing", False): |
| self._processing = True |
| try: |
| text, text_pair = self._preprocess_pair(text, text_pair, flag) |
| return super().__call__(text=text, text_pair=text_pair, *args, **kwargs) |
| finally: |
| self._processing = False |
| return super().__call__(text=text, text_pair=text_pair, *args, **kwargs) |
|
|
| def encode(self, text, text_pair=None, *args, **kwargs): |
| flag = self._pop_flag(kwargs) |
| if not getattr(self, "_processing", False): |
| self._processing = True |
| try: |
| text, text_pair = self._preprocess_pair(text, text_pair, flag) |
| return super().encode(text, text_pair, *args, **kwargs) |
| finally: |
| self._processing = False |
| return super().encode(text, text_pair, *args, **kwargs) |
|
|
| def encode_plus(self, text=None, text_pair=None, *args, **kwargs): |
| flag = self._pop_flag(kwargs) |
| if not getattr(self, "_processing", False): |
| self._processing = True |
| try: |
| text, text_pair = self._preprocess_pair(text, text_pair, flag) |
| return super().encode_plus(text=text, text_pair=text_pair, *args, **kwargs) |
| finally: |
| self._processing = False |
| return super().encode_plus(text=text, text_pair=text_pair, *args, **kwargs) |
|
|
| def batch_encode_plus(self, batch_text_or_text_pairs=None, *args, **kwargs): |
| flag = self._pop_flag(kwargs) |
| if not getattr(self, "_processing", False): |
| self._processing = True |
| try: |
| data = batch_text_or_text_pairs |
| if isinstance(data, (list, tuple)): |
| new_data = [] |
| for item in data: |
| if isinstance(item, (list, tuple)) and len(item) == 2: |
| new_data.append(self._preprocess_pair(item[0], item[1], flag)) |
| else: |
| new_data.append(self._preprocess_one(item, flag)) |
| batch_text_or_text_pairs = new_data |
| return super().batch_encode_plus(batch_text_or_text_pairs=batch_text_or_text_pairs, *args, **kwargs) |
| finally: |
| self._processing = False |
| return super().batch_encode_plus(batch_text_or_text_pairs=batch_text_or_text_pairs, *args, **kwargs) |
|
|
| def preprocess(self, text, apply_stemming=True): |
| flag = self.apply_stemming if apply_stemming is None else bool(apply_stemming) |
| return self._preprocess_one(text, flag) |
|
|