ara-route-parser / src /route_extractor.py
Amr-Yasserr's picture
Migrate to new arabic-speech-to-route-hf architecture
4f90754
"""
Route Extractor v2: NER + Gazetteer Hybrid Pipeline
Replaces the QA-based extraction with a proper NER architecture:
Pipeline:
Transcript → Normalizer → NER Model → Gazetteer Resolver → Validated Output
Components:
1. Text Normalizer: Hamza, diacritics, filler word removal
2. NER Model: EgyBERT/MARBERTv2 fine-tuned for B-ORIGIN/I-ORIGIN/B-DEST/I-DEST
3. Gazetteer: Multi-strategy location resolution (exact/normalized/fuzzy/phonetic)
4. Syntactic Fallback: Regex patterns as last resort
5. Confidence Scorer: Combined NER + gazetteer confidence
"""
import os
import re
import sys
# Ensure src directory is in path
_src_dir = os.path.dirname(os.path.abspath(__file__))
if _src_dir not in sys.path:
sys.path.insert(0, _src_dir)
from normalizer import normalize_arabic, remove_noise, full_preprocess, normalize_for_matching
from gazetteer import Gazetteer
# ============================================================
# Configuration
# ============================================================
NER_CONFIDENCE_THRESHOLD = 0.3 # Minimum NER confidence to trust extraction
GAZ_MIN_FUZZY_SCORE = 0.65 # Minimum fuzzy match score for gazetteer
COMBINED_CONFIDENCE_WEIGHTS = (0.6, 0.4) # (NER weight, Gazetteer weight)
# ============================================================
# Syntactic Fallback (from v1, refined)
# ============================================================
def syntactic_extraction(text: str) -> dict:
"""
Fallback: Egyptian Arabic syntactic pattern matching.
Only used when NER confidence is below threshold.
"""
text_norm = normalize_arabic(text)
# Pattern 0: انا في [Origin] عايز/رايح [Destination]
match_loc = re.search(
r'(?:انا\s+)?(?:في|واقف\s+في)\s+(.+?)\s+(?:(?:و\s*)?عايز\s+(?:اروح)?|(?:و\s*)?رايح|(?:و\s*)?محتاج)\s+(.+)',
text_norm
)
if match_loc:
return {'origin': match_loc.group(1).strip(), 'destination': match_loc.group(2).strip()}
# Pattern 1: من [Origin] MARKER [Destination]
match_word = re.search(
r'من\s+(.+?)\s+(?:و\s*)?(?:الي|الى|علي|على|رايح|رايحة|رايحه|اروح|اوصل|روح|هروح)\s+(.+)',
text_norm
)
if match_word:
return {'origin': match_word.group(1).strip(), 'destination': match_word.group(2).strip()}
# Pattern 1B: من [Origin] ل/لل[Destination]
match_pref = re.search(r'من\s+(.+?)\s+(?:و\s*)?(?:لل|لـ|ل)(\S+.*)', text_norm)
if match_pref:
return {'origin': match_pref.group(1).strip(), 'destination': match_pref.group(2).strip()}
# Pattern 2: MARKER [Destination] من [Origin]
match_rev = re.search(
r'(?:و\s*)?(?:رايح|رايحة|اروح|اوصل|عايز\s+اروح|خدني|وصلني|هروح|نازل)\s+(.+?)\s+من\s+(.+)',
text_norm
)
if match_rev:
return {'origin': match_rev.group(2).strip(), 'destination': match_rev.group(1).strip()}
# Pattern 3: من [Origin] [multi-word Destination]
match_from = re.search(r'من\s+(.+?)\s+(.+)', text_norm)
if match_from:
return {'origin': match_from.group(1).strip(), 'destination': match_from.group(2).strip()}
return None
# ============================================================
# Route Extractor v2
# ============================================================
class RouteExtractor:
"""
Production route extraction pipeline.
Combines NER model + Gazetteer + Syntactic fallback
for robust origin/destination extraction from Egyptian Arabic.
"""
def __init__(self, ner_model_path: str = None, locations_path: str = None,
use_onnx: bool = False):
"""
Initialize the route extractor.
Args:
ner_model_path: Path to fine-tuned NER model directory
locations_path: Path to locations.json for gazetteer
use_onnx: Use ONNX Runtime for NER inference
"""
self.ner = None
self.gazetteer = None
# Load NER model if available
if ner_model_path and os.path.exists(ner_model_path):
try:
from ner_model import NERModel
self.ner = NERModel(ner_model_path, use_onnx=use_onnx)
except Exception as e:
print(f"⚠ Failed to load NER model: {e}")
print(" Falling back to syntactic-only extraction")
# Load Gazetteer
if locations_path and os.path.exists(locations_path):
try:
self.gazetteer = Gazetteer(locations_path)
except Exception as e:
print(f"⚠ Failed to load Gazetteer: {e}")
else:
# Try default paths
candidates = [
os.path.join(os.path.dirname(_src_dir), 'data', 'locations.json'),
os.path.join(_src_dir, 'locations.json'),
]
for path in candidates:
if os.path.exists(path):
self.gazetteer = Gazetteer(path)
break
if not self.ner:
print("⚠ NER model not loaded — using syntactic extraction + gazetteer only")
def extract(self, text: str) -> dict:
"""
Extract origin and destination from Arabic routing text.
Returns:
{
'origin': str or None,
'origin_score': float,
'origin_canonical': str or None,
'destination': str or None,
'destination_score': float,
'destination_canonical': str or None,
'method': str, # 'ner', 'syntactic', 'gazetteer_only'
}
"""
if not text or not text.strip():
return self._empty_result()
# Step 1: Preprocess
# Normalize characters (hamza, ya, etc.)
normalized = normalize_arabic(text)
# Light noise removal: strip politeness fillers but KEEP routing markers
# These are words that interfere with extraction but are NOT routing signals
politeness_noise = [
"لو سمحت", "من فضلك", "يا ريس", "يا باشا", "يا حج",
"يا كابتن", "يا معلم", "يا عم", "الله يخليك",
"بسرعة", "دلوقتي", "بالراحة", "يعني", "كده", "بقي",
"طب", "يلا", "بقولك", "اسمع", "شوف",
]
light_cleaned = normalized
for noise in politeness_noise:
light_cleaned = light_cleaned.replace(noise, '')
light_cleaned = re.sub(r'\s+', ' ', light_cleaned).strip()
# Full noise removal for NER (also removes routing words)
preprocessed = remove_noise(normalized)
# Step 2: Try NER extraction
ner_result = None
if self.ner:
try:
ner_result = self.ner.predict(preprocessed)
except Exception as e:
print(f"⚠ NER prediction failed: {e}")
# Step 3: Check NER confidence
ner_origin = None
ner_dest = None
ner_origin_score = 0
ner_dest_score = 0
if ner_result:
ner_origin = ner_result.get('origin')
ner_dest = ner_result.get('destination')
ner_origin_score = ner_result.get('origin_score', 0)
ner_dest_score = ner_result.get('destination_score', 0)
# Step 4: If NER confidence is low, try syntactic fallback
# NOTE: Syntactic runs on NORMALIZED text (not noise-stripped) to see routing markers
method = 'ner'
if not ner_origin or not ner_dest or min(ner_origin_score, ner_dest_score) < NER_CONFIDENCE_THRESHOLD:
syn_result = syntactic_extraction(light_cleaned)
if syn_result:
method = 'syntactic'
if not ner_origin or ner_origin_score < NER_CONFIDENCE_THRESHOLD:
ner_origin = syn_result.get('origin')
ner_origin_score = 0.7 # Syntactic confidence
if not ner_dest or ner_dest_score < NER_CONFIDENCE_THRESHOLD:
ner_dest = syn_result.get('destination')
ner_dest_score = 0.7
# Step 5: Resolve through gazetteer
origin_canonical = None
dest_canonical = None
origin_gaz_score = 0
dest_gaz_score = 0
if self.gazetteer:
if ner_origin:
gaz_match = self.gazetteer.resolve(ner_origin, min_fuzzy_score=GAZ_MIN_FUZZY_SCORE)
if gaz_match:
origin_canonical = gaz_match['canonical_name']
origin_gaz_score = gaz_match['confidence']
if ner_dest:
gaz_match = self.gazetteer.resolve(ner_dest, min_fuzzy_score=GAZ_MIN_FUZZY_SCORE)
if gaz_match:
dest_canonical = gaz_match['canonical_name']
dest_gaz_score = gaz_match['confidence']
# Step 6: Compute combined confidence
w_ner, w_gaz = COMBINED_CONFIDENCE_WEIGHTS
origin_final_score = (w_ner * ner_origin_score + w_gaz * origin_gaz_score) if ner_origin else 0
dest_final_score = (w_ner * ner_dest_score + w_gaz * dest_gaz_score) if ner_dest else 0
# Use canonical name if available, otherwise raw NER output
origin_final = origin_canonical or ner_origin
dest_final = dest_canonical or ner_dest
return {
'origin': origin_final,
'origin_score': round(origin_final_score, 4),
'origin_canonical': origin_canonical,
'destination': dest_final,
'destination_score': round(dest_final_score, 4),
'destination_canonical': dest_canonical,
'method': method,
}
def _empty_result(self) -> dict:
return {
'origin': None, 'origin_score': 0, 'origin_canonical': None,
'destination': None, 'destination_score': 0, 'destination_canonical': None,
'method': 'none',
}
# ============================================================
# Backward-compatible API (drop-in replacement for v1)
# ============================================================
# Module-level instance (loaded on import)
_extractor = None
def _get_extractor():
"""Lazy-load the extractor singleton."""
global _extractor
if _extractor is None:
# Determine paths
project_root = os.path.dirname(_src_dir)
locations_path = os.path.join(project_root, 'data', 'locations.json')
# Try to find a trained NER model
ner_candidates = [
os.path.join(project_root, 'models', 'ner_best'),
os.path.join(project_root, 'models', 'egybert_best'),
os.path.join(project_root, 'models', 'marbert_best'),
]
ner_model_path = None
for candidate in ner_candidates:
if os.path.exists(candidate):
ner_model_path = candidate
break
# Check for ONNX version
use_onnx = False
onnx_candidates = [
os.path.join(project_root, 'models', 'egybert_onnx_int8'),
os.path.join(project_root, 'models', 'ner_onnx_int8'),
os.path.join(project_root, 'models', 'ner_onnx'),
]
for candidate in onnx_candidates:
if os.path.exists(candidate):
ner_model_path = candidate
use_onnx = True
break
_extractor = RouteExtractor(
ner_model_path=ner_model_path,
locations_path=locations_path,
use_onnx=use_onnx,
)
return _extractor
def advanced_extract_route(text: str) -> dict:
"""
Backward-compatible extraction function.
Drop-in replacement for the v1 QA-based extractor.
"""
extractor = _get_extractor()
result = extractor.extract(text)
# Return in v1-compatible format
return {
'origin': result['origin'],
'origin_score': result['origin_score'],
'destination': result['destination'],
'destination_score': result['destination_score'],
}
# ============================================================
# Quick Test
# ============================================================
if __name__ == '__main__':
print("\n" + "="*60)
print("Route Extractor v2 — Quick Test")
print("="*60)
# Test with gazetteer-only mode (no NER model yet)
test_cases = [
"من مدينة نصر الى المعادي",
"عايز اروح التجمع الخامس من الهرم",
"رايح الاسكندرية من القاهرة",
"انا في شبرا عايز اروح الدقي",
"خدني من المهندسين على الزمالك لو سمحت",
"من فضلك وصلني من رمسيس للمعادي يا باشا",
]
extractor = _get_extractor()
for text in test_cases:
result = extractor.extract(text)
print(f"\n📝 Input: {text}")
print(f" Origin: {result['origin']} (score: {result['origin_score']:.2f}, canonical: {result['origin_canonical']})")
print(f" Dest: {result['destination']} (score: {result['destination_score']:.2f}, canonical: {result['destination_canonical']})")
print(f" Method: {result['method']}")