Cortex / retrieval /graph_builder.py
aditya-joshi-05's picture
update graph path in graph_builder.py
aa4e0f4
"""
Cortex RAG β€” Knowledge Graph Builder (Phase 3)
What this does
──────────────
During ingestion, every chunk is processed to extract:
1. Named entities (spaCy NER: PERSON, ORG, WORK_OF_ART, PRODUCT, …)
2. Relations (few-shot LLM: subject β†’ predicate β†’ object triples)
These are assembled into a NetworkX undirected graph where:
- Nodes = entities (label + type + first-seen source)
- Edges = relations (predicate label + list of source chunk_ids)
Each node also carries a list of chunk_ids it appeared in, so the
graph retriever can map entity β†’ chunks without an extra lookup.
The graph is persisted as a JSON file (graphs are small β€” a 100-doc
corpus typically has <10k nodes). On reload the full graph is
reconstructed in seconds from the JSON.
──────────────
(Phase 3, refactored)
The builder is now responsible ONLY for:
- spaCy NER (entities are always extracted the same way)
- Assembling triples into a NetworkX graph
- Persisting / loading the graph
Relation extraction is delegated to a RelationExtractor strategy:
- REBELExtractor (default) β€” local model, no API calls
- LLMExtractor β€” Groq, free-form predicates
Switch via .env:
GRAPH_EXTRACTOR=rebel # default, recommended
GRAPH_EXTRACTOR=llm # original method
Or pass explicitly:
builder = KnowledgeGraphBuilder(extractor=LLMExtractor())
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Optional
import networkx as nx
from ingestion.chunker import Chunk
from retrieval.relation_extractors import (
RelationExtractor,
Triple,
build_extractor,
)
from config import get_settings
logger = logging.getLogger(__name__)
cfg = get_settings()
_DEFAULT_GRAPH_PATH = Path(cfg.graph_path)
# spaCy entity types we care about for RAG
_ENTITY_TYPES = {
"PERSON", "ORG", "GPE", "PRODUCT", "WORK_OF_ART",
"EVENT", "LAW", "NORP", "FAC", "LOC",
}
class KnowledgeGraphBuilder:
"""
Builds and maintains the knowledge graph.
Usage (at ingestion time):
# REBEL (default β€” no API calls)
builder = KnowledgeGraphBuilder()
builder.process_chunks(chunks)
# LLM method (original)
from retrieval.relation_extractors import LLMExtractor
builder = KnowledgeGraphBuilder(extractor=LLMExtractor())
builder.process_chunks(chunks)
Usage (at query time):
builder = KnowledgeGraphBuilder()
G = builder.graph # loaded from disk automatically
"""
def __init__(
self,
graph_path: str | Path = _DEFAULT_GRAPH_PATH,
extractor: Optional[RelationExtractor] = None,
) -> None:
self._path = Path(graph_path)
self._graph: nx.Graph = nx.Graph()
# If no extractor is injected, build_extractor() reads GRAPH_EXTRACTOR from .env
self._extractor: RelationExtractor = extractor or build_extractor()
self._nlp = None
self._load_if_exists()
logger.info(
"KnowledgeGraphBuilder ready (extractor=%s)", self._extractor.name
)
# ── Public API ─────────────────────────────────────────────
@property
def graph(self) -> nx.Graph:
return self._graph
@property
def extractor_name(self) -> str:
return self._extractor.name
def process_chunks(self, chunks: list[Chunk]) -> dict:
"""
Extract entities and relations from chunks; update and save graph.
Uses the configured extractor's extract_batch() for efficiency.
Returns stats dict.
"""
if not chunks:
return {"chunks": 0, "entities": 0, "triples": 0, "errors": 0}
stats = {"chunks": len(chunks), "entities": 0, "triples": 0, "errors": 0}
# ── Batch relation extraction ──────────────────────────
# REBEL processes all chunks in one forward pass.
# LLM falls back to sequential (one API call per chunk).
try:
triple_map = self._extractor.extract_batch(chunks)
except Exception as exc:
logger.error("Batch extraction failed, falling back to sequential: %s", exc)
triple_map = {}
for chunk in chunks:
try:
triple_map[chunk.chunk_id] = self._extractor.extract(chunk)
except Exception as e:
logger.warning("Extraction failed for %s: %s", chunk.chunk_id, e)
triple_map[chunk.chunk_id] = []
stats["errors"] += 1
# ── Entity extraction + graph update ───────────────────
for chunk in chunks:
try:
entities = self._extract_entities(chunk.text)
triples = triple_map.get(chunk.chunk_id, [])
self._add_entities_to_graph(entities, chunk)
self._add_triples_to_graph(triples)
stats["entities"] += len(entities)
stats["triples"] += len(triples)
except Exception as exc:
logger.warning("Graph update failed for chunk %s: %s", chunk.chunk_id, exc)
stats["errors"] += 1
self.save()
logger.info(
"Graph updated via %s: +%d entities, +%d triples (nodes=%d, edges=%d)",
self._extractor.name,
stats["entities"], stats["triples"],
self._graph.number_of_nodes(), self._graph.number_of_edges(),
)
return stats
def save(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
data = nx.node_link_data(self._graph)
with open(self._path, "w") as fh:
json.dump(data, fh, indent=2)
logger.debug("Graph saved to %s", self._path)
def stats(self) -> dict:
return {
"nodes": self._graph.number_of_nodes(),
"edges": self._graph.number_of_edges(),
"extractor": self._extractor.name,
"graph_path": str(self._path),
}
# ── Entity extraction (always spaCy β€” same for both methods) ─
def _extract_entities(self, text: str) -> list[tuple[str, str]]:
nlp = self._get_nlp()
doc = nlp(text[:10_000])
seen: set[str] = set()
entities: list[tuple[str, str]] = []
for ent in doc.ents:
if ent.label_ not in _ENTITY_TYPES:
continue
normalised = ent.text.strip().title()
if normalised in seen or len(normalised) < 2:
continue
seen.add(normalised)
entities.append((normalised, ent.label_))
return entities
# ── Graph construction (shared by both methods) ────────────
def _add_entities_to_graph(
self, entities: list[tuple[str, str]], chunk: Chunk
) -> None:
for label, etype in entities:
if self._graph.has_node(label):
existing = self._graph.nodes[label].get("chunk_ids", [])
if chunk.chunk_id not in existing:
existing.append(chunk.chunk_id)
self._graph.nodes[label]["chunk_ids"] = existing
else:
self._graph.add_node(
label,
entity_type=etype,
chunk_ids=[chunk.chunk_id],
source=chunk.source,
)
def _add_triples_to_graph(self, triples: list[Triple]) -> None:
for triple in triples:
for node in (triple.subject, triple.object):
if not self._graph.has_node(node):
self._graph.add_node(
node,
entity_type="UNKNOWN",
chunk_ids=[],
source=triple.source,
extractor=triple.extractor,
)
if self._graph.has_edge(triple.subject, triple.object):
edge = self._graph[triple.subject][triple.object]
predicates = edge.get("predicates", [])
chunk_ids = edge.get("chunk_ids", [])
if triple.predicate not in predicates:
predicates.append(triple.predicate)
if triple.chunk_id not in chunk_ids:
chunk_ids.append(triple.chunk_id)
edge["predicates"] = predicates
edge["chunk_ids"] = chunk_ids
else:
self._graph.add_edge(
triple.subject, triple.object,
predicates=[triple.predicate],
chunk_ids=[triple.chunk_id],
source=triple.source,
extractor=triple.extractor,
)
# ── Persistence ───────────────────────────────────────────
def _load_if_exists(self) -> None:
if not self._path.exists():
return
try:
with open(self._path) as fh:
data = json.load(fh)
self._graph = nx.node_link_graph(data)
logger.info(
"Knowledge graph loaded: %d nodes, %d edges",
self._graph.number_of_nodes(),
self._graph.number_of_edges(),
)
except Exception as exc:
logger.warning("Failed to load graph (%s) β€” starting fresh.", exc)
# ── spaCy ─────────────────────────────────────────────────
def _get_nlp(self):
if self._nlp is None:
try:
import spacy # type: ignore
except ImportError as exc:
raise RuntimeError("Install spacy: pip install spacy") from exc
try:
self._nlp = spacy.load("en_core_web_sm")
except OSError:
raise RuntimeError(
"Run: python -m spacy download en_core_web_sm"
)
return self._nlp