File size: 3,906 Bytes
3b69792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
graph/entity_extractor.py

Extracts named entities from document chunks using spaCy.
Entity types: PERSON, GPE (places), ORG, EVENT, DATE

Usage:
    from graph.entity_extractor import extractor
    entities = extractor.extract("Mayor Fitzgerald attended the rally in Boston.")
"""

from __future__ import annotations
from dataclasses import dataclass
from typing import List
import spacy


# Entity types we care about
RELEVANT_TYPES = {"PERSON", "GPE", "LOC", "ORG", "EVENT", "DATE", "FAC", "NORP"}

# Map spaCy types to our simplified types
TYPE_MAP = {
    "PERSON": "PERSON",
    "GPE":    "PLACE",     # Geo-political entity (cities, countries)
    "LOC":    "PLACE",     # Non-GPE locations
    "FAC":    "PLACE",     # Facilities (buildings, bridges)
    "ORG":    "ORG",
    "EVENT":  "EVENT",
    "DATE":   "DATE",
    "NORP":   "ORG",       # Nationalities, religious groups, political groups
}


@dataclass
class Entity:
    text:       str   # normalized entity text
    type:       str   # simplified type (PERSON, PLACE, ORG, EVENT, DATE)
    raw_type:   str   # original spaCy label
    count:      int = 1


class EntityExtractor:

    def __init__(self, model: str = "en_core_web_sm"):
        self._nlp = None
        self._model = model

    def _load(self):
        if self._nlp is None:
            print(f"Loading spaCy model: {self._model}")
            try:
                self._nlp = spacy.load(self._model)
            except OSError:
                print(f"Model {self._model} not found. Downloading...")
                import subprocess
                subprocess.run(
                    ["python", "-m", "spacy", "download", self._model],
                    check=True
                )
                self._nlp = spacy.load(self._model)

    def extract(self, text: str) -> List[Entity]:
        """
        Extract named entities from text.
        Processes in 100K character windows to avoid memory issues
        while covering the full document.
        """
        self._load()

        if not text or not text.strip():
            return []

        # Process in overlapping windows to cover full document
        WINDOW_SIZE = 100_000
        OVERLAP     = 1_000    # overlap to avoid missing entities at boundaries

        entity_counts: dict[tuple, int] = {}
        entity_raw:    dict[tuple, str] = {}

        start = 0
        while start < len(text):
            end     = min(start + WINDOW_SIZE, len(text))
            window  = text[start:end]
            doc     = self._nlp(window)

            for ent in doc.ents:
                if ent.label_ not in RELEVANT_TYPES:
                    continue
                normalized = ent.text.strip().lower()
                if len(normalized) < 2 or len(normalized) > 100:
                    continue
                if ent.label_ != "DATE" and normalized.replace(" ", "").isdigit():
                    continue

                mapped_type = TYPE_MAP.get(ent.label_, ent.label_)
                key = (normalized, mapped_type)
                entity_counts[key] = entity_counts.get(key, 0) + 1
                entity_raw[key]    = ent.label_

            # Move window forward with overlap
            start += WINDOW_SIZE - OVERLAP
            if end == len(text):
                break

        return [
            Entity(
                text     = text_,
                type     = type_,
                raw_type = entity_raw[(text_, type_)],
                count    = count,
            )
            for (text_, type_), count in sorted(
                entity_counts.items(),
                key=lambda x: x[1],
                reverse=True
            )
        ]

    def extract_top(self, text: str, n: int = 20) -> List[Entity]:
        """Return only the top N most frequent entities."""
        return self.extract(text)[:n]


# Module-level singleton
extractor = EntityExtractor()