""" Knowledge Graph data structure for the Graph Reasoning Network. Each node represents a fact, each edge represents a relationship between facts. The graph supports dynamic node/edge creation, navigation, and DAG subgraph extraction. """ import torch import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.utils import subgraph, k_hop_subgraph, to_undirected from typing import Optional, List, Tuple, Dict import json class KnowledgeGraph: """ A dynamic knowledge graph that can grow over time. Nodes: facts (with text descriptions and learned embeddings) Edges: relationships between facts (typed, directed) Supports: - Adding new nodes and edges - Navigating to find relevant subgraphs - Extracting DAG subgraphs as traceable reasoning chains - Serialization for persistence """ def __init__(self, node_dim: int = 256, edge_dim: int = 64): self.node_dim = node_dim self.edge_dim = edge_dim self.node_features = torch.zeros(0, node_dim) self.node_texts: List[str] = [] self.node_types: List[str] = [] self.edge_index = torch.zeros(2, 0, dtype=torch.long) self.edge_features = torch.zeros(0, edge_dim) self.edge_texts: List[str] = [] self.relation_types: List[str] = [] self.entity_type_to_id: Dict[str, int] = {} self.relation_type_to_id: Dict[str, int] = {} self.start_node_idx: Optional[int] = None self.end_node_idx: Optional[int] = None @property def num_nodes(self) -> int: return self.node_features.shape[0] @property def num_edges(self) -> int: return self.edge_index.shape[1] @property def num_entity_types(self) -> int: return len(self.entity_type_to_id) @property def num_relation_types(self) -> int: return len(self.relation_type_to_id) def add_node(self, features: torch.Tensor, text: str = "", entity_type: str = "default") -> int: """Add a new fact node to the graph. Returns the node index.""" assert features.shape[-1] == self.node_dim if features.dim() == 1: features = features.unsqueeze(0) self.node_features = torch.cat([self.node_features, features], dim=0) self.node_texts.append(text) self.node_types.append(entity_type) if entity_type not in self.entity_type_to_id: self.entity_type_to_id[entity_type] = len(self.entity_type_to_id) return self.num_nodes - 1 def add_edge(self, src: int, dst: int, features: Optional[torch.Tensor] = None, text: str = "", relation_type: str = "default") -> int: """Add a directed edge (relationship) between two fact nodes.""" assert 0 <= src < self.num_nodes and 0 <= dst < self.num_nodes new_edge = torch.tensor([[src], [dst]], dtype=torch.long) self.edge_index = torch.cat([self.edge_index, new_edge], dim=1) if features is None: features = torch.zeros(1, self.edge_dim) elif features.dim() == 1: features = features.unsqueeze(0) self.edge_features = torch.cat([self.edge_features, features], dim=0) self.edge_texts.append(text) self.relation_types.append(relation_type) if relation_type not in self.relation_type_to_id: self.relation_type_to_id[relation_type] = len(self.relation_type_to_id) return self.num_edges - 1 def to_pyg_data(self) -> Data: """Convert to PyG Data object for GNN processing.""" rel_type_ids = torch.tensor( [self.relation_type_to_id.get(r, 0) for r in self.relation_types], dtype=torch.long ) if self.relation_types else torch.zeros(0, dtype=torch.long) entity_type_ids = torch.tensor( [self.entity_type_to_id.get(t, 0) for t in self.node_types], dtype=torch.long ) if self.node_types else torch.zeros(0, dtype=torch.long) return Data(x=self.node_features, edge_index=self.edge_index, edge_attr=self.edge_features, edge_type=rel_type_ids, node_type=entity_type_ids, num_nodes=self.num_nodes) def extract_subgraph(self, node_indices: torch.Tensor) -> 'KnowledgeGraph': """Extract a subgraph containing only the specified nodes.""" mask = torch.zeros(self.num_nodes, dtype=torch.bool) mask[node_indices] = True new_edge_index, new_edge_attr, edge_mask = subgraph( mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True) sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim) sub_kg.node_features = self.node_features[mask] sub_kg.node_texts = [self.node_texts[i] for i in node_indices.tolist()] sub_kg.node_types = [self.node_types[i] for i in node_indices.tolist()] sub_kg.edge_index = new_edge_index sub_kg.edge_features = new_edge_attr sub_kg.edge_texts = [self.edge_texts[i] for i, m in enumerate(edge_mask.tolist()) if m] sub_kg.relation_types = [self.relation_types[i] for i, m in enumerate(edge_mask.tolist()) if m] sub_kg.entity_type_to_id = dict(self.entity_type_to_id) sub_kg.relation_type_to_id = dict(self.relation_type_to_id) return sub_kg def extract_dag_subgraph(self, node_scores: torch.Tensor, threshold: float = 0.5) -> 'KnowledgeGraph': """Extract a DAG subgraph. Removes cycles by keeping only forward edges.""" mask = node_scores > threshold selected_indices = mask.nonzero(as_tuple=True)[0] if selected_indices.numel() == 0: return KnowledgeGraph(self.node_dim, self.edge_dim) new_edge_index, new_edge_attr, edge_mask = subgraph( mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True) if new_edge_index.shape[1] > 0: orig_src = self.edge_index[0][edge_mask] orig_dst = self.edge_index[1][edge_mask] dag_mask = node_scores[orig_src] <= node_scores[orig_dst] new_edge_index = new_edge_index[:, dag_mask] new_edge_attr = new_edge_attr[dag_mask] if new_edge_attr.numel() > 0 else new_edge_attr sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim) sub_kg.node_features = self.node_features[mask] sub_kg.node_texts = [self.node_texts[i] for i in selected_indices.tolist()] sub_kg.node_types = [self.node_types[i] for i in selected_indices.tolist()] sub_kg.edge_index = new_edge_index sub_kg.edge_features = new_edge_attr edge_text_indices = [i for i, m in enumerate(edge_mask.tolist()) if m] if new_edge_index.shape[1] > 0: dag_indices = dag_mask.nonzero(as_tuple=True)[0].tolist() sub_kg.edge_texts = [self.edge_texts[edge_text_indices[j]] for j in dag_indices] sub_kg.relation_types = [self.relation_types[edge_text_indices[j]] for j in dag_indices] else: sub_kg.edge_texts = [] sub_kg.relation_types = [] sub_kg.entity_type_to_id = dict(self.entity_type_to_id) sub_kg.relation_type_to_id = dict(self.relation_type_to_id) return sub_kg def get_neighbors(self, node_idx: int, direction: str = "out") -> torch.Tensor: if direction == "out": mask = self.edge_index[0] == node_idx return self.edge_index[1][mask] elif direction == "in": mask = self.edge_index[1] == node_idx return self.edge_index[0][mask] else: out_mask = self.edge_index[0] == node_idx in_mask = self.edge_index[1] == node_idx return torch.cat([self.edge_index[1][out_mask], self.edge_index[0][in_mask]]).unique() def save(self, path: str): torch.save({ 'node_features': self.node_features, 'node_texts': self.node_texts, 'node_types': self.node_types, 'edge_index': self.edge_index, 'edge_features': self.edge_features, 'edge_texts': self.edge_texts, 'relation_types': self.relation_types, 'entity_type_to_id': self.entity_type_to_id, 'relation_type_to_id': self.relation_type_to_id, 'node_dim': self.node_dim, 'edge_dim': self.edge_dim, 'start_node_idx': self.start_node_idx, 'end_node_idx': self.end_node_idx, }, path) @classmethod def load(cls, path: str) -> 'KnowledgeGraph': state = torch.load(path, weights_only=False) kg = cls(state['node_dim'], state['edge_dim']) kg.node_features = state['node_features'] kg.node_texts = state['node_texts'] kg.node_types = state['node_types'] kg.edge_index = state['edge_index'] kg.edge_features = state['edge_features'] kg.edge_texts = state['edge_texts'] kg.relation_types = state['relation_types'] kg.entity_type_to_id = state['entity_type_to_id'] kg.relation_type_to_id = state['relation_type_to_id'] kg.start_node_idx = state.get('start_node_idx') kg.end_node_idx = state.get('end_node_idx') return kg def __repr__(self): return (f"KnowledgeGraph(nodes={self.num_nodes}, edges={self.num_edges}, " f"entity_types={self.num_entity_types}, relation_types={self.num_relation_types})")