Adam-Ben-Khalifa commited on
Commit
beb1305
·
verified ·
1 Parent(s): ecda7d8

Upload grn/knowledge_graph.py

Browse files
Files changed (1) hide show
  1. grn/knowledge_graph.py +196 -0
grn/knowledge_graph.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge Graph data structure for the Graph Reasoning Network.
3
+
4
+ Each node represents a fact, each edge represents a relationship between facts.
5
+ The graph supports dynamic node/edge creation, navigation, and DAG subgraph extraction.
6
+ """
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch_geometric.data import Data
10
+ from torch_geometric.utils import subgraph, k_hop_subgraph, to_undirected
11
+ from typing import Optional, List, Tuple, Dict
12
+ import json
13
+
14
+
15
+ class KnowledgeGraph:
16
+ """
17
+ A dynamic knowledge graph that can grow over time.
18
+
19
+ Nodes: facts (with text descriptions and learned embeddings)
20
+ Edges: relationships between facts (typed, directed)
21
+
22
+ Supports:
23
+ - Adding new nodes and edges
24
+ - Navigating to find relevant subgraphs
25
+ - Extracting DAG subgraphs as traceable reasoning chains
26
+ - Serialization for persistence
27
+ """
28
+
29
+ def __init__(self, node_dim: int = 256, edge_dim: int = 64):
30
+ self.node_dim = node_dim
31
+ self.edge_dim = edge_dim
32
+ self.node_features = torch.zeros(0, node_dim)
33
+ self.node_texts: List[str] = []
34
+ self.node_types: List[str] = []
35
+ self.edge_index = torch.zeros(2, 0, dtype=torch.long)
36
+ self.edge_features = torch.zeros(0, edge_dim)
37
+ self.edge_texts: List[str] = []
38
+ self.relation_types: List[str] = []
39
+ self.entity_type_to_id: Dict[str, int] = {}
40
+ self.relation_type_to_id: Dict[str, int] = {}
41
+ self.start_node_idx: Optional[int] = None
42
+ self.end_node_idx: Optional[int] = None
43
+
44
+ @property
45
+ def num_nodes(self) -> int:
46
+ return self.node_features.shape[0]
47
+
48
+ @property
49
+ def num_edges(self) -> int:
50
+ return self.edge_index.shape[1]
51
+
52
+ @property
53
+ def num_entity_types(self) -> int:
54
+ return len(self.entity_type_to_id)
55
+
56
+ @property
57
+ def num_relation_types(self) -> int:
58
+ return len(self.relation_type_to_id)
59
+
60
+ def add_node(self, features: torch.Tensor, text: str = "",
61
+ entity_type: str = "default") -> int:
62
+ """Add a new fact node to the graph. Returns the node index."""
63
+ assert features.shape[-1] == self.node_dim
64
+ if features.dim() == 1:
65
+ features = features.unsqueeze(0)
66
+ self.node_features = torch.cat([self.node_features, features], dim=0)
67
+ self.node_texts.append(text)
68
+ self.node_types.append(entity_type)
69
+ if entity_type not in self.entity_type_to_id:
70
+ self.entity_type_to_id[entity_type] = len(self.entity_type_to_id)
71
+ return self.num_nodes - 1
72
+
73
+ def add_edge(self, src: int, dst: int, features: Optional[torch.Tensor] = None,
74
+ text: str = "", relation_type: str = "default") -> int:
75
+ """Add a directed edge (relationship) between two fact nodes."""
76
+ assert 0 <= src < self.num_nodes and 0 <= dst < self.num_nodes
77
+ new_edge = torch.tensor([[src], [dst]], dtype=torch.long)
78
+ self.edge_index = torch.cat([self.edge_index, new_edge], dim=1)
79
+ if features is None:
80
+ features = torch.zeros(1, self.edge_dim)
81
+ elif features.dim() == 1:
82
+ features = features.unsqueeze(0)
83
+ self.edge_features = torch.cat([self.edge_features, features], dim=0)
84
+ self.edge_texts.append(text)
85
+ self.relation_types.append(relation_type)
86
+ if relation_type not in self.relation_type_to_id:
87
+ self.relation_type_to_id[relation_type] = len(self.relation_type_to_id)
88
+ return self.num_edges - 1
89
+
90
+ def to_pyg_data(self) -> Data:
91
+ """Convert to PyG Data object for GNN processing."""
92
+ rel_type_ids = torch.tensor(
93
+ [self.relation_type_to_id.get(r, 0) for r in self.relation_types], dtype=torch.long
94
+ ) if self.relation_types else torch.zeros(0, dtype=torch.long)
95
+ entity_type_ids = torch.tensor(
96
+ [self.entity_type_to_id.get(t, 0) for t in self.node_types], dtype=torch.long
97
+ ) if self.node_types else torch.zeros(0, dtype=torch.long)
98
+ return Data(x=self.node_features, edge_index=self.edge_index,
99
+ edge_attr=self.edge_features, edge_type=rel_type_ids,
100
+ node_type=entity_type_ids, num_nodes=self.num_nodes)
101
+
102
+ def extract_subgraph(self, node_indices: torch.Tensor) -> 'KnowledgeGraph':
103
+ """Extract a subgraph containing only the specified nodes."""
104
+ mask = torch.zeros(self.num_nodes, dtype=torch.bool)
105
+ mask[node_indices] = True
106
+ new_edge_index, new_edge_attr, edge_mask = subgraph(
107
+ mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
108
+ sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
109
+ sub_kg.node_features = self.node_features[mask]
110
+ sub_kg.node_texts = [self.node_texts[i] for i in node_indices.tolist()]
111
+ sub_kg.node_types = [self.node_types[i] for i in node_indices.tolist()]
112
+ sub_kg.edge_index = new_edge_index
113
+ sub_kg.edge_features = new_edge_attr
114
+ sub_kg.edge_texts = [self.edge_texts[i] for i, m in enumerate(edge_mask.tolist()) if m]
115
+ sub_kg.relation_types = [self.relation_types[i] for i, m in enumerate(edge_mask.tolist()) if m]
116
+ sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
117
+ sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
118
+ return sub_kg
119
+
120
+ def extract_dag_subgraph(self, node_scores: torch.Tensor,
121
+ threshold: float = 0.5) -> 'KnowledgeGraph':
122
+ """Extract a DAG subgraph. Removes cycles by keeping only forward edges."""
123
+ mask = node_scores > threshold
124
+ selected_indices = mask.nonzero(as_tuple=True)[0]
125
+ if selected_indices.numel() == 0:
126
+ return KnowledgeGraph(self.node_dim, self.edge_dim)
127
+ new_edge_index, new_edge_attr, edge_mask = subgraph(
128
+ mask, self.edge_index, self.edge_features, relabel_nodes=True, return_edge_mask=True)
129
+ if new_edge_index.shape[1] > 0:
130
+ orig_src = self.edge_index[0][edge_mask]
131
+ orig_dst = self.edge_index[1][edge_mask]
132
+ dag_mask = node_scores[orig_src] <= node_scores[orig_dst]
133
+ new_edge_index = new_edge_index[:, dag_mask]
134
+ new_edge_attr = new_edge_attr[dag_mask] if new_edge_attr.numel() > 0 else new_edge_attr
135
+ sub_kg = KnowledgeGraph(self.node_dim, self.edge_dim)
136
+ sub_kg.node_features = self.node_features[mask]
137
+ sub_kg.node_texts = [self.node_texts[i] for i in selected_indices.tolist()]
138
+ sub_kg.node_types = [self.node_types[i] for i in selected_indices.tolist()]
139
+ sub_kg.edge_index = new_edge_index
140
+ sub_kg.edge_features = new_edge_attr
141
+ edge_text_indices = [i for i, m in enumerate(edge_mask.tolist()) if m]
142
+ if new_edge_index.shape[1] > 0:
143
+ dag_indices = dag_mask.nonzero(as_tuple=True)[0].tolist()
144
+ sub_kg.edge_texts = [self.edge_texts[edge_text_indices[j]] for j in dag_indices]
145
+ sub_kg.relation_types = [self.relation_types[edge_text_indices[j]] for j in dag_indices]
146
+ else:
147
+ sub_kg.edge_texts = []
148
+ sub_kg.relation_types = []
149
+ sub_kg.entity_type_to_id = dict(self.entity_type_to_id)
150
+ sub_kg.relation_type_to_id = dict(self.relation_type_to_id)
151
+ return sub_kg
152
+
153
+ def get_neighbors(self, node_idx: int, direction: str = "out") -> torch.Tensor:
154
+ if direction == "out":
155
+ mask = self.edge_index[0] == node_idx
156
+ return self.edge_index[1][mask]
157
+ elif direction == "in":
158
+ mask = self.edge_index[1] == node_idx
159
+ return self.edge_index[0][mask]
160
+ else:
161
+ out_mask = self.edge_index[0] == node_idx
162
+ in_mask = self.edge_index[1] == node_idx
163
+ return torch.cat([self.edge_index[1][out_mask], self.edge_index[0][in_mask]]).unique()
164
+
165
+ def save(self, path: str):
166
+ torch.save({
167
+ 'node_features': self.node_features, 'node_texts': self.node_texts,
168
+ 'node_types': self.node_types, 'edge_index': self.edge_index,
169
+ 'edge_features': self.edge_features, 'edge_texts': self.edge_texts,
170
+ 'relation_types': self.relation_types,
171
+ 'entity_type_to_id': self.entity_type_to_id,
172
+ 'relation_type_to_id': self.relation_type_to_id,
173
+ 'node_dim': self.node_dim, 'edge_dim': self.edge_dim,
174
+ 'start_node_idx': self.start_node_idx, 'end_node_idx': self.end_node_idx,
175
+ }, path)
176
+
177
+ @classmethod
178
+ def load(cls, path: str) -> 'KnowledgeGraph':
179
+ state = torch.load(path, weights_only=False)
180
+ kg = cls(state['node_dim'], state['edge_dim'])
181
+ kg.node_features = state['node_features']
182
+ kg.node_texts = state['node_texts']
183
+ kg.node_types = state['node_types']
184
+ kg.edge_index = state['edge_index']
185
+ kg.edge_features = state['edge_features']
186
+ kg.edge_texts = state['edge_texts']
187
+ kg.relation_types = state['relation_types']
188
+ kg.entity_type_to_id = state['entity_type_to_id']
189
+ kg.relation_type_to_id = state['relation_type_to_id']
190
+ kg.start_node_idx = state.get('start_node_idx')
191
+ kg.end_node_idx = state.get('end_node_idx')
192
+ return kg
193
+
194
+ def __repr__(self):
195
+ return (f"KnowledgeGraph(nodes={self.num_nodes}, edges={self.num_edges}, "
196
+ f"entity_types={self.num_entity_types}, relation_types={self.num_relation_types})")