Spaces:
Running
Running
File size: 10,871 Bytes
6246bba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 | """
Named Entity Recognition (NER) Extractor
Extracts entities from queries:
- Locations (Ethiopia, Addis Ababa, Tigray)
- Organizations (BBC, Al Jazeera, UN)
- Persons (Abiy Ahmed, etc.)
- Dates (today, yesterday, May 2026)
Uses lightweight spaCy model for fast extraction (<10ms).
"""
import logging
import re
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
import threading
logger = logging.getLogger(__name__)
@dataclass
class ExtractedEntities:
"""Extracted entities from query"""
locations: List[str]
organizations: List[str]
persons: List[str]
dates: List[str]
temporal_keywords: List[str]
source_keywords: List[str]
raw_entities: List[Dict[str, Any]]
class EntityExtractor:
"""
Extract named entities from queries using spaCy.
Features:
- Fast extraction (<10ms)
- Lazy loading (only loads when first used)
- Thread-safe
- Caching support
"""
# Known news sources for better extraction
NEWS_SOURCES = {
"bbc", "al jazeera", "aljazeera", "reuters", "cnn", "guardian",
"the guardian", "financial times", "ft", "new york times", "nyt",
"washington post", "wapo", "associated press", "ap", "afp",
"dw", "deutsche welle", "france24", "africanews", "allaf rica",
"financial afrik", "africa news"
}
# Temporal keywords
TEMPORAL_KEYWORDS = {
"today", "yesterday", "tomorrow", "tonight", "now", "currently",
"latest", "breaking", "recent", "just", "this morning", "this evening",
"this week", "this month", "this year", "last week", "last month",
"last year", "past", "ago"
}
# Ethiopian locations for better recognition
ETHIOPIAN_LOCATIONS = {
"ethiopia", "addis ababa", "addis", "tigray", "amhara", "oromia",
"oromo", "afar", "somali", "sidama", "snnpr", "gambela", "harari",
"dire dawa", "bahir dar", "mekelle", "gondar", "hawassa", "jimma",
"gonder", "dessie", "harar"
}
def __init__(self, cache=None):
"""
Initialize entity extractor.
Args:
cache: Cache adapter for storing extractions
"""
self._nlp = None
self._lock = threading.Lock()
self._load_failed = False
self.cache = cache
def _load(self):
"""Lazy load spaCy model (thread-safe)"""
if self._nlp is not None or self._load_failed:
return
with self._lock:
if self._nlp is not None or self._load_failed:
return
try:
import spacy
# Try to load small English model
try:
self._nlp = spacy.load("en_core_web_sm")
logger.info("β
Loaded spaCy en_core_web_sm model")
except OSError:
# Model not installed, use blank model with basic NER
logger.warning("spaCy model not found, using pattern-based extraction")
self._nlp = None
self._load_failed = True
except ImportError:
logger.warning("spaCy not installed, using pattern-based extraction")
self._nlp = None
self._load_failed = True
def extract(self, query: str) -> ExtractedEntities:
"""
Extract entities from query.
Args:
query: User query
Returns:
ExtractedEntities with all extracted information
"""
# Check cache first
if self.cache:
cache_key = f"entity_extraction:{query.lower()}"
cached = self.cache.get(cache_key)
if cached:
logger.debug(f"Entity extraction cache hit: {query}")
return ExtractedEntities(**cached)
# Try spaCy extraction first
self._load()
if self._nlp:
result = self._extract_with_spacy(query)
else:
# Fallback to pattern-based extraction
result = self._extract_with_patterns(query)
# Cache result
if self.cache:
cache_key = f"entity_extraction:{query.lower()}"
self.cache.set(
cache_key,
{
"locations": result.locations,
"organizations": result.organizations,
"persons": result.persons,
"dates": result.dates,
"temporal_keywords": result.temporal_keywords,
"source_keywords": result.source_keywords,
"raw_entities": result.raw_entities
},
expiration=3600 # 1 hour
)
return result
def _extract_with_spacy(self, query: str) -> ExtractedEntities:
"""Extract entities using spaCy NER"""
doc = self._nlp(query)
locations = []
organizations = []
persons = []
dates = []
raw_entities = []
for ent in doc.ents:
entity_info = {
"text": ent.text,
"label": ent.label_,
"start": ent.start_char,
"end": ent.end_char
}
raw_entities.append(entity_info)
if ent.label_ in ["GPE", "LOC"]: # Geopolitical entity or location
locations.append(ent.text)
elif ent.label_ == "ORG": # Organization
organizations.append(ent.text)
elif ent.label_ == "PERSON": # Person
persons.append(ent.text)
elif ent.label_ == "DATE": # Date
dates.append(ent.text)
# Add pattern-based extraction to supplement spaCy
pattern_result = self._extract_with_patterns(query)
# Merge results (deduplicate)
locations = list(set(locations + pattern_result.locations))
organizations = list(set(organizations + pattern_result.organizations))
persons = list(set(persons + pattern_result.persons))
dates = list(set(dates + pattern_result.dates))
return ExtractedEntities(
locations=locations,
organizations=organizations,
persons=persons,
dates=dates,
temporal_keywords=pattern_result.temporal_keywords,
source_keywords=pattern_result.source_keywords,
raw_entities=raw_entities
)
def _extract_with_patterns(self, query: str) -> ExtractedEntities:
"""Extract entities using regex patterns (fallback)"""
query_lower = query.lower()
# Extract locations
locations = []
for loc in self.ETHIOPIAN_LOCATIONS:
if loc in query_lower:
locations.append(loc.title())
# Extract organizations (news sources)
organizations = []
source_keywords = []
for source in self.NEWS_SOURCES:
if source in query_lower:
organizations.append(source.title())
source_keywords.append(source)
# Extract temporal keywords
temporal_keywords = []
for keyword in self.TEMPORAL_KEYWORDS:
if keyword in query_lower:
temporal_keywords.append(keyword)
# Extract dates using patterns
dates = []
# Pattern: "May 2026", "April 30", etc.
date_pattern = r'\b(january|february|march|april|may|june|july|august|september|october|november|december)\s+\d{1,2}(?:,?\s+\d{4})?\b'
date_matches = re.findall(date_pattern, query_lower, re.IGNORECASE)
dates.extend(date_matches)
# Pattern: "2026-05-03", "2026/05/03"
iso_pattern = r'\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b'
iso_matches = re.findall(iso_pattern, query)
dates.extend(iso_matches)
# Pattern: "3 days ago", "2 weeks ago"
relative_pattern = r'\b\d+\s+(day|days|week|weeks|month|months|year|years)\s+ago\b'
relative_matches = re.findall(relative_pattern, query_lower)
dates.extend([' '.join(m) for m in relative_matches])
return ExtractedEntities(
locations=list(set(locations)),
organizations=list(set(organizations)),
persons=[], # Pattern-based person extraction is unreliable
dates=list(set(dates)),
temporal_keywords=list(set(temporal_keywords)),
source_keywords=list(set(source_keywords)),
raw_entities=[]
)
def get_source_filter(self, entities: ExtractedEntities) -> Optional[str]:
"""
Get source filter from extracted entities.
Returns:
Source name if found, None otherwise
"""
if entities.source_keywords:
# Return first source keyword
return entities.source_keywords[0]
if entities.organizations:
# Check if any organization is a known news source
for org in entities.organizations:
org_lower = org.lower()
if org_lower in self.NEWS_SOURCES:
return org_lower
return None
def get_location_filter(self, entities: ExtractedEntities) -> Optional[str]:
"""
Get location filter from extracted entities.
Returns:
Location name if found, None otherwise
"""
if entities.locations:
# Return first location
return entities.locations[0]
return None
def has_temporal_context(self, entities: ExtractedEntities) -> bool:
"""Check if query has temporal context"""
return len(entities.temporal_keywords) > 0 or len(entities.dates) > 0
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# SINGLETON INSTANCE
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Will be initialized with dependencies in main.py
entity_extractor: Optional[EntityExtractor] = None
def initialize_entity_extractor(cache=None):
"""Initialize global entity extractor instance"""
global entity_extractor
entity_extractor = EntityExtractor(cache)
logger.info("Entity extractor initialized")
|