youssefreda9 commited on
Commit
4608bcd
·
1 Parent(s): 32a135f

FIX-44: OOV cleanup pass between spelling and grammar stages

Browse files

NEW PIPELINE STEP after spelling, before grammar:
1. Trailing و removal (from legacy AraSpell):
- المصنعو→المصنع, الماهرينوومن→الماهرينوومن
- Catches PC004, PC008, PC010 benchmark failures

2. Edit-distance-1 OOV→IV correction:
- For remaining OOV words, find closest IV word in BERT vocab
- Only replaces when edit-1 candidate exists and first letter matches
- Catches: صممو→صمموا (PC001), حضرو→حضروا (PC042)

Also adds contextual_corrector.py module (MLM-based validation).
Tests: 39 passing.

src/app.py CHANGED
@@ -1967,6 +1967,103 @@ def analyze_text():
1967
  logger.error(traceback.format_exc())
1968
  timing_ms['spelling_error'] = f"{type(e).__name__}: {str(e)[:200]}"
1969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1970
  # ── FIX-07: Religious text already detected above (before spelling) ──
1971
  # _is_religious_text was set earlier to skip ALL stages for sacred text
1972
 
 
1967
  logger.error(traceback.format_exc())
1968
  timing_ms['spelling_error'] = f"{type(e).__name__}: {str(e)[:200]}"
1969
 
1970
+ # ── FIX-44: OOV Cleanup Pass (between spelling and grammar) ──
1971
+ # After spelling corrections, some OOV words remain because:
1972
+ # 1. The model didn't correct them (missed)
1973
+ # 2. Our guards blocked a bad correction (but word is still OOV)
1974
+ # 3. Trailing و artifacts from model output
1975
+ #
1976
+ # For each remaining OOV word, try to find the closest IV word
1977
+ # using edit-distance-1 candidates from BERT vocabulary.
1978
+ if not _is_religious_text:
1979
+ try:
1980
+ from nlp.spelling.araspell_service import get_spelling_model
1981
+ _oov_checker = get_spelling_model()
1982
+ _oov_text = ctx.current_text
1983
+ _oov_words = _oov_text.split()
1984
+ _oov_changed = False
1985
+ _oov_result = []
1986
+
1987
+ for _ow_idx, _ow in enumerate(_oov_words):
1988
+ # Skip short words (prepositions etc.)
1989
+ if len(_ow) <= 2:
1990
+ _oov_result.append(_ow)
1991
+ continue
1992
+
1993
+ # Strip trailing punctuation for IV check
1994
+ _ow_clean = _ow.rstrip('.،؛؟!?!')
1995
+
1996
+ # Skip if already IV
1997
+ if _oov_checker.vocab_manager.is_iv(_ow_clean):
1998
+ _oov_result.append(_ow)
1999
+ continue
2000
+
2001
+ # ── Trailing و removal (from legacy AraSpell L263-267) ──
2002
+ # الماضيةو → الماضية, المصنعو → المصنع, الدروسو → الدروس
2003
+ if (len(_ow_clean) > 4 and _ow_clean.endswith('و')
2004
+ and _ow_clean[-2] in 'ةهاأإآءين'):
2005
+ _wo_cand = _ow_clean[:-1]
2006
+ if _oov_checker.vocab_manager.is_iv(_wo_cand):
2007
+ _punct_suffix = _ow[len(_ow_clean):] # preserve punctuation
2008
+ logger.info(
2009
+ f"[OOV-CLEANUP] Trailing و fix: '{_ow}'→'{_wo_cand}{_punct_suffix}'"
2010
+ )
2011
+ _oov_result.append(_wo_cand + _punct_suffix)
2012
+ _oov_changed = True
2013
+
2014
+ # Create a patch for the UI
2015
+ _ow_pos = sum(len(w) + 1 for w in _oov_words[:_ow_idx])
2016
+ if _ow_pos + len(_ow) <= len(_oov_text):
2017
+ ctx.add_patch(
2018
+ 'spelling', _ow_pos, _ow_pos + len(_ow),
2019
+ _wo_cand + _punct_suffix, confidence=0.75,
2020
+ )
2021
+ continue
2022
+
2023
+ # ── Edit-distance-1 OOV→IV correction ──
2024
+ # Generate all edit-1 candidates and filter to IV words
2025
+ try:
2026
+ _ed1_candidates = _oov_checker.edit_corrector.known(
2027
+ _oov_checker.edit_corrector.edits1(_ow_clean)
2028
+ )
2029
+ if _ed1_candidates:
2030
+ # Pick best: lowest vocab rank (most frequent)
2031
+ _best_cand = min(
2032
+ _ed1_candidates,
2033
+ key=lambda w: _oov_checker.vocab_manager.get_frequency_rank(w)
2034
+ )
2035
+ # Safety: don't change first letter (same guard as FIX-42b)
2036
+ if _best_cand[0] == _ow_clean[0] or (
2037
+ _best_cand[0] in 'أإآاء' and _ow_clean[0] in 'أإآاء'
2038
+ ):
2039
+ _punct_suffix = _ow[len(_ow_clean):]
2040
+ logger.info(
2041
+ f"[OOV-CLEANUP] Edit-1 fix: '{_ow}'→'{_best_cand}{_punct_suffix}'"
2042
+ )
2043
+ _oov_result.append(_best_cand + _punct_suffix)
2044
+ _oov_changed = True
2045
+
2046
+ _ow_pos = sum(len(w) + 1 for w in _oov_words[:_ow_idx])
2047
+ if _ow_pos + len(_ow) <= len(_oov_text):
2048
+ ctx.add_patch(
2049
+ 'spelling', _ow_pos, _ow_pos + len(_ow),
2050
+ _best_cand + _punct_suffix, confidence=0.65,
2051
+ )
2052
+ continue
2053
+ except Exception:
2054
+ pass # Edit-distance fallback is best-effort
2055
+
2056
+ _oov_result.append(_ow)
2057
+
2058
+ if _oov_changed:
2059
+ _oov_new_text = ' '.join(_oov_result)
2060
+ logger.info(f"[OOV-CLEANUP] Applied OOV fixes: '{_oov_text[:80]}' → '{_oov_new_text[:80]}'")
2061
+ ctx.mutate_text(_oov_new_text, OffsetMapper)
2062
+ current_text = ctx.current_text
2063
+
2064
+ except Exception as e:
2065
+ logger.warning(f"[OOV-CLEANUP] Failed: {type(e).__name__}: {e}")
2066
+
2067
  # ── FIX-07: Religious text already detected above (before spelling) ──
