recitation-segmenter-app-API / arabic_aligner.py
aboalaa147's picture
Create arabic_aligner.py
b0b0b0f verified
import re
from typing import List, Tuple, Dict
from dataclasses import dataclass
from enum import Enum
class ErrorType(Enum):
MATCH = "match"
SUBSTITUTION = "substitution"
INSERTION = "insertion"
DELETION = "deletion"
DIACRITIC_ERROR = "diacritic_error"
@dataclass
class AlignmentError:
error_type: ErrorType
position: int
user_word: str
reference_word: str
details: str = ""
class ArabicAligner:
# Arabic diacritics
DIACRITICS = '\u064B\u064C\u064D\u064E\u064F\u0650\u0651\u0652\u0653\u0654\u0655\u0656\u0657\u0658'
DIACRITIC_PATTERN = f'[{DIACRITICS}]'
def __init__(self):
self.alignment_matrix = None
self.backtrack_matrix = None
def normalize_text(self, text: str) -> str:
"""Normalize Arabic text: remove extra spaces, normalize characters"""
# Remove tatweel (ـ)
text = text.replace('\u0640', '')
# Normalize Alef variations to plain Alef
text = re.sub('[إأآٱ]', 'ا', text)
# Normalize Hamza variations
text = re.sub('[ؤئ]', 'ء', text)
# Normalize Teh Marbuta
text = re.sub('ة', 'ه', text)
# Remove extra whitespace
text = ' '.join(text.split())
return text.strip()
def remove_diacritics(self, text: str) -> str:
"""Remove all diacritics from Arabic text"""
return re.sub(self.DIACRITIC_PATTERN, '', text)
def extract_diacritics(self, word: str) -> List[Tuple[int, str]]:
"""Extract diacritics and their positions from a word"""
diacritics = []
pos = 0
for i, char in enumerate(word):
if char in self.DIACRITICS:
diacritics.append((pos, char))
else:
pos += 1
return diacritics
def tokenize(self, text: str) -> List[str]:
"""Tokenize text into words"""
# Split by whitespace and punctuation
words = re.findall(r'[\w\u0600-\u06FF]+', text)
return [w for w in words if w.strip()]
def compute_alignment(self, user_words: List[str], ref_words: List[str]) -> Tuple[List[List[int]], List[List[str]]]:
"""
Compute word-level alignment using dynamic programming (edit distance).
Returns the cost matrix and backtrack matrix.
"""
m, n = len(user_words), len(ref_words)
# Initialize matrices
dp = [[0] * (n + 1) for _ in range(m + 1)]
backtrack = [['' for _ in range(n + 1)] for _ in range(m + 1)]
# Initialize base cases
for i in range(m + 1):
dp[i][0] = i
if i > 0:
backtrack[i][0] = 'INS'
for j in range(n + 1):
dp[0][j] = j
if j > 0:
backtrack[0][j] = 'DEL'
backtrack[0][0] = ''
# Fill the DP table
for i in range(1, m + 1):
for j in range(1, n + 1):
# Remove diacritics for comparison
user_clean = self.remove_diacritics(user_words[i-1])
ref_clean = self.remove_diacritics(ref_words[j-1])
if user_clean == ref_clean:
# Match (cost 0)
dp[i][j] = dp[i-1][j-1]
backtrack[i][j] = 'MATCH'
else:
# Substitution
subst_cost = dp[i-1][j-1] + 1
# Deletion from reference
del_cost = dp[i][j-1] + 1
# Insertion to user
ins_cost = dp[i-1][j] + 1
min_cost = min(subst_cost, del_cost, ins_cost)
dp[i][j] = min_cost
if min_cost == subst_cost:
backtrack[i][j] = 'SUBST'
elif min_cost == del_cost:
backtrack[i][j] = 'DEL'
else:
backtrack[i][j] = 'INS'
self.alignment_matrix = dp
self.backtrack_matrix = backtrack
return dp, backtrack
def traceback_alignment(self, user_words: List[str], ref_words: List[str]) -> List[Tuple[str, int, int]]:
"""
Traceback through the alignment to get aligned pairs.
Returns list of (operation, user_idx, ref_idx) tuples.
"""
if self.backtrack_matrix is None:
raise ValueError("Must call compute_alignment first")
alignments = []
i, j = len(user_words), len(ref_words)
while i > 0 or j > 0:
operation = self.backtrack_matrix[i][j]
if operation == 'MATCH':
alignments.append(('MATCH', i-1, j-1))
i -= 1
j -= 1
elif operation == 'SUBST':
alignments.append(('SUBST', i-1, j-1))
i -= 1
j -= 1
elif operation == 'DEL':
alignments.append(('DEL', -1, j-1))
j -= 1
elif operation == 'INS':
alignments.append(('INS', i-1, -1))
i -= 1
return list(reversed(alignments))
def compare_diacritics(self, user_word: str, ref_word: str) -> Tuple[bool, str]:
"""
Compare diacritics between two words (after confirming base match).
Returns (is_match, details_string)
"""
user_clean = self.remove_diacritics(user_word)
ref_clean = self.remove_diacritics(ref_word)
if user_clean != ref_clean:
return False, "Base words don't match"
user_diacs = self.extract_diacritics(user_word)
ref_diacs = self.extract_diacritics(ref_word)
if user_diacs == ref_diacs:
return True, "Perfect match"
# Detailed comparison
user_dict = {pos: diac for pos, diac in user_diacs}
ref_dict = {pos: diac for pos, diac in ref_diacs}
errors = []
all_positions = sorted(set(user_dict.keys()) | set(ref_dict.keys()))
for pos in all_positions:
if pos in user_dict and pos not in ref_dict:
errors.append(f"Extra diacritic '{user_dict[pos]}' at position {pos}")
elif pos not in user_dict and pos in ref_dict:
errors.append(f"Missing diacritic '{ref_dict[pos]}' at position {pos}")
elif user_dict[pos] != ref_dict[pos]:
errors.append(f"Wrong diacritic at position {pos}: '{user_dict[pos]}' should be '{ref_dict[pos]}'")
return False, "; ".join(errors)
def align_and_compare(self, user_text: str, reference_text: str) -> Dict:
"""
Main function: align texts and detect all errors.
"""
# Step 1: Normalize
user_normalized = self.normalize_text(user_text)
ref_normalized = self.normalize_text(reference_text)
# Step 2: Tokenize
user_words = self.tokenize(user_normalized)
ref_words = self.tokenize(ref_normalized)
# Step 3: Compute alignment
dp, backtrack = self.compute_alignment(user_words, ref_words)
# Step 4: Traceback and identify errors
alignments = self.traceback_alignment(user_words, ref_words)
errors = []
ref_position = 0
for operation, user_idx, ref_idx in alignments:
if operation == 'MATCH':
# Check diacritics for matched words
user_word = user_words[user_idx]
ref_word = ref_words[ref_idx]
is_match, details = self.compare_diacritics(user_word, ref_word)
if is_match:
errors.append(AlignmentError(
error_type=ErrorType.MATCH,
position=ref_position,
user_word=user_word,
reference_word=ref_word,
details="Perfect match"
))
else:
errors.append(AlignmentError(
error_type=ErrorType.DIACRITIC_ERROR,
position=ref_position,
user_word=user_word,
reference_word=ref_word,
details=details
))
ref_position += 1
elif operation == 'SUBST':
errors.append(AlignmentError(
error_type=ErrorType.SUBSTITUTION,
position=ref_position,
user_word=user_words[user_idx],
reference_word=ref_words[ref_idx],
details=f"Word substituted"
))
ref_position += 1
elif operation == 'DEL':
errors.append(AlignmentError(
error_type=ErrorType.DELETION,
position=ref_position,
user_word="",
reference_word=ref_words[ref_idx],
details=f"Word deleted from user text"
))
ref_position += 1
elif operation == 'INS':
errors.append(AlignmentError(
error_type=ErrorType.INSERTION,
position=ref_position,
user_word=user_words[user_idx],
reference_word="",
details=f"Word inserted in user text"
))
# Compile results
total_errors = sum(1 for e in errors if e.error_type != ErrorType.MATCH)
diacritic_errors = sum(1 for e in errors if e.error_type == ErrorType.DIACRITIC_ERROR)
word_errors = sum(1 for e in errors if e.error_type in [ErrorType.SUBSTITUTION, ErrorType.INSERTION, ErrorType.DELETION])
return {
'user_words': user_words,
'reference_words': ref_words,
'alignments': alignments,
'errors': errors,
'edit_distance': dp[-1][-1],
'statistics': {
'total_reference_words': len(ref_words),
'total_user_words': len(user_words),
'total_errors': total_errors,
'word_level_errors': word_errors,
'diacritic_errors': diacritic_errors,
'accuracy': (len(ref_words) - total_errors) / len(ref_words) * 100 if ref_words else 0
}
}
def print_results(self, results: Dict):
"""Print formatted results"""
print("=" * 80)
print("ARABIC TEXT ALIGNMENT ANALYSIS")
print("=" * 80)
print(f"\nUser Text Words: {len(results['user_words'])}")
print(f"Reference Text Words: {len(results['reference_words'])}")
print(f"Edit Distance: {results['edit_distance']}")
print("\n" + "-" * 80)
print("STATISTICS")
print("-" * 80)
stats = results['statistics']
print(f"Total Errors: {stats['total_errors']}")
print(f" - Word-level Errors: {stats['word_level_errors']}")
print(f" - Diacritic Errors: {stats['diacritic_errors']}")
print(f"Accuracy: {stats['accuracy']:.2f}%")
print("\n" + "-" * 80)
print("DETAILED ERRORS")
print("-" * 80)
for i, error in enumerate(results['errors'], 1):
if error.error_type == ErrorType.MATCH:
continue # Skip perfect matches in detailed output
print(f"\n[{i}] Position: {error.position}")
print(f" Type: {error.error_type.value.upper()}")
if error.error_type == ErrorType.INSERTION:
print(f" User: '{error.user_word}' (extra word)")
print(f" Expected: [nothing]")
elif error.error_type == ErrorType.DELETION:
print(f" User: [missing]")
print(f" Expected: '{error.reference_word}'")
elif error.error_type == ErrorType.SUBSTITUTION:
print(f" User: '{error.user_word}'")
print(f" Expected: '{error.reference_word}'")
elif error.error_type == ErrorType.DIACRITIC_ERROR:
print(f" User: '{error.user_word}'")
print(f" Expected: '{error.reference_word}'")
print(f" Details: {error.details}")