aml_app / model.py
waltertaya's picture
Added files
a29c713 verified
"""
model.py — EdgeGNN definition, exactly matching the training notebook.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
class EdgeGNN(nn.Module):
"""2-layer GraphSAGE node encoder + MLP edge classifier."""
def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int = 64, dropout: float = 0.2):
super().__init__()
self.conv1 = SAGEConv(in_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
self.dropout = dropout
self.edge_mlp = nn.Sequential(
nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
def encode_nodes(self, x, edge_index):
h = F.relu(self.conv1(x, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.conv2(h, edge_index)
def edge_logits(self, node_emb, edge_index, edge_attr, local_idx=None):
if local_idx is not None:
s = edge_index[0, local_idx]
d = edge_index[1, local_idx]
ea = edge_attr[local_idx]
else:
s = edge_index[0]
d = edge_index[1]
ea = edge_attr
return self.edge_mlp(
torch.cat([node_emb[s], node_emb[d], ea], dim=1)
).squeeze(-1)
def forward(self, x, edge_index, edge_attr):
return self.edge_logits(
self.encode_nodes(x, edge_index), edge_index, edge_attr
)