| """ |
| Graph Reasoning Network (GRN) - Core Model |
| |
| An LLM alternative that operates on knowledge graphs instead of text tokens. |
| See the full docstrings in the training script for architecture details. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.nn import MessagePassing |
| from typing import Optional, Tuple, Dict |
| import math |
|
|
|
|
| class QueryEncoder(nn.Module): |
| def __init__(self, vocab_size=32000, embed_dim=256, num_heads=8, num_layers=4, max_seq_len=512): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.token_embedding = nn.Embedding(vocab_size, embed_dim) |
| self.position_embedding = nn.Embedding(max_seq_len, embed_dim) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, |
| dropout=0.1, activation='gelu', batch_first=True, norm_first=True) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| self.output_proj = nn.Linear(embed_dim, embed_dim) |
| self.layer_norm = nn.LayerNorm(embed_dim) |
|
|
| def forward(self, token_ids, attention_mask=None): |
| B, L = token_ids.shape |
| positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1) |
| x = self.token_embedding(token_ids) + self.position_embedding(positions) |
| mask = ~attention_mask.bool() if attention_mask is not None else None |
| x = self.transformer(x, src_key_padding_mask=mask) |
| if attention_mask is not None: |
| m = attention_mask.unsqueeze(-1).float() |
| pooled = (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1) |
| else: |
| pooled = x.mean(dim=1) |
| return self.layer_norm(self.output_proj(pooled)) |
|
|
|
|
| class RelationAwareMessagePassing(MessagePassing): |
| def __init__(self, hidden_dim, edge_dim, num_relation_types=256): |
| super().__init__(aggr='add') |
| self.message_mlp = nn.Sequential( |
| nn.Linear(hidden_dim + edge_dim, hidden_dim * 2), nn.GELU(), |
| nn.Linear(hidden_dim * 2, hidden_dim)) |
| self.gate_mlp = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim), nn.Sigmoid()) |
| self.layer_norm = nn.LayerNorm(hidden_dim) |
|
|
| def forward(self, x, edge_index, edge_attr): |
| msg = self.propagate(edge_index, x=x, edge_attr=edge_attr) |
| gate = self.gate_mlp(torch.cat([x, msg], dim=-1)) |
| return self.layer_norm(gate * msg + (1 - gate) * x) |
|
|
| def message(self, x_j, edge_attr): |
| return self.message_mlp(torch.cat([x_j, edge_attr], dim=-1)) |
|
|
|
|
| class GraphNavigator(nn.Module): |
| def __init__(self, hidden_dim=256, edge_dim=64, num_layers=6, num_relation_types=256): |
| super().__init__() |
| self.num_layers = num_layers |
| self.query_proj = nn.Linear(hidden_dim, hidden_dim) |
| self.mp_layers = nn.ModuleList([ |
| RelationAwareMessagePassing(hidden_dim, edge_dim, num_relation_types) |
| for _ in range(num_layers)]) |
| self.query_attention = nn.ModuleList([ |
| nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True) |
| for _ in range(num_layers)]) |
| self.query_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)]) |
| self.relevance_head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) |
|
|
| def forward(self, node_features, edge_index, edge_attr, query, start_node_mask=None): |
| h = node_features.clone() |
| if start_node_mask is not None: |
| qi = self.query_proj(query) |
| if qi.dim() == 2: qi = qi.squeeze(0) |
| h[start_node_mask.bool()] = h[start_node_mask.bool()] + qi |
| for i in range(self.num_layers): |
| h = self.mp_layers[i](h, edge_index, edge_attr) |
| hu = h.unsqueeze(0) |
| qu = query.unsqueeze(1) if query.dim() == 2 else query.unsqueeze(0).unsqueeze(1) |
| att, _ = self.query_attention[i](hu, qu, qu) |
| h = self.query_norms[i](h + att.squeeze(0)) |
| return h, self.relevance_head(h) |
|
|
|
|
| class NodeCreator(nn.Module): |
| def __init__(self, hidden_dim=256, edge_dim=64): |
| super().__init__() |
| self.coverage_head = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), |
| nn.Linear(hidden_dim, 1), nn.Sigmoid()) |
| self.node_generator = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim * 2), nn.GELU(), |
| nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim)) |
| self.edge_generator = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, edge_dim)) |
| self.connection_scorer = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) |
|
|
| def forward(self, node_features, query, relevance_scores): |
| if query.dim() == 2: query = query.squeeze(0) |
| w = torch.softmax(relevance_scores.squeeze(-1), dim=0) |
| gs = (node_features * w.unsqueeze(-1)).sum(dim=0) |
| coverage = self.coverage_head(torch.cat([gs, query], dim=-1)) |
| new_node = self.node_generator(torch.cat([gs, query], dim=-1)).unsqueeze(0) |
| qe = query.unsqueeze(0).expand(node_features.shape[0], -1) |
| ci = torch.cat([node_features, qe], dim=-1) |
| return coverage, new_node, self.connection_scorer(ci).squeeze(-1), self.edge_generator(ci) |
|
|
|
|
| class EdgePredictor(nn.Module): |
| def __init__(self, hidden_dim=256, edge_dim=64, num_relation_types=256, gamma=12.0): |
| super().__init__() |
| self.gamma = gamma |
| self.complex_dim = hidden_dim // 2 |
| self.head_proj = nn.Linear(hidden_dim, hidden_dim) |
| self.tail_proj = nn.Linear(hidden_dim, hidden_dim) |
| self.relation_phases = nn.Embedding(num_relation_types, self.complex_dim) |
| nn.init.uniform_(self.relation_phases.weight, 0, 2 * math.pi) |
| self.edge_feat_gen = nn.Sequential( |
| nn.Linear(self.complex_dim, edge_dim * 2), nn.GELU(), nn.Linear(edge_dim * 2, edge_dim)) |
|
|
| def score_edges(self, head, tail, relation_ids): |
| h, t = self.head_proj(head), self.tail_proj(tail) |
| cd = self.complex_dim |
| re_h, im_h = h[:, :cd], h[:, cd:] |
| re_t, im_t = t[:, :cd], t[:, cd:] |
| phase = self.relation_phases(relation_ids) |
| re_r, im_r = torch.cos(phase), torch.sin(phase) |
| re_s = re_h * re_r - im_h * im_r - re_t |
| im_s = re_h * im_r + im_h * re_r - im_t |
| return self.gamma - torch.norm(torch.stack([re_s, im_s], dim=0), dim=0).sum(dim=-1) |
|
|
| def forward(self, node_features, candidate_edges, relation_ids): |
| scores = self.score_edges(node_features[candidate_edges[0]], |
| node_features[candidate_edges[1]], relation_ids) |
| return scores, self.edge_feat_gen(self.relation_phases(relation_ids)) |
|
|
|
|
| class SubgraphExtractor(nn.Module): |
| def __init__(self, hidden_dim=256): |
| super().__init__() |
| self.node_selector = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) |
| self.edge_selector = nn.Sequential( |
| nn.Linear(hidden_dim * 3, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) |
| self.order_head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) |
|
|
| def forward(self, node_features, edge_index, edge_attr, query, relevance_scores): |
| if query.dim() == 2: query = query.squeeze(0) |
| N = node_features.shape[0] |
| qe = query.unsqueeze(0).expand(N, -1) |
| nl = self.node_selector(torch.cat([node_features, qe], dim=-1)).squeeze(-1) |
| np_ = torch.sigmoid(nl + relevance_scores.squeeze(-1)) |
| if edge_index.shape[1] > 0: |
| sf = node_features[edge_index[0]] |
| df = node_features[edge_index[1]] |
| qee = query.unsqueeze(0).expand(edge_index.shape[1], -1) |
| el = self.edge_selector(torch.cat([sf, df, qee], dim=-1)).squeeze(-1) |
| ep = torch.sigmoid(el) * np_[edge_index[0]] * np_[edge_index[1]] |
| else: |
| ep = torch.zeros(0, device=node_features.device) |
| return np_, ep, self.order_head(node_features).squeeze(-1) |
|
|
|
|
| class GraphReasoningNetwork(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| hd = config.get('hidden_dim', 256) |
| ed = config.get('edge_dim', 64) |
| self.hidden_dim = hd |
| self.edge_dim = ed |
| self.query_encoder = QueryEncoder(config.get('vocab_size', 32000), hd, 8, config.get('num_encoder_layers', 4)) |
| self.navigator = GraphNavigator(hd, ed, config.get('num_nav_layers', 6), config.get('num_relation_types', 256)) |
| self.node_creator = NodeCreator(hd, ed) |
| self.edge_predictor = EdgePredictor(hd, ed, config.get('num_relation_types', 256)) |
| self.subgraph_extractor = SubgraphExtractor(hd) |
| self.node_input_proj = nn.Linear(hd, hd) |
| self.edge_input_proj = nn.Linear(ed, ed) |
|
|
| def forward(self, token_ids, attention_mask, node_features, edge_index, edge_attr, |
| start_node_mask, target_node_mask=None, target_edge_mask=None, |
| target_new_nodes=None, candidate_edges=None, candidate_edge_labels=None, |
| candidate_edge_relations=None): |
| results = {} |
| query = self.query_encoder(token_ids, attention_mask) |
| h = self.node_input_proj(node_features) |
| e = self.edge_input_proj(edge_attr) if edge_attr.shape[0] > 0 else edge_attr |
| h_nav, rel = self.navigator(h, edge_index, e, query, start_node_mask) |
| results.update({'node_features': h_nav, 'relevance_scores': rel}) |
| cov, nn_, cs, nef = self.node_creator(h_nav, query, rel) |
| results.update({'coverage': cov, 'new_node_features': nn_, 'connection_scores': cs}) |
| if candidate_edges is not None and candidate_edges.shape[1] > 0: |
| es, pef = self.edge_predictor(h_nav, candidate_edges, candidate_edge_relations) |
| results['edge_scores'] = es |
| np_, ep, to = self.subgraph_extractor(h_nav, edge_index, e, query, rel) |
| results.update({'node_selection_probs': np_, 'edge_selection_probs': ep, 'topological_order': to}) |
| losses = {} |
| if target_node_mask is not None: |
| losses['node_selection_loss'] = F.binary_cross_entropy(np_, target_node_mask.float()) |
| if target_edge_mask is not None and ep.numel() > 0: |
| losses['edge_selection_loss'] = F.binary_cross_entropy(ep, target_edge_mask.float()) |
| if target_new_nodes is not None: |
| losses['node_creation_loss'] = F.mse_loss(nn_, target_new_nodes) |
| if candidate_edge_labels is not None and 'edge_scores' in results: |
| losses['edge_prediction_loss'] = F.binary_cross_entropy_with_logits(results['edge_scores'], candidate_edge_labels.float()) |
| if target_node_mask is not None: |
| losses['coverage_loss'] = F.mse_loss(cov.squeeze(), target_node_mask.float().mean()) |
| if edge_index.shape[1] > 0 and target_edge_mask is not None: |
| ov = F.relu(to[edge_index[0]] - to[edge_index[1]] + 0.1) |
| losses['dag_ordering_loss'] = (ov * target_edge_mask.float()).mean() |
| results['losses'] = losses |
| if losses: results['total_loss'] = sum(losses.values()) |
| return results |
|
|
| def count_parameters(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| @torch.no_grad() |
| def reason(self, token_ids, attention_mask, node_features, edge_index, edge_attr, |
| start_node_mask, node_threshold=0.5, edge_threshold=0.3, create_threshold=0.5): |
| self.eval() |
| r = self.forward(token_ids, attention_mask, node_features, edge_index, edge_attr, start_node_mask) |
| sn = r['node_selection_probs'] > node_threshold |
| se = r['edge_selection_probs'] > edge_threshold if r['edge_selection_probs'].numel() > 0 else torch.zeros(0, dtype=torch.bool) |
| order = r['topological_order'] |
| if se.any(): |
| v = sn[edge_index[0][se]] & sn[edge_index[1][se]] |
| sf = se.clone(); sf[se] = v |
| else: sf = se |
| if sf.any(): |
| em = sf.nonzero(as_tuple=True)[0] |
| fwd = order[edge_index[0][em]] < order[edge_index[1][em]] |
| sd = torch.zeros_like(sf); sd[em[fwd]] = True |
| else: sd = sf |
| return {'selected_nodes': sn, 'selected_edges': sd, 'node_scores': r['node_selection_probs'], |
| 'edge_scores': r['edge_selection_probs'], 'topological_order': order, |
| 'relevance_scores': r['relevance_scores'], 'coverage': r['coverage'], |
| 'new_node_features': r['new_node_features'], 'connection_scores': r['connection_scores'], |
| 'should_create_node': r['coverage'].item() < create_threshold} |
|
|