import torch import torch.nn.functional as F from sklearn.metrics import roc_auc_score, average_precision_score import numpy as np from model.HeteroGNN import HeteroGNN, LinkPredictor, NodeClassifier def train_link_prediction(data, edge_type=('company', 'owns', 'patent'), epochs=100, hidden_channels=64): """链接预测训练""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 划分训练/验证边 edge_index = data[edge_type].edge_index num_edges = edge_index.size(1) perm = torch.randperm(num_edges) train_size = int(0.8 * num_edges) train_edge_index = edge_index[:, perm[:train_size]] val_edge_index = edge_index[:, perm[train_size:]] # 负采样 def get_neg_edges(edge_index, num_nodes_src, num_nodes_dst, num_neg): neg_edges = [] while len(neg_edges) < num_neg: src = torch.randint(0, num_nodes_src, (num_neg,)) dst = torch.randint(0, num_nodes_dst, (num_neg,)) neg = torch.stack([src, dst]) # 简化: 不检查重复 neg_edges.append(neg) if len(neg_edges) * num_neg >= num_neg: break return torch.cat(neg_edges, dim=1)[:, :num_neg] # 模型 gnn = HeteroGNN(hidden_channels, num_layers=2, metadata=data.metadata()).to(device) predictor = LinkPredictor(hidden_channels).to(device) optimizer = torch.optim.Adam( list(gnn.parameters()) + list(predictor.parameters()), lr=0.001 ) data = data.to(device) src_type, _, dst_type = edge_type for epoch in range(epochs): gnn.train() predictor.train() optimizer.zero_grad() # 前向传播 x_dict = gnn(data.x_dict, data.edge_index_dict) # 正样本 pos_pred = predictor( x_dict[src_type], x_dict[dst_type], train_edge_index ) # 负样本 neg_edge_index = get_neg_edges( train_edge_index, data[src_type].num_nodes, data[dst_type].num_nodes, train_edge_index.size(1) ).to(device) neg_pred = predictor( x_dict[src_type], x_dict[dst_type], neg_edge_index ) # 损失 loss = F.binary_cross_entropy_with_logits( torch.cat([pos_pred, neg_pred]), torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)]) ) loss.backward() optimizer.step() # 验证 if epoch % 10 == 0: gnn.eval() predictor.eval() with torch.no_grad(): x_dict = gnn(data.x_dict, data.edge_index_dict) val_pos_pred = predictor(x_dict[src_type], x_dict[dst_type], val_edge_index) val_neg_edge_index = get_neg_edges( val_edge_index, data[src_type].num_nodes, data[dst_type].num_nodes, val_edge_index.size(1) ).to(device) val_neg_pred = predictor(x_dict[src_type], x_dict[dst_type], val_neg_edge_index) preds = torch.cat([val_pos_pred, val_neg_pred]).sigmoid().cpu().numpy() labels = np.concatenate([np.ones(val_pos_pred.size(0)), np.zeros(val_neg_pred.size(0))]) auc = roc_auc_score(labels, preds) ap = average_precision_score(labels, preds) print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Val AUC: {auc:.4f} | Val AP: {ap:.4f}') return gnn, predictor def train_node_classification(data, node_type='company', target='industry', epochs=100, hidden_channels=64): """节点分类训练""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 准备标签 labels = data[node_type][target] num_classes = labels.max().item() + 1 num_nodes = data[node_type].num_nodes # 划分 perm = torch.randperm(num_nodes) train_size = int(0.6 * num_nodes) val_size = int(0.2 * num_nodes) train_mask = torch.zeros(num_nodes, dtype=torch.bool) val_mask = torch.zeros(num_nodes, dtype=torch.bool) test_mask = torch.zeros(num_nodes, dtype=torch.bool) train_mask[perm[:train_size]] = True val_mask[perm[train_size:train_size + val_size]] = True test_mask[perm[train_size + val_size:]] = True # 模型 gnn = HeteroGNN(hidden_channels, num_layers=2, metadata=data.metadata()).to(device) classifier = NodeClassifier(hidden_channels, num_classes).to(device) optimizer = torch.optim.Adam( list(gnn.parameters()) + list(classifier.parameters()), lr=0.01 ) data = data.to(device) labels = labels.to(device) train_mask = train_mask.to(device) val_mask = val_mask.to(device) test_mask = test_mask.to(device) for epoch in range(epochs): gnn.train() classifier.train() optimizer.zero_grad() x_dict = gnn(data.x_dict, data.edge_index_dict) out = classifier(x_dict[node_type]) loss = F.cross_entropy(out[train_mask], labels[train_mask]) loss.backward() optimizer.step() if epoch % 10 == 0: gnn.eval() classifier.eval() with torch.no_grad(): x_dict = gnn(data.x_dict, data.edge_index_dict) out = classifier(x_dict[node_type]) pred = out.argmax(dim=1) train_acc = (pred[train_mask] == labels[train_mask]).float().mean() val_acc = (pred[val_mask] == labels[val_mask]).float().mean() print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}') # 测试 gnn.eval() classifier.eval() with torch.no_grad(): x_dict = gnn(data.x_dict, data.edge_index_dict) out = classifier(x_dict[node_type]) pred = out.argmax(dim=1) test_acc = (pred[test_mask] == labels[test_mask]).float().mean() print(f'\n🎯 Test Accuracy: {test_acc:.4f}') return gnn, classifier # train.py if __name__ == "__main__": from utils.data_generator import IPEcosystemGenerator # 生成数据 generator = IPEcosystemGenerator(seed=42) data = generator.generate( n_companies=500, n_patents=3000, n_trademarks=1500, n_persons=2000, n_institutions=50 ) print("\n" + "=" * 60) print("任务1: 链接预测 (企业-专利)") print("=" * 60) gnn1, pred1 = train_link_prediction( data, edge_type=('company', 'owns', 'patent'), epochs=500, hidden_channels=64 ) print("\n" + "=" * 60) print("任务2: 节点分类 (企业产业预测)") print("=" * 60) gnn2, cls2 = train_node_classification( data, node_type='company', target='industry', epochs=500, hidden_channels=64 ) print("\n✅ 训练完成!")