File size: 9,500 Bytes
beb1305 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """
Knowledge Graph data structure for the Graph Reasoning Network.
Each node represents a fact, each edge represents a relationship between facts.
The graph supports dynamic node/edge creation, navigation, and DAG subgraph extraction.
"""
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, k_hop_subgraph, to_undirected
from typing import Optional, List, Tuple, Dict
import json
class KnowledgeGraph:
"""
A dynamic knowledge graph that can grow over time.
Nodes: facts (with text descriptions and learned embeddings)
Edges: relationships between facts (typed, directed)
Supports:
- Adding new nodes and edges
- Navigating to find relevant subgraphs
- Extracting DAG subgraphs as traceable reasoning chains
- Serialization for persistence
"""
def __init__(self, node_dim: int = 256, edge_dim: int = 64):
self.node_dim = node_dim
self.edge_dim = edge_dim
self.node_features = torch.zeros(0, node_dim)
self.node_texts: List[str] = []
self.node_types: List[str] = []
self.edge_index = torch.zeros(2, 0, dtype=torch.long)
self.edge_features = torch.zeros(0, edge_dim)
self.edge_texts: List[str] = []
self.relation_types: List[str] = []
self.entity_type_to_id: Dict[str, int] = {}
self.relation_type_to_id: Dict[str, int] = {}
self.start_node_idx: Optional[int] = None
self.end_node_idx: Optional[int] = None
@property
def num_nodes(self) -> int:
return self.node_features.shape[0]
@property
def num_edges(self) -> int:
return self.edge_index.shape[1]
@property
def num_entity_types(self) -> int:
return len(self.entity_type_to_id)
@property
def num_relation_types(self) -> int:
return len(self.relation_type_to_id)
def add_node(self, features: torch.Tensor, text: str = "",
entity_type: str = "default") -> int:
"""Add a new fact node to the graph. Returns the node index."""
assert features.shape[-1] == self.node_dim
if features.dim() == 1:
features = features.unsqueeze(0)
self.node_features = torch.cat([self.node_features, features], dim=0)
self.node_texts.append(text)
self.node_types.append(entity_type)
if entity_type not in self.entity_type_to_id:
self.entity_type_to_id[entity_type] = len(self.entity_type_to_id)
return self.num_nodes - 1
def add_edge(self, src: int, dst: int, features: Optional[torch.Tensor] = None,
text: str = "", relation_type: str = "default") -> int:
"""Add a directed edge (relationship) between two fact nodes."""
assert 0 <= src < self.num_nodes and 0 <= dst < self.num_nodes
new_edge = torch.tensor([[src], [dst]], dtype=torch.long)
self.edge_index = torch.cat([self.edge_index, new_edge], dim=1)
if features is None:
features = torch.zeros(1, self.edge_dim)
elif features.dim() == 1:
features = features.unsqueeze(0)
self.edge_features = torch.cat([self.edge_features, features], dim=0)
self.edge_texts.append(text)
self.relation_types.append(relation_type)
if relation_type not in self.relation_type_to_id:
self.relation_type_to_id[relation_type] = len(self.relation_type_to_id)
return self.num_edges - 1
def to_pyg_data(self) -> Data:
"""Convert to PyG Data object for GNN processing."""
rel_type_ids = torch.tensor(
[self.relation_type_to_id.get(r, 0) for r in self.relation_types], dtype=torch.long
) if self.relation_types else torch.zeros(0, dtype=torch.long)
entity_type_ids = torch.tensor(
[self.entity_type_to_id.get(t, 0) for t in self.node_types], dtype=torch.long
) if self.node_types else torch.zeros(0, dtype=torch.long)
return Data(x=self.node_features, edge_index=self.edge_index,
edge_attr=self.edge_features, edge_type=rel_type_ids,
node_type=entity_type_ids, num_nodes=self.num_nodes)
def extract_subgraph(self, node_indices: torch.Tensor) -> 'KnowledgeGraph':
"""Extract a subgraph containing only the specified nodes."""
mask = torch.zeros(self.num_nodes, dtype=torch.bool)
mask[node_indices] = True
new_edge_index, new_edge_attr, edge_mask = subgraph(
mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
sub_kg.node_features = self.node_features[mask]
sub_kg.node_texts = [self.node_texts[i] for i in node_indices.tolist()]
sub_kg.node_types = [self.node_types[i] for i in node_indices.tolist()]
sub_kg.edge_index = new_edge_index
sub_kg.edge_features = new_edge_attr
sub_kg.edge_texts = [self.edge_texts[i] for i, m in enumerate(edge_mask.tolist()) if m]
sub_kg.relation_types = [self.relation_types[i] for i, m in enumerate(edge_mask.tolist()) if m]
sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
return sub_kg
def extract_dag_subgraph(self, node_scores: torch.Tensor,
threshold: float = 0.5) -> 'KnowledgeGraph':
"""Extract a DAG subgraph. Removes cycles by keeping only forward edges."""
mask = node_scores > threshold
selected_indices = mask.nonzero(as_tuple=True)[0]
if selected_indices.numel() == 0:
return KnowledgeGraph(self.node_dim, self.edge_dim)
new_edge_index, new_edge_attr, edge_mask = subgraph(
mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
if new_edge_index.shape[1] > 0:
orig_src = self.edge_index[0][edge_mask]
orig_dst = self.edge_index[1][edge_mask]
dag_mask = node_scores[orig_src] <= node_scores[orig_dst]
new_edge_index = new_edge_index[:, dag_mask]
new_edge_attr = new_edge_attr[dag_mask] if new_edge_attr.numel() > 0 else new_edge_attr
sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
sub_kg.node_features = self.node_features[mask]
sub_kg.node_texts = [self.node_texts[i] for i in selected_indices.tolist()]
sub_kg.node_types = [self.node_types[i] for i in selected_indices.tolist()]
sub_kg.edge_index = new_edge_index
sub_kg.edge_features = new_edge_attr
edge_text_indices = [i for i, m in enumerate(edge_mask.tolist()) if m]
if new_edge_index.shape[1] > 0:
dag_indices = dag_mask.nonzero(as_tuple=True)[0].tolist()
sub_kg.edge_texts = [self.edge_texts[edge_text_indices[j]] for j in dag_indices]
sub_kg.relation_types = [self.relation_types[edge_text_indices[j]] for j in dag_indices]
else:
sub_kg.edge_texts = []
sub_kg.relation_types = []
sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
return sub_kg
def get_neighbors(self, node_idx: int, direction: str = "out") -> torch.Tensor:
if direction == "out":
mask = self.edge_index[0] == node_idx
return self.edge_index[1][mask]
elif direction == "in":
mask = self.edge_index[1] == node_idx
return self.edge_index[0][mask]
else:
out_mask = self.edge_index[0] == node_idx
in_mask = self.edge_index[1] == node_idx
return torch.cat([self.edge_index[1][out_mask], self.edge_index[0][in_mask]]).unique()
def save(self, path: str):
torch.save({
'node_features': self.node_features, 'node_texts': self.node_texts,
'node_types': self.node_types, 'edge_index': self.edge_index,
'edge_features': self.edge_features, 'edge_texts': self.edge_texts,
'relation_types': self.relation_types,
'entity_type_to_id': self.entity_type_to_id,
'relation_type_to_id': self.relation_type_to_id,
'node_dim': self.node_dim, 'edge_dim': self.edge_dim,
'start_node_idx': self.start_node_idx, 'end_node_idx': self.end_node_idx,
}, path)
@classmethod
def load(cls, path: str) -> 'KnowledgeGraph':
state = torch.load(path, weights_only=False)
kg = cls(state['node_dim'], state['edge_dim'])
kg.node_features = state['node_features']
kg.node_texts = state['node_texts']
kg.node_types = state['node_types']
kg.edge_index = state['edge_index']
kg.edge_features = state['edge_features']
kg.edge_texts = state['edge_texts']
kg.relation_types = state['relation_types']
kg.entity_type_to_id = state['entity_type_to_id']
kg.relation_type_to_id = state['relation_type_to_id']
kg.start_node_idx = state.get('start_node_idx')
kg.end_node_idx = state.get('end_node_idx')
return kg
def __repr__(self):
return (f"KnowledgeGraph(nodes={self.num_nodes}, edges={self.num_edges}, "
f"entity_types={self.num_entity_types}, relation_types={self.num_relation_types})")
|