Spaces:
Sleeping
Sleeping
| # model.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear | |
| class HeteroGNN(nn.Module): | |
| """异构图神经网络""" | |
| def __init__(self, hidden_channels=64, num_layers=2, metadata=None): | |
| super().__init__() | |
| self.convs = nn.ModuleList() | |
| self.norms = nn.ModuleList() | |
| # 第一层 | |
| self.convs.append(HeteroConv({ | |
| edge_type: SAGEConv((-1, -1), hidden_channels) | |
| for edge_type in metadata[1] | |
| }, aggr='sum')) | |
| # 中间层 | |
| for _ in range(num_layers - 1): | |
| self.convs.append(HeteroConv({ | |
| edge_type: SAGEConv(hidden_channels, hidden_channels) | |
| for edge_type in metadata[1] | |
| }, aggr='sum')) | |
| self.norms.append(nn.BatchNorm1d(hidden_channels)) | |
| def forward(self, x_dict, edge_index_dict): | |
| for i, conv in enumerate(self.convs): | |
| x_dict = conv(x_dict, edge_index_dict) | |
| if i < len(self.norms): | |
| x_dict = {key: self.norms[i](x) for key, x in x_dict.items()} | |
| x_dict = {key: F.relu(x) for key, x in x_dict.items()} | |
| x_dict = {key: F.dropout(x, p=0.2, training=self.training) | |
| for key, x in x_dict.items()} | |
| return x_dict | |
| class LinkPredictor(nn.Module): | |
| """链接预测头""" | |
| def __init__(self, in_channels, hidden_channels=64): | |
| super().__init__() | |
| self.lin1 = Linear(2 * in_channels, hidden_channels) | |
| self.lin2 = Linear(hidden_channels, 1) | |
| def forward(self, x_src, x_dst, edge_label_index): | |
| src = x_src[edge_label_index[0]] | |
| dst = x_dst[edge_label_index[1]] | |
| x = torch.cat([src, dst], dim=-1) | |
| x = F.relu(self.lin1(x)) | |
| x = self.lin2(x) | |
| return x.squeeze() | |
| class NodeClassifier(nn.Module): | |
| """节点分类头""" | |
| def __init__(self, in_channels, num_classes): | |
| super().__init__() | |
| self.lin = Linear(in_channels, num_classes) | |
| def forward(self, x): | |
| return self.lin(x) |