2068
  # _is_religious_text was set earlier to skip ALL stages for sacred text
2069
 
src/nlp/spelling/contextual_corrector.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ContextualCorrector — MLM-based contextual validation for spelling corrections
2
+ # Adapted from legacy AraSpell ContextualCorrector.
3
+ #
4
+ # Purpose: After the spelling model produces corrections, this module validates
5
+ # each OOV word by masking it and asking BERT what word should go there.
6
+ # If BERT's top prediction is very different from the correction, the
7
+ # original word is kept (the model hallucinated).
8
+ #
9
+ # Usage in pipeline: Called AFTER spelling correction, BEFORE grammar.
10
+ # Only processes OOV words (never touches IV words).
11
+
12
+ import logging
13
+ import torch
14
+ from typing import List, Tuple, Optional, Dict
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Singleton instance
19
+ _instance = None
20
+ _loading = False
21
+
22
+
23
+ class ContextualCorrector:
24
+ """MLM-based contextual validation for spelling corrections.
25
+
26
+ Uses BERT's masked language model to validate spelling corrections.
27
+ For each OOV word in the corrected text:
28
+ 1. Masks the word and asks BERT for predictions
29
+ 2. If BERT strongly disagrees with the correction, reverts to original
30
+ 3. Never touches IV words (they're already correct)
31
+ """
32
+
33
+ def __init__(self, model_name: str = 'aubmindlab/bert-base-arabertv02'):
34
+ """Initialize with BERT MLM model."""
35
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
36
+
37
+ logger.info(f"[MLM] Loading contextual corrector: {model_name}")
38
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
40
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ self.model = self.model.to(self.device)
42
+ self.model.eval()
43
+
44
+ # Simple cache for scores
45
+ self._cache: Dict[str, float] = {}
46
+ self._cache_max = 5000
47
+
48
+ # Vocab for candidate filtering
49
+ self.vocab = self.tokenizer.get_vocab()
50
+
51
+ logger.info(f"[MLM] Contextual corrector loaded on {self.device}")
52
+
53
+ def score_word_in_context(self, text: str, position: int, word: str) -> float:
54
+ """Score how well a word fits in context using BERT MLM.
55
+
56
+ Args:
57
+ text: Full sentence
58
+ position: Word index (0-based) in the sentence
59
+ word: The word to score
60
+
61
+ Returns:
62
+ Probability score (0.0 to 1.0) — higher = better fit
63
+ """
64
+ cache_key = f"{text[:100]}|{position}|{word}"
65
+ if cache_key in self._cache:
66
+ return self._cache[cache_key]
67
+
68
+ words = text.split()
69
+ if position >= len(words):
70
+ return 0.0
71
+
72
+ # Create masked text
73
+ masked_words = words.copy()
74
+ masked_words[position] = '[MASK]'
75
+ masked_text = ' '.join(masked_words)
76
+
77
+ try:
78
+ inputs = self.tokenizer(
79
+ masked_text, return_tensors='pt',
80
+ padding=True, truncation=True, max_length=128
81
+ )
82
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
83
+
84
+ with torch.no_grad():
85
+ outputs = self.model(**inputs)
86
+
87
+ # Find [MASK] token position
88
+ mask_idx = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
89
+ if len(mask_idx) == 0:
90
+ return 0.0
91
+
92
+ # Get probability for the target word
93
+ logits = outputs.logits[0, mask_idx[0], :]
94
+ probs = torch.softmax(logits, dim=0)
95
+
96
+ word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
97
+ if not word_tokens:
98
+ return 0.0
99
+
100
+ score = probs[word_tokens[0]].item()
101
+
102
+ except Exception as e:
103
+ logger.warning(f"[MLM] Score error for '{word}': {e}")
104
+ score = 0.0
105
+
106
+ # Cache management
107
+ if len(self._cache) >= self._cache_max:
108
+ # Remove oldest 20% of entries
109
+ keys_to_remove = list(self._cache.keys())[:self._cache_max // 5]
110
+ for k in keys_to_remove:
111
+ del self._cache[k]
112
+ self._cache[cache_key] = score
113
+
114
+ return score
115
+
116
+ def validate_corrections(
117
+ self,
118
+ original_text: str,
119
+ corrected_text: str,
120
+ vocab_manager=None,
121
+ confidence_threshold: float = 0.001,
122
+ min_pred_score: float = 0.12,
123
+ similarity_threshold: float = 0.90,
124
+ ) -> str:
125
+ """Validate spelling corrections using MLM context.
126
+
127
+ For each word that changed between original and corrected:
128
+ - If the correction is OOV: revert (model hallucinated)
129
+ - If the correction scores very low in context AND the original
130
+ scores much better: revert
131
+ - If BERT has a better suggestion that's similar to original: use it
132
+
133
+ Args:
134
+ original_text: Text before spelling correction
135
+ corrected_text: Text after spelling correction
136
+ vocab_manager: VocabManager for IV/OOV checks
137
+ confidence_threshold: Min BERT score to keep a word without checking
138
+ min_pred_score: Min BERT score for a replacement candidate
139
+ similarity_threshold: Min similarity (Levenshtein) for replacements
140
+
141
+ Returns:
142
+ Validated text with hallucinations reverted
143
+ """
144
+ orig_words = original_text.split()
145
+ corr_words = corrected_text.split()
146
+
147
+ # Only process when word counts match (1:1 mapping)
148
+ if len(orig_words) != len(corr_words):
149
+ return corrected_text
150
+
151
+ result_words = corr_words.copy()
152
+ changes_made = 0
153
+
154
+ for i, (orig_w, corr_w) in enumerate(zip(orig_words, corr_words)):
155
+ # Skip unchanged words
156
+ if orig_w == corr_w:
157
+ continue
158
+
159
+ # Never touch IV words in correction
160
+ if vocab_manager and vocab_manager.is_iv(corr_w):
161
+ continue
162
+
163
+ # Score the correction in context
164
+ corr_score = self.score_word_in_context(corrected_text, i, corr_w)
165
+
166
+ # If correction has decent BERT confidence, keep it
167
+ if corr_score > confidence_threshold:
168
+ continue
169
+
170
+ # Score the original word in the corrected context
171
+ orig_score = self.score_word_in_context(corrected_text, i, orig_w)
172
+
173
+ # If original scores better, revert
174
+ if orig_score > corr_score * 10 and orig_score > 0.01:
175
+ logger.info(
176
+ f"[MLM] Reverting hallucination: '{corr_w}'→'{orig_w}' "
177
+ f"(corr_score={corr_score:.4f}, orig_score={orig_score:.4f})"
178
+ )
179
+ result_words[i] = orig_w
180
+ changes_made += 1
181
+ continue
182
+
183
+ # Try BERT's own top predictions as alternatives
184
+ predictions = self._predict_top_k(corrected_text, i, top_k=5)
185
+
186
+ for pred_word, pred_score in predictions:
187
+ if pred_word == corr_w or pred_word == orig_w:
188
+ continue
189
+
190
+ # Must be IV
191
+ if vocab_manager and not vocab_manager.is_iv(pred_word):
192
+ continue
193
+
194
+ # Must be similar to the original (not a random word)
195
+ similarity = self._similarity(orig_w, pred_word)
196
+ if similarity < similarity_threshold:
197
+ continue
198
+
199
+ # Must have strong BERT confidence
200
+ if pred_score < min_pred_score:
201
+ continue
202
+
203
+ # Must be a big improvement
204
+ if pred_score > corr_score * 50 and pred_score > 0.2:
205
+ logger.info(
206
+ f"[MLM] Replacing with BERT prediction: '{corr_w}'→'{pred_word}' "
207
+ f"(pred_score={pred_score:.4f}, corr_score={corr_score:.4f})"
208
+ )
209
+ result_words[i] = pred_word
210
+ changes_made += 1
211
+ break
212
+
213
+ if changes_made:
214
+ logger.info(f"[MLM] Contextual validation: {changes_made} corrections adjusted")
215
+
216
+ return ' '.join(result_words)
217
+
218
+ def _predict_top_k(self, text: str, position: int, top_k: int = 5) -> List[Tuple[str, float]]:
219
+ """Predict top-k words for a masked position."""
220
+ words = text.split()
221
+ if position >= len(words):
222
+ return []
223
+
224
+ masked_words = words.copy()
225
+ masked_words[position] = '[MASK]'
226
+ masked_text = ' '.join(masked_words)
227
+
228
+ try:
229
+ inputs = self.tokenizer(
230
+ masked_text, return_tensors='pt',
231
+ padding=True, truncation=True, max_length=128
232
+ ).to(self.device)
233
+
234
+ with torch.no_grad():
235
+ outputs = self.model(**inputs)
236
+
237
+ mask_idx = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
238
+ if len(mask_idx) == 0:
239
+ return []
240
+
241
+ logits = outputs.logits[0, mask_idx[0], :]
242
+ probs = torch.softmax(logits, dim=0)
243
+ top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
244
+
245
+ results = []
246
+ for j in range(top_k):
247
+ token_id = top_k_indices[j].item()
248
+ score = top_k_weights[j].item()
249
+ token = self.tokenizer.decode([token_id]).strip()
250
+ # Skip subword tokens and special tokens
251
+ if not token.startswith("##") and token not in self.tokenizer.all_special_tokens:
252
+ results.append((token, score))
253
+
254
+ return results
255
+
256
+ except Exception as e:
257
+ logger.warning(f"[MLM] Prediction error: {e}")
258
+ return []
259
+
260
+ @staticmethod
261
+ def _similarity(a: str, b: str) -> float:
262
+ """Calculate normalized Levenshtein similarity between two strings."""
263
+ if not a or not b:
264
+ return 0.0
265
+ max_len = max(len(a), len(b))
266
+ if max_len == 0:
267
+ return 1.0
268
+ # Inline Levenshtein to avoid extra dependency
269
+ m, n = len(a), len(b)
270
+ dp = list(range(n + 1))
271
+ for i in range(1, m + 1):
272
+ prev = dp[0]
273
+ dp[0] = i
274
+ for j in range(1, n + 1):
275
+ temp = dp[j]
276
+ if a[i-1] == b[j-1]:
277
+ dp[j] = prev
278
+ else:
279
+ dp[j] = 1 + min(prev, dp[j], dp[j-1])
280
+ prev = temp
281
+ dist = dp[n]
282
+ return 1.0 - (dist / max_len)
283
+
284
+
285
+ def get_contextual_corrector() -> Optional[ContextualCorrector]:
286
+ """Get or create the singleton ContextualCorrector instance.
287
+
288
+ Returns None if loading fails (graceful degradation).
289
+ """
290
+ global _instance, _loading
291
+
292
+ if _instance is not None:
293
+ return _instance
294
+
295
+ if _loading:
296
+ return None # Prevent recursive loading
297
+
298
+ _loading = True
299
+ try:
300
+ _instance = ContextualCorrector()
301
+ return _instance
302
+ except Exception as e:
303
+ logger.warning(f"[MLM] Failed to load contextual corrector: {e}")
304
+ return None
305
+ finally:
306
+ _loading = False
307
+
308
+
309
+ def is_loaded() -> bool:
310
+ """Check if the contextual corrector is loaded."""
311
+ return _instance is not None