""" model.py — EdgeGNN definition, exactly matching the training notebook. """ import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import SAGEConv class EdgeGNN(nn.Module): """2-layer GraphSAGE node encoder + MLP edge classifier.""" def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int = 64, dropout: float = 0.2): super().__init__() self.conv1 = SAGEConv(in_dim, hidden_dim) self.conv2 = SAGEConv(hidden_dim, hidden_dim) self.dropout = dropout self.edge_mlp = nn.Sequential( nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, 1), ) def encode_nodes(self, x, edge_index): h = F.relu(self.conv1(x, edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) return self.conv2(h, edge_index) def edge_logits(self, node_emb, edge_index, edge_attr, local_idx=None): if local_idx is not None: s = edge_index[0, local_idx] d = edge_index[1, local_idx] ea = edge_attr[local_idx] else: s = edge_index[0] d = edge_index[1] ea = edge_attr return self.edge_mlp( torch.cat([node_emb[s], node_emb[d], ea], dim=1) ).squeeze(-1) def forward(self, x, edge_index, edge_attr): return self.edge_logits( self.encode_nodes(x, edge_index), edge_index, edge_attr )