Spaces:
Sleeping
Sleeping
| """ | |
| model.py — EdgeGNN definition, exactly matching the training notebook. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import SAGEConv | |
| class EdgeGNN(nn.Module): | |
| """2-layer GraphSAGE node encoder + MLP edge classifier.""" | |
| def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int = 64, dropout: float = 0.2): | |
| super().__init__() | |
| self.conv1 = SAGEConv(in_dim, hidden_dim) | |
| self.conv2 = SAGEConv(hidden_dim, hidden_dim) | |
| self.dropout = dropout | |
| self.edge_mlp = nn.Sequential( | |
| nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, 1), | |
| ) | |
| def encode_nodes(self, x, edge_index): | |
| h = F.relu(self.conv1(x, edge_index)) | |
| h = F.dropout(h, p=self.dropout, training=self.training) | |
| return self.conv2(h, edge_index) | |
| def edge_logits(self, node_emb, edge_index, edge_attr, local_idx=None): | |
| if local_idx is not None: | |
| s = edge_index[0, local_idx] | |
| d = edge_index[1, local_idx] | |
| ea = edge_attr[local_idx] | |
| else: | |
| s = edge_index[0] | |
| d = edge_index[1] | |
| ea = edge_attr | |
| return self.edge_mlp( | |
| torch.cat([node_emb[s], node_emb[d], ea], dim=1) | |
| ).squeeze(-1) | |
| def forward(self, x, edge_index, edge_attr): | |
| return self.edge_logits( | |
| self.encode_nodes(x, edge_index), edge_index, edge_attr | |
| ) | |