Adam-Ben-Khalifa commited on
Commit
f7dcd05
·
verified ·
1 Parent(s): beb1305

Upload grn/model.py

Browse files
Files changed (1) hide show
  1. grn/model.py +246 -0
grn/model.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph Reasoning Network (GRN) - Core Model
3
+
4
+ An LLM alternative that operates on knowledge graphs instead of text tokens.
5
+ See the full docstrings in the training script for architecture details.
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch_geometric.nn import MessagePassing
11
+ from typing import Optional, Tuple, Dict
12
+ import math
13
+
14
+
15
+ class QueryEncoder(nn.Module):
16
+ def __init__(self, vocab_size=32000, embed_dim=256, num_heads=8, num_layers=4, max_seq_len=512):
17
+ super().__init__()
18
+ self.embed_dim = embed_dim
19
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
20
+ self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
21
+ encoder_layer = nn.TransformerEncoderLayer(
22
+ d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4,
23
+ dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
24
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
25
+ self.output_proj = nn.Linear(embed_dim, embed_dim)
26
+ self.layer_norm = nn.LayerNorm(embed_dim)
27
+
28
+ def forward(self, token_ids, attention_mask=None):
29
+ B, L = token_ids.shape
30
+ positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1)
31
+ x = self.token_embedding(token_ids) + self.position_embedding(positions)
32
+ mask = ~attention_mask.bool() if attention_mask is not None else None
33
+ x = self.transformer(x, src_key_padding_mask=mask)
34
+ if attention_mask is not None:
35
+ m = attention_mask.unsqueeze(-1).float()
36
+ pooled = (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1)
37
+ else:
38
+ pooled = x.mean(dim=1)
39
+ return self.layer_norm(self.output_proj(pooled))
40
+
41
+
42
+ class RelationAwareMessagePassing(MessagePassing):
43
+ def __init__(self, hidden_dim, edge_dim, num_relation_types=256):
44
+ super().__init__(aggr='add')
45
+ self.message_mlp = nn.Sequential(
46
+ nn.Linear(hidden_dim + edge_dim, hidden_dim * 2), nn.GELU(),
47
+ nn.Linear(hidden_dim * 2, hidden_dim))
48
+ self.gate_mlp = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim), nn.Sigmoid())
49
+ self.layer_norm = nn.LayerNorm(hidden_dim)
50
+
51
+ def forward(self, x, edge_index, edge_attr):
52
+ msg = self.propagate(edge_index, x=x, edge_attr=edge_attr)
53
+ gate = self.gate_mlp(torch.cat([x, msg], dim=-1))
54
+ return self.layer_norm(gate * msg + (1 - gate) * x)
55
+
56
+ def message(self, x_j, edge_attr):
57
+ return self.message_mlp(torch.cat([x_j, edge_attr], dim=-1))
58
+
59
+
60
+ class GraphNavigator(nn.Module):
61
+ def __init__(self, hidden_dim=256, edge_dim=64, num_layers=6, num_relation_types=256):
62
+ super().__init__()
63
+ self.num_layers = num_layers
64
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim)
65
+ self.mp_layers = nn.ModuleList([
66
+ RelationAwareMessagePassing(hidden_dim, edge_dim, num_relation_types)
67
+ for _ in range(num_layers)])
68
+ self.query_attention = nn.ModuleList([
69
+ nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
70
+ for _ in range(num_layers)])
71
+ self.query_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
72
+ self.relevance_head = nn.Sequential(
73
+ nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
74
+
75
+ def forward(self, node_features, edge_index, edge_attr, query, start_node_mask=None):
76
+ h = node_features.clone()
77
+ if start_node_mask is not None:
78
+ qi = self.query_proj(query)
79
+ if qi.dim() == 2: qi = qi.squeeze(0)
80
+ h[start_node_mask.bool()] = h[start_node_mask.bool()] + qi
81
+ for i in range(self.num_layers):
82
+ h = self.mp_layers[i](h, edge_index, edge_attr)
83
+ hu = h.unsqueeze(0)
84
+ qu = query.unsqueeze(1) if query.dim() == 2 else query.unsqueeze(0).unsqueeze(1)
85
+ att, _ = self.query_attention[i](hu, qu, qu)
86
+ h = self.query_norms[i](h + att.squeeze(0))
87
+ return h, self.relevance_head(h)
88
+
89
+
90
+ class NodeCreator(nn.Module):
91
+ def __init__(self, hidden_dim=256, edge_dim=64):
92
+ super().__init__()
93
+ self.coverage_head = nn.Sequential(
94
+ nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(),
95
+ nn.Linear(hidden_dim, 1), nn.Sigmoid())
96
+ self.node_generator = nn.Sequential(
97
+ nn.Linear(hidden_dim * 2, hidden_dim * 2), nn.GELU(),
98
+ nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim))
99
+ self.edge_generator = nn.Sequential(
100
+ nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, edge_dim))
101
+ self.connection_scorer = nn.Sequential(
102
+ nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
103
+
104
+ def forward(self, node_features, query, relevance_scores):
105
+ if query.dim() == 2: query = query.squeeze(0)
106
+ w = torch.softmax(relevance_scores.squeeze(-1), dim=0)
107
+ gs = (node_features * w.unsqueeze(-1)).sum(dim=0)
108
+ coverage = self.coverage_head(torch.cat([gs, query], dim=-1))
109
+ new_node = self.node_generator(torch.cat([gs, query], dim=-1)).unsqueeze(0)
110
+ qe = query.unsqueeze(0).expand(node_features.shape[0], -1)
111
+ ci = torch.cat([node_features, qe], dim=-1)
112
+ return coverage, new_node, self.connection_scorer(ci).squeeze(-1), self.edge_generator(ci)
113
+
114
+
115
+ class EdgePredictor(nn.Module):
116
+ def __init__(self, hidden_dim=256, edge_dim=64, num_relation_types=256, gamma=12.0):
117
+ super().__init__()
118
+ self.gamma = gamma
119
+ self.complex_dim = hidden_dim // 2
120
+ self.head_proj = nn.Linear(hidden_dim, hidden_dim)
121
+ self.tail_proj = nn.Linear(hidden_dim, hidden_dim)
122
+ self.relation_phases = nn.Embedding(num_relation_types, self.complex_dim)
123
+ nn.init.uniform_(self.relation_phases.weight, 0, 2 * math.pi)
124
+ self.edge_feat_gen = nn.Sequential(
125
+ nn.Linear(self.complex_dim, edge_dim * 2), nn.GELU(), nn.Linear(edge_dim * 2, edge_dim))
126
+
127
+ def score_edges(self, head, tail, relation_ids):
128
+ h, t = self.head_proj(head), self.tail_proj(tail)
129
+ cd = self.complex_dim
130
+ re_h, im_h = h[:, :cd], h[:, cd:]
131
+ re_t, im_t = t[:, :cd], t[:, cd:]
132
+ phase = self.relation_phases(relation_ids)
133
+ re_r, im_r = torch.cos(phase), torch.sin(phase)
134
+ re_s = re_h * re_r - im_h * im_r - re_t
135
+ im_s = re_h * im_r + im_h * re_r - im_t
136
+ return self.gamma - torch.norm(torch.stack([re_s, im_s], dim=0), dim=0).sum(dim=-1)
137
+
138
+ def forward(self, node_features, candidate_edges, relation_ids):
139
+ scores = self.score_edges(node_features[candidate_edges[0]],
140
+ node_features[candidate_edges[1]], relation_ids)
141
+ return scores, self.edge_feat_gen(self.relation_phases(relation_ids))
142
+
143
+
144
+ class SubgraphExtractor(nn.Module):
145
+ def __init__(self, hidden_dim=256):
146
+ super().__init__()
147
+ self.node_selector = nn.Sequential(
148
+ nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
149
+ self.edge_selector = nn.Sequential(
150
+ nn.Linear(hidden_dim * 3, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
151
+ self.order_head = nn.Sequential(
152
+ nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
153
+
154
+ def forward(self, node_features, edge_index, edge_attr, query, relevance_scores):
155
+ if query.dim() == 2: query = query.squeeze(0)
156
+ N = node_features.shape[0]
157
+ qe = query.unsqueeze(0).expand(N, -1)
158
+ nl = self.node_selector(torch.cat([node_features, qe], dim=-1)).squeeze(-1)
159
+ np_ = torch.sigmoid(nl + relevance_scores.squeeze(-1))
160
+ if edge_index.shape[1] > 0:
161
+ sf = node_features[edge_index[0]]
162
+ df = node_features[edge_index[1]]
163
+ qee = query.unsqueeze(0).expand(edge_index.shape[1], -1)
164
+ el = self.edge_selector(torch.cat([sf, df, qee], dim=-1)).squeeze(-1)
165
+ ep = torch.sigmoid(el) * np_[edge_index[0]] * np_[edge_index[1]]
166
+ else:
167
+ ep = torch.zeros(0, device=node_features.device)
168
+ return np_, ep, self.order_head(node_features).squeeze(-1)
169
+
170
+
171
+ class GraphReasoningNetwork(nn.Module):
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ self.config = config
175
+ hd = config.get('hidden_dim', 256)
176
+ ed = config.get('edge_dim', 64)
177
+ self.hidden_dim = hd
178
+ self.edge_dim = ed
179
+ self.query_encoder = QueryEncoder(config.get('vocab_size', 32000), hd, 8, config.get('num_encoder_layers', 4))
180
+ self.navigator = GraphNavigator(hd, ed, config.get('num_nav_layers', 6), config.get('num_relation_types', 256))
181
+ self.node_creator = NodeCreator(hd, ed)
182
+ self.edge_predictor = EdgePredictor(hd, ed, config.get('num_relation_types', 256))
183
+ self.subgraph_extractor = SubgraphExtractor(hd)
184
+ self.node_input_proj = nn.Linear(hd, hd)
185
+ self.edge_input_proj = nn.Linear(ed, ed)
186
+
187
+ def forward(self, token_ids, attention_mask, node_features, edge_index, edge_attr,
188
+ start_node_mask, target_node_mask=None, target_edge_mask=None,
189
+ target_new_nodes=None, candidate_edges=None, candidate_edge_labels=None,
190
+ candidate_edge_relations=None):
191
+ results = {}
192
+ query = self.query_encoder(token_ids, attention_mask)
193
+ h = self.node_input_proj(node_features)
194
+ e = self.edge_input_proj(edge_attr) if edge_attr.shape[0] > 0 else edge_attr
195
+ h_nav, rel = self.navigator(h, edge_index, e, query, start_node_mask)
196
+ results.update({'node_features': h_nav, 'relevance_scores': rel})
197
+ cov, nn_, cs, nef = self.node_creator(h_nav, query, rel)
198
+ results.update({'coverage': cov, 'new_node_features': nn_, 'connection_scores': cs})
199
+ if candidate_edges is not None and candidate_edges.shape[1] > 0:
200
+ es, pef = self.edge_predictor(h_nav, candidate_edges, candidate_edge_relations)
201
+ results['edge_scores'] = es
202
+ np_, ep, to = self.subgraph_extractor(h_nav, edge_index, e, query, rel)
203
+ results.update({'node_selection_probs': np_, 'edge_selection_probs': ep, 'topological_order': to})
204
+ losses = {}
205
+ if target_node_mask is not None:
206
+ losses['node_selection_loss'] = F.binary_cross_entropy(np_, target_node_mask.float())
207
+ if target_edge_mask is not None and ep.numel() > 0:
208
+ losses['edge_selection_loss'] = F.binary_cross_entropy(ep, target_edge_mask.float())
209
+ if target_new_nodes is not None:
210
+ losses['node_creation_loss'] = F.mse_loss(nn_, target_new_nodes)
211
+ if candidate_edge_labels is not None and 'edge_scores' in results:
212
+ losses['edge_prediction_loss'] = F.binary_cross_entropy_with_logits(results['edge_scores'], candidate_edge_labels.float())
213
+ if target_node_mask is not None:
214
+ losses['coverage_loss'] = F.mse_loss(cov.squeeze(), target_node_mask.float().mean())
215
+ if edge_index.shape[1] > 0 and target_edge_mask is not None:
216
+ ov = F.relu(to[edge_index[0]] - to[edge_index[1]] + 0.1)
217
+ losses['dag_ordering_loss'] = (ov * target_edge_mask.float()).mean()
218
+ results['losses'] = losses
219
+ if losses: results['total_loss'] = sum(losses.values())
220
+ return results
221
+
222
+ def count_parameters(self):
223
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
224
+
225
+ @torch.no_grad()
226
+ def reason(self, token_ids, attention_mask, node_features, edge_index, edge_attr,
227
+ start_node_mask, node_threshold=0.5, edge_threshold=0.3, create_threshold=0.5):
228
+ self.eval()
229
+ r = self.forward(token_ids, attention_mask, node_features, edge_index, edge_attr, start_node_mask)
230
+ sn = r['node_selection_probs'] > node_threshold
231
+ se = r['edge_selection_probs'] > edge_threshold if r['edge_selection_probs'].numel() > 0 else torch.zeros(0, dtype=torch.bool)
232
+ order = r['topological_order']
233
+ if se.any():
234
+ v = sn[edge_index[0][se]] & sn[edge_index[1][se]]
235
+ sf = se.clone(); sf[se] = v
236
+ else: sf = se
237
+ if sf.any():
238
+ em = sf.nonzero(as_tuple=True)[0]
239
+ fwd = order[edge_index[0][em]] < order[edge_index[1][em]]
240
+ sd = torch.zeros_like(sf); sd[em[fwd]] = True
241
+ else: sd = sf
242
+ return {'selected_nodes': sn, 'selected_edges': sd, 'node_scores': r['node_selection_probs'],
243
+ 'edge_scores': r['edge_selection_probs'], 'topological_order': order,
244
+ 'relevance_scores': r['relevance_scores'], 'coverage': r['coverage'],
245
+ 'new_node_features': r['new_node_features'], 'connection_scores': r['connection_scores'],
246
+ 'should_create_node': r['coverage'].item() < create_threshold}