from typing import List, Tuple from transformers import PreTrainedTokenizerFast import re import fast_disambig _TATWEEL_RE = re.compile(r"\u0640") _ALIF_RE = re.compile(r"[آأإٱ]") _ALIF_MAK_RE = re.compile(r"ى") _TEH_MARB_RE = re.compile(r"ة") _ZERO_WIDTH_RE = re.compile(r"[\u200B-\u200D\u200E\u200F\uFEFF]") ARABIC_DIACRITICS = { "ً", "ٌ", "ٍ", "َ", "ُ", "ِ", "ّ", "ْ", "ٗ", "٘", "ٙ", "ٚ", "ٛ", "ٜ", "ٝ", "ٞ", "ٟ", "ؐ", "ؑ", "ؒ", "ؓ", "ؔ", "ؕ", "ؖ", "ؗ", "ؘ", "ؙ", "ؚ", "ۖ", "ۗ", "ۘ", "ۙ", "ۚ", "ۛ", "ۜ", "۟", "۠", "ۡ", "ۢ", "ۣ", "ۤ", "ۧ", "ۨ", "۪", "۫", "۬", "ۭ", } def separate_diacritics(text): tokens = re.split(r'(\s+|\[\+\])', text) processed_tokens = [] for token in tokens: if not token: continue if token.isspace() or token == '[+]': processed_tokens.append(token) continue if not any(c in ARABIC_DIACRITICS for c in token): processed_tokens.append(token) continue base_chars = [] diac_groups = [] for char in token: if char in ARABIC_DIACRITICS: if not diac_groups: base_chars.append(" ") diac_groups.append([]) diac_groups[-1].append(char) else: base_chars.append(char) diac_groups.append([]) base_word = "".join(base_chars) diac_string = [] for group in diac_groups: if group: diac_string.append("".join(group)) else: diac_string.append("◌") processed_tokens.append(base_word + " " + "".join(diac_string)) return "".join(processed_tokens) def normalize_arabic(text): text = _TATWEEL_RE.sub("", text) text = _ZERO_WIDTH_RE.sub("", text) text = _ALIF_RE.sub("ا", text) text = _ALIF_MAK_RE.sub("ي", text) text = _TEH_MARB_RE.sub("ه", text) return text class ArabicMorphTokenizer(PreTrainedTokenizerFast): slow_tokenizer_class = None def __init__(self, tokenizer_file=None, apply_stemming=True, **kwargs): super().__init__(tokenizer_file=tokenizer_file, **kwargs) self.apply_stemming = apply_stemming if self.apply_stemming: self.stemmer = fast_disambig.camel.Stemmer() def _preprocess_one(self, s, do_stem): if isinstance(s, (list, tuple)): return [self._preprocess_one(x, do_stem) for x in s] if do_stem: s = self.stemmer.stem(s, preserve_diacritics=True) s = normalize_arabic(s) s = separate_diacritics(s) return s def _preprocess_pair(self, text, text_pair, do_stem): def maybe(s): return self._preprocess_one(s, do_stem) if isinstance(s, str) else s if isinstance(text, (list, tuple)): text = [maybe(x) for x in text] else: text = maybe(text) if isinstance(text_pair, (list, tuple)): text_pair = [maybe(x) for x in text_pair] else: text_pair = maybe(text_pair) return text, text_pair def _pop_flag(self, kwargs): v = kwargs.pop("apply_stemming", None) return self.apply_stemming if v is None else bool(v) def __call__(self, text=None, text_pair=None, *args, **kwargs): flag = self._pop_flag(kwargs) if not getattr(self, "_processing", False): self._processing = True try: text, text_pair = self._preprocess_pair(text, text_pair, flag) return super().__call__(text=text, text_pair=text_pair, *args, **kwargs) finally: self._processing = False return super().__call__(text=text, text_pair=text_pair, *args, **kwargs) def encode(self, text, text_pair=None, *args, **kwargs): flag = self._pop_flag(kwargs) if not getattr(self, "_processing", False): self._processing = True try: text, text_pair = self._preprocess_pair(text, text_pair, flag) return super().encode(text, text_pair, *args, **kwargs) finally: self._processing = False return super().encode(text, text_pair, *args, **kwargs) def encode_plus(self, text=None, text_pair=None, *args, **kwargs): flag = self._pop_flag(kwargs) if not getattr(self, "_processing", False): self._processing = True try: text, text_pair = self._preprocess_pair(text, text_pair, flag) return super().encode_plus(text=text, text_pair=text_pair, *args, **kwargs) finally: self._processing = False return super().encode_plus(text=text, text_pair=text_pair, *args, **kwargs) def batch_encode_plus(self, batch_text_or_text_pairs=None, *args, **kwargs): flag = self._pop_flag(kwargs) if not getattr(self, "_processing", False): self._processing = True try: data = batch_text_or_text_pairs if isinstance(data, (list, tuple)): new_data = [] for item in data: if isinstance(item, (list, tuple)) and len(item) == 2: new_data.append(self._preprocess_pair(item[0], item[1], flag)) else: new_data.append(self._preprocess_one(item, flag)) batch_text_or_text_pairs = new_data return super().batch_encode_plus(batch_text_or_text_pairs=batch_text_or_text_pairs, *args, **kwargs) finally: self._processing = False return super().batch_encode_plus(batch_text_or_text_pairs=batch_text_or_text_pairs, *args, **kwargs) def preprocess(self, text, apply_stemming=True): flag = self.apply_stemming if apply_stemming is None else bool(apply_stemming) return self._preprocess_one(text, flag)