File size: 10,371 Bytes
f0d100b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa4e0f4
 
f0d100b
 
aa4e0f4
 
f0d100b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
Cortex RAG β€” Knowledge Graph Builder (Phase 3)

What this does
──────────────
During ingestion, every chunk is processed to extract:
  1. Named entities  (spaCy NER: PERSON, ORG, WORK_OF_ART, PRODUCT, …)
  2. Relations       (few-shot LLM: subject β†’ predicate β†’ object triples)

These are assembled into a NetworkX undirected graph where:
  - Nodes  = entities (label + type + first-seen source)
  - Edges  = relations (predicate label + list of source chunk_ids)

Each node also carries a list of chunk_ids it appeared in, so the
graph retriever can map entity β†’ chunks without an extra lookup.

The graph is persisted as a JSON file (graphs are small β€” a 100-doc
corpus typically has <10k nodes). On reload the full graph is
reconstructed in seconds from the JSON.

──────────────
(Phase 3, refactored)

The builder is now responsible ONLY for:
  - spaCy NER (entities are always extracted the same way)
  - Assembling triples into a NetworkX graph
  - Persisting / loading the graph

Relation extraction is delegated to a RelationExtractor strategy:
  - REBELExtractor  (default) β€” local model, no API calls
  - LLMExtractor              β€” Groq, free-form predicates

Switch via .env:
  GRAPH_EXTRACTOR=rebel    # default, recommended
  GRAPH_EXTRACTOR=llm      # original method

Or pass explicitly:
  builder = KnowledgeGraphBuilder(extractor=LLMExtractor())

"""
from __future__ import annotations

import json
import logging

from pathlib import Path
from typing import Optional

import networkx as nx

from ingestion.chunker import Chunk
from retrieval.relation_extractors import (
    RelationExtractor,
    Triple,
    build_extractor,
)

from config import get_settings

logger = logging.getLogger(__name__)

cfg = get_settings()
_DEFAULT_GRAPH_PATH = Path(cfg.graph_path)

# spaCy entity types we care about for RAG
_ENTITY_TYPES = {
    "PERSON", "ORG", "GPE", "PRODUCT", "WORK_OF_ART",
    "EVENT", "LAW", "NORP", "FAC", "LOC",
}

class KnowledgeGraphBuilder:
    """
    Builds and maintains the knowledge graph.

    Usage (at ingestion time):
        # REBEL (default β€” no API calls)
        builder = KnowledgeGraphBuilder()
        builder.process_chunks(chunks)

        # LLM method (original)
        from retrieval.relation_extractors import LLMExtractor
        builder = KnowledgeGraphBuilder(extractor=LLMExtractor())
        builder.process_chunks(chunks)

    Usage (at query time):
        builder = KnowledgeGraphBuilder()
        G = builder.graph    # loaded from disk automatically
    """

    def __init__(
        self,
        graph_path: str | Path = _DEFAULT_GRAPH_PATH,
        extractor: Optional[RelationExtractor] = None,
    ) -> None:
        self._path = Path(graph_path)
        self._graph: nx.Graph = nx.Graph()
        # If no extractor is injected, build_extractor() reads GRAPH_EXTRACTOR from .env
        self._extractor: RelationExtractor = extractor or build_extractor()
        self._nlp = None
        self._load_if_exists()
        logger.info(
            "KnowledgeGraphBuilder ready (extractor=%s)", self._extractor.name
        )

    # ── Public API ─────────────────────────────────────────────

    @property
    def graph(self) -> nx.Graph:
        return self._graph

    @property
    def extractor_name(self) -> str:
        return self._extractor.name

    def process_chunks(self, chunks: list[Chunk]) -> dict:
        """
        Extract entities and relations from chunks; update and save graph.
        Uses the configured extractor's extract_batch() for efficiency.
        Returns stats dict.
        """
        if not chunks:
            return {"chunks": 0, "entities": 0, "triples": 0, "errors": 0}

        stats = {"chunks": len(chunks), "entities": 0, "triples": 0, "errors": 0}

        # ── Batch relation extraction ──────────────────────────
        # REBEL processes all chunks in one forward pass.
        # LLM falls back to sequential (one API call per chunk).
        try:
            triple_map = self._extractor.extract_batch(chunks)
        except Exception as exc:
            logger.error("Batch extraction failed, falling back to sequential: %s", exc)
            triple_map = {}
            for chunk in chunks:
                try:
                    triple_map[chunk.chunk_id] = self._extractor.extract(chunk)
                except Exception as e:
                    logger.warning("Extraction failed for %s: %s", chunk.chunk_id, e)
                    triple_map[chunk.chunk_id] = []
                    stats["errors"] += 1

        # ── Entity extraction + graph update ───────────────────
        for chunk in chunks:
            try:
                entities = self._extract_entities(chunk.text)
                triples  = triple_map.get(chunk.chunk_id, [])

                self._add_entities_to_graph(entities, chunk)
                self._add_triples_to_graph(triples)

                stats["entities"] += len(entities)
                stats["triples"]  += len(triples)

            except Exception as exc:
                logger.warning("Graph update failed for chunk %s: %s", chunk.chunk_id, exc)
                stats["errors"] += 1

        self.save()
        logger.info(
            "Graph updated via %s: +%d entities, +%d triples (nodes=%d, edges=%d)",
            self._extractor.name,
            stats["entities"], stats["triples"],
            self._graph.number_of_nodes(), self._graph.number_of_edges(),
        )
        return stats

    def save(self) -> None:
        self._path.parent.mkdir(parents=True, exist_ok=True)
        data = nx.node_link_data(self._graph)
        with open(self._path, "w") as fh:
            json.dump(data, fh, indent=2)
        logger.debug("Graph saved to %s", self._path)

    def stats(self) -> dict:
        return {
            "nodes":      self._graph.number_of_nodes(),
            "edges":      self._graph.number_of_edges(),
            "extractor":  self._extractor.name,
            "graph_path": str(self._path),
        }

    # ── Entity extraction (always spaCy β€” same for both methods) ─

    def _extract_entities(self, text: str) -> list[tuple[str, str]]:
        nlp = self._get_nlp()
        doc = nlp(text[:10_000])

        seen: set[str] = set()
        entities: list[tuple[str, str]] = []
        for ent in doc.ents:
            if ent.label_ not in _ENTITY_TYPES:
                continue
            normalised = ent.text.strip().title()
            if normalised in seen or len(normalised) < 2:
                continue
            seen.add(normalised)
            entities.append((normalised, ent.label_))
        return entities

    # ── Graph construction (shared by both methods) ────────────

    def _add_entities_to_graph(
        self, entities: list[tuple[str, str]], chunk: Chunk
    ) -> None:
        for label, etype in entities:
            if self._graph.has_node(label):
                existing = self._graph.nodes[label].get("chunk_ids", [])
                if chunk.chunk_id not in existing:
                    existing.append(chunk.chunk_id)
                self._graph.nodes[label]["chunk_ids"] = existing
            else:
                self._graph.add_node(
                    label,
                    entity_type=etype,
                    chunk_ids=[chunk.chunk_id],
                    source=chunk.source,
                )

    def _add_triples_to_graph(self, triples: list[Triple]) -> None:
        for triple in triples:
            for node in (triple.subject, triple.object):
                if not self._graph.has_node(node):
                    self._graph.add_node(
                        node,
                        entity_type="UNKNOWN",
                        chunk_ids=[],
                        source=triple.source,
                        extractor=triple.extractor,
                    )

            if self._graph.has_edge(triple.subject, triple.object):
                edge = self._graph[triple.subject][triple.object]
                predicates = edge.get("predicates", [])
                chunk_ids  = edge.get("chunk_ids", [])
                if triple.predicate not in predicates:
                    predicates.append(triple.predicate)
                if triple.chunk_id not in chunk_ids:
                    chunk_ids.append(triple.chunk_id)
                edge["predicates"] = predicates
                edge["chunk_ids"]  = chunk_ids
            else:
                self._graph.add_edge(
                    triple.subject, triple.object,
                    predicates=[triple.predicate],
                    chunk_ids=[triple.chunk_id],
                    source=triple.source,
                    extractor=triple.extractor,
                )

    # ── Persistence ───────────────────────────────────────────

    def _load_if_exists(self) -> None:
        if not self._path.exists():
            return
        try:
            with open(self._path) as fh:
                data = json.load(fh)
            self._graph = nx.node_link_graph(data)
            logger.info(
                "Knowledge graph loaded: %d nodes, %d edges",
                self._graph.number_of_nodes(),
                self._graph.number_of_edges(),
            )
        except Exception as exc:
            logger.warning("Failed to load graph (%s) β€” starting fresh.", exc)

    # ── spaCy ─────────────────────────────────────────────────

    def _get_nlp(self):
        if self._nlp is None:
            try:
                import spacy  # type: ignore
            except ImportError as exc:
                raise RuntimeError("Install spacy: pip install spacy") from exc
            try:
                self._nlp = spacy.load("en_core_web_sm")
            except OSError:
                raise RuntimeError(
                    "Run: python -m spacy download en_core_web_sm"
                )
        return self._nlp