Spaces:
Running
Running
| """ | |
| Cortex RAG β Knowledge Graph Builder (Phase 3) | |
| What this does | |
| ββββββββββββββ | |
| During ingestion, every chunk is processed to extract: | |
| 1. Named entities (spaCy NER: PERSON, ORG, WORK_OF_ART, PRODUCT, β¦) | |
| 2. Relations (few-shot LLM: subject β predicate β object triples) | |
| These are assembled into a NetworkX undirected graph where: | |
| - Nodes = entities (label + type + first-seen source) | |
| - Edges = relations (predicate label + list of source chunk_ids) | |
| Each node also carries a list of chunk_ids it appeared in, so the | |
| graph retriever can map entity β chunks without an extra lookup. | |
| The graph is persisted as a JSON file (graphs are small β a 100-doc | |
| corpus typically has <10k nodes). On reload the full graph is | |
| reconstructed in seconds from the JSON. | |
| ββββββββββββββ | |
| (Phase 3, refactored) | |
| The builder is now responsible ONLY for: | |
| - spaCy NER (entities are always extracted the same way) | |
| - Assembling triples into a NetworkX graph | |
| - Persisting / loading the graph | |
| Relation extraction is delegated to a RelationExtractor strategy: | |
| - REBELExtractor (default) β local model, no API calls | |
| - LLMExtractor β Groq, free-form predicates | |
| Switch via .env: | |
| GRAPH_EXTRACTOR=rebel # default, recommended | |
| GRAPH_EXTRACTOR=llm # original method | |
| Or pass explicitly: | |
| builder = KnowledgeGraphBuilder(extractor=LLMExtractor()) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| import networkx as nx | |
| from ingestion.chunker import Chunk | |
| from retrieval.relation_extractors import ( | |
| RelationExtractor, | |
| Triple, | |
| build_extractor, | |
| ) | |
| from config import get_settings | |
| logger = logging.getLogger(__name__) | |
| cfg = get_settings() | |
| _DEFAULT_GRAPH_PATH = Path(cfg.graph_path) | |
| # spaCy entity types we care about for RAG | |
| _ENTITY_TYPES = { | |
| "PERSON", "ORG", "GPE", "PRODUCT", "WORK_OF_ART", | |
| "EVENT", "LAW", "NORP", "FAC", "LOC", | |
| } | |
| class KnowledgeGraphBuilder: | |
| """ | |
| Builds and maintains the knowledge graph. | |
| Usage (at ingestion time): | |
| # REBEL (default β no API calls) | |
| builder = KnowledgeGraphBuilder() | |
| builder.process_chunks(chunks) | |
| # LLM method (original) | |
| from retrieval.relation_extractors import LLMExtractor | |
| builder = KnowledgeGraphBuilder(extractor=LLMExtractor()) | |
| builder.process_chunks(chunks) | |
| Usage (at query time): | |
| builder = KnowledgeGraphBuilder() | |
| G = builder.graph # loaded from disk automatically | |
| """ | |
| def __init__( | |
| self, | |
| graph_path: str | Path = _DEFAULT_GRAPH_PATH, | |
| extractor: Optional[RelationExtractor] = None, | |
| ) -> None: | |
| self._path = Path(graph_path) | |
| self._graph: nx.Graph = nx.Graph() | |
| # If no extractor is injected, build_extractor() reads GRAPH_EXTRACTOR from .env | |
| self._extractor: RelationExtractor = extractor or build_extractor() | |
| self._nlp = None | |
| self._load_if_exists() | |
| logger.info( | |
| "KnowledgeGraphBuilder ready (extractor=%s)", self._extractor.name | |
| ) | |
| # ββ Public API βββββββββββββββββββββββββββββββββββββββββββββ | |
| def graph(self) -> nx.Graph: | |
| return self._graph | |
| def extractor_name(self) -> str: | |
| return self._extractor.name | |
| def process_chunks(self, chunks: list[Chunk]) -> dict: | |
| """ | |
| Extract entities and relations from chunks; update and save graph. | |
| Uses the configured extractor's extract_batch() for efficiency. | |
| Returns stats dict. | |
| """ | |
| if not chunks: | |
| return {"chunks": 0, "entities": 0, "triples": 0, "errors": 0} | |
| stats = {"chunks": len(chunks), "entities": 0, "triples": 0, "errors": 0} | |
| # ββ Batch relation extraction ββββββββββββββββββββββββββ | |
| # REBEL processes all chunks in one forward pass. | |
| # LLM falls back to sequential (one API call per chunk). | |
| try: | |
| triple_map = self._extractor.extract_batch(chunks) | |
| except Exception as exc: | |
| logger.error("Batch extraction failed, falling back to sequential: %s", exc) | |
| triple_map = {} | |
| for chunk in chunks: | |
| try: | |
| triple_map[chunk.chunk_id] = self._extractor.extract(chunk) | |
| except Exception as e: | |
| logger.warning("Extraction failed for %s: %s", chunk.chunk_id, e) | |
| triple_map[chunk.chunk_id] = [] | |
| stats["errors"] += 1 | |
| # ββ Entity extraction + graph update βββββββββββββββββββ | |
| for chunk in chunks: | |
| try: | |
| entities = self._extract_entities(chunk.text) | |
| triples = triple_map.get(chunk.chunk_id, []) | |
| self._add_entities_to_graph(entities, chunk) | |
| self._add_triples_to_graph(triples) | |
| stats["entities"] += len(entities) | |
| stats["triples"] += len(triples) | |
| except Exception as exc: | |
| logger.warning("Graph update failed for chunk %s: %s", chunk.chunk_id, exc) | |
| stats["errors"] += 1 | |
| self.save() | |
| logger.info( | |
| "Graph updated via %s: +%d entities, +%d triples (nodes=%d, edges=%d)", | |
| self._extractor.name, | |
| stats["entities"], stats["triples"], | |
| self._graph.number_of_nodes(), self._graph.number_of_edges(), | |
| ) | |
| return stats | |
| def save(self) -> None: | |
| self._path.parent.mkdir(parents=True, exist_ok=True) | |
| data = nx.node_link_data(self._graph) | |
| with open(self._path, "w") as fh: | |
| json.dump(data, fh, indent=2) | |
| logger.debug("Graph saved to %s", self._path) | |
| def stats(self) -> dict: | |
| return { | |
| "nodes": self._graph.number_of_nodes(), | |
| "edges": self._graph.number_of_edges(), | |
| "extractor": self._extractor.name, | |
| "graph_path": str(self._path), | |
| } | |
| # ββ Entity extraction (always spaCy β same for both methods) β | |
| def _extract_entities(self, text: str) -> list[tuple[str, str]]: | |
| nlp = self._get_nlp() | |
| doc = nlp(text[:10_000]) | |
| seen: set[str] = set() | |
| entities: list[tuple[str, str]] = [] | |
| for ent in doc.ents: | |
| if ent.label_ not in _ENTITY_TYPES: | |
| continue | |
| normalised = ent.text.strip().title() | |
| if normalised in seen or len(normalised) < 2: | |
| continue | |
| seen.add(normalised) | |
| entities.append((normalised, ent.label_)) | |
| return entities | |
| # ββ Graph construction (shared by both methods) ββββββββββββ | |
| def _add_entities_to_graph( | |
| self, entities: list[tuple[str, str]], chunk: Chunk | |
| ) -> None: | |
| for label, etype in entities: | |
| if self._graph.has_node(label): | |
| existing = self._graph.nodes[label].get("chunk_ids", []) | |
| if chunk.chunk_id not in existing: | |
| existing.append(chunk.chunk_id) | |
| self._graph.nodes[label]["chunk_ids"] = existing | |
| else: | |
| self._graph.add_node( | |
| label, | |
| entity_type=etype, | |
| chunk_ids=[chunk.chunk_id], | |
| source=chunk.source, | |
| ) | |
| def _add_triples_to_graph(self, triples: list[Triple]) -> None: | |
| for triple in triples: | |
| for node in (triple.subject, triple.object): | |
| if not self._graph.has_node(node): | |
| self._graph.add_node( | |
| node, | |
| entity_type="UNKNOWN", | |
| chunk_ids=[], | |
| source=triple.source, | |
| extractor=triple.extractor, | |
| ) | |
| if self._graph.has_edge(triple.subject, triple.object): | |
| edge = self._graph[triple.subject][triple.object] | |
| predicates = edge.get("predicates", []) | |
| chunk_ids = edge.get("chunk_ids", []) | |
| if triple.predicate not in predicates: | |
| predicates.append(triple.predicate) | |
| if triple.chunk_id not in chunk_ids: | |
| chunk_ids.append(triple.chunk_id) | |
| edge["predicates"] = predicates | |
| edge["chunk_ids"] = chunk_ids | |
| else: | |
| self._graph.add_edge( | |
| triple.subject, triple.object, | |
| predicates=[triple.predicate], | |
| chunk_ids=[triple.chunk_id], | |
| source=triple.source, | |
| extractor=triple.extractor, | |
| ) | |
| # ββ Persistence βββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_if_exists(self) -> None: | |
| if not self._path.exists(): | |
| return | |
| try: | |
| with open(self._path) as fh: | |
| data = json.load(fh) | |
| self._graph = nx.node_link_graph(data) | |
| logger.info( | |
| "Knowledge graph loaded: %d nodes, %d edges", | |
| self._graph.number_of_nodes(), | |
| self._graph.number_of_edges(), | |
| ) | |
| except Exception as exc: | |
| logger.warning("Failed to load graph (%s) β starting fresh.", exc) | |
| # ββ spaCy βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_nlp(self): | |
| if self._nlp is None: | |
| try: | |
| import spacy # type: ignore | |
| except ImportError as exc: | |
| raise RuntimeError("Install spacy: pip install spacy") from exc | |
| try: | |
| self._nlp = spacy.load("en_core_web_sm") | |
| except OSError: | |
| raise RuntimeError( | |
| "Run: python -m spacy download en_core_web_sm" | |
| ) | |
| return self._nlp | |