RenAI / utils /postprocessing.py
Arsh124's picture
Initial RenAI app
ebcc7d1
import re
import unicodedata
from collections import defaultdict
from typing import List, Tuple, Dict, Set
import heapq
from loguru import logger
class SpanishFuzzyMatcher:
def __init__(self, dictionary_path: str):
self.dictionary = set()
self.word_by_length = defaultdict(list)
self.ngram_index = defaultdict(set)
self.common_words = set()
self._load_dictionary(dictionary_path)
self._build_indexes()
self._load_common_words()
def _detect_encoding(self, path: str) -> str:
encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252', 'utf-16']
for encoding in encodings:
try:
with open(path, 'r', encoding=encoding) as f:
f.read(1024) # Try to read first 1KB
return encoding
except (UnicodeDecodeError, UnicodeError):
continue
return 'utf-8'
def _load_dictionary(self, path: str):
try:
encoding = self._detect_encoding(path)
print(f"Detected encoding: {encoding}")
with open(path, 'r', encoding=encoding, errors='ignore') as f:
for line_num, line in enumerate(f, 1):
try:
word = line.strip().lower()
if word and len(word) > 1:
# Remove any non-alphabetic characters except hyphens and apostrophes
cleaned_word = re.sub(r"[^a-záéíóúüñç\-']", "", word)
if cleaned_word and len(cleaned_word) > 1:
self.dictionary.add(cleaned_word)
self.word_by_length[len(cleaned_word)].append(cleaned_word)
except Exception as e:
print(f"Warning: Skipping line {line_num} due to error: {e}")
continue
print(f"Loaded {len(self.dictionary)} words from dictionary")
except FileNotFoundError:
raise FileNotFoundError(f"Dictionary file not found: {path}")
except Exception as e:
raise Exception(f"Error loading dictionary: {e}")
def _load_common_words(self):
common_spanish = {
'el', 'la', 'de', 'que', 'y', 'a', 'en', 'un', 'es', 'se', 'no', 'te', 'lo', 'le', 'da', 'su', 'por', 'son', 'con', 'para', 'al', 'las', 'del', 'los', 'una', 'mi', 'muy', 'mas', 'me', 'si', 'ya', 'todo', 'como', 'pero', 'hay', 'o', 'cuando', 'esta', 'ser', 'tiene', 'estar', 'hacer', 'sobre', 'entre', 'poder', 'antes', 'tiempo', 'año', 'casa', 'día', 'vida', 'trabajo', 'hombre', 'mujer', 'mundo', 'parte', 'momento', 'lugar', 'país', 'forma', 'manera', 'estado', 'caso', 'grupo', 'agua', 'punto', 'vez', 'donde', 'quien', 'haber', 'tener', 'hacer', 'decir', 'ir', 'ver', 'dar', 'saber', 'querer', 'llegar', 'pasar', 'deber', 'poner', 'parecer', 'quedar', 'creer', 'hablar', 'llevar', 'dejar', 'seguir', 'encontrar', 'llamar', 'venir', 'pensar', 'salir', 'volver', 'tomar', 'conocer', 'vivir', 'sentir', 'tratar', 'mirar', 'contar', 'empezar', 'esperar', 'buscar', 'existir', 'entrar', 'trabajar', 'escribir', 'perder', 'producir', 'ocurrir', 'entender', 'pedir', 'recibir', 'recordar', 'terminar', 'permitir', 'aparecer', 'conseguir', 'comenzar', 'servir', 'sacar', 'necesitar', 'mantener', 'resultar', 'leer', 'caer', 'cambiar', 'presentar', 'crear', 'abrir', 'considerar', 'oír', 'acabar', 'convertir', 'ganar', 'traer', 'realizar', 'suponer', 'comprender', 'explicar', 'dedicar', 'andar', 'estudiar', 'mano', 'cabeza', 'ojo', 'cara', 'pie', 'corazón', 'vez', 'palabra', 'número', 'color', 'mesa', 'silla', 'libro', 'papel', 'coche', 'calle', 'puerta', 'ventana', 'ciudad', 'pueblo', 'escuela', 'hospital', 'iglesia', 'tienda', 'mercado', 'banco', 'hotel', 'restaurante', 'café', 'bar', 'teatro', 'cine', 'museo', 'parque', 'jardín', 'playa', 'montaña', 'río', 'mar', 'lago', 'bosque', 'árbol', 'flor', 'animal', 'perro', 'gato', 'pájaro', 'pez', 'comida', 'pan', 'carne', 'pollo', 'pescado', 'leche', 'huevo', 'queso', 'fruta', 'verdura', 'patata', 'tomate', 'cebolla', 'ajo', 'sal', 'azúcar', 'aceite', 'vino', 'cerveza', 'café', 'té', 'agua', 'fuego', 'aire', 'tierra', 'sol', 'luna', 'estrella', 'nube', 'lluvia', 'nieve', 'viento', 'calor', 'frío', 'luz', 'sombra', 'mañana', 'tarde', 'noche', 'hoy', 'ayer', 'mañana', 'semana', 'mes', 'año', 'hora', 'minuto', 'segundo', 'lunes', 'martes', 'miércoles', 'jueves', 'viernes', 'sábado', 'domingo', 'enero', 'febrero', 'marzo', 'abril', 'mayo', 'junio', 'julio', 'agosto', 'septiembre', 'octubre', 'noviembre', 'diciembre', 'primavera', 'verano', 'otoño', 'invierno', 'bueno', 'malo', 'grande', 'pequeño', 'alto', 'bajo', 'largo', 'corto', 'ancho', 'estrecho', 'grueso', 'delgado', 'fuerte', 'débil', 'rápido', 'lento', 'fácil', 'difícil', 'nuevo', 'viejo', 'joven', 'mayor', 'blanco', 'negro', 'rojo', 'azul', 'verde', 'amarillo', 'gris', 'marrón', 'rosa', 'naranja', 'morado', 'feliz', 'triste', 'contento', 'enfadado', 'cansado', 'aburrido', 'interesante', 'divertido', 'importante', 'necesario', 'posible', 'imposible', 'seguro', 'peligroso', 'rico', 'pobre', 'caro', 'barato', 'limpio', 'sucio', 'sano', 'enfermo', 'vivo', 'muerto', 'lleno', 'vacío', 'abierto', 'cerrado', 'caliente', 'frío', 'seco', 'mojado', 'duro', 'blando', 'suave', 'áspero', 'dulce', 'amargo', 'salado', 'picante', 'conocerte', 'tengas'
}
self.common_words = {word for word in common_spanish if word in self.dictionary}
print(f"Loaded {len(self.common_words)} common words")
def _is_common_spanish_error(self, ocr_word: str, dict_word: str) -> bool:
ocr_lower = ocr_word.lower()
dict_lower = dict_word.lower()
# Common OCR confusions in Spanish
ocr_substitutions = {
'b': 'v', 'v': 'b', # b/v confusion
'c': 's', 's': 'c', # c/s confusion
'z': 's', 's': 'z', # z/s confusion
'j': 'g', 'g': 'j', # j/g confusion
'y': 'i', 'i': 'y', # y/i confusion
'u': 'n', 'n': 'u', # u/n confusion (handwriting)
'll': 'y', 'y': 'll', # ll/y confusion
'ñ': 'n', 'n': 'ñ', # ñ/n confusion
}
if len(ocr_lower) == len(dict_lower):
diff_count = sum(1 for a, b in zip(ocr_lower, dict_lower) if a != b)
if diff_count == 1:
for i, (a, b) in enumerate(zip(ocr_lower, dict_lower)):
if a != b:
return a in ocr_substitutions and ocr_substitutions[a] == b
return False
def _build_indexes(self):
for word in self.dictionary:
padded_word = f"${word}$"
for i in range(len(padded_word) - 2):
trigram = padded_word[i:i+3]
self.ngram_index[trigram].add(word)
def _normalize_text(self, text: str) -> str:
text = unicodedata.normalize('NFD', text)
text = ''.join(c for c in text if unicodedata.category(c) != 'Mn')
return text.lower()
def _levenshtein_distance(self, s1: str, s2: str) -> int:
if len(s1) < len(s2):
return self._levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def _damerau_levenshtein_distance(self, s1: str, s2: str) -> int:
len1, len2 = len(s1), len(s2)
da = {}
for char in s1 + s2:
if char not in da:
da[char] = 0
max_dist = len1 + len2
h = [[max_dist for _ in range(len2 + 2)] for _ in range(len1 + 2)]
h[0][0] = max_dist
for i in range(0, len1 + 1):
h[i + 1][0] = max_dist
h[i + 1][1] = i
for j in range(0, len2 + 1):
h[0][j + 1] = max_dist
h[1][j + 1] = j
for i in range(1, len1 + 1):
db = 0
for j in range(1, len2 + 1):
k = da[s2[j - 1]]
l = db
if s1[i - 1] == s2[j - 1]:
cost = 0
db = j
else:
cost = 1
h[i + 1][j + 1] = min(
h[i][j] + cost, # substitution
h[i + 1][j] + 1, # insertion
h[i][j + 1] + 1, # deletion
h[k][l] + (i - k - 1) + 1 + (j - l - 1) # transposition
)
da[s1[i - 1]] = i
return h[len1 + 1][len2 + 1]
def _jaro_winkler_similarity(self, s1: str, s2: str) -> float:
def jaro_similarity(s1: str, s2: str) -> float:
if s1 == s2:
return 1.0
len1, len2 = len(s1), len(s2)
if len1 == 0 or len2 == 0:
return 0.0
match_window = max(len1, len2) // 2 - 1
if match_window < 0:
match_window = 0
s1_matches = [False] * len1
s2_matches = [False] * len2
matches = 0
transpositions = 0
for i in range(len1):
start = max(0, i - match_window)
end = min(i + match_window + 1, len2)
for j in range(start, end):
if s2_matches[j] or s1[i] != s2[j]:
continue
s1_matches[i] = s2_matches[j] = True
matches += 1
break
if matches == 0:
return 0.0
k = 0
for i in range(len1):
if not s1_matches[i]:
continue
while not s2_matches[k]:
k += 1
if s1[i] != s2[k]:
transpositions += 1
k += 1
jaro = (matches / len1 + matches / len2 +
(matches - transpositions / 2) / matches) / 3
return jaro
jaro = jaro_similarity(s1, s2)
prefix_len = 0
for i in range(min(len(s1), len(s2), 4)):
if s1[i] == s2[i]:
prefix_len += 1
else:
break
return jaro + (0.1 * prefix_len * (1 - jaro))
def _get_candidates(self, word: str, max_candidates: int = 200) -> Set[str]:
candidates = set()
word_len = len(word)
common_candidates = set()
for common_word in self.common_words:
if abs(len(common_word) - word_len) <= 2:
common_candidates.add(common_word)
candidates.update(common_candidates)
for length in range(max(1, word_len - 2), word_len + 3):
length_words = self.word_by_length[length]
# Sort by length (shorter words first) and limit
sorted_words = sorted(length_words, key=len)[:max_candidates//3]
candidates.update(sorted_words)
padded_word = f"${word}$"
trigram_candidates = set()
trigram_scores = defaultdict(int)
for i in range(len(padded_word) - 2):
trigram = padded_word[i:i+3]
if trigram in self.ngram_index:
for candidate in self.ngram_index[trigram]:
trigram_scores[candidate] += 1
sorted_trigram = sorted(trigram_scores.items(), key=lambda x: x[1], reverse=True)
trigram_candidates = {word for word, score in sorted_trigram[:max_candidates//2]}
candidates.update(trigram_candidates)
return candidates
def _calculate_composite_score(self, word1: str, word2: str) -> float:
norm_word1 = self._normalize_text(word1)
norm_word2 = self._normalize_text(word2)
levenshtein = self._levenshtein_distance(norm_word1, norm_word2)
damerau = self._damerau_levenshtein_distance(norm_word1, norm_word2)
jaro_winkler = self._jaro_winkler_similarity(norm_word1, norm_word2)
max_len = max(len(norm_word1), len(norm_word2))
if max_len == 0:
return 1.0
levenshtein_sim = 1 - (levenshtein / max_len)
damerau_sim = 1 - (damerau / max_len)
length_diff = abs(len(norm_word1) - len(norm_word2))
length_penalty = 1 - (length_diff / max(len(norm_word1), len(norm_word2)))
frequency_bonus = 1.0
if norm_word2 in self.common_words:
frequency_bonus = 1.3
spanish_error_bonus = 1.0
if self._is_common_spanish_error(word1, word2):
spanish_error_bonus = 1.2
exact_length_bonus = 1.0
if len(norm_word1) == len(norm_word2):
exact_length_bonus = 1.1
base_score = (
0.25 * levenshtein_sim +
0.45 * damerau_sim +
0.25 * jaro_winkler +
0.05 * length_penalty
)
final_score = base_score * frequency_bonus * spanish_error_bonus * exact_length_bonus
return min(final_score, 1.0)
def find_best_matches(self, word: str, top_k: int = 5, threshold: float = 0.4) -> List[Tuple[str, float]]:
if not word or len(word) < 2:
return []
normalized_word = self._normalize_text(word)
if normalized_word in self.dictionary:
return [(word, 1.0)]
if word.lower() in self.dictionary:
return [(word.lower(), 1.0)]
candidates = self._get_candidates(normalized_word)
scored_matches = []
for candidate in candidates:
score = self._calculate_composite_score(word, candidate)
if score >= threshold:
heapq.heappush(scored_matches, (-score, candidate, score))
results = []
seen_words = set()
for _ in range(min(top_k, len(scored_matches))):
if scored_matches:
_, candidate, score = heapq.heappop(scored_matches)
if candidate not in seen_words:
results.append((candidate, score))
seen_words.add(candidate)
return results
def correct_sentence(self, sentence: str, confidence_threshold: float = 0.6) -> str:
words = re.findall(r'\b\w+\b|\W+', sentence)
corrected_words = []
for token in words:
if re.match(r'\b\w+\b', token):
matches = self.find_best_matches(token, top_k=1, threshold=0.3)
if matches and matches[0][1] >= confidence_threshold:
corrected_words.append(matches[0][0])
else:
corrected_words.append(token)
else:
corrected_words.append(token)
return ''.join(corrected_words)
def PostProcessing(ocr_sentence):
try:
logger.info("Post processing started......")
matcher = SpanishFuzzyMatcher('Diccionario.Espanol.136k.palabras.txt')
logger.info("Dictionary loaded successfully!")
corrected = matcher.correct_sentence(ocr_sentence, confidence_threshold=0.6)
logger.info("Post processing completed successfully!")
return corrected
except Exception as e:
print(e)
logger.error(f"Post processing failed: {e}")
return ocr_sentence