AraSpell-Model / AraSpell.py
youssefreda9's picture
Upload AraSpell.py
2f99d61 verified
Raw
History Blame Contribute Delete
98.8 kB
# AraSpell — Arabic Spell Checker Pipeline
# Production-ready version
import re
import math
import logging
import torch
import os
from collections import Counter
from transformers import AutoTokenizer, EncoderDecoderModel
import Levenshtein
import jellyfish
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
# ═══════════════════════════════════════════════════════════════════════════════
# LOAD ARABERT SEQ2SEQ MODEL
# ═══════════════════════════════════════════════════════════════════════════════
from huggingface_hub import hf_hub_download
MODEL_REPO = 'bayan10/AraSpell-Model'
MODEL_FILENAME = 'last_model.pt'
try:
logger.info(f"Downloading/loading model from Hugging Face: {MODEL_REPO}")
MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
except Exception as e:
raise RuntimeError(f"Failed to download model from Hugging Face: {e}")
MODEL_NAME = 'aubmindlab/bert-base-arabertv02'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_NAME, MODEL_NAME)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.generation_config.max_length = 128
model.generation_config.decoder_start_token_id = tokenizer.cls_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.generation_config.eos_token_id = tokenizer.sep_token_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model = model.to(device)
model.eval()
logger.info(f"Model loaded on {device}, epoch: {checkpoint.get('epoch', 'N/A')}")
from enum import Enum
from typing import List, Tuple, Optional
# ─────────────────────────────────────────────────────────────────────────────
# ERROR TYPE ENUM
# ─────────────────────────────────────────────────────────────────────────────
class ErrorType(Enum):
"""Types of spelling errors"""
CHAR_REPETITION = "char_repetition"
WORD_MERGE = "word_merge"
CHAR_SUBSTITUTION = "char_substitution"
MIXED = "mixed"
CLEAN = "clean"
# ═══════════════════════════════════════════════════════════════════════════════
# POST PROCESSOR
# ═══════════════════════════════════════════════════════════════════════════════
class AraSpellPostProcessor:
"""Arabic text post-processing techniques."""
ARABIC_HARAKAT = 'ًٌٍَُِّْ'
TATWEEL = 'ـ'
NORMALIZER_MAP = {
'ﻹ': 'لإ', 'ﻷ': 'لأ', 'ﻵ': 'لآ', 'ﻻ': 'لا', 'ﷲ': 'الله'
}
ARABIC_CONSONANTS = set('بتثجحخدذرزسشصضطظعغفقكلمن')
# --- Basic Normalization ---
@staticmethod
def remove_harakat(text: str) -> str:
"""Remove Arabic diacritics"""
return re.sub(r'[ً-ْ]', '', text)
@staticmethod
def remove_tatweel(text: str) -> str:
"""Remove Arabic kashida/tatweel"""
return text.replace(AraSpellPostProcessor.TATWEEL, '')
@staticmethod
def normalize_special_chars(text: str) -> str:
"""Normalize special Arabic ligatures"""
for old, new in AraSpellPostProcessor.NORMALIZER_MAP.items():
text = text.replace(old, new)
return text
# --- Core Functions ---
@staticmethod
def unified_collapse_repeated(text: str) -> str:
"""
Collapse repeated characters.
Arabic: 3+ consecutive → 1 | Latin: 2+ consecutive → 1
"""
# Arabic characters: 3+ → 1
text = re.sub(r"([\u0600-\u06FF])\1{2,}", r"\1", text)
# Latin characters: 2+ → 1
text = re.sub(r"([a-zA-Z])\1+", r"\1", text)
return text
@staticmethod
def remove_duplicate_words(text: str) -> str:
"""Remove consecutive duplicate words. e.g. كتاب كتاب → كتاب"""
words = text.split()
if len(words) < 2:
return text
result = [words[0]]
for i in range(1, len(words)):
if words[i] != words[i-1]:
result.append(words[i])
return ' '.join(result)
@staticmethod
def normalize_spaces(text: str) -> str:
"""Normalize whitespace: multiple spaces, unicode spaces, punctuation spacing."""
# Multiple spaces → single
text = re.sub(r' +', ' ', text)
# Unicode spaces
text = text.replace('\u00A0', ' ') # Non-breaking space
text = text.replace('\u200B', '') # Zero-width space
text = text.replace('\u200C', '') # Zero-width non-joiner
text = text.replace('\u200D', '') # Zero-width joiner
# Trim
text = text.strip()
# Punctuation spacing
text = re.sub(r'\s*([،؛؟!.])\s*', r'\1 ', text)
text = text.strip()
return text
@staticmethod
def remove_word_repetition_with_wa(text: str) -> str:
"""Remove word و word → word"""
words = text.split()
result = []
i = 0
while i < len(words):
if i + 2 < len(words) and words[i] == words[i+2] and words[i+1] == 'و':
result.append(words[i])
i += 3
else:
result.append(words[i])
i += 1
return ' '.join(result)
# --- Hamza & Ta Marbuta Handling ---
@staticmethod
def fix_hamza_conservative(text: str) -> str:
"""Conservative Hamza normalization — only at word END, not middle."""
words = text.split()
result = []
for word in words:
if len(word) >= 3:
# Fix trailing أ → ا
if word.endswith('أ'):
word = word[:-1] + 'ا'
# Fix trailing إ → ا
if word.endswith('إ'):
word = word[:-1] + 'ا'
result.append(word)
return ' '.join(result)
@staticmethod
def fix_ha_ta_marbuta(text: str, vocab_manager=None) -> str:
"""
Smart ه → ة fix at end of words.
Key insight: ه at word end can be:
- Ta Marbuta (should be ة): المدرسه → المدرسة
- Possessive pronoun (should stay ه): تحقيقه = his achievement
Strategy: Only convert if the ة version is IV (in tokenizer vocab).
This distinguishes المدرسة (IV) from تحقيقة (not a real word form).
Without vocab_manager, falls back to original pattern-based approach.
"""
# Protected words: anything containing لله
PROTECTED_ENDINGS = ['لله']
words = text.split()
result = []
for word in words:
# Skip protected words (Allah-related)
if any(word.endswith(e) for e in PROTECTED_ENDINGS):
result.append(word)
continue
if len(word) >= 4 and word.endswith('ه'):
# Check if second-to-last char is a consonant
if word[-2] in AraSpellPostProcessor.ARABIC_CONSONANTS:
candidate_with_ta = word[:-1] + 'ة'
if vocab_manager:
# SMART MODE: Use vocab to decide
ta_iv = vocab_manager.is_iv(candidate_with_ta)
ha_iv = vocab_manager.is_iv(word)
if ta_iv:
# ة version is IV → convert (المدرسه→المدرسة)
result.append(candidate_with_ta)
continue
elif ha_iv:
# Only ه version is IV → keep ه (possessive: تحقيقه)
result.append(word)
continue
# else: NEITHER is IV → keep original ه
# (safer than guessing — could be rare possessive)
else:
# FALLBACK: No vocab → use original pattern-based approach
result.append(candidate_with_ta)
continue
result.append(word)
return ' '.join(result)
# --- Hallucination Removal ---
@staticmethod
def remove_hallucinations(text: str) -> str:
"""Remove model hallucinations: duplicate words, trailing 'و' artifacts."""
words = text.split()
if not words:
return text
result = []
i = 0
def normalize_word(w: str) -> str:
"""Normalize for comparison"""
w = w.replace('ال', '').replace('ة', 'ه')
w = re.sub(r'[أإآ]', 'ا', w)
return w
while i < len(words):
word = words[i]
# Remove trailing 'و' artifacts (الماضيةو → الماضية)
if len(word) > 4 and word.endswith('و'):
prev_char = word[-2]
if prev_char in 'ةهاأإآء':
word = word[:-1]
# Check for duplicate patterns
if i + 1 < len(words):
next_word = words[i + 1]
if normalize_word(word) == normalize_word(next_word):
# Keep the one with 'ال' if possible
keep = next_word if next_word.startswith('ال') and not word.startswith('ال') else word
result.append(keep)
i += 2
continue
result.append(word)
i += 1
return ' '.join(result)
@staticmethod
def remove_hallucinated_prefix(text: str, original: str) -> str:
"""Remove particles (و/في) added by model if not in original"""
if not original:
return text
if text.startswith('و ') and not original.startswith('و'):
rest = text[2:].strip()
# Verify it matches original
if AraSpellPostProcessor.normalize_special_chars(rest) == AraSpellPostProcessor.normalize_special_chars(original):
return rest
return text
# --- Word Splitting & Merging ---
@staticmethod
def merge_separated_al(text: str) -> str:
"""Merge 'ال' separated by space: ال + كتاب → الكتاب"""
return re.sub(r'\bال\s+(\w+)', r'ال\1', text)
@staticmethod
def join_fragments(text: str) -> str:
"""Join short fragments with validation. e.g. الط + الب → الطالب"""
words = text.split()
if len(words) < 2:
return text
# Common standalone words that should NOT be merged
STANDALONE_WORDS = {
'من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى', 'حتى', 'منذ', 'خلال',
'بعد', 'قبل', 'ب', 'ل', 'ك', 'و', 'أو', 'لا', 'ما', 'لم', 'لن',
'هو', 'هي', 'هم', 'أن', 'إن', 'كل', 'كان', 'قد', 'قال', 'ذلك',
'هذا', 'هذه', 'تلك', 'التي', 'الذي', 'التى', 'اللذي'
}
result = []
i = 0
while i < len(words):
word = words[i]
if i + 1 < len(words):
next_word = words[i + 1]
# SAFETY: Don't merge if both are standalone words
if word in STANDALONE_WORDS and next_word in STANDALONE_WORDS:
result.append(word)
i += 1
continue
# Case 1: Single char fragment (safe to merge)
if len(next_word) == 1:
result.append(word + next_word)
i += 2
continue
# Case 2: Overlap (last char of word == first char of next)
if len(word) >= 2 and len(next_word) >= 2 and word[-1] == next_word[0]:
if not (word in STANDALONE_WORDS and next_word in STANDALONE_WORDS):
result.append(word[:-1] + next_word)
i += 2
continue
# Case 3: Short fragments (2-4 chars + 1-2 chars)
if (2 <= len(word) <= 4 and
1 <= len(next_word) <= 2 and
3 <= len(word) + len(next_word) <= 7):
if not (word in STANDALONE_WORDS and next_word in STANDALONE_WORDS):
result.append(word + next_word)
i += 2
continue
result.append(word)
i += 1
return ' '.join(result)
# --- Main Pipelines ---
@staticmethod
def full_postprocess(text: str, original: str = "", vocab_manager=None) -> str:
"""
Apply all post-processing steps (OPTIMIZED ORDER!)
vocab_manager: optional, enables smart ه/ة handling
"""
# 1. Remove hallucinated prefixes
if original:
text = AraSpellPostProcessor.remove_hallucinated_prefix(text, original)
# 2. Basic normalization
text = AraSpellPostProcessor.normalize_special_chars(text)
# 3. Remove hallucinations
text = AraSpellPostProcessor.remove_hallucinations(text)
# 4. Collapse repetitions (UNIFIED!)
text = AraSpellPostProcessor.unified_collapse_repeated(text)
# 5. Fix Hamza (CONSERVATIVE!)
text = AraSpellPostProcessor.fix_hamza_conservative(text)
# 6. Fix Ta Marbuta (SMART MODE with vocab_manager!)
text = AraSpellPostProcessor.fix_ha_ta_marbuta(text, vocab_manager=vocab_manager)
# 7. Remove word repetition with 'و'
text = AraSpellPostProcessor.remove_word_repetition_with_wa(text)
# 8. Remove duplicate words
text = AraSpellPostProcessor.remove_duplicate_words(text)
# 9. Final space normalization
text = AraSpellPostProcessor.normalize_spaces(text)
return text
# ─────────────────────────────────────────────────────────────────────────────
# ERROR CLASSIFIER
# ─────────────────────────────────────────────────────────────────────────────
class ErrorClassifier:
"""Classify type of spelling error"""
NON_ARABIC_KEYBOARD = set('پگچژکەڕڤڵڎےۀۃھیټډڼڑ')
@staticmethod
def has_char_substitution(text: str) -> bool:
return any(c in ErrorClassifier.NON_ARABIC_KEYBOARD for c in text)
@staticmethod
def has_char_repetition(text: str, threshold: int = 3) -> bool:
return bool(re.search(r"(.)\1{" + str(threshold - 1) + ",}", text))
@staticmethod
def has_word_merge(text: str, max_word_len: int = 8) -> bool:
words = text.split()
if any(len(w) > max_word_len for w in words):
return True
if len(words) == 1 and len(text) > 6:
return True
return False
@staticmethod
def classify(text: str) -> ErrorType:
"""Classify the error type"""
has_rep = ErrorClassifier.has_char_repetition(text)
has_merge = ErrorClassifier.has_word_merge(text)
has_sub = ErrorClassifier.has_char_substitution(text)
error_count = sum([has_rep, has_merge, has_sub])
if error_count >= 2:
return ErrorType.MIXED
elif has_sub:
return ErrorType.CHAR_SUBSTITUTION
elif has_rep:
return ErrorType.CHAR_REPETITION
elif has_merge:
return ErrorType.WORD_MERGE
else:
return ErrorType.CLEAN
# ═══════════════════════════════════════════════════════════════════════════════
# RULES-BASED CORRECTOR
# ═══════════════════════════════════════════════════════════════════════════════
class RulesBasedCorrector:
"""Rules-based correction with keyboard proximity mapping."""
# Persian/Urdu → Arabic mapping
SUBSTITUTION_MAP = {
'ک': 'ك', 'ی': 'ي', 'ے': 'ي',
'پ': 'ب', 'چ': 'ج', 'ژ': 'ز',
'گ': 'ك', 'ڤ': 'ف', 'ڵ': 'ل',
'ڕ': 'ر', 'ڎ': 'د', 'ڼ': 'ن',
'ټ': 'ت', 'ډ': 'د', 'ړ': 'ر',
'ۀ': 'ه', 'ۃ': 'ة', 'ھ': 'ه',
'ە': 'ه', 'ڑ': 'ر'
}
# EXPANDED: 16 prepositions instead of 2
PREPOSITIONS = {
'من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى',
'حتى', 'منذ', 'خلال', 'بعد', 'قبل',
'ب', 'ل', 'ك',
'لل'
}
# Keyboard Proximity Mapping
# Arabic keyboard layout adjacency
KEYBOARD_NEIGHBORS = {
'ض': ['ص', 'ق'],
'ص': ['ض', 'ث', 'ق'],
'ث': ['ص', 'ق'],
'ق': ['ض', 'ص', 'ث', 'ف', 'غ'],
'ف': ['ق', 'غ', 'ع', 'ب'],
'غ': ['ق', 'ف', 'ع', 'ه'],
'ع': ['ف', 'غ', 'ه', 'خ'],
'ه': ['غ', 'ع', 'خ', 'ح'],
'خ': ['ع', 'ه', 'ح', 'ج'],
'ح': ['ه', 'خ', 'ج'],
'ج': ['خ', 'ح', 'د'],
'د': ['ج', 'ذ'],
'ذ': ['د'],
'ش': ['س', 'ي', 'ئ'],
'س': ['ش', 'ي', 'ب'],
'ي': ['ش', 'س', 'ب', 'ت'],
'ب': ['ي', 'س', 'ف', 'ل', 'ن'],
'ل': ['ب', 'ا', 'ن', 'م'],
'ا': ['ل', 'ت', 'م'],
'ت': ['ي', 'ا', 'ن'],
'ن': ['ب', 'ل', 'ت', 'م', 'ك'],
'م': ['ل', 'ا', 'ن', 'ك'],
'ك': ['ن', 'م', 'ط'],
'ط': ['ك', 'ظ'],
'ظ': ['ط'],
'ئ': ['ش', 'ء', 'ر'],
'ء': ['ئ', 'ؤ'],
'ؤ': ['ء', 'ر'],
'ر': ['ئ', 'ؤ', 'لا', 'ى', 'ز'],
'لا': ['ر', 'ى'],
'ى': ['ر', 'لا', 'ة', 'ز'],
'ة': ['ى', 'و', 'ز'],
'و': ['ة', 'ز'],
'ز': ['ر', 'ى', 'ة', 'و'],
# Alif variants
'أ': ['ا', 'إ', 'آ'],
'إ': ['ا', 'أ'],
'آ': ['ا', 'أ'],
}
@staticmethod
def is_keyboard_neighbor(char1: str, char2: str) -> bool:
"""Check if two Arabic chars are adjacent on keyboard."""
neighbors = RulesBasedCorrector.KEYBOARD_NEIGHBORS.get(char1, [])
return char2 in neighbors
@staticmethod
def fix_char_substitution(text: str) -> str:
"""Replace Persian/Urdu characters with Arabic"""
for old, new in RulesBasedCorrector.SUBSTITUTION_MAP.items():
text = text.replace(old, new)
return text
@staticmethod
def fix_char_repetition(text: str) -> str:
"""Remove excessive character repetition (3+ consecutive → 1)."""
# Only collapse 3+ repetitions (not 2+)
text = re.sub(r'([^\d\s])\1{2,}', r'\1', text)
return text
@staticmethod
def advanced_heuristic_repair(text: str) -> str:
"""
Apply aggressive heuristic repairs to generate a strong baseline candidate.
1. Unified Char Fixes (Persian/Urdu + Repetition)
2. Aggressive Word Splitting (Iterative & Anchored)
"""
# 1. Base Fixes
text = RulesBasedCorrector.fix_char_substitution(text)
text = RulesBasedCorrector.fix_char_repetition(text)
# 2. Heuristic Split
words = text.split()
processed_words = []
for word in words:
processed_words.append(RulesBasedCorrector._recursive_split(word))
return ' '.join(processed_words)
@staticmethod
def _recursive_split(word: str) -> str:
"""
Recursively split merged words (Anchored to Start)
Avoids splitting 'المنزل' -> 'ال من زل' (middle split)
"""
if len(word) < 4:
return word
# 1. Separable Prepositions (Must be at START)
# "فيالبيت" -> "في البيت"
separables = sorted(['من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى', 'حتى', 'منذ', 'خلال', 'بعد', 'قبل'], key=len, reverse=True)
for sep in separables:
# Check matches: exact match or prefix match
if word == sep:
return word
if word.startswith(sep):
remainder = word[len(sep):]
# Condition: Remainder must be substantial (usually starts with al- or len > 2)
if len(remainder) >= 3:
# Recursive call on remainder
return sep + " " + RulesBasedCorrector._recursive_split(remainder)
# 2. Common typo merges (e.g. "يا" + Name)
if word.startswith('يا') and len(word) > 4:
return 'يا ' + RulesBasedCorrector._recursive_split(word[2:])
# 3. Attached Particles (Only 'Wa' and 'Fa' are commonly mistakenly merged with non-al words in typos)
# "وال" -> "و ال" is usually correct in tokenization but "و" is attached in script.
# We only split if it looks like a HARD merge error.
return word
# ═══════════════════════════════════════════════════════════════════════════════
# OUTPUT VALIDATOR (Hallucination Prevention)
# ═══════════════════════════════════════════════════════════════════════════════
class OutputValidator:
"""Validate model outputs to prevent hallucinations"""
@staticmethod
def calculate_edit_distance(s1: str, s2: str) -> int:
"""Calculate Levenshtein distance"""
return Levenshtein.distance(s1, s2)
@staticmethod
def check_character_preservation(original: str, corrected: str) -> Tuple[bool, str]:
"""Check if characters are mostly preserved (Jaccard similarity)"""
chars_original = set(original)
chars_corrected = set(corrected)
if not chars_original:
return True, "valid"
intersection = chars_original & chars_corrected
union = chars_original | chars_corrected
jaccard = len(intersection) / len(union) if union else 0
if jaccard < 0.35:
return False, "low_character_similarity"
return True, "valid"
@staticmethod
def check_word_count(original: str, corrected: str) -> Tuple[bool, str]:
"""
Check if word count is reasonable
Relaxed: Allow splitting merged words (count can double)
"""
len_orig = len(original.split())
len_corr = len(corrected.split())
# Allow expanding 1 word to up to 3 (e.g. "فيالمدرسة" -> "في المدرسة")
if len_orig == 1:
if len_corr <= 3:
return True, "valid"
# If original is very long, allow more splits (e.g. "هذاالولدذهبإلىالمدرسة")
if len(original) > 12 and len_corr <= 6:
return True, "valid"
# For sentences, stricter ratio
ratio = len_corr / len_orig if len_orig > 0 else 0
if ratio > 2.0 or ratio < 0.5:
return False, "word_count_mismatch"
return True, "valid"
def validate(self, original: str, corrected: str, error_type: str) -> Tuple[bool, str]:
"""
Main validation logic
"""
# 0. Sanity Check
if not corrected or not corrected.strip():
return False, "empty_output"
# Space Leniency: if ONLY difference is whitespace → accept
original_no_space = original.replace(' ', '').replace('\u200c', '') # Also handle ZWNJ
corrected_no_space = corrected.replace(' ', '').replace('\u200c', '')
if original_no_space == corrected_no_space:
# Only whitespace changed - accept immediately
return True, "space_leniency_accept"
# 1. Length Ratio Check
len_orig = len(original)
len_corr = len(corrected)
# Allow expansion for word splitting
if len_corr > len_orig * 2.5:
return False, "too_long"
# Allow shrinking (but not typically more than 50% unless removing repetition)
if len_corr < len_orig * 0.5:
# Exception: if original had excessive repetition
if error_type == ErrorType.CHAR_REPETITION:
pass
else:
return False, "too_short"
# 2. Check Word Count
is_valid_count, reason = self.check_word_count(original, corrected)
if not is_valid_count:
return False, reason
# 3. Check Character Preservation
# Critical for avoiding hallucinations
is_valid_chars, reason = self.check_character_preservation(original, corrected)
if not is_valid_chars:
# Exception: If input was garbage/keyboard mash, preservation might be low.
# But for valid inputs, this prevents changing "كتاب" to "مكتبة" (if no roots match)
return False, reason
return True, "valid"
# ═══════════════════════════════════════════════════════════════════════════════
# VOCABULARY MANAGER
# ═══════════════════════════════════════════════════════════════════════════════
class VocabularyManager:
"""
Centralized vocabulary management for OOV/IV detection.
Key for vocabulary-aware acceptance: OOV→IV = accept, IV→OOV = reject.
"""
# Arabic character equivalence for normalization
HAMZA_VARIANTS = {'أ', 'إ', 'آ', 'ء', 'ؤ', 'ئ', 'ا'}
ALEF_NORMALIZED = 'ا'
TA_MARBUTA = 'ة'
HA = 'ه'
YA_VARIANTS = {'ي', 'ى'}
YA_NORMALIZED = 'ي'
def __init__(self, tokenizer):
self.tokenizer = tokenizer
# Build vocabulary set from tokenizer (exclude subwords and short tokens)
self.vocab = {
w for w in tokenizer.get_vocab().keys()
if w.isalpha() and not w.startswith('##') and len(w) > 1
}
# Frequency rank: lower index = more common (usually)
self.vocab_rank = {w: i for w, i in tokenizer.get_vocab().items()}
# Build normalized vocabulary for fuzzy matching
self.normalized_vocab = {self.normalize_for_comparison(w): w for w in self.vocab}
logger.info(f"VocabularyManager initialized: {len(self.vocab)} words")
@classmethod
def normalize_for_comparison(cls, word: str) -> str:
"""
Normalize Arabic word for comparison (hamza, ta marbuta, etc.)
Used for equivalence checking, not for final output.
"""
result = []
for i, char in enumerate(word):
# Normalize Hamza variants to Alef
if char in cls.HAMZA_VARIANTS:
result.append(cls.ALEF_NORMALIZED)
# Normalize Ta Marbuta to Ha at word end
elif char == cls.TA_MARBUTA and i == len(word) - 1:
result.append(cls.HA)
# Normalize Ya variants
elif char in cls.YA_VARIANTS:
result.append(cls.YA_NORMALIZED)
else:
result.append(char)
return ''.join(result)
def is_iv(self, word: str) -> bool:
"""Check if word is In-Vocabulary (known word)."""
clean = re.sub(r'[^\w]', '', word)
if not clean:
return True # Empty/punctuation only = treat as valid
# Direct check
if clean in self.vocab:
return True
# Normalized check (handles hamza/ta marbuta variations)
normalized = self.normalize_for_comparison(clean)
if normalized in self.normalized_vocab:
return True
return False
def is_oov(self, word: str) -> bool:
"""Check if word is Out-Of-Vocabulary (unknown word)."""
return not self.is_iv(word)
def get_frequency_rank(self, word: str) -> int:
"""Get frequency rank (lower = more common). Returns 999999 for OOV."""
clean = re.sub(r'[^\w]', '', word)
return self.vocab_rank.get(clean, 999999)
def all_words_iv(self, text: str) -> bool:
"""Check if ALL words in text are In-Vocabulary."""
words = text.split()
return all(self.is_iv(w) for w in words)
def count_oov_words(self, text: str) -> int:
"""Count number of OOV words in text."""
words = text.split()
return sum(1 for w in words if self.is_oov(w))
def get_oov_words(self, text: str) -> List[str]:
"""Get list of OOV words in text."""
words = text.split()
return [w for w in words if self.is_oov(w)]
def words_are_equivalent(self, word1: str, word2: str) -> bool:
"""
Check if two words are equivalent (considering Arabic character variations).
Useful for accepting corrections that only differ in hamza/ta marbuta.
"""
norm1 = self.normalize_for_comparison(word1)
norm2 = self.normalize_for_comparison(word2)
return norm1 == norm2
@staticmethod
def damerau_levenshtein_distance(s1: str, s2: str) -> int:
"""
Calculate Damerau-Levenshtein distance (transpositions count as 1 edit).
This is better for Arabic typos like اقصتاديا→اقتصاديا (swap صت→تص).
"""
return jellyfish.damerau_levenshtein_distance(s1, s2)
def calculate_similarity(self, original: str, corrected: str) -> float:
"""
Calculate similarity score using Damerau-Levenshtein distance.
Returns value between 0 and 1 (1 = identical).
"""
dist = self.damerau_levenshtein_distance(original, corrected)
max_len = max(len(original), len(corrected), 1)
return 1.0 - (dist / max_len)
# ═══════════════════════════════════════════════════════════════════════════════
# WORD ALIGNER
# ═══════════════════════════════════════════════════════════════════════════════
class WordAligner:
"""
Aligns input and output words to create hybrid corrections.
Helps when model fixes one word but breaks another (Raw Wins/Both Wrong cause).
"""
def __init__(self, vocab_manager):
"""Initialize with VocabularyManager for IV checks."""
self.vocab = vocab_manager
def align_words(self, input_text: str, output_text: str) -> str:
"""
Create hybrid by selecting best word from each position.
Uses simple space-based alignment (works for most Arabic cases).
"""
input_words = input_text.split()
output_words = output_text.split()
# If lengths differ significantly, alignment is risky -> fallback to output
if abs(len(input_words) - len(output_words)) > 2:
input_oov = self.vocab.count_oov_words(input_text)
output_oov = self.vocab.count_oov_words(output_text)
return output_text if output_oov < input_oov else input_text
result = []
# Simple position-based alignment (min length)
min_len = min(len(input_words), len(output_words))
for i in range(min_len):
in_word = input_words[i]
out_word = output_words[i]
best_word = self._select_best_word(in_word, out_word)
result.append(best_word)
# Append remaining words from the longer sequence
if len(output_words) > min_len:
result.extend(output_words[min_len:])
elif len(input_words) > min_len:
# If input is longer, verify if trailing words are IV
# If trailing input words are OOV, maybe model was right to remove them?
# Safest is to keep them if they are IV, else drop.
for w in input_words[min_len:]:
if self.vocab.is_iv(w):
result.append(w)
return ' '.join(result)
def _select_best_word(self, input_word: str, output_word: str) -> str:
"""
Select best word between input and output version.
Logic:
1. Input OOV + Output IV → Take Output (Model fixed it)
2. Input IV + Output OOV → Keep Input (Model broke it)
3. Input IV + Output IV → Keep Input (Conservative) unless Output is much better?
- For now, strict conservative: if input is valid, keep it.
4. Both OOV → Take Output (Model likely closer)
"""
if input_word == output_word:
return input_word
in_iv = self.vocab.is_iv(input_word)
out_iv = self.vocab.is_iv(output_word)
# Case 1: Correction worked (OOV -> IV)
if not in_iv and out_iv:
return output_word
# Case 2: Correction broke it (IV -> OOV)
if in_iv and not out_iv:
return input_word
# Case 3: Both IV (Semantic change or split/merge)
# Conservative: Keep input to avoid semantic drift (Contextual errors are rare compared to typos)
if in_iv and out_iv:
return input_word
# Case 4: Both OOV
# Subword-level correction
# If words are similar length, try character-level blending to find IV
if len(input_word) == len(output_word) and len(input_word) >= 3:
# Try replacing one char at a time from output into input
for i in range(len(input_word)):
if input_word[i] != output_word[i]:
# Try input with this one char from output
hybrid = input_word[:i] + output_word[i] + input_word[i+1:]
if self.vocab.is_iv(hybrid):
return hybrid
# Try output with this one char from input
hybrid2 = output_word[:i] + input_word[i] + output_word[i+1:]
if self.vocab.is_iv(hybrid2):
return hybrid2
# Default: Take output, usually closer to target even if still OOV
return output_word
# ═══════════════════════════════════════════════════════════════════════════════
# SPLIT/MERGE SPECIALIST
# ═══════════════════════════════════════════════════════════════════════════════
class SplitMergeSpecialist:
"""
Handles word splitting and merging with vocabulary validation.
Key patterns:
1. SPLIT: OOV word that can be split into two IV words
- فيالغالب → في الغالب
- يقعبجماعة → يقع بجماعة
2. MERGE: Adjacent OOV fragments that can merge to IV
- السوري ة → السورية (ta-marbuta attachment)
- ال كتاب → الكتاب
"""
# Common Arabic prefixes that can be detached
SEPARABLE_PREFIXES = [
# Prepositions (longer first for greedy matching)
'من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى', 'حتى', 'منذ', 'خلال',
'بعد', 'قبل', 'بين', 'حول', 'تحت', 'فوق', 'أمام', 'وراء', 'دون',
# Particles
'أن', 'لن', 'لم', 'قد', 'سوف', 'كي', 'إذا', 'لو', 'مثل', 'غير',
# Call particle
'يا',
]
# Protected short words that shouldn't be split
PROTECTED_WORDS = {
'في', 'من', 'على', 'عن', 'مع', 'إلى', 'الى', 'ان', 'أن', 'لا', 'ما', 'هو', 'هي',
'لم', 'لن', 'قد', 'كل', 'كان', 'ذلك', 'هذا', 'هذه', 'التي', 'الذي', 'بين',
}
def __init__(self, vocab_manager):
"""Initialize with VocabularyManager for IV checks."""
self.vocab = vocab_manager
self.separable_prefixes = sorted(
self.SEPARABLE_PREFIXES, key=len, reverse=True
)
# Attached prefix patterns that should NOT be split (normal Arabic word formations)
ATTACHED_PREFIXES = [
'وال', 'بال', 'فال', 'كال', 'لل', # Conjunction/Preposition + Article
'وب', 'وف', 'ول', 'وك', 'وم', 'ون', # Conjunction + Preposition
'فب', 'فل', 'فك', 'فم', # Conjunction + Preposition
]
def split_word(self, word: str) -> str:
"""
Try to split an OOV word into IV components.
Strict Strategy:
- Only split when BOTH parts are IV
- Protect attached prefix patterns (وال، بال، etc.)
- Minimum part lengths to prevent micro-splits
"""
# Short words: don't split (increased from 4 to 5 for safety)
if len(word) < 5:
return word
# Already IV: no need to split
if self.vocab.is_iv(word):
return word
# Protected words: don't split
if word in self.PROTECTED_WORDS:
return word
# Protected prefix patterns (وال، بال، فال، etc.)
# These are normal Arabic word formations, NOT merge errors
for prefix in self.ATTACHED_PREFIXES:
if word.startswith(prefix):
remainder = word[len(prefix):]
# If the remainder (without the prefix) is IV, this is a valid prefixed word
if self.vocab.is_iv(remainder):
return word # Don't split — it's prefix+valid_word
# Also check with article: e.g. والخصوصي → وال+خصوصي, check خصوصي
if prefix.endswith('ال') and self.vocab.is_iv(remainder):
return word
# 1. Try separable prefixes first (higher priority)
for prefix in self.separable_prefixes:
if word.startswith(prefix) and len(word) > len(prefix) + 2: # Remainder must be > 2 chars
remainder = word[len(prefix):]
# Only accept if remainder is IV
if self.vocab.is_iv(remainder):
return f"{prefix} {remainder}"
# 2. Try all positions - STRICT: BOTH parts must be IV AND both >= 3 chars
for i in range(3, len(word) - 2): # Both parts at least 3 chars
left = word[:i]
right = word[i:]
if self.vocab.is_iv(left) and self.vocab.is_iv(right):
return f"{left} {right}"
# No valid split found
return word
# Common Arabic pronoun/possessive suffixes (2-3 chars)
# These are often incorrectly split from their host word
PRONOUN_SUFFIXES = {'كم', 'هم', 'ها', 'هن', 'كن', 'نا', 'هما', 'كما', 'تم', 'تن'}
def merge_fragments(self, text: str) -> str:
"""
Try to merge adjacent OOV fragments into IV words.
Key patterns:
1. Ta-marbuta detachment: السوري ة → السورية
2. Al- detachment: ال كتاب → الكتاب
3. General OOV+OOV merging: Only if both are OOV and result is IV
4. Short OOV fragment: 1-2 char OOV + next → IV
5. Pronoun suffix reattachment: علي كم → عليكم
"""
words = text.split()
if len(words) < 2:
return text
result = []
i = 0
while i < len(words):
word = words[i]
# Try to merge with next word
if i + 1 < len(words):
next_word = words[i + 1]
merged = word + next_word
# Pattern 1: Detached suffix (ة، ه، ي، ك...)
# Allow merging even if 'word' is IV because detached suffix is definitely wrong
if len(next_word) == 1 and next_word in 'ةهاي':
if self.vocab.is_iv(merged):
result.append(merged)
i += 2
continue
# Pattern 2: Detached 'Al-' prefix
# ال كتاب → الكتاب (Safe to merge)
if word == 'ال' and len(next_word) >= 2:
if self.vocab.is_iv(merged):
result.append(merged)
i += 2
continue
# Pattern 3: General OOV + OOV → IV
# STRICT: Both must be OOV to avoid merging valid words
if self.vocab.is_oov(word) and self.vocab.is_oov(next_word):
if self.vocab.is_iv(merged):
result.append(merged)
i += 2
continue
# Pattern 4: Short OOV fragment (1-2 chars) merge
if len(word) <= 2 and self.vocab.is_oov(word):
if self.vocab.is_iv(merged):
result.append(merged)
i += 2
continue
# Pattern 5: Pronoun suffix reattachment
# Fixes over-splitting: علي كم → عليكم
if next_word in self.PRONOUN_SUFFIXES:
if self.vocab.is_iv(merged) and not self.vocab.is_iv(word):
result.append(merged)
i += 2
continue
# Pattern 6: Short fragment merge
# Merges two short words when combined they form a valid longer word
# Fixes: علي كم → عليكم, ويت أمل → ويتأمل, المد فتر → المدفتر
# Condition: both words ≤ 3 chars, merged ≥ 5 chars and IV
if len(word) <= 3 and len(next_word) <= 3:
if len(merged) >= 5 and self.vocab.is_iv(merged):
result.append(merged)
i += 2
continue
result.append(word)
i += 1
return ' '.join(result)
def process_text(self, text: str) -> str:
"""
Apply full split/merge processing to text.
Order: First merge, then split.
"""
# Step 1: Merge fragments
text = self.merge_fragments(text)
# Step 2: Split OOV words
words = text.split()
processed = []
for word in words:
if self.vocab.is_oov(word) and len(word) >= 4:
split_result = self.split_word(word)
processed.append(split_result)
else:
processed.append(word)
return ' '.join(processed)
# ═══════════════════════════════════════════════════════════════════════════════
# EDIT DISTANCE CORRECTOR
# ═══════════════════════════════════════════════════════════════════════════════
class EditDistanceCorrector:
"""
Generates candidates based on Levenshtein distance.
Uses BERT Vocabulary to filter for valid words.
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
# Build strict vocabulary (ignore subwords starting with ## and punctuation)
self.vocab = {
w for w in tokenizer.get_vocab().keys()
if w.isalpha() and not w.startswith('##') and len(w) > 1
}
# Frequency rank heuristic: lower index = higher frequency (usually)
self.vocab_rank = {w: i for w, i in tokenizer.get_vocab().items()}
def edits1(self, word):
"""All edits that are one edit away from `word`."""
letters = 'أابتثجحخدذرزسشصضطظعغفقكلمنهويءآىةئؤ' # Arabic chars
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
deletes = [L + R[1:] for L, R in splits if R]
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
inserts = [L + c + R for L, R in splits for c in letters]
return set(deletes + transposes + replaces + inserts)
def edits2(self, word):
"""All edits that are two edits away from `word`."""
return (e2 for e1 in self.edits1(word) for e2 in self.edits1(e1))
def known(self, words):
"""The subset of `words` that appear in the dictionary of known words."""
return set(w for w in words if w in self.vocab)
def generate_candidate(self, text: str) -> str:
"""
Generate a candidate sentence by fixing OOV words using Edit Distance.
"""
words = text.split()
corrected_words = []
for word in words:
# Clean word for checking
clean_word = re.sub(r'[^\w]', '', word)
# If word is known, keep it
if clean_word in self.vocab:
corrected_words.append(word)
continue
# If OOV, try to find neighbor
# 1. Edits 1
candidates = self.known(self.edits1(clean_word))
# 2. Edits 2 (if no Edits 1)
if not candidates:
# Optimize: Only check edits2 if word length is reasonable
if len(clean_word) < 7:
candidates = self.known(self.edits2(clean_word))
if candidates:
# Pick best candidate: Lowest vocab rank (most frequent)
best_candidate = min(candidates, key=lambda w: self.vocab_rank.get(w, 999999))
corrected_words.append(best_candidate)
else:
# No correction found, keep original
corrected_words.append(word)
return ' '.join(corrected_words)
# ═══════════════════════════════════════════════════════════════════════════════
# CONTEXTUAL CORRECTOR (MLM-based with Batch Scoring)
# ═══════════════════════════════════════════════════════════════════════════════
class ContextualCorrector:
"""MLM-based contextual correction for confusion pairs"""
# Common confusion pairs in Arabic
CONFUSION_PAIRS = [
('ض', 'ظ'), ('ذ', 'ز'), ('ث', 'س'), ('ص', 'س'),
('ط', 'ت'), ('ق', 'ك'), ('ه', 'ة'), ('ا', 'ى'),
('ت', 'د'), ('د', 'ض'), ('ك', 'ق'), ('غ', 'ق'),
('ج', 'ش'), ('س', 'ز'), ('ف', 'ب'), ('و', 'و'), # (و, و) placeholder, maybe (و, ؤ)?
('ؤ', 'و'), ('ئ', 'ي'), ('ء', 'أ'), ('إ', 'أ'),
]
def __init__(self, model_name: str = 'aubmindlab/bert-base-arabertv02', cache_size: int = 10000):
"""Initialize with BERT MLM model and LRU cache"""
from transformers import AutoTokenizer, AutoModelForMaskedLM
from functools import lru_cache
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
self.model.eval()
# Build confusion map
self.confusion_map = self._build_confusion_map()
# Stats
self.cache_hits = 0
self.cache_misses = 0
# Create LRU cache for scoring
self._score_cache = {}
self.cache_size = cache_size
# Load vocabulary for filtering
self.vocab = self.tokenizer.get_vocab()
def _build_confusion_map(self):
"""Build bidirectional confusion map"""
confusion_map = {}
for char1, char2 in self.CONFUSION_PAIRS:
if char1 not in confusion_map:
confusion_map[char1] = []
if char2 not in confusion_map:
confusion_map[char2] = []
confusion_map[char1].append(char2)
confusion_map[char2].append(char1)
return confusion_map
def get_confusable_chars(self, char: str) -> List[str]:
"""Get confusable characters for a given char"""
return self.confusion_map.get(char, [])
def generate_candidates(self, word: str) -> List[str]:
"""Generate candidate corrections for a word"""
candidates = [word]
# 1. Substitute confusable chars
for i, char in enumerate(word):
confusables = self.get_confusable_chars(char)
for conf_char in confusables:
candidate = word[:i] + conf_char + word[i+1:]
if candidate not in candidates:
candidates.append(candidate)
# 2. Remove repeated characters (deletion)
# Fixes: مدررسة -> مدرسة, جميلل -> جميل
for i in range(len(word) - 1):
if word[i] == word[i+1]:
# Remove one instance of the repeated char
candidate = word[:i] + word[i+1:]
if candidate not in candidates:
candidates.append(candidate)
# 3. Edit Distance 1 Candidates (Insertions, Substitutions, Transpositions)
# Using a restricted set of characters to avoid explosion
COMMON_CHARS = 'ابتثجحخدذرزسشصضطظعغفقكلمنهويأإآءئؤةى'
# Filter candidates by vocabulary to prevent hallucinations and scoring errors
# Only keep candidates that are valid single tokens in the vocabulary.
# Insertions (missing char)
for i in range(len(word) + 1):
for char in COMMON_CHARS:
candidate = word[:i] + char + word[i:]
if candidate in self.vocab and candidate not in candidates:
candidates.append(candidate)
# Substitutions (wrong char)
if len(word) < 7:
for i in range(len(word)):
for char in COMMON_CHARS:
if char != word[i]:
candidate = word[:i] + char + word[i+1:]
if candidate in self.vocab and candidate not in candidates:
candidates.append(candidate)
# Deletions (extra char) - General
for i in range(len(word)):
candidate = word[:i] + word[i+1:]
if len(candidate) > 1:
# For deletions, candidate might be a valid word even if not in vocab?
# But to be safe and consistent with scoring, let's enforce vocab.
# (Note: 'جميل' IS in vocab, so it works).
if candidate in self.vocab and candidate not in candidates:
candidates.append(candidate)
return candidates
def score_with_mlm(self, text: str, position: int, word: str) -> float:
"""Score a word in context using BERT MLM"""
# Check cache
cache_key = f"{text}|{position}|{word}"
if cache_key in self._score_cache:
self.cache_hits += 1
return self._score_cache[cache_key]
self.cache_misses += 1
# Create masked text
words = text.split()
if position >= len(words):
return 0.0
masked_words = words.copy()
masked_words[position] = '[MASK]'
masked_text = ' '.join(masked_words)
# Tokenize
inputs = self.tokenizer(masked_text, return_tensors='pt', padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
predictions = outputs.logits
# Find mask position
mask_token_index = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
if len(mask_token_index) == 0:
return 0.0
# Get probabilities for the word
mask_token_logits = predictions[0, mask_token_index[0], :]
probs = torch.softmax(mask_token_logits, dim=0)
# Get word token id
word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
if not word_tokens:
return 0.0
word_token_id = word_tokens[0]
score = probs[word_token_id].item()
# Update cache (with size limit)
if len(self._score_cache) >= self.cache_size:
# Remove oldest entry (simple FIFO)
self._score_cache.pop(next(iter(self._score_cache)))
self._score_cache[cache_key] = score
return score
def score_candidates_batch(self, text: str, position: int, candidates: List[str]) -> dict:
"""
Batch score multiple candidates (NEW - more efficient!)
Returns: {candidate: score}
"""
scores = {}
for candidate in candidates:
scores[candidate] = self.score_with_mlm(text, position, candidate)
return scores
def predict_masked_token(self, text: str, position: int, top_k: int = 5) -> List[Tuple[str, float]]:
"""Predict words for a masked position. Returns list of (word, score)."""
words = text.split()
if position >= len(words):
return []
masked_words = words.copy()
masked_words[position] = '[MASK]'
masked_text = ' '.join(masked_words)
inputs = self.tokenizer(masked_text, return_tensors='pt', padding=True, truncation=True).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
predictions = outputs.logits
mask_token_index = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
if len(mask_token_index) == 0:
return []
mask_token_logits = predictions[0, mask_token_index[0], :]
probs = torch.softmax(mask_token_logits, dim=0)
top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
results = []
for i in range(top_k):
token_id = top_k_indices[i].item()
score = top_k_weights[i].item()
token = self.tokenizer.decode([token_id]).strip()
if not token.startswith("##") and token not in self.tokenizer.all_special_tokens:
results.append((token, score))
return results
def refine_sentence_with_mask(self, text: str, threshold: float = 0.001, vocab_manager=None, raw_model_output=None) -> str:
"""Refine sentence by masking weak words and predicting replacements.
IV-Safe + Strict similarity + BERT Kill Switch.
"""
words = text.split()
refined_words = words.copy()
# Build set of raw model words for kill switch
raw_words = raw_model_output.split() if raw_model_output else []
for i, word in enumerate(words):
# IV-Safe check - NEVER replace IV words
if vocab_manager and vocab_manager.is_iv(word):
continue
# BERT Kill Switch: skip words matching raw model output
if i < len(raw_words) and word == raw_words[i]:
continue
# Skip very short words (prepositions etc)
if len(word) <= 2:
continue
# 1. Check confidence
current_score = self.score_with_mlm(text, i, word)
if current_score > threshold:
continue
# 2. Mask and Predict
predictions = self.predict_masked_token(text, i, top_k=10)
# 3. Filter and Select (strict)
for pred_word, pred_score in predictions:
if pred_word == word:
continue
if abs(len(pred_word) - len(word)) > 1:
continue
# Similarity Check (0.90 minimum)
dist = Levenshtein.distance(word, pred_word)
max_len = max(len(word), len(pred_word))
similarity = 1.0 - (dist / max_len)
if similarity < 0.90:
continue
# Must be IV
if vocab_manager and vocab_manager.is_oov(pred_word):
continue
# Minimum absolute confidence gate (12%)
if pred_score < 0.12:
continue
# Score Improvement
is_original_common = current_score > 0.001
if is_original_common:
if pred_score > current_score * 1000:
refined_words[i] = pred_word
break
else:
if pred_score > current_score * 50 and pred_score > 0.2:
refined_words[i] = pred_word
break
return ' '.join(refined_words)
def calculate_sentence_score(self, text: str) -> float:
"""Calculate fluency score using BERT MLM average word probability."""
words = text.split()
if not words:
return 0.0
total_score = 0.0
scored_words = 0
for i, word in enumerate(words):
score = self.score_with_mlm(text, i, word)
total_score += score
scored_words += 1
if scored_words == 0:
return 0.0
return total_score / scored_words
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN SPELL CHECKER CLASS
# ═══════════════════════════════════════════════════════════════════════════════
class ArabicSpellChecker:
"""Main Arabic Spell Checker class"""
def __init__(self, model, tokenizer, device, use_contextual: bool = True):
"""Initialize spell checker with model and components"""
self.model = model
self.tokenizer = tokenizer
self.device = device
# Initialize components
self.postprocessor = AraSpellPostProcessor()
self.classifier = ErrorClassifier()
self.rules = RulesBasedCorrector()
self.validator = OutputValidator()
self.vocab_manager = VocabularyManager(tokenizer)
self.edit_corrector = EditDistanceCorrector(tokenizer) # Edit Distance candidates
self.split_merge = SplitMergeSpecialist(self.vocab_manager)
# WordAligner for word-level hybrid corrections
self.word_aligner = WordAligner(self.vocab_manager)
# Initialize contextual corrector (optional)
self.use_contextual = use_contextual
if use_contextual:
try:
self.contextual = ContextualCorrector()
logger.info("Contextual correction enabled")
except Exception as e:
logger.warning(f"Contextual correction disabled: {e}")
self.contextual = None
self.use_contextual = False
else:
self.contextual = None
def _fix_repeated_end_chars(self, text: str) -> str:
"""
🆕 Fix repeated characters at word endings
Examples:
اليومم → اليوم
جميلل → جميل
صباحح → صباح
"""
# Remove repeated chars at word end (keep only one)
text = re.sub(r'([ا-ي])\1+\b', r'\1', text)
return text
def _fix_merged_with_errors(self, text: str) -> str:
""" Fix merged words that contain errors
Examples:
الممدرسة → المدرسة
الكتابب → الكتاب
الططالب → الطالب
"""
# Pattern 1: ال + repeated char + word
text = re.sub(r'ال([ا-ي])\1+([ا-ي]{2,})', r'ال\2', text)
# Pattern 2: word + repeated char at end
text = re.sub(r'\b([ا-ي]{3,})([ا-ي])\2+\b', r'\1\2', text)
return text
def _split_merged_words_linguistic(self, text: str) -> str:
""" Split merged words using linguistic patterns
Examples:
كلصباح → كل صباح
فيالطريق → في الطريق
السلامعليكم → السلام عليكم
"""
# Pattern 1: Prepositions + (article)? + word
# Added: ك (like in كالكتاب) but careful not to split overlapping words
text = re.sub(
r'\b(في|من|إلى|الى|حتى|منذ|خلال|بعد|قبل)(ال)?([ا-ي]{3,})',
r'\1 \2\3',
text
)
# Pattern 2: كل + word
text = re.sub(r'\b(كل)([ا-ي]{3,})', r'\1 \2', text)
# Pattern 3: Article repetition
text = re.sub(r'([ا-ي]{3,})(ال)([ا-ي]{3,})', r'\1 \2\3', text)
# Pattern 4: Single-letter prepositions
text = re.sub(r'\b([بلك])(ال)?([ا-ي]{3,})', r'\1 \2\3', text)
# Pattern 5: Word + عليكم/عليك
text = re.sub(r'([ا-ي]{4,})(عليكم|عليك|عليه|عليها)', r'\1 \2', text)
# Pattern 6: على/عن in middle of (merged) words
text = re.sub(r'([ا-ي]{3,})(على|عن)([ا-ي]{3,})', r'\1 \2 \3', text)
return text
def _split_long_words_heuristic(self, text: str, max_length: int = 15) -> str:
""" Split suspiciously long words using heuristics
"""
words = text.split()
result = []
for word in words:
if len(word) <= max_length:
result.append(word)
continue
# Check for embedded article
if 'ال' in word[2:]:
parts = word.split('ال', 1)
if len(parts[0]) >= 2 and len(parts[1]) >= 3:
result.extend([parts[0], 'ال' + parts[1]])
continue
# Check for common prefixes at start of long word
if len(word) >= 8:
split_found = False
for split_pos in [2, 3]:
prefix = word[:split_pos]
suffix = word[split_pos:]
if prefix in ['في', 'من', 'على', 'عن', 'مع', 'كل', 'ب', 'ل', 'ك']:
result.extend([prefix, suffix])
split_found = True
break
if not split_found:
result.append(word)
else:
result.append(word)
return ' '.join(result)
def _normalize_tanween_patterns(self, text: str) -> str:
""" Normalize tanween patterns
Examples:
جدأ → جداً
كثيرأ → كثيراً
"""
# أ at word end → اً
text = re.sub(r'([ا-ي]{2,})أ\b', r'\1اً', text)
# Remove standalone أ
text = re.sub(r'\s+أ\s+', ' ', text)
# Fix accidental splits (e.g. ب + space + word)
text = re.sub(r'\b([بلك])\s+([ا-ي])', r'\1\2', text)
return text
def preprocess(self, text: str) -> str:
"""Preprocessing pipeline (مع التحسينات المدمجة)"""
# Basic normalization
text = self.postprocessor.remove_harakat(text)
text = self.postprocessor.remove_tatweel(text)
text = self.postprocessor.normalize_special_chars(text)
# Integrated improvements
# Fix repeated chars and merged words with errors FIRST
text = self._fix_repeated_end_chars(text)
text = self._fix_merged_with_errors(text)
# Then split merged words
text = self._split_merged_words_linguistic(text)
text = self._split_long_words_heuristic(text)
text = self._normalize_tanween_patterns(text)
# Merge separated 'ال'
text = self.postprocessor.merge_separated_al(text)
# Collapse repetitions
text = self.postprocessor.unified_collapse_repeated(text)
# Rules-based fixes
text = self.rules.fix_char_substitution(text)
text = self.rules.fix_char_repetition(text)
# Normalize spaces
text = self.postprocessor.normalize_spaces(text)
return text
def postprocess(self, text: str, original: str = "") -> str:
"""Postprocessing pipeline — passes vocab_manager for smart ه/ة handling"""
return self.postprocessor.full_postprocess(text, original, vocab_manager=self.vocab_manager)
def model_inference(self, text: str, num_return_sequences: int = 5) -> List[str]:
"""Run seq2seq model inference and return top candidates.
Also extracts beam scores (token-level probabilities) for diagnostics.
"""
# Tokenize
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate with beam search
# Keeping 5 beams as model was trained/optimized for this
# Keeping 5 beams as model was trained/optimized for this
with torch.no_grad():
outputs = self.model.generate(
**inputs,
num_beams=5,
num_return_sequences=num_return_sequences,
early_stopping=True,
return_dict_in_generate=True,
output_scores=True
)
# Decode
candidates = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
# Store beam scores for potential use
self._last_beam_scores = {}
if hasattr(outputs, 'sequences_scores') and outputs.sequences_scores is not None:
scores = outputs.sequences_scores.tolist()
for cand, score in zip(candidates, scores):
self._last_beam_scores[cand] = score
return candidates
def correct(self, text: str) -> str:
"""
Main correction pipeline (RERANKING APPROACH)
Steps:
1. Preprocess
2. Generate Candidates (Model Beams + Baseline)
3. Rerank Candidates (Validator + Fluency)
4. Select Best
5. Postprocess
"""
if not text or not text.strip():
return text
original = text
# 1. Preprocess
# This provides a strong baseline candidate
preprocessed_text = self.preprocess(text)
# 2. Classify error type
error_type = self.classifier.classify(preprocessed_text)
# 3. Generate Candidates
candidates = []
# A. Baseline (Preprocessed)
candidates.append(preprocessed_text)
# B. Smart Rules Candidate (Aggressive Heuristic)
rules_candidate = self.rules.advanced_heuristic_repair(text)
candidates.append(rules_candidate)
# B2. Edit Distance Candidate
edit_candidate = self.edit_corrector.generate_candidate(text)
if edit_candidate != text and edit_candidate != rules_candidate:
candidates.append(edit_candidate)
# C. Model Beams
raw_model_output = None # Track for safety net
try:
model_candidates = self.model_inference(preprocessed_text, num_return_sequences=5)
raw_model_output = model_candidates[0] if model_candidates else None
candidates.extend(model_candidates)
# D. Word-Aligned Hybrid Candidate
# Creates a hybrid by selecting best word from each position
if model_candidates:
hybrid_candidate = self.word_aligner.align_words(preprocessed_text, model_candidates[0])
if hybrid_candidate not in candidates:
candidates.append(hybrid_candidate)
# E. Word-Aligned with ALL top beams (not just beam 0)
for beam in model_candidates[1:3]: # Top 3 beams
hybrid_beam = self.word_aligner.align_words(preprocessed_text, beam)
if hybrid_beam not in candidates:
candidates.append(hybrid_beam)
# D2. Token-level Voting Candidate
# Majority-vote each token across all beams
if model_candidates and len(model_candidates) >= 3:
try:
beam_word_lists = [c.split() for c in model_candidates]
max_words = max(len(wl) for wl in beam_word_lists)
voted_words = []
for pos in range(max_words):
words_at_pos = []
for wl in beam_word_lists:
if pos < len(wl):
words_at_pos.append(wl[pos])
if words_at_pos:
most_common = Counter(words_at_pos).most_common(1)[0][0]
voted_words.append(most_common)
voted_candidate = ' '.join(voted_words)
if voted_candidate not in candidates:
candidates.append(voted_candidate)
except Exception:
pass
except Exception as e:
logger.warning(f"Model inference failed: {e}")
# Remove duplicates while preserving order
unique_candidates = []
seen = set()
for c in candidates:
if c not in seen:
unique_candidates.append(c)
seen.add(c)
candidates = unique_candidates
# 4. Rerank Candidates
best_candidate = preprocessed_text
best_score = -1.0
# Debug info
candidate_scores = []
for cand in candidates:
# A. Validation Score (Hard Penalty)
# Check validity against strict original
is_valid, reason = self.validator.validate(original, cand, error_type.value)
# Additional check: If candidate is suspiciously shorter than original (and not just harakat removal)
if len(cand) < len(original) * 0.5:
is_valid = False
reason = "too_short"
# ═══════════════════════════════════════════════════════════════════════════
# VOCABULARY-AWARE ACCEPTANCE
# ═══════════════════════════════════════════════════════════════════════════
# Logic: OOV→IV = ACCEPT (boost), IV→OOV = REJECT (penalize)
# This prevents over-conservative validation from rejecting correct corrections
input_oov_count = self.vocab_manager.count_oov_words(original)
cand_oov_count = self.vocab_manager.count_oov_words(cand)
vocab_boost = 1.0
# Case 1: OOV→IV (Correction fixed unknown words) → Accept more readily
if input_oov_count > 0 and cand_oov_count < input_oov_count:
# Significant boost for reducing OOV words
oov_reduction = input_oov_count - cand_oov_count
vocab_boost = 1.0 + (oov_reduction * 0.3) # +30% per OOV fixed
# If ALL words are now IV, accept even with higher edit distance
if cand_oov_count == 0 and self.vocab_manager.all_words_iv(cand):
# Override validation rejection if OOV→IV
if not is_valid and reason not in ["empty_output"]:
is_valid = True
reason = "vocab_aware_accept"
# Case 2: IV→OOV (Correction introduced unknown words) → Penalize
elif cand_oov_count > input_oov_count:
# Penalize for introducing new OOV words
vocab_boost = 0.5 # 50% penalty
# Case 3: All IV to begin with → Standard validation
elif input_oov_count == 0 and cand_oov_count == 0:
# Both are valid vocab, prefer minimal edits
vocab_boost = 1.0
# ═══════════════════════════════════════════════════════════════════════════
# Penalty factor
# Valid: 1.0
# Invalid: 0.01 (Heavy penalty, essentially disqualified unless all are invalid)
validity_factor = 1.0 if is_valid else 0.001
# B. Fluency Score (BERT MLM)
fluency_score = 0.0
if self.use_contextual and self.contextual:
try:
fluency_score = self.contextual.calculate_sentence_score(cand)
except Exception as e:
logger.warning(f"Scoring failed: {e}")
fluency_score = 0.5 # Default fallback
else:
fluency_score = 1.0
# C. Similarity Score (Damerau-Levenshtein Distance)
dist = VocabularyManager.damerau_levenshtein_distance(preprocessed_text, cand)
max_len = max(len(preprocessed_text), len(cand), 1)
similarity = 1.0 - (dist / max_len)
# Boost exact matches
if cand == preprocessed_text:
similarity = 1.0
# Keyboard Proximity Bonus
# If changes between input and candidate are keyboard-adjacent,
# it's more likely a typo fix (give bonus)
keyboard_bonus = 1.0
input_words = preprocessed_text.split()
cand_words = cand.split()
if len(input_words) == len(cand_words):
for iw, cw in zip(input_words, cand_words):
if iw != cw and len(iw) == len(cw):
# Check char-by-char differences
for ic, cc in zip(iw, cw):
if ic != cc and RulesBasedCorrector.is_keyboard_neighbor(ic, cc):
keyboard_bonus *= 1.05 # 5% bonus per keyboard-adjacent fix
# HIGH CONFIDENCE GATING
# If model is extremely confident (high fluency) and words are valid, relax validation
# This allows correcting severe corruptions that fail strict edit distance
if fluency_score > 0.85 and cand_oov_count == 0:
if not is_valid and reason in ["too_short", "low_character_similarity", "word_count_mismatch"]:
# Check if it makes sense length-wise (don't allow completely empty or massive hallucinations)
if len(cand) >= len(original) * 0.4:
is_valid = True
reason = "high_confidence_override"
vocab_boost *= 1.2 # Bonus for high confidence
validity_factor = 1.0 # Reset validity factor
# Final Score = (Fluency^0.3) * (Similarity^3.0) * Validity * VocabBoost * KeyboardBonus * BeamBoost
fluency_exp = 0.3
similarity_exp = 3.0
# Beam 0 Boost — model's top beam gets 15% priority
beam_boost = 1.0
if raw_model_output and cand == raw_model_output:
beam_boost = 1.15
final_score = (fluency_score ** fluency_exp) * (similarity ** similarity_exp) * validity_factor * vocab_boost * keyboard_bonus * beam_boost
candidate_scores.append({
'text': cand,
'is_valid': is_valid,
'reason': reason,
'fluency': fluency_score,
'similarity': similarity,
'vocab_boost': vocab_boost,
'input_oov': input_oov_count,
'cand_oov': cand_oov_count,
'final_score': final_score
})
if final_score > best_score:
best_score = final_score
best_candidate = cand
# ═══════════════════════════════════════════════════════════════════════════
# --- Output Quality Scoring (Minimum Score Threshold) ---
# If ALL candidates scored poorly, the correction is unreliable → keep input
# ═══════════════════════════════════════════════════════════════════════════
if best_candidate != preprocessed_text:
# Check: did the best candidate actually get a decent score?
# The preprocessed input (candidate 0) is always in the pool.
# If the best candidate barely beats preprocessed_text, it might not be trustworthy.
preprocessed_score = 0.0
for cs in candidate_scores:
if cs['text'] == preprocessed_text:
preprocessed_score = cs['final_score']
break
# If best score is less than 1.05x the preprocessed score AND
# the best candidate introduced OOV words → fall back to preprocessed
if preprocessed_score > 0 and best_score < preprocessed_score * 1.05:
best_oov = self.vocab_manager.count_oov_words(best_candidate)
prep_oov = self.vocab_manager.count_oov_words(preprocessed_text)
if best_oov > prep_oov:
best_candidate = preprocessed_text
best_score = preprocessed_score
# ═══════════════════════════════════════════════════════════════════════════
# --- Contextual Validation Layer ---
# Compare fluency of input vs best candidate
# If correction made text LESS fluent → reject the correction
# ═══════════════════════════════════════════════════════════════════════════
if best_candidate != preprocessed_text and self.use_contextual and self.contextual:
try:
input_fluency = self.contextual.calculate_sentence_score(preprocessed_text)
best_fluency = 0.0
for cs in candidate_scores:
if cs['text'] == best_candidate:
best_fluency = cs['fluency']
break
# If input is significantly more fluent than best candidate
# AND both have similar OOV counts → prefer input
if input_fluency > 0 and best_fluency > 0:
if input_fluency > best_fluency * 1.5: # Input 50% more fluent
input_oov = self.vocab_manager.count_oov_words(preprocessed_text)
best_oov = self.vocab_manager.count_oov_words(best_candidate)
if input_oov <= best_oov:
# Input is more fluent AND has fewer/equal OOV → keep input
best_candidate = preprocessed_text
except Exception:
pass # Contextual validation is optional
# 5. Postprocess Winner
result = self.postprocess(best_candidate, original)
# 5.5 IV-Safe Postprocessing Check
# If postprocessing changed an IV word to OOV, revert that word
if result != best_candidate:
result_words = result.split()
best_words = best_candidate.split()
if len(result_words) == len(best_words):
fixed_words = []
input_words_pp = preprocessed_text.split()
for idx_fw, (rw, bw) in enumerate(zip(result_words, best_words)):
if rw != bw:
# Postprocessor changed this word
bw_iv = self.vocab_manager.is_iv(bw)
rw_iv = self.vocab_manager.is_iv(rw)
if bw_iv and not rw_iv:
# IV → OOV: revert to pre-postprocess version
fixed_words.append(bw)
elif bw_iv and rw_iv:
# Postprocess Distance Guard
# DISABLED: Caused word-level regression. When both are IV,
# the postprocessor's choice (rw) is usually better because
# it applies Arabic-specific rules (hamza, ta marbuta).
fixed_words.append(rw)
else:
fixed_words.append(rw)
else:
fixed_words.append(rw)
result = ' '.join(fixed_words)
# 6. Contextual fine-tuning (BERT Masked Refinement)
# IV-Safe mode - pass vocab_manager to protect IV words
# BERT Kill Switch - also pass raw_model_output to protect model-confident words
if self.use_contextual and self.contextual:
if len(result) > 3:
result = self.contextual.refine_sentence_with_mask(
result, vocab_manager=self.vocab_manager,
raw_model_output=raw_model_output
)
# 7. Safe Split/Merge Post-processing
# Only apply merge_fragments (safe: only merges when result is IV)
result = self.split_merge.merge_fragments(result)
# ═══════════════════════════════════════════════════════════════════════════
# VALIDATION & QUALITY CHECKS
# ═══════════════════════════════════════════════════════════════════════════
# 8. Output Stability Test (Solution 30)
# If correcting the output again changes it → unstable correction → reject
# Stable corrections are idempotent: correct(correct(x)) == correct(x)
if result != preprocessed_text and raw_model_output:
try:
# Quick stability check: run the result through preprocessing only
# (full model inference would be too slow)
re_preprocessed = self.preprocess(result)
# If re-preprocessing changes the result significantly, it was unstable
stability_dist = VocabularyManager.damerau_levenshtein_distance(result, re_preprocessed)
result_len = max(len(result), 1)
if stability_dist > 0:
# Result is not stable under re-preprocessing
stability_ratio = stability_dist / result_len
if stability_ratio > 0.15: # More than 15% changed → very unstable
# Fall back to raw model output if it's more stable
raw_re = self.preprocess(raw_model_output)
raw_stability = VocabularyManager.damerau_levenshtein_distance(
raw_model_output, raw_re
) / max(len(raw_model_output), 1)
if raw_stability < stability_ratio:
# Raw is more stable → use it
raw_oov = self.vocab_manager.count_oov_words(raw_model_output)
our_oov = self.vocab_manager.count_oov_words(result)
if raw_oov <= our_oov:
result = raw_model_output
except Exception:
pass # Stability check is optional, don't break pipeline
# 9. Bidirectional Word-Level Validation (Solution 24)
# Compare our result word-by-word with raw model output
# If we corrupted a word that the model got right, revert that word
if raw_model_output and result != raw_model_output:
result_words = result.split()
raw_words = raw_model_output.split()
if len(result_words) == len(raw_words):
corrected_words = []
changed = False
for rw, raw_w in zip(result_words, raw_words):
if rw != raw_w:
rw_iv = self.vocab_manager.is_iv(rw)
raw_iv = self.vocab_manager.is_iv(raw_w)
# Case 1: Our word is OOV but raw word is IV → take raw
if not rw_iv and raw_iv:
corrected_words.append(raw_w)
changed = True
# Case 2: Both IV but our word is further from input
elif rw_iv and raw_iv:
# Find corresponding input word
input_words = preprocessed_text.split()
idx = len(corrected_words)
if idx < len(input_words):
input_w = input_words[idx]
rw_dist = Levenshtein.distance(input_w, rw)
raw_dist = Levenshtein.distance(input_w, raw_w)
# If raw is closer to input AND both are IV → prefer raw
# (our pipeline likely introduced unnecessary change)
if raw_dist < rw_dist:
corrected_words.append(raw_w)
changed = True
else:
corrected_words.append(rw)
else:
corrected_words.append(rw)
else:
corrected_words.append(rw)
else:
corrected_words.append(rw)
if changed:
new_result = ' '.join(corrected_words)
# Only accept if the new result doesn't increase OOV
new_oov = self.vocab_manager.count_oov_words(new_result)
old_oov = self.vocab_manager.count_oov_words(result)
if new_oov <= old_oov:
result = new_result
# 10. SAFETY NET: Compare with raw model output (Conservative)
# Only switch to raw if raw is CLEARLY better
if raw_model_output and raw_model_output != result:
raw_oov = self.vocab_manager.count_oov_words(raw_model_output)
our_oov = self.vocab_manager.count_oov_words(result)
# Case A: Raw all-IV, ours has OOV
if raw_oov == 0 and our_oov > 0:
is_valid, reason = self.validator.validate(original, raw_model_output, "mixed")
if is_valid or reason == "space_leniency_accept":
result = raw_model_output
# Case B: Both all-IV but raw is more similar to input
# Catches BERT/postprocess damage (word substitutions up to 5 char distance)
elif raw_oov == 0 and our_oov == 0:
raw_dist = VocabularyManager.damerau_levenshtein_distance(original, raw_model_output)
our_dist = VocabularyManager.damerau_levenshtein_distance(original, result)
result_vs_raw_dist = VocabularyManager.damerau_levenshtein_distance(result, raw_model_output)
# Threshold at 3 chars — covers single char edits and small substitutions
# (widening to 5 caused regression by reverting valid hybrid corrections)
if raw_dist < our_dist and result_vs_raw_dist <= 3:
raw_valid, _ = self.validator.validate(original, raw_model_output, "mixed")
if raw_valid:
result = raw_model_output
# Case C: Word count differs — raw might have correct splitting
# Catches: 'فيلق → في فيلق' (pipeline added word)
# or 'بلاكبيرن روفرز → بلاكبيرن روفر' (pipeline lost word ending)
elif raw_oov == 0:
raw_wc = len(raw_model_output.split())
our_wc = len(result.split())
if raw_wc != our_wc:
raw_dist = VocabularyManager.damerau_levenshtein_distance(original, raw_model_output)
our_dist = VocabularyManager.damerau_levenshtein_distance(original, result)
if raw_dist < our_dist:
raw_valid, _ = self.validator.validate(original, raw_model_output, "mixed")
if raw_valid:
result = raw_model_output
return result
# ═══════════════════════════════════════════════════════════════════════════════
# PUBLIC API
# ═══════════════════════════════════════════════════════════════════════════════
# Exported for use by benchmark.py and external consumers
spell_checker = None # Will be initialized on first import with __main__ or by benchmark
def initialize(use_contextual=True):
"""Initialize the spell checker. Call once before using."""
global spell_checker
spell_checker = ArabicSpellChecker(model, tokenizer, device, use_contextual=use_contextual)
logger.info("Spell checker initialized")
return spell_checker
if __name__ == "__main__":
sc = initialize(use_contextual=True)
# Quick demo
test_cases = [
"السلام عليكممم",
"فيالمدرسه",
"الطقص جميل اليومم",
]
print("\n" + "=" * 60)
print("AraSpell Demo")
print("=" * 60)
for text in test_cases:
corrected = sc.correct(text)
print(f"\n Input: {text}")
print(f" Corrected: {corrected}")
print("\n" + "=" * 60)
print("For full benchmark, run: python benchmark.py")
print("=" * 60)