BPL-RAG-Spring-2026 / graph /entity_extractor.py
han-na's picture
fix: add graphrag code
3b69792
"""
graph/entity_extractor.py
Extracts named entities from document chunks using spaCy.
Entity types: PERSON, GPE (places), ORG, EVENT, DATE
Usage:
from graph.entity_extractor import extractor
entities = extractor.extract("Mayor Fitzgerald attended the rally in Boston.")
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List
import spacy
# Entity types we care about
RELEVANT_TYPES = {"PERSON", "GPE", "LOC", "ORG", "EVENT", "DATE", "FAC", "NORP"}
# Map spaCy types to our simplified types
TYPE_MAP = {
"PERSON": "PERSON",
"GPE": "PLACE", # Geo-political entity (cities, countries)
"LOC": "PLACE", # Non-GPE locations
"FAC": "PLACE", # Facilities (buildings, bridges)
"ORG": "ORG",
"EVENT": "EVENT",
"DATE": "DATE",
"NORP": "ORG", # Nationalities, religious groups, political groups
}
@dataclass
class Entity:
text: str # normalized entity text
type: str # simplified type (PERSON, PLACE, ORG, EVENT, DATE)
raw_type: str # original spaCy label
count: int = 1
class EntityExtractor:
def __init__(self, model: str = "en_core_web_sm"):
self._nlp = None
self._model = model
def _load(self):
if self._nlp is None:
print(f"Loading spaCy model: {self._model}")
try:
self._nlp = spacy.load(self._model)
except OSError:
print(f"Model {self._model} not found. Downloading...")
import subprocess
subprocess.run(
["python", "-m", "spacy", "download", self._model],
check=True
)
self._nlp = spacy.load(self._model)
def extract(self, text: str) -> List[Entity]:
"""
Extract named entities from text.
Processes in 100K character windows to avoid memory issues
while covering the full document.
"""
self._load()
if not text or not text.strip():
return []
# Process in overlapping windows to cover full document
WINDOW_SIZE = 100_000
OVERLAP = 1_000 # overlap to avoid missing entities at boundaries
entity_counts: dict[tuple, int] = {}
entity_raw: dict[tuple, str] = {}
start = 0
while start < len(text):
end = min(start + WINDOW_SIZE, len(text))
window = text[start:end]
doc = self._nlp(window)
for ent in doc.ents:
if ent.label_ not in RELEVANT_TYPES:
continue
normalized = ent.text.strip().lower()
if len(normalized) < 2 or len(normalized) > 100:
continue
if ent.label_ != "DATE" and normalized.replace(" ", "").isdigit():
continue
mapped_type = TYPE_MAP.get(ent.label_, ent.label_)
key = (normalized, mapped_type)
entity_counts[key] = entity_counts.get(key, 0) + 1
entity_raw[key] = ent.label_
# Move window forward with overlap
start += WINDOW_SIZE - OVERLAP
if end == len(text):
break
return [
Entity(
text = text_,
type = type_,
raw_type = entity_raw[(text_, type_)],
count = count,
)
for (text_, type_), count in sorted(
entity_counts.items(),
key=lambda x: x[1],
reverse=True
)
]
def extract_top(self, text: str, n: int = 20) -> List[Entity]:
"""Return only the top N most frequent entities."""
return self.extract(text)[:n]
# Module-level singleton
extractor = EntityExtractor()