Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from sklearn.metrics import roc_auc_score, average_precision_score | |
| import numpy as np | |
| from model.HeteroGNN import HeteroGNN, LinkPredictor, NodeClassifier | |
| def train_link_prediction(data, edge_type=('company', 'owns', 'patent'), | |
| epochs=100, hidden_channels=64): | |
| """链接预测训练""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # 划分训练/验证边 | |
| edge_index = data[edge_type].edge_index | |
| num_edges = edge_index.size(1) | |
| perm = torch.randperm(num_edges) | |
| train_size = int(0.8 * num_edges) | |
| train_edge_index = edge_index[:, perm[:train_size]] | |
| val_edge_index = edge_index[:, perm[train_size:]] | |
| # 负采样 | |
| def get_neg_edges(edge_index, num_nodes_src, num_nodes_dst, num_neg): | |
| neg_edges = [] | |
| while len(neg_edges) < num_neg: | |
| src = torch.randint(0, num_nodes_src, (num_neg,)) | |
| dst = torch.randint(0, num_nodes_dst, (num_neg,)) | |
| neg = torch.stack([src, dst]) | |
| # 简化: 不检查重复 | |
| neg_edges.append(neg) | |
| if len(neg_edges) * num_neg >= num_neg: | |
| break | |
| return torch.cat(neg_edges, dim=1)[:, :num_neg] | |
| # 模型 | |
| gnn = HeteroGNN(hidden_channels, num_layers=2, metadata=data.metadata()).to(device) | |
| predictor = LinkPredictor(hidden_channels).to(device) | |
| optimizer = torch.optim.Adam( | |
| list(gnn.parameters()) + list(predictor.parameters()), | |
| lr=0.001 | |
| ) | |
| data = data.to(device) | |
| src_type, _, dst_type = edge_type | |
| for epoch in range(epochs): | |
| gnn.train() | |
| predictor.train() | |
| optimizer.zero_grad() | |
| # 前向传播 | |
| x_dict = gnn(data.x_dict, data.edge_index_dict) | |
| # 正样本 | |
| pos_pred = predictor( | |
| x_dict[src_type], | |
| x_dict[dst_type], | |
| train_edge_index | |
| ) | |
| # 负样本 | |
| neg_edge_index = get_neg_edges( | |
| train_edge_index, | |
| data[src_type].num_nodes, | |
| data[dst_type].num_nodes, | |
| train_edge_index.size(1) | |
| ).to(device) | |
| neg_pred = predictor( | |
| x_dict[src_type], | |
| x_dict[dst_type], | |
| neg_edge_index | |
| ) | |
| # 损失 | |
| loss = F.binary_cross_entropy_with_logits( | |
| torch.cat([pos_pred, neg_pred]), | |
| torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)]) | |
| ) | |
| loss.backward() | |
| optimizer.step() | |
| # 验证 | |
| if epoch % 10 == 0: | |
| gnn.eval() | |
| predictor.eval() | |
| with torch.no_grad(): | |
| x_dict = gnn(data.x_dict, data.edge_index_dict) | |
| val_pos_pred = predictor(x_dict[src_type], x_dict[dst_type], val_edge_index) | |
| val_neg_edge_index = get_neg_edges( | |
| val_edge_index, | |
| data[src_type].num_nodes, | |
| data[dst_type].num_nodes, | |
| val_edge_index.size(1) | |
| ).to(device) | |
| val_neg_pred = predictor(x_dict[src_type], x_dict[dst_type], val_neg_edge_index) | |
| preds = torch.cat([val_pos_pred, val_neg_pred]).sigmoid().cpu().numpy() | |
| labels = np.concatenate([np.ones(val_pos_pred.size(0)), np.zeros(val_neg_pred.size(0))]) | |
| auc = roc_auc_score(labels, preds) | |
| ap = average_precision_score(labels, preds) | |
| print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Val AUC: {auc:.4f} | Val AP: {ap:.4f}') | |
| return gnn, predictor | |
| def train_node_classification(data, node_type='company', target='industry', | |
| epochs=100, hidden_channels=64): | |
| """节点分类训练""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # 准备标签 | |
| labels = data[node_type][target] | |
| num_classes = labels.max().item() + 1 | |
| num_nodes = data[node_type].num_nodes | |
| # 划分 | |
| perm = torch.randperm(num_nodes) | |
| train_size = int(0.6 * num_nodes) | |
| val_size = int(0.2 * num_nodes) | |
| train_mask = torch.zeros(num_nodes, dtype=torch.bool) | |
| val_mask = torch.zeros(num_nodes, dtype=torch.bool) | |
| test_mask = torch.zeros(num_nodes, dtype=torch.bool) | |
| train_mask[perm[:train_size]] = True | |
| val_mask[perm[train_size:train_size + val_size]] = True | |
| test_mask[perm[train_size + val_size:]] = True | |
| # 模型 | |
| gnn = HeteroGNN(hidden_channels, num_layers=2, metadata=data.metadata()).to(device) | |
| classifier = NodeClassifier(hidden_channels, num_classes).to(device) | |
| optimizer = torch.optim.Adam( | |
| list(gnn.parameters()) + list(classifier.parameters()), | |
| lr=0.01 | |
| ) | |
| data = data.to(device) | |
| labels = labels.to(device) | |
| train_mask = train_mask.to(device) | |
| val_mask = val_mask.to(device) | |
| test_mask = test_mask.to(device) | |
| for epoch in range(epochs): | |
| gnn.train() | |
| classifier.train() | |
| optimizer.zero_grad() | |
| x_dict = gnn(data.x_dict, data.edge_index_dict) | |
| out = classifier(x_dict[node_type]) | |
| loss = F.cross_entropy(out[train_mask], labels[train_mask]) | |
| loss.backward() | |
| optimizer.step() | |
| if epoch % 10 == 0: | |
| gnn.eval() | |
| classifier.eval() | |
| with torch.no_grad(): | |
| x_dict = gnn(data.x_dict, data.edge_index_dict) | |
| out = classifier(x_dict[node_type]) | |
| pred = out.argmax(dim=1) | |
| train_acc = (pred[train_mask] == labels[train_mask]).float().mean() | |
| val_acc = (pred[val_mask] == labels[val_mask]).float().mean() | |
| print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}') | |
| # 测试 | |
| gnn.eval() | |
| classifier.eval() | |
| with torch.no_grad(): | |
| x_dict = gnn(data.x_dict, data.edge_index_dict) | |
| out = classifier(x_dict[node_type]) | |
| pred = out.argmax(dim=1) | |
| test_acc = (pred[test_mask] == labels[test_mask]).float().mean() | |
| print(f'\n🎯 Test Accuracy: {test_acc:.4f}') | |
| return gnn, classifier | |
| # train.py | |
| if __name__ == "__main__": | |
| from utils.data_generator import IPEcosystemGenerator | |
| # 生成数据 | |
| generator = IPEcosystemGenerator(seed=42) | |
| data = generator.generate( | |
| n_companies=500, | |
| n_patents=3000, | |
| n_trademarks=1500, | |
| n_persons=2000, | |
| n_institutions=50 | |
| ) | |
| print("\n" + "=" * 60) | |
| print("任务1: 链接预测 (企业-专利)") | |
| print("=" * 60) | |
| gnn1, pred1 = train_link_prediction( | |
| data, | |
| edge_type=('company', 'owns', 'patent'), | |
| epochs=500, | |
| hidden_channels=64 | |
| ) | |
| print("\n" + "=" * 60) | |
| print("任务2: 节点分类 (企业产业预测)") | |
| print("=" * 60) | |
| gnn2, cls2 = train_node_classification( | |
| data, | |
| node_type='company', | |
| target='industry', | |
| epochs=500, | |
| hidden_channels=64 | |
| ) | |
| print("\n✅ 训练完成!") |