graph-reasoning-network / grn /knowledge_graph.py
Adam-Ben-Khalifa's picture
Upload grn/knowledge_graph.py
beb1305 verified
"""
Knowledge Graph data structure for the Graph Reasoning Network.
Each node represents a fact, each edge represents a relationship between facts.
The graph supports dynamic node/edge creation, navigation, and DAG subgraph extraction.
"""
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, k_hop_subgraph, to_undirected
from typing import Optional, List, Tuple, Dict
import json
class KnowledgeGraph:
"""
A dynamic knowledge graph that can grow over time.
Nodes: facts (with text descriptions and learned embeddings)
Edges: relationships between facts (typed, directed)
Supports:
- Adding new nodes and edges
- Navigating to find relevant subgraphs
- Extracting DAG subgraphs as traceable reasoning chains
- Serialization for persistence
"""
def __init__(self, node_dim: int = 256, edge_dim: int = 64):
self.node_dim = node_dim
self.edge_dim = edge_dim
self.node_features = torch.zeros(0, node_dim)
self.node_texts: List[str] = []
self.node_types: List[str] = []
self.edge_index = torch.zeros(2, 0, dtype=torch.long)
self.edge_features = torch.zeros(0, edge_dim)
self.edge_texts: List[str] = []
self.relation_types: List[str] = []
self.entity_type_to_id: Dict[str, int] = {}
self.relation_type_to_id: Dict[str, int] = {}
self.start_node_idx: Optional[int] = None
self.end_node_idx: Optional[int] = None
@property
def num_nodes(self) -> int:
return self.node_features.shape[0]
@property
def num_edges(self) -> int:
return self.edge_index.shape[1]
@property
def num_entity_types(self) -> int:
return len(self.entity_type_to_id)
@property
def num_relation_types(self) -> int:
return len(self.relation_type_to_id)
def add_node(self, features: torch.Tensor, text: str = "",
entity_type: str = "default") -> int:
"""Add a new fact node to the graph. Returns the node index."""
assert features.shape[-1] == self.node_dim
if features.dim() == 1:
features = features.unsqueeze(0)
self.node_features = torch.cat([self.node_features, features], dim=0)
self.node_texts.append(text)
self.node_types.append(entity_type)
if entity_type not in self.entity_type_to_id:
self.entity_type_to_id[entity_type] = len(self.entity_type_to_id)
return self.num_nodes - 1
def add_edge(self, src: int, dst: int, features: Optional[torch.Tensor] = None,
text: str = "", relation_type: str = "default") -> int:
"""Add a directed edge (relationship) between two fact nodes."""
assert 0 <= src < self.num_nodes and 0 <= dst < self.num_nodes
new_edge = torch.tensor([[src], [dst]], dtype=torch.long)
self.edge_index = torch.cat([self.edge_index, new_edge], dim=1)
if features is None:
features = torch.zeros(1, self.edge_dim)
elif features.dim() == 1:
features = features.unsqueeze(0)
self.edge_features = torch.cat([self.edge_features, features], dim=0)
self.edge_texts.append(text)
self.relation_types.append(relation_type)
if relation_type not in self.relation_type_to_id:
self.relation_type_to_id[relation_type] = len(self.relation_type_to_id)
return self.num_edges - 1
def to_pyg_data(self) -> Data:
"""Convert to PyG Data object for GNN processing."""
rel_type_ids = torch.tensor(
[self.relation_type_to_id.get(r, 0) for r in self.relation_types], dtype=torch.long
) if self.relation_types else torch.zeros(0, dtype=torch.long)
entity_type_ids = torch.tensor(
[self.entity_type_to_id.get(t, 0) for t in self.node_types], dtype=torch.long
) if self.node_types else torch.zeros(0, dtype=torch.long)
return Data(x=self.node_features, edge_index=self.edge_index,
edge_attr=self.edge_features, edge_type=rel_type_ids,
node_type=entity_type_ids, num_nodes=self.num_nodes)
def extract_subgraph(self, node_indices: torch.Tensor) -> 'KnowledgeGraph':
"""Extract a subgraph containing only the specified nodes."""
mask = torch.zeros(self.num_nodes, dtype=torch.bool)
mask[node_indices] = True
new_edge_index, new_edge_attr, edge_mask = subgraph(
mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
sub_kg.node_features = self.node_features[mask]
sub_kg.node_texts = [self.node_texts[i] for i in node_indices.tolist()]
sub_kg.node_types = [self.node_types[i] for i in node_indices.tolist()]
sub_kg.edge_index = new_edge_index
sub_kg.edge_features = new_edge_attr
sub_kg.edge_texts = [self.edge_texts[i] for i, m in enumerate(edge_mask.tolist()) if m]
sub_kg.relation_types = [self.relation_types[i] for i, m in enumerate(edge_mask.tolist()) if m]
sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
return sub_kg
def extract_dag_subgraph(self, node_scores: torch.Tensor,
threshold: float = 0.5) -> 'KnowledgeGraph':
"""Extract a DAG subgraph. Removes cycles by keeping only forward edges."""
mask = node_scores > threshold
selected_indices = mask.nonzero(as_tuple=True)[0]
if selected_indices.numel() == 0:
return KnowledgeGraph(self.node_dim, self.edge_dim)
new_edge_index, new_edge_attr, edge_mask = subgraph(
mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
if new_edge_index.shape[1] > 0:
orig_src = self.edge_index[0][edge_mask]
orig_dst = self.edge_index[1][edge_mask]
dag_mask = node_scores[orig_src] <= node_scores[orig_dst]
new_edge_index = new_edge_index[:, dag_mask]
new_edge_attr = new_edge_attr[dag_mask] if new_edge_attr.numel() > 0 else new_edge_attr
sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
sub_kg.node_features = self.node_features[mask]
sub_kg.node_texts = [self.node_texts[i] for i in selected_indices.tolist()]
sub_kg.node_types = [self.node_types[i] for i in selected_indices.tolist()]
sub_kg.edge_index = new_edge_index
sub_kg.edge_features = new_edge_attr
edge_text_indices = [i for i, m in enumerate(edge_mask.tolist()) if m]
if new_edge_index.shape[1] > 0:
dag_indices = dag_mask.nonzero(as_tuple=True)[0].tolist()
sub_kg.edge_texts = [self.edge_texts[edge_text_indices[j]] for j in dag_indices]
sub_kg.relation_types = [self.relation_types[edge_text_indices[j]] for j in dag_indices]
else:
sub_kg.edge_texts = []
sub_kg.relation_types = []
sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
return sub_kg
def get_neighbors(self, node_idx: int, direction: str = "out") -> torch.Tensor:
if direction == "out":
mask = self.edge_index[0] == node_idx
return self.edge_index[1][mask]
elif direction == "in":
mask = self.edge_index[1] == node_idx
return self.edge_index[0][mask]
else:
out_mask = self.edge_index[0] == node_idx
in_mask = self.edge_index[1] == node_idx
return torch.cat([self.edge_index[1][out_mask], self.edge_index[0][in_mask]]).unique()
def save(self, path: str):
torch.save({
'node_features': self.node_features, 'node_texts': self.node_texts,
'node_types': self.node_types, 'edge_index': self.edge_index,
'edge_features': self.edge_features, 'edge_texts': self.edge_texts,
'relation_types': self.relation_types,
'entity_type_to_id': self.entity_type_to_id,
'relation_type_to_id': self.relation_type_to_id,
'node_dim': self.node_dim, 'edge_dim': self.edge_dim,
'start_node_idx': self.start_node_idx, 'end_node_idx': self.end_node_idx,
}, path)
@classmethod
def load(cls, path: str) -> 'KnowledgeGraph':
state = torch.load(path, weights_only=False)
kg = cls(state['node_dim'], state['edge_dim'])
kg.node_features = state['node_features']
kg.node_texts = state['node_texts']
kg.node_types = state['node_types']
kg.edge_index = state['edge_index']
kg.edge_features = state['edge_features']
kg.edge_texts = state['edge_texts']
kg.relation_types = state['relation_types']
kg.entity_type_to_id = state['entity_type_to_id']
kg.relation_type_to_id = state['relation_type_to_id']
kg.start_node_idx = state.get('start_node_idx')
kg.end_node_idx = state.get('end_node_idx')
return kg
def __repr__(self):
return (f"KnowledgeGraph(nodes={self.num_nodes}, edges={self.num_edges}, "
f"entity_types={self.num_entity_types}, relation_types={self.num_relation_types})")