GNN / train.py
Huxxshadow's picture
Upload 10 files
f9f7f3b verified
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
from model.HeteroGNN import HeteroGNN, LinkPredictor, NodeClassifier
def train_link_prediction(data, edge_type=('company', 'owns', 'patent'),
epochs=100, hidden_channels=64):
"""链接预测训练"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 划分训练/验证边
edge_index = data[edge_type].edge_index
num_edges = edge_index.size(1)
perm = torch.randperm(num_edges)
train_size = int(0.8 * num_edges)
train_edge_index = edge_index[:, perm[:train_size]]
val_edge_index = edge_index[:, perm[train_size:]]
# 负采样
def get_neg_edges(edge_index, num_nodes_src, num_nodes_dst, num_neg):
neg_edges = []
while len(neg_edges) < num_neg:
src = torch.randint(0, num_nodes_src, (num_neg,))
dst = torch.randint(0, num_nodes_dst, (num_neg,))
neg = torch.stack([src, dst])
# 简化: 不检查重复
neg_edges.append(neg)
if len(neg_edges) * num_neg >= num_neg:
break
return torch.cat(neg_edges, dim=1)[:, :num_neg]
# 模型
gnn = HeteroGNN(hidden_channels, num_layers=2, metadata=data.metadata()).to(device)
predictor = LinkPredictor(hidden_channels).to(device)
optimizer = torch.optim.Adam(
list(gnn.parameters()) + list(predictor.parameters()),
lr=0.001
)
data = data.to(device)
src_type, _, dst_type = edge_type
for epoch in range(epochs):
gnn.train()
predictor.train()
optimizer.zero_grad()
# 前向传播
x_dict = gnn(data.x_dict, data.edge_index_dict)
# 正样本
pos_pred = predictor(
x_dict[src_type],
x_dict[dst_type],
train_edge_index
)
# 负样本
neg_edge_index = get_neg_edges(
train_edge_index,
data[src_type].num_nodes,
data[dst_type].num_nodes,
train_edge_index.size(1)
).to(device)
neg_pred = predictor(
x_dict[src_type],
x_dict[dst_type],
neg_edge_index
)
# 损失
loss = F.binary_cross_entropy_with_logits(
torch.cat([pos_pred, neg_pred]),
torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)])
)
loss.backward()
optimizer.step()
# 验证
if epoch % 10 == 0:
gnn.eval()
predictor.eval()
with torch.no_grad():
x_dict = gnn(data.x_dict, data.edge_index_dict)
val_pos_pred = predictor(x_dict[src_type], x_dict[dst_type], val_edge_index)
val_neg_edge_index = get_neg_edges(
val_edge_index,
data[src_type].num_nodes,
data[dst_type].num_nodes,
val_edge_index.size(1)
).to(device)
val_neg_pred = predictor(x_dict[src_type], x_dict[dst_type], val_neg_edge_index)
preds = torch.cat([val_pos_pred, val_neg_pred]).sigmoid().cpu().numpy()
labels = np.concatenate([np.ones(val_pos_pred.size(0)), np.zeros(val_neg_pred.size(0))])
auc = roc_auc_score(labels, preds)
ap = average_precision_score(labels, preds)
print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Val AUC: {auc:.4f} | Val AP: {ap:.4f}')
return gnn, predictor
def train_node_classification(data, node_type='company', target='industry',
epochs=100, hidden_channels=64):
"""节点分类训练"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 准备标签
labels = data[node_type][target]
num_classes = labels.max().item() + 1
num_nodes = data[node_type].num_nodes
# 划分
perm = torch.randperm(num_nodes)
train_size = int(0.6 * num_nodes)
val_size = int(0.2 * num_nodes)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[perm[:train_size]] = True
val_mask[perm[train_size:train_size + val_size]] = True
test_mask[perm[train_size + val_size:]] = True
# 模型
gnn = HeteroGNN(hidden_channels, num_layers=2, metadata=data.metadata()).to(device)
classifier = NodeClassifier(hidden_channels, num_classes).to(device)
optimizer = torch.optim.Adam(
list(gnn.parameters()) + list(classifier.parameters()),
lr=0.01
)
data = data.to(device)
labels = labels.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)
for epoch in range(epochs):
gnn.train()
classifier.train()
optimizer.zero_grad()
x_dict = gnn(data.x_dict, data.edge_index_dict)
out = classifier(x_dict[node_type])
loss = F.cross_entropy(out[train_mask], labels[train_mask])
loss.backward()
optimizer.step()
if epoch % 10 == 0:
gnn.eval()
classifier.eval()
with torch.no_grad():
x_dict = gnn(data.x_dict, data.edge_index_dict)
out = classifier(x_dict[node_type])
pred = out.argmax(dim=1)
train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}')
# 测试
gnn.eval()
classifier.eval()
with torch.no_grad():
x_dict = gnn(data.x_dict, data.edge_index_dict)
out = classifier(x_dict[node_type])
pred = out.argmax(dim=1)
test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
print(f'\n🎯 Test Accuracy: {test_acc:.4f}')
return gnn, classifier
# train.py
if __name__ == "__main__":
from utils.data_generator import IPEcosystemGenerator
# 生成数据
generator = IPEcosystemGenerator(seed=42)
data = generator.generate(
n_companies=500,
n_patents=3000,
n_trademarks=1500,
n_persons=2000,
n_institutions=50
)
print("\n" + "=" * 60)
print("任务1: 链接预测 (企业-专利)")
print("=" * 60)
gnn1, pred1 = train_link_prediction(
data,
edge_type=('company', 'owns', 'patent'),
epochs=500,
hidden_channels=64
)
print("\n" + "=" * 60)
print("任务2: 节点分类 (企业产业预测)")
print("=" * 60)
gnn2, cls2 = train_node_classification(
data,
node_type='company',
target='industry',
epochs=500,
hidden_channels=64
)
print("\n✅ 训练完成!")