Spaces:
Running
Running
| """ | |
| Named Entity Recognition (NER) Extractor | |
| Extracts entities from queries: | |
| - Locations (Ethiopia, Addis Ababa, Tigray) | |
| - Organizations (BBC, Al Jazeera, UN) | |
| - Persons (Abiy Ahmed, etc.) | |
| - Dates (today, yesterday, May 2026) | |
| Uses lightweight spaCy model for fast extraction (<10ms). | |
| """ | |
| import logging | |
| import re | |
| from typing import Dict, List, Any, Optional | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| import threading | |
| logger = logging.getLogger(__name__) | |
| class ExtractedEntities: | |
| """Extracted entities from query""" | |
| locations: List[str] | |
| organizations: List[str] | |
| persons: List[str] | |
| dates: List[str] | |
| temporal_keywords: List[str] | |
| source_keywords: List[str] | |
| raw_entities: List[Dict[str, Any]] | |
| class EntityExtractor: | |
| """ | |
| Extract named entities from queries using spaCy. | |
| Features: | |
| - Fast extraction (<10ms) | |
| - Lazy loading (only loads when first used) | |
| - Thread-safe | |
| - Caching support | |
| """ | |
| # Known news sources for better extraction | |
| NEWS_SOURCES = { | |
| "bbc", "al jazeera", "aljazeera", "reuters", "cnn", "guardian", | |
| "the guardian", "financial times", "ft", "new york times", "nyt", | |
| "washington post", "wapo", "associated press", "ap", "afp", | |
| "dw", "deutsche welle", "france24", "africanews", "allaf rica", | |
| "financial afrik", "africa news" | |
| } | |
| # Temporal keywords | |
| TEMPORAL_KEYWORDS = { | |
| "today", "yesterday", "tomorrow", "tonight", "now", "currently", | |
| "latest", "breaking", "recent", "just", "this morning", "this evening", | |
| "this week", "this month", "this year", "last week", "last month", | |
| "last year", "past", "ago" | |
| } | |
| # Ethiopian locations for better recognition | |
| ETHIOPIAN_LOCATIONS = { | |
| "ethiopia", "addis ababa", "addis", "tigray", "amhara", "oromia", | |
| "oromo", "afar", "somali", "sidama", "snnpr", "gambela", "harari", | |
| "dire dawa", "bahir dar", "mekelle", "gondar", "hawassa", "jimma", | |
| "gonder", "dessie", "harar" | |
| } | |
| def __init__(self, cache=None): | |
| """ | |
| Initialize entity extractor. | |
| Args: | |
| cache: Cache adapter for storing extractions | |
| """ | |
| self._nlp = None | |
| self._lock = threading.Lock() | |
| self._load_failed = False | |
| self.cache = cache | |
| def _load(self): | |
| """Lazy load spaCy model (thread-safe)""" | |
| if self._nlp is not None or self._load_failed: | |
| return | |
| with self._lock: | |
| if self._nlp is not None or self._load_failed: | |
| return | |
| try: | |
| import spacy | |
| # Try to load small English model | |
| try: | |
| self._nlp = spacy.load("en_core_web_sm") | |
| logger.info("β Loaded spaCy en_core_web_sm model") | |
| except OSError: | |
| # Model not installed, use blank model with basic NER | |
| logger.warning("spaCy model not found, using pattern-based extraction") | |
| self._nlp = None | |
| self._load_failed = True | |
| except ImportError: | |
| logger.warning("spaCy not installed, using pattern-based extraction") | |
| self._nlp = None | |
| self._load_failed = True | |
| def extract(self, query: str) -> ExtractedEntities: | |
| """ | |
| Extract entities from query. | |
| Args: | |
| query: User query | |
| Returns: | |
| ExtractedEntities with all extracted information | |
| """ | |
| # Check cache first | |
| if self.cache: | |
| cache_key = f"entity_extraction:{query.lower()}" | |
| cached = self.cache.get(cache_key) | |
| if cached: | |
| logger.debug(f"Entity extraction cache hit: {query}") | |
| return ExtractedEntities(**cached) | |
| # Try spaCy extraction first | |
| self._load() | |
| if self._nlp: | |
| result = self._extract_with_spacy(query) | |
| else: | |
| # Fallback to pattern-based extraction | |
| result = self._extract_with_patterns(query) | |
| # Cache result | |
| if self.cache: | |
| cache_key = f"entity_extraction:{query.lower()}" | |
| self.cache.set( | |
| cache_key, | |
| { | |
| "locations": result.locations, | |
| "organizations": result.organizations, | |
| "persons": result.persons, | |
| "dates": result.dates, | |
| "temporal_keywords": result.temporal_keywords, | |
| "source_keywords": result.source_keywords, | |
| "raw_entities": result.raw_entities | |
| }, | |
| expiration=3600 # 1 hour | |
| ) | |
| return result | |
| def _extract_with_spacy(self, query: str) -> ExtractedEntities: | |
| """Extract entities using spaCy NER""" | |
| doc = self._nlp(query) | |
| locations = [] | |
| organizations = [] | |
| persons = [] | |
| dates = [] | |
| raw_entities = [] | |
| for ent in doc.ents: | |
| entity_info = { | |
| "text": ent.text, | |
| "label": ent.label_, | |
| "start": ent.start_char, | |
| "end": ent.end_char | |
| } | |
| raw_entities.append(entity_info) | |
| if ent.label_ in ["GPE", "LOC"]: # Geopolitical entity or location | |
| locations.append(ent.text) | |
| elif ent.label_ == "ORG": # Organization | |
| organizations.append(ent.text) | |
| elif ent.label_ == "PERSON": # Person | |
| persons.append(ent.text) | |
| elif ent.label_ == "DATE": # Date | |
| dates.append(ent.text) | |
| # Add pattern-based extraction to supplement spaCy | |
| pattern_result = self._extract_with_patterns(query) | |
| # Merge results (deduplicate) | |
| locations = list(set(locations + pattern_result.locations)) | |
| organizations = list(set(organizations + pattern_result.organizations)) | |
| persons = list(set(persons + pattern_result.persons)) | |
| dates = list(set(dates + pattern_result.dates)) | |
| return ExtractedEntities( | |
| locations=locations, | |
| organizations=organizations, | |
| persons=persons, | |
| dates=dates, | |
| temporal_keywords=pattern_result.temporal_keywords, | |
| source_keywords=pattern_result.source_keywords, | |
| raw_entities=raw_entities | |
| ) | |
| def _extract_with_patterns(self, query: str) -> ExtractedEntities: | |
| """Extract entities using regex patterns (fallback)""" | |
| query_lower = query.lower() | |
| # Extract locations | |
| locations = [] | |
| for loc in self.ETHIOPIAN_LOCATIONS: | |
| if loc in query_lower: | |
| locations.append(loc.title()) | |
| # Extract organizations (news sources) | |
| organizations = [] | |
| source_keywords = [] | |
| for source in self.NEWS_SOURCES: | |
| if source in query_lower: | |
| organizations.append(source.title()) | |
| source_keywords.append(source) | |
| # Extract temporal keywords | |
| temporal_keywords = [] | |
| for keyword in self.TEMPORAL_KEYWORDS: | |
| if keyword in query_lower: | |
| temporal_keywords.append(keyword) | |
| # Extract dates using patterns | |
| dates = [] | |
| # Pattern: "May 2026", "April 30", etc. | |
| date_pattern = r'\b(january|february|march|april|may|june|july|august|september|october|november|december)\s+\d{1,2}(?:,?\s+\d{4})?\b' | |
| date_matches = re.findall(date_pattern, query_lower, re.IGNORECASE) | |
| dates.extend(date_matches) | |
| # Pattern: "2026-05-03", "2026/05/03" | |
| iso_pattern = r'\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b' | |
| iso_matches = re.findall(iso_pattern, query) | |
| dates.extend(iso_matches) | |
| # Pattern: "3 days ago", "2 weeks ago" | |
| relative_pattern = r'\b\d+\s+(day|days|week|weeks|month|months|year|years)\s+ago\b' | |
| relative_matches = re.findall(relative_pattern, query_lower) | |
| dates.extend([' '.join(m) for m in relative_matches]) | |
| return ExtractedEntities( | |
| locations=list(set(locations)), | |
| organizations=list(set(organizations)), | |
| persons=[], # Pattern-based person extraction is unreliable | |
| dates=list(set(dates)), | |
| temporal_keywords=list(set(temporal_keywords)), | |
| source_keywords=list(set(source_keywords)), | |
| raw_entities=[] | |
| ) | |
| def get_source_filter(self, entities: ExtractedEntities) -> Optional[str]: | |
| """ | |
| Get source filter from extracted entities. | |
| Returns: | |
| Source name if found, None otherwise | |
| """ | |
| if entities.source_keywords: | |
| # Return first source keyword | |
| return entities.source_keywords[0] | |
| if entities.organizations: | |
| # Check if any organization is a known news source | |
| for org in entities.organizations: | |
| org_lower = org.lower() | |
| if org_lower in self.NEWS_SOURCES: | |
| return org_lower | |
| return None | |
| def get_location_filter(self, entities: ExtractedEntities) -> Optional[str]: | |
| """ | |
| Get location filter from extracted entities. | |
| Returns: | |
| Location name if found, None otherwise | |
| """ | |
| if entities.locations: | |
| # Return first location | |
| return entities.locations[0] | |
| return None | |
| def has_temporal_context(self, entities: ExtractedEntities) -> bool: | |
| """Check if query has temporal context""" | |
| return len(entities.temporal_keywords) > 0 or len(entities.dates) > 0 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SINGLETON INSTANCE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Will be initialized with dependencies in main.py | |
| entity_extractor: Optional[EntityExtractor] = None | |
| def initialize_entity_extractor(cache=None): | |
| """Initialize global entity extractor instance""" | |
| global entity_extractor | |
| entity_extractor = EntityExtractor(cache) | |
| logger.info("Entity extractor initialized") | |