""" DF-Arc Tokenizer Morphology-aware, dialect-inclusive tokenization for Arabic LLMs. """ import json import os import re import unicodedata from typing import List, Dict, Any, Optional, Tuple, Union from transformers import PreTrainedTokenizerFast from tokenizers import Tokenizer class ArabicNormalizer: """Normalizes Arabic text with configurable rules.""" DIACRITICS_PATTERN = re.compile(r'[\u064B-\u0652]') TATWEEL_PATTERN = re.compile(r'\u0640') ALEF_PATTERN = re.compile(r'[أإآ]') YEH_PATTERN = re.compile(r'ى') TEH_MARBUTA_PATTERN = re.compile(r'ة') REPEATS_PATTERN = re.compile(r'(.)\1{2,}') URL_PATTERN = re.compile(r'http\S+|www\S+|https\S+', re.MULTILINE) EMAIL_PATTERN = re.compile(r'\S+@\S+') WHITESPACE_PATTERN = re.compile(r'\s+') def __init__(self, unify_alef: bool = True, unify_yeh: bool = True, unify_teh_marbuta: bool = True, remove_diacritics: bool = True, remove_tatweel: bool = True, remove_repeats: bool = True): self.unify_alef = unify_alef self.unify_yeh = unify_yeh self.unify_teh_marbuta = unify_teh_marbuta self.remove_diacritics = remove_diacritics self.remove_tatweel = remove_tatweel self.remove_repeats = remove_repeats def normalize(self, text: str) -> str: if not text: return "" text = unicodedata.normalize("NFKC", text) text = self.URL_PATTERN.sub('', text) text = self.EMAIL_PATTERN.sub('', text) if self.remove_diacritics: text = self.DIACRITICS_PATTERN.sub('', text) if self.remove_tatweel: text = self.TATWEEL_PATTERN.sub('', text) if self.unify_alef: text = self.ALEF_PATTERN.sub('ا', text) if self.unify_yeh: text = self.YEH_PATTERN.sub('ي', text) if self.unify_teh_marbuta: text = self.TEH_MARBUTA_PATTERN.sub('ه', text) if self.remove_repeats: text = self.REPEATS_PATTERN.sub(r'\1', text) text = self.WHITESPACE_PATTERN.sub(' ', text).strip() return text class MorphologicalPreTokenizer: """ Rule-based Arabic morphological pre-tokenizer. Segments Arabic words into prefix-stem-suffix units. """ PREFIXES = ['و', 'ف', 'ب', 'ك', 'ل', 'ال', 'س', 'وال', 'بال', 'كال', 'لل', 'فال'] SUFFIXES = ['ني', 'نا', 'ك', 'كم', 'ه', 'ها', 'هم', 'هن', 'ي', 'ون', 'ين', 'ان', 'ت', 'وا', 'ة'] def __init__(self, min_stem_length: int = 2): self.min_stem_length = min_stem_length self.prefixes = sorted(self.PREFIXES, key=len, reverse=True) self.suffixes = sorted(self.SUFFIXES, key=len, reverse=True) self.arabic_pattern = re.compile(r'[\u0600-\u06FF]+') def segment_word(self, word: str) -> List[str]: if not word or not self.arabic_pattern.fullmatch(word): return [word] original = word segments = [] prefix = "" for p in self.prefixes: if word.startswith(p) and len(word) - len(p) >= self.min_stem_length: prefix = p word = word[len(p):] break suffix = "" for s in self.suffixes: if word.endswith(s) and len(word) - len(s) >= self.min_stem_length: suffix = s word = word[:-len(s)] break if prefix: segments.append(prefix) segments.append(word) if suffix: segments.append(suffix) if len(word) < self.min_stem_length: return [original] return segments def segment_text(self, text: str) -> str: words = text.split() segmented_words = [] for word in words: segments = self.segment_word(word) segmented_words.append('_'.join(segments)) return ' '.join(segmented_words) class PhraseMerger: """Detects and merges common word n-grams.""" def __init__(self, phrases_file: Optional[str] = None): self.phrase_vocab = {} self.max_ngram = 3 self.merge_char = "" if phrases_file: self.load_phrases(phrases_file) def load_phrases(self, path: str) -> None: try: with open(path, 'r', encoding='utf-8') as f: loaded_vocab = json.load(f) self.phrase_vocab = {} for phrase_str, freq in loaded_vocab.items(): ngram = tuple(phrase_str.split()) self.phrase_vocab[ngram] = freq self.max_ngram = max(self.max_ngram, len(ngram)) except FileNotFoundError: pass def merge_phrases(self, text: str) -> str: if not self.phrase_vocab: return text words = text.split() result = [] i = 0 while i < len(words): matched = False for n in range(self.max_ngram, 1, -1): if i + n <= len(words): ngram = tuple(words[i:i+n]) if ngram in self.phrase_vocab: result.append(self.merge_char.join(ngram)) i += n matched = True break if not matched: result.append(words[i]) i += 1 return ' '.join(result) class DFArcTokenizer(PreTrainedTokenizerFast): """ DF-Arc: Morphology-aware Arabic Tokenizer. Wrapper around PreTrainedTokenizerFast that applies custom normalization, morphological segmentation, and phrase merging before tokenization. """ def __init__( self, vocab_file: Optional[str] = None, tokenizer_file: Optional[str] = None, phrases_file: Optional[str] = None, normalization_config: Optional[Dict[str, bool]] = None, min_stem_length: int = 2, **kwargs ): self.normalizer_helper = ArabicNormalizer(**(normalization_config or {})) self.morph_helper = MorphologicalPreTokenizer(min_stem_length=min_stem_length) self.phrase_helper = PhraseMerger(phrases_file=phrases_file) super().__init__( vocab_file=vocab_file, tokenizer_file=tokenizer_file, **kwargs ) def _batch_encode_plus(self, batch_text_or_text_pairs: Union[str, List[str], List[Tuple[str, str]]], *args, **kwargs): def preprocess(text: str) -> str: if not text: return "" t = self.normalizer_helper.normalize(text) t = self.morph_helper.segment_text(t) t = self.phrase_helper.merge_phrases(t) return t if isinstance(batch_text_or_text_pairs, str): batch_text_or_text_pairs = preprocess(batch_text_or_text_pairs) elif isinstance(batch_text_or_text_pairs, (list, tuple)): processed = [] for item in batch_text_or_text_pairs: if isinstance(item, str): processed.append(preprocess(item)) elif isinstance(item, (list, tuple)): processed.append((preprocess(item[0]), preprocess(item[1]))) else: processed.append(item) batch_text_or_text_pairs = processed return super()._batch_encode_plus(batch_text_or_text_pairs, *args, **kwargs) def encode(self, text, *args, **kwargs): if isinstance(text, str): text = self.normalizer_helper.normalize(text) text = self.morph_helper.segment_text(text) text = self.phrase_helper.merge_phrases(text) return super().encode(text, *args, **kwargs) def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=None, **kwargs): """ Override decode to force use of convert_tokens_to_string for readable output. """ # Ensure token_ids is a list of ints if isinstance(token_ids, int): token_ids = [token_ids] # Convert to tokens tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) # Convert to string using our custom logic return self.convert_tokens_to_string(tokens) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens into a single string.""" text = " ".join(tokens) # Remove internal morphological underscores (e.g., 'w_s_y' -> 'wsy') # We use a regex to ensure we only remove underscores that are # acting as connectors between Arabic segments, preserving snake_case. arabic_range = r'[\u0600-\u06FF]' return re.sub(rf'(?<={arabic_range})_|_(?={arabic_range})', '', text)