Spaces:
Sleeping
Sleeping
File size: 1,554 Bytes
a29c713 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | """
model.py — EdgeGNN definition, exactly matching the training notebook.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
class EdgeGNN(nn.Module):
"""2-layer GraphSAGE node encoder + MLP edge classifier."""
def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int = 64, dropout: float = 0.2):
super().__init__()
self.conv1 = SAGEConv(in_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
self.dropout = dropout
self.edge_mlp = nn.Sequential(
nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
def encode_nodes(self, x, edge_index):
h = F.relu(self.conv1(x, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.conv2(h, edge_index)
def edge_logits(self, node_emb, edge_index, edge_attr, local_idx=None):
if local_idx is not None:
s = edge_index[0, local_idx]
d = edge_index[1, local_idx]
ea = edge_attr[local_idx]
else:
s = edge_index[0]
d = edge_index[1]
ea = edge_attr
return self.edge_mlp(
torch.cat([node_emb[s], node_emb[d], ea], dim=1)
).squeeze(-1)
def forward(self, x, edge_index, edge_attr):
return self.edge_logits(
self.encode_nodes(x, edge_index), edge_index, edge_attr
)
|