""" Graph Reasoning Network (GRN) - Core Model An LLM alternative that operates on knowledge graphs instead of text tokens. See the full docstrings in the training script for architecture details. """ import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing from typing import Optional, Tuple, Dict import math class QueryEncoder(nn.Module): def __init__(self, vocab_size=32000, embed_dim=256, num_heads=8, num_layers=4, max_seq_len=512): super().__init__() self.embed_dim = embed_dim self.token_embedding = nn.Embedding(vocab_size, embed_dim) self.position_embedding = nn.Embedding(max_seq_len, embed_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=0.1, activation='gelu', batch_first=True, norm_first=True) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output_proj = nn.Linear(embed_dim, embed_dim) self.layer_norm = nn.LayerNorm(embed_dim) def forward(self, token_ids, attention_mask=None): B, L = token_ids.shape positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1) x = self.token_embedding(token_ids) + self.position_embedding(positions) mask = ~attention_mask.bool() if attention_mask is not None else None x = self.transformer(x, src_key_padding_mask=mask) if attention_mask is not None: m = attention_mask.unsqueeze(-1).float() pooled = (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1) else: pooled = x.mean(dim=1) return self.layer_norm(self.output_proj(pooled)) class RelationAwareMessagePassing(MessagePassing): def __init__(self, hidden_dim, edge_dim, num_relation_types=256): super().__init__(aggr='add') self.message_mlp = nn.Sequential( nn.Linear(hidden_dim + edge_dim, hidden_dim * 2), nn.GELU(), nn.Linear(hidden_dim * 2, hidden_dim)) self.gate_mlp = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim), nn.Sigmoid()) self.layer_norm = nn.LayerNorm(hidden_dim) def forward(self, x, edge_index, edge_attr): msg = self.propagate(edge_index, x=x, edge_attr=edge_attr) gate = self.gate_mlp(torch.cat([x, msg], dim=-1)) return self.layer_norm(gate * msg + (1 - gate) * x) def message(self, x_j, edge_attr): return self.message_mlp(torch.cat([x_j, edge_attr], dim=-1)) class GraphNavigator(nn.Module): def __init__(self, hidden_dim=256, edge_dim=64, num_layers=6, num_relation_types=256): super().__init__() self.num_layers = num_layers self.query_proj = nn.Linear(hidden_dim, hidden_dim) self.mp_layers = nn.ModuleList([ RelationAwareMessagePassing(hidden_dim, edge_dim, num_relation_types) for _ in range(num_layers)]) self.query_attention = nn.ModuleList([ nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True) for _ in range(num_layers)]) self.query_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)]) self.relevance_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) def forward(self, node_features, edge_index, edge_attr, query, start_node_mask=None): h = node_features.clone() if start_node_mask is not None: qi = self.query_proj(query) if qi.dim() == 2: qi = qi.squeeze(0) h[start_node_mask.bool()] = h[start_node_mask.bool()] + qi for i in range(self.num_layers): h = self.mp_layers[i](h, edge_index, edge_attr) hu = h.unsqueeze(0) qu = query.unsqueeze(1) if query.dim() == 2 else query.unsqueeze(0).unsqueeze(1) att, _ = self.query_attention[i](hu, qu, qu) h = self.query_norms[i](h + att.squeeze(0)) return h, self.relevance_head(h) class NodeCreator(nn.Module): def __init__(self, hidden_dim=256, edge_dim=64): super().__init__() self.coverage_head = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1), nn.Sigmoid()) self.node_generator = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim * 2), nn.GELU(), nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim)) self.edge_generator = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, edge_dim)) self.connection_scorer = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) def forward(self, node_features, query, relevance_scores): if query.dim() == 2: query = query.squeeze(0) w = torch.softmax(relevance_scores.squeeze(-1), dim=0) gs = (node_features * w.unsqueeze(-1)).sum(dim=0) coverage = self.coverage_head(torch.cat([gs, query], dim=-1)) new_node = self.node_generator(torch.cat([gs, query], dim=-1)).unsqueeze(0) qe = query.unsqueeze(0).expand(node_features.shape[0], -1) ci = torch.cat([node_features, qe], dim=-1) return coverage, new_node, self.connection_scorer(ci).squeeze(-1), self.edge_generator(ci) class EdgePredictor(nn.Module): def __init__(self, hidden_dim=256, edge_dim=64, num_relation_types=256, gamma=12.0): super().__init__() self.gamma = gamma self.complex_dim = hidden_dim // 2 self.head_proj = nn.Linear(hidden_dim, hidden_dim) self.tail_proj = nn.Linear(hidden_dim, hidden_dim) self.relation_phases = nn.Embedding(num_relation_types, self.complex_dim) nn.init.uniform_(self.relation_phases.weight, 0, 2 * math.pi) self.edge_feat_gen = nn.Sequential( nn.Linear(self.complex_dim, edge_dim * 2), nn.GELU(), nn.Linear(edge_dim * 2, edge_dim)) def score_edges(self, head, tail, relation_ids): h, t = self.head_proj(head), self.tail_proj(tail) cd = self.complex_dim re_h, im_h = h[:, :cd], h[:, cd:] re_t, im_t = t[:, :cd], t[:, cd:] phase = self.relation_phases(relation_ids) re_r, im_r = torch.cos(phase), torch.sin(phase) re_s = re_h * re_r - im_h * im_r - re_t im_s = re_h * im_r + im_h * re_r - im_t return self.gamma - torch.norm(torch.stack([re_s, im_s], dim=0), dim=0).sum(dim=-1) def forward(self, node_features, candidate_edges, relation_ids): scores = self.score_edges(node_features[candidate_edges[0]], node_features[candidate_edges[1]], relation_ids) return scores, self.edge_feat_gen(self.relation_phases(relation_ids)) class SubgraphExtractor(nn.Module): def __init__(self, hidden_dim=256): super().__init__() self.node_selector = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) self.edge_selector = nn.Sequential( nn.Linear(hidden_dim * 3, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) self.order_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) def forward(self, node_features, edge_index, edge_attr, query, relevance_scores): if query.dim() == 2: query = query.squeeze(0) N = node_features.shape[0] qe = query.unsqueeze(0).expand(N, -1) nl = self.node_selector(torch.cat([node_features, qe], dim=-1)).squeeze(-1) np_ = torch.sigmoid(nl + relevance_scores.squeeze(-1)) if edge_index.shape[1] > 0: sf = node_features[edge_index[0]] df = node_features[edge_index[1]] qee = query.unsqueeze(0).expand(edge_index.shape[1], -1) el = self.edge_selector(torch.cat([sf, df, qee], dim=-1)).squeeze(-1) ep = torch.sigmoid(el) * np_[edge_index[0]] * np_[edge_index[1]] else: ep = torch.zeros(0, device=node_features.device) return np_, ep, self.order_head(node_features).squeeze(-1) class GraphReasoningNetwork(nn.Module): def __init__(self, config): super().__init__() self.config = config hd = config.get('hidden_dim', 256) ed = config.get('edge_dim', 64) self.hidden_dim = hd self.edge_dim = ed self.query_encoder = QueryEncoder(config.get('vocab_size', 32000), hd, 8, config.get('num_encoder_layers', 4)) self.navigator = GraphNavigator(hd, ed, config.get('num_nav_layers', 6), config.get('num_relation_types', 256)) self.node_creator = NodeCreator(hd, ed) self.edge_predictor = EdgePredictor(hd, ed, config.get('num_relation_types', 256)) self.subgraph_extractor = SubgraphExtractor(hd) self.node_input_proj = nn.Linear(hd, hd) self.edge_input_proj = nn.Linear(ed, ed) def forward(self, token_ids, attention_mask, node_features, edge_index, edge_attr, start_node_mask, target_node_mask=None, target_edge_mask=None, target_new_nodes=None, candidate_edges=None, candidate_edge_labels=None, candidate_edge_relations=None): results = {} query = self.query_encoder(token_ids, attention_mask) h = self.node_input_proj(node_features) e = self.edge_input_proj(edge_attr) if edge_attr.shape[0] > 0 else edge_attr h_nav, rel = self.navigator(h, edge_index, e, query, start_node_mask) results.update({'node_features': h_nav, 'relevance_scores': rel}) cov, nn_, cs, nef = self.node_creator(h_nav, query, rel) results.update({'coverage': cov, 'new_node_features': nn_, 'connection_scores': cs}) if candidate_edges is not None and candidate_edges.shape[1] > 0: es, pef = self.edge_predictor(h_nav, candidate_edges, candidate_edge_relations) results['edge_scores'] = es np_, ep, to = self.subgraph_extractor(h_nav, edge_index, e, query, rel) results.update({'node_selection_probs': np_, 'edge_selection_probs': ep, 'topological_order': to}) losses = {} if target_node_mask is not None: losses['node_selection_loss'] = F.binary_cross_entropy(np_, target_node_mask.float()) if target_edge_mask is not None and ep.numel() > 0: losses['edge_selection_loss'] = F.binary_cross_entropy(ep, target_edge_mask.float()) if target_new_nodes is not None: losses['node_creation_loss'] = F.mse_loss(nn_, target_new_nodes) if candidate_edge_labels is not None and 'edge_scores' in results: losses['edge_prediction_loss'] = F.binary_cross_entropy_with_logits(results['edge_scores'], candidate_edge_labels.float()) if target_node_mask is not None: losses['coverage_loss'] = F.mse_loss(cov.squeeze(), target_node_mask.float().mean()) if edge_index.shape[1] > 0 and target_edge_mask is not None: ov = F.relu(to[edge_index[0]] - to[edge_index[1]] + 0.1) losses['dag_ordering_loss'] = (ov * target_edge_mask.float()).mean() results['losses'] = losses if losses: results['total_loss'] = sum(losses.values()) return results def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) @torch.no_grad() def reason(self, token_ids, attention_mask, node_features, edge_index, edge_attr, start_node_mask, node_threshold=0.5, edge_threshold=0.3, create_threshold=0.5): self.eval() r = self.forward(token_ids, attention_mask, node_features, edge_index, edge_attr, start_node_mask) sn = r['node_selection_probs'] > node_threshold se = r['edge_selection_probs'] > edge_threshold if r['edge_selection_probs'].numel() > 0 else torch.zeros(0, dtype=torch.bool) order = r['topological_order'] if se.any(): v = sn[edge_index[0][se]] & sn[edge_index[1][se]] sf = se.clone(); sf[se] = v else: sf = se if sf.any(): em = sf.nonzero(as_tuple=True)[0] fwd = order[edge_index[0][em]] < order[edge_index[1][em]] sd = torch.zeros_like(sf); sd[em[fwd]] = True else: sd = sf return {'selected_nodes': sn, 'selected_edges': sd, 'node_scores': r['node_selection_probs'], 'edge_scores': r['edge_selection_probs'], 'topological_order': order, 'relevance_scores': r['relevance_scores'], 'coverage': r['coverage'], 'new_node_features': r['new_node_features'], 'connection_scores': r['connection_scores'], 'should_create_node': r['coverage'].item() < create_threshold}