GNN / model /HeteroGNN.py
Huxxshadow's picture
Upload 10 files
f9f7f3b verified
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear
class HeteroGNN(nn.Module):
"""异构图神经网络"""
def __init__(self, hidden_channels=64, num_layers=2, metadata=None):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
# 第一层
self.convs.append(HeteroConv({
edge_type: SAGEConv((-1, -1), hidden_channels)
for edge_type in metadata[1]
}, aggr='sum'))
# 中间层
for _ in range(num_layers - 1):
self.convs.append(HeteroConv({
edge_type: SAGEConv(hidden_channels, hidden_channels)
for edge_type in metadata[1]
}, aggr='sum'))
self.norms.append(nn.BatchNorm1d(hidden_channels))
def forward(self, x_dict, edge_index_dict):
for i, conv in enumerate(self.convs):
x_dict = conv(x_dict, edge_index_dict)
if i < len(self.norms):
x_dict = {key: self.norms[i](x) for key, x in x_dict.items()}
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
x_dict = {key: F.dropout(x, p=0.2, training=self.training)
for key, x in x_dict.items()}
return x_dict
class LinkPredictor(nn.Module):
"""链接预测头"""
def __init__(self, in_channels, hidden_channels=64):
super().__init__()
self.lin1 = Linear(2 * in_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)
def forward(self, x_src, x_dst, edge_label_index):
src = x_src[edge_label_index[0]]
dst = x_dst[edge_label_index[1]]
x = torch.cat([src, dst], dim=-1)
x = F.relu(self.lin1(x))
x = self.lin2(x)
return x.squeeze()
class NodeClassifier(nn.Module):
"""节点分类头"""
def __init__(self, in_channels, num_classes):
super().__init__()
self.lin = Linear(in_channels, num_classes)
def forward(self, x):
return self.lin(x)