Adam-Ben-Khalifa's picture
Upload grn/model.py
f7dcd05 verified
"""
Graph Reasoning Network (GRN) - Core Model
An LLM alternative that operates on knowledge graphs instead of text tokens.
See the full docstrings in the training script for architecture details.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from typing import Optional, Tuple, Dict
import math
class QueryEncoder(nn.Module):
def __init__(self, vocab_size=32000, embed_dim=256, num_heads=8, num_layers=4, max_seq_len=512):
super().__init__()
self.embed_dim = embed_dim
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4,
dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.layer_norm = nn.LayerNorm(embed_dim)
def forward(self, token_ids, attention_mask=None):
B, L = token_ids.shape
positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1)
x = self.token_embedding(token_ids) + self.position_embedding(positions)
mask = ~attention_mask.bool() if attention_mask is not None else None
x = self.transformer(x, src_key_padding_mask=mask)
if attention_mask is not None:
m = attention_mask.unsqueeze(-1).float()
pooled = (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1)
else:
pooled = x.mean(dim=1)
return self.layer_norm(self.output_proj(pooled))
class RelationAwareMessagePassing(MessagePassing):
def __init__(self, hidden_dim, edge_dim, num_relation_types=256):
super().__init__(aggr='add')
self.message_mlp = nn.Sequential(
nn.Linear(hidden_dim + edge_dim, hidden_dim * 2), nn.GELU(),
nn.Linear(hidden_dim * 2, hidden_dim))
self.gate_mlp = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim), nn.Sigmoid())
self.layer_norm = nn.LayerNorm(hidden_dim)
def forward(self, x, edge_index, edge_attr):
msg = self.propagate(edge_index, x=x, edge_attr=edge_attr)
gate = self.gate_mlp(torch.cat([x, msg], dim=-1))
return self.layer_norm(gate * msg + (1 - gate) * x)
def message(self, x_j, edge_attr):
return self.message_mlp(torch.cat([x_j, edge_attr], dim=-1))
class GraphNavigator(nn.Module):
def __init__(self, hidden_dim=256, edge_dim=64, num_layers=6, num_relation_types=256):
super().__init__()
self.num_layers = num_layers
self.query_proj = nn.Linear(hidden_dim, hidden_dim)
self.mp_layers = nn.ModuleList([
RelationAwareMessagePassing(hidden_dim, edge_dim, num_relation_types)
for _ in range(num_layers)])
self.query_attention = nn.ModuleList([
nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
for _ in range(num_layers)])
self.query_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
self.relevance_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
def forward(self, node_features, edge_index, edge_attr, query, start_node_mask=None):
h = node_features.clone()
if start_node_mask is not None:
qi = self.query_proj(query)
if qi.dim() == 2: qi = qi.squeeze(0)
h[start_node_mask.bool()] = h[start_node_mask.bool()] + qi
for i in range(self.num_layers):
h = self.mp_layers[i](h, edge_index, edge_attr)
hu = h.unsqueeze(0)
qu = query.unsqueeze(1) if query.dim() == 2 else query.unsqueeze(0).unsqueeze(1)
att, _ = self.query_attention[i](hu, qu, qu)
h = self.query_norms[i](h + att.squeeze(0))
return h, self.relevance_head(h)
class NodeCreator(nn.Module):
def __init__(self, hidden_dim=256, edge_dim=64):
super().__init__()
self.coverage_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(),
nn.Linear(hidden_dim, 1), nn.Sigmoid())
self.node_generator = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim * 2), nn.GELU(),
nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim))
self.edge_generator = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, edge_dim))
self.connection_scorer = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
def forward(self, node_features, query, relevance_scores):
if query.dim() == 2: query = query.squeeze(0)
w = torch.softmax(relevance_scores.squeeze(-1), dim=0)
gs = (node_features * w.unsqueeze(-1)).sum(dim=0)
coverage = self.coverage_head(torch.cat([gs, query], dim=-1))
new_node = self.node_generator(torch.cat([gs, query], dim=-1)).unsqueeze(0)
qe = query.unsqueeze(0).expand(node_features.shape[0], -1)
ci = torch.cat([node_features, qe], dim=-1)
return coverage, new_node, self.connection_scorer(ci).squeeze(-1), self.edge_generator(ci)
class EdgePredictor(nn.Module):
def __init__(self, hidden_dim=256, edge_dim=64, num_relation_types=256, gamma=12.0):
super().__init__()
self.gamma = gamma
self.complex_dim = hidden_dim // 2
self.head_proj = nn.Linear(hidden_dim, hidden_dim)
self.tail_proj = nn.Linear(hidden_dim, hidden_dim)
self.relation_phases = nn.Embedding(num_relation_types, self.complex_dim)
nn.init.uniform_(self.relation_phases.weight, 0, 2 * math.pi)
self.edge_feat_gen = nn.Sequential(
nn.Linear(self.complex_dim, edge_dim * 2), nn.GELU(), nn.Linear(edge_dim * 2, edge_dim))
def score_edges(self, head, tail, relation_ids):
h, t = self.head_proj(head), self.tail_proj(tail)
cd = self.complex_dim
re_h, im_h = h[:, :cd], h[:, cd:]
re_t, im_t = t[:, :cd], t[:, cd:]
phase = self.relation_phases(relation_ids)
re_r, im_r = torch.cos(phase), torch.sin(phase)
re_s = re_h * re_r - im_h * im_r - re_t
im_s = re_h * im_r + im_h * re_r - im_t
return self.gamma - torch.norm(torch.stack([re_s, im_s], dim=0), dim=0).sum(dim=-1)
def forward(self, node_features, candidate_edges, relation_ids):
scores = self.score_edges(node_features[candidate_edges[0]],
node_features[candidate_edges[1]], relation_ids)
return scores, self.edge_feat_gen(self.relation_phases(relation_ids))
class SubgraphExtractor(nn.Module):
def __init__(self, hidden_dim=256):
super().__init__()
self.node_selector = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
self.edge_selector = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
self.order_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1))
def forward(self, node_features, edge_index, edge_attr, query, relevance_scores):
if query.dim() == 2: query = query.squeeze(0)
N = node_features.shape[0]
qe = query.unsqueeze(0).expand(N, -1)
nl = self.node_selector(torch.cat([node_features, qe], dim=-1)).squeeze(-1)
np_ = torch.sigmoid(nl + relevance_scores.squeeze(-1))
if edge_index.shape[1] > 0:
sf = node_features[edge_index[0]]
df = node_features[edge_index[1]]
qee = query.unsqueeze(0).expand(edge_index.shape[1], -1)
el = self.edge_selector(torch.cat([sf, df, qee], dim=-1)).squeeze(-1)
ep = torch.sigmoid(el) * np_[edge_index[0]] * np_[edge_index[1]]
else:
ep = torch.zeros(0, device=node_features.device)
return np_, ep, self.order_head(node_features).squeeze(-1)
class GraphReasoningNetwork(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
hd = config.get('hidden_dim', 256)
ed = config.get('edge_dim', 64)
self.hidden_dim = hd
self.edge_dim = ed
self.query_encoder = QueryEncoder(config.get('vocab_size', 32000), hd, 8, config.get('num_encoder_layers', 4))
self.navigator = GraphNavigator(hd, ed, config.get('num_nav_layers', 6), config.get('num_relation_types', 256))
self.node_creator = NodeCreator(hd, ed)
self.edge_predictor = EdgePredictor(hd, ed, config.get('num_relation_types', 256))
self.subgraph_extractor = SubgraphExtractor(hd)
self.node_input_proj = nn.Linear(hd, hd)
self.edge_input_proj = nn.Linear(ed, ed)
def forward(self, token_ids, attention_mask, node_features, edge_index, edge_attr,
start_node_mask, target_node_mask=None, target_edge_mask=None,
target_new_nodes=None, candidate_edges=None, candidate_edge_labels=None,
candidate_edge_relations=None):
results = {}
query = self.query_encoder(token_ids, attention_mask)
h = self.node_input_proj(node_features)
e = self.edge_input_proj(edge_attr) if edge_attr.shape[0] > 0 else edge_attr
h_nav, rel = self.navigator(h, edge_index, e, query, start_node_mask)
results.update({'node_features': h_nav, 'relevance_scores': rel})
cov, nn_, cs, nef = self.node_creator(h_nav, query, rel)
results.update({'coverage': cov, 'new_node_features': nn_, 'connection_scores': cs})
if candidate_edges is not None and candidate_edges.shape[1] > 0:
es, pef = self.edge_predictor(h_nav, candidate_edges, candidate_edge_relations)
results['edge_scores'] = es
np_, ep, to = self.subgraph_extractor(h_nav, edge_index, e, query, rel)
results.update({'node_selection_probs': np_, 'edge_selection_probs': ep, 'topological_order': to})
losses = {}
if target_node_mask is not None:
losses['node_selection_loss'] = F.binary_cross_entropy(np_, target_node_mask.float())
if target_edge_mask is not None and ep.numel() > 0:
losses['edge_selection_loss'] = F.binary_cross_entropy(ep, target_edge_mask.float())
if target_new_nodes is not None:
losses['node_creation_loss'] = F.mse_loss(nn_, target_new_nodes)
if candidate_edge_labels is not None and 'edge_scores' in results:
losses['edge_prediction_loss'] = F.binary_cross_entropy_with_logits(results['edge_scores'], candidate_edge_labels.float())
if target_node_mask is not None:
losses['coverage_loss'] = F.mse_loss(cov.squeeze(), target_node_mask.float().mean())
if edge_index.shape[1] > 0 and target_edge_mask is not None:
ov = F.relu(to[edge_index[0]] - to[edge_index[1]] + 0.1)
losses['dag_ordering_loss'] = (ov * target_edge_mask.float()).mean()
results['losses'] = losses
if losses: results['total_loss'] = sum(losses.values())
return results
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@torch.no_grad()
def reason(self, token_ids, attention_mask, node_features, edge_index, edge_attr,
start_node_mask, node_threshold=0.5, edge_threshold=0.3, create_threshold=0.5):
self.eval()
r = self.forward(token_ids, attention_mask, node_features, edge_index, edge_attr, start_node_mask)
sn = r['node_selection_probs'] > node_threshold
se = r['edge_selection_probs'] > edge_threshold if r['edge_selection_probs'].numel() > 0 else torch.zeros(0, dtype=torch.bool)
order = r['topological_order']
if se.any():
v = sn[edge_index[0][se]] & sn[edge_index[1][se]]
sf = se.clone(); sf[se] = v
else: sf = se
if sf.any():
em = sf.nonzero(as_tuple=True)[0]
fwd = order[edge_index[0][em]] < order[edge_index[1][em]]
sd = torch.zeros_like(sf); sd[em[fwd]] = True
else: sd = sf
return {'selected_nodes': sn, 'selected_edges': sd, 'node_scores': r['node_selection_probs'],
'edge_scores': r['edge_selection_probs'], 'topological_order': order,
'relevance_scores': r['relevance_scores'], 'coverage': r['coverage'],
'new_node_features': r['new_node_features'], 'connection_scores': r['connection_scores'],
'should_create_node': r['coverage'].item() < create_threshold}