rag-api-node-1 / src /infrastructure /adapters /entity_extractor.py
Peterase's picture
feat: Add query enhancements and flexible prompting (v2.1)
6246bba
"""
Named Entity Recognition (NER) Extractor
Extracts entities from queries:
- Locations (Ethiopia, Addis Ababa, Tigray)
- Organizations (BBC, Al Jazeera, UN)
- Persons (Abiy Ahmed, etc.)
- Dates (today, yesterday, May 2026)
Uses lightweight spaCy model for fast extraction (<10ms).
"""
import logging
import re
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
import threading
logger = logging.getLogger(__name__)
@dataclass
class ExtractedEntities:
"""Extracted entities from query"""
locations: List[str]
organizations: List[str]
persons: List[str]
dates: List[str]
temporal_keywords: List[str]
source_keywords: List[str]
raw_entities: List[Dict[str, Any]]
class EntityExtractor:
"""
Extract named entities from queries using spaCy.
Features:
- Fast extraction (<10ms)
- Lazy loading (only loads when first used)
- Thread-safe
- Caching support
"""
# Known news sources for better extraction
NEWS_SOURCES = {
"bbc", "al jazeera", "aljazeera", "reuters", "cnn", "guardian",
"the guardian", "financial times", "ft", "new york times", "nyt",
"washington post", "wapo", "associated press", "ap", "afp",
"dw", "deutsche welle", "france24", "africanews", "allaf rica",
"financial afrik", "africa news"
}
# Temporal keywords
TEMPORAL_KEYWORDS = {
"today", "yesterday", "tomorrow", "tonight", "now", "currently",
"latest", "breaking", "recent", "just", "this morning", "this evening",
"this week", "this month", "this year", "last week", "last month",
"last year", "past", "ago"
}
# Ethiopian locations for better recognition
ETHIOPIAN_LOCATIONS = {
"ethiopia", "addis ababa", "addis", "tigray", "amhara", "oromia",
"oromo", "afar", "somali", "sidama", "snnpr", "gambela", "harari",
"dire dawa", "bahir dar", "mekelle", "gondar", "hawassa", "jimma",
"gonder", "dessie", "harar"
}
def __init__(self, cache=None):
"""
Initialize entity extractor.
Args:
cache: Cache adapter for storing extractions
"""
self._nlp = None
self._lock = threading.Lock()
self._load_failed = False
self.cache = cache
def _load(self):
"""Lazy load spaCy model (thread-safe)"""
if self._nlp is not None or self._load_failed:
return
with self._lock:
if self._nlp is not None or self._load_failed:
return
try:
import spacy
# Try to load small English model
try:
self._nlp = spacy.load("en_core_web_sm")
logger.info("βœ… Loaded spaCy en_core_web_sm model")
except OSError:
# Model not installed, use blank model with basic NER
logger.warning("spaCy model not found, using pattern-based extraction")
self._nlp = None
self._load_failed = True
except ImportError:
logger.warning("spaCy not installed, using pattern-based extraction")
self._nlp = None
self._load_failed = True
def extract(self, query: str) -> ExtractedEntities:
"""
Extract entities from query.
Args:
query: User query
Returns:
ExtractedEntities with all extracted information
"""
# Check cache first
if self.cache:
cache_key = f"entity_extraction:{query.lower()}"
cached = self.cache.get(cache_key)
if cached:
logger.debug(f"Entity extraction cache hit: {query}")
return ExtractedEntities(**cached)
# Try spaCy extraction first
self._load()
if self._nlp:
result = self._extract_with_spacy(query)
else:
# Fallback to pattern-based extraction
result = self._extract_with_patterns(query)
# Cache result
if self.cache:
cache_key = f"entity_extraction:{query.lower()}"
self.cache.set(
cache_key,
{
"locations": result.locations,
"organizations": result.organizations,
"persons": result.persons,
"dates": result.dates,
"temporal_keywords": result.temporal_keywords,
"source_keywords": result.source_keywords,
"raw_entities": result.raw_entities
},
expiration=3600 # 1 hour
)
return result
def _extract_with_spacy(self, query: str) -> ExtractedEntities:
"""Extract entities using spaCy NER"""
doc = self._nlp(query)
locations = []
organizations = []
persons = []
dates = []
raw_entities = []
for ent in doc.ents:
entity_info = {
"text": ent.text,
"label": ent.label_,
"start": ent.start_char,
"end": ent.end_char
}
raw_entities.append(entity_info)
if ent.label_ in ["GPE", "LOC"]: # Geopolitical entity or location
locations.append(ent.text)
elif ent.label_ == "ORG": # Organization
organizations.append(ent.text)
elif ent.label_ == "PERSON": # Person
persons.append(ent.text)
elif ent.label_ == "DATE": # Date
dates.append(ent.text)
# Add pattern-based extraction to supplement spaCy
pattern_result = self._extract_with_patterns(query)
# Merge results (deduplicate)
locations = list(set(locations + pattern_result.locations))
organizations = list(set(organizations + pattern_result.organizations))
persons = list(set(persons + pattern_result.persons))
dates = list(set(dates + pattern_result.dates))
return ExtractedEntities(
locations=locations,
organizations=organizations,
persons=persons,
dates=dates,
temporal_keywords=pattern_result.temporal_keywords,
source_keywords=pattern_result.source_keywords,
raw_entities=raw_entities
)
def _extract_with_patterns(self, query: str) -> ExtractedEntities:
"""Extract entities using regex patterns (fallback)"""
query_lower = query.lower()
# Extract locations
locations = []
for loc in self.ETHIOPIAN_LOCATIONS:
if loc in query_lower:
locations.append(loc.title())
# Extract organizations (news sources)
organizations = []
source_keywords = []
for source in self.NEWS_SOURCES:
if source in query_lower:
organizations.append(source.title())
source_keywords.append(source)
# Extract temporal keywords
temporal_keywords = []
for keyword in self.TEMPORAL_KEYWORDS:
if keyword in query_lower:
temporal_keywords.append(keyword)
# Extract dates using patterns
dates = []
# Pattern: "May 2026", "April 30", etc.
date_pattern = r'\b(january|february|march|april|may|june|july|august|september|october|november|december)\s+\d{1,2}(?:,?\s+\d{4})?\b'
date_matches = re.findall(date_pattern, query_lower, re.IGNORECASE)
dates.extend(date_matches)
# Pattern: "2026-05-03", "2026/05/03"
iso_pattern = r'\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b'
iso_matches = re.findall(iso_pattern, query)
dates.extend(iso_matches)
# Pattern: "3 days ago", "2 weeks ago"
relative_pattern = r'\b\d+\s+(day|days|week|weeks|month|months|year|years)\s+ago\b'
relative_matches = re.findall(relative_pattern, query_lower)
dates.extend([' '.join(m) for m in relative_matches])
return ExtractedEntities(
locations=list(set(locations)),
organizations=list(set(organizations)),
persons=[], # Pattern-based person extraction is unreliable
dates=list(set(dates)),
temporal_keywords=list(set(temporal_keywords)),
source_keywords=list(set(source_keywords)),
raw_entities=[]
)
def get_source_filter(self, entities: ExtractedEntities) -> Optional[str]:
"""
Get source filter from extracted entities.
Returns:
Source name if found, None otherwise
"""
if entities.source_keywords:
# Return first source keyword
return entities.source_keywords[0]
if entities.organizations:
# Check if any organization is a known news source
for org in entities.organizations:
org_lower = org.lower()
if org_lower in self.NEWS_SOURCES:
return org_lower
return None
def get_location_filter(self, entities: ExtractedEntities) -> Optional[str]:
"""
Get location filter from extracted entities.
Returns:
Location name if found, None otherwise
"""
if entities.locations:
# Return first location
return entities.locations[0]
return None
def has_temporal_context(self, entities: ExtractedEntities) -> bool:
"""Check if query has temporal context"""
return len(entities.temporal_keywords) > 0 or len(entities.dates) > 0
# ═══════════════════════════════════════════════════════════════════════════
# SINGLETON INSTANCE
# ═══════════════════════════════════════════════════════════════════════════
# Will be initialized with dependencies in main.py
entity_extractor: Optional[EntityExtractor] = None
def initialize_entity_extractor(cache=None):
"""Initialize global entity extractor instance"""
global entity_extractor
entity_extractor = EntityExtractor(cache)
logger.info("Entity extractor initialized")