Spaces:
Sleeping
Sleeping
| """ | |
| graph/entity_extractor.py | |
| Extracts named entities from document chunks using spaCy. | |
| Entity types: PERSON, GPE (places), ORG, EVENT, DATE | |
| Usage: | |
| from graph.entity_extractor import extractor | |
| entities = extractor.extract("Mayor Fitzgerald attended the rally in Boston.") | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import List | |
| import spacy | |
| # Entity types we care about | |
| RELEVANT_TYPES = {"PERSON", "GPE", "LOC", "ORG", "EVENT", "DATE", "FAC", "NORP"} | |
| # Map spaCy types to our simplified types | |
| TYPE_MAP = { | |
| "PERSON": "PERSON", | |
| "GPE": "PLACE", # Geo-political entity (cities, countries) | |
| "LOC": "PLACE", # Non-GPE locations | |
| "FAC": "PLACE", # Facilities (buildings, bridges) | |
| "ORG": "ORG", | |
| "EVENT": "EVENT", | |
| "DATE": "DATE", | |
| "NORP": "ORG", # Nationalities, religious groups, political groups | |
| } | |
| class Entity: | |
| text: str # normalized entity text | |
| type: str # simplified type (PERSON, PLACE, ORG, EVENT, DATE) | |
| raw_type: str # original spaCy label | |
| count: int = 1 | |
| class EntityExtractor: | |
| def __init__(self, model: str = "en_core_web_sm"): | |
| self._nlp = None | |
| self._model = model | |
| def _load(self): | |
| if self._nlp is None: | |
| print(f"Loading spaCy model: {self._model}") | |
| try: | |
| self._nlp = spacy.load(self._model) | |
| except OSError: | |
| print(f"Model {self._model} not found. Downloading...") | |
| import subprocess | |
| subprocess.run( | |
| ["python", "-m", "spacy", "download", self._model], | |
| check=True | |
| ) | |
| self._nlp = spacy.load(self._model) | |
| def extract(self, text: str) -> List[Entity]: | |
| """ | |
| Extract named entities from text. | |
| Processes in 100K character windows to avoid memory issues | |
| while covering the full document. | |
| """ | |
| self._load() | |
| if not text or not text.strip(): | |
| return [] | |
| # Process in overlapping windows to cover full document | |
| WINDOW_SIZE = 100_000 | |
| OVERLAP = 1_000 # overlap to avoid missing entities at boundaries | |
| entity_counts: dict[tuple, int] = {} | |
| entity_raw: dict[tuple, str] = {} | |
| start = 0 | |
| while start < len(text): | |
| end = min(start + WINDOW_SIZE, len(text)) | |
| window = text[start:end] | |
| doc = self._nlp(window) | |
| for ent in doc.ents: | |
| if ent.label_ not in RELEVANT_TYPES: | |
| continue | |
| normalized = ent.text.strip().lower() | |
| if len(normalized) < 2 or len(normalized) > 100: | |
| continue | |
| if ent.label_ != "DATE" and normalized.replace(" ", "").isdigit(): | |
| continue | |
| mapped_type = TYPE_MAP.get(ent.label_, ent.label_) | |
| key = (normalized, mapped_type) | |
| entity_counts[key] = entity_counts.get(key, 0) + 1 | |
| entity_raw[key] = ent.label_ | |
| # Move window forward with overlap | |
| start += WINDOW_SIZE - OVERLAP | |
| if end == len(text): | |
| break | |
| return [ | |
| Entity( | |
| text = text_, | |
| type = type_, | |
| raw_type = entity_raw[(text_, type_)], | |
| count = count, | |
| ) | |
| for (text_, type_), count in sorted( | |
| entity_counts.items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| ] | |
| def extract_top(self, text: str, n: int = 20) -> List[Entity]: | |
| """Return only the top N most frequent entities.""" | |
| return self.extract(text)[:n] | |
| # Module-level singleton | |
| extractor = EntityExtractor() | |