Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import Linear | |
| import math | |
| from typing import Dict, List, Tuple, Optional | |
| class HGTLayer(nn.Module): | |
| """HGT层实现""" | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| node_types: List[str], | |
| edge_types: List[Tuple[str, str, str]], | |
| n_heads: int = 8, | |
| dropout: float = 0.2, | |
| use_norm: bool = True): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.node_types = node_types | |
| self.edge_types = edge_types | |
| self.n_heads = n_heads | |
| self.d_k = out_channels // n_heads | |
| self.dropout = dropout | |
| # 节点类型映射 | |
| self.node_type_to_id = {ntype: i for i, ntype in enumerate(node_types)} | |
| self.edge_type_to_id = {etype: i for i, etype in enumerate(edge_types)} | |
| # 1. 节点类型特定的线性变换层 | |
| self.k_linears = nn.ModuleDict() | |
| self.q_linears = nn.ModuleDict() | |
| self.v_linears = nn.ModuleDict() | |
| self.a_linears = nn.ModuleDict() | |
| for ntype in node_types: | |
| self.k_linears[ntype] = Linear(in_channels, out_channels) | |
| self.q_linears[ntype] = Linear(in_channels, out_channels) | |
| self.v_linears[ntype] = Linear(in_channels, out_channels) | |
| self.a_linears[ntype] = Linear(out_channels, out_channels) | |
| # 2. 边类型特定的注意力参数 | |
| self.edge_att = nn.ParameterDict() | |
| for i, (src_type, edge_type, dst_type) in enumerate(edge_types): | |
| edge_key = f"{src_type}__{edge_type}__{dst_type}" | |
| self.edge_att[edge_key] = nn.Parameter(torch.randn(n_heads, self.d_k)) | |
| # 3. 消息传递权重 | |
| self.msg_linears = nn.ModuleDict() | |
| for src_type, edge_type, dst_type in edge_types: | |
| edge_key = f"{src_type}__{edge_type}__{dst_type}" | |
| self.msg_linears[edge_key] = Linear(in_channels, out_channels) | |
| # 4. Layer Norm和残差连接 | |
| if use_norm: | |
| self.norms = nn.ModuleDict() | |
| for ntype in node_types: | |
| self.norms[ntype] = nn.LayerNorm(out_channels) | |
| else: | |
| self.norms = None | |
| self.dropout_layer = nn.Dropout(dropout) | |
| # 参数初始化 | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| """参数初始化""" | |
| for linear in self.k_linears.values(): | |
| nn.init.xavier_uniform_(linear.weight) | |
| for linear in self.q_linears.values(): | |
| nn.init.xavier_uniform_(linear.weight) | |
| for linear in self.v_linears.values(): | |
| nn.init.xavier_uniform_(linear.weight) | |
| for linear in self.a_linears.values(): | |
| nn.init.xavier_uniform_(linear.weight) | |
| for linear in self.msg_linears.values(): | |
| nn.init.xavier_uniform_(linear.weight) | |
| for param in self.edge_att.values(): | |
| nn.init.xavier_uniform_(param) | |
| def forward(self, | |
| x_dict: Dict[str, torch.Tensor], | |
| edge_index_dict: Dict[Tuple[str, str, str], torch.Tensor], | |
| edge_attr_dict: Optional[Dict[Tuple[str, str, str], torch.Tensor]] = None): | |
| # 1. 计算K, Q, V | |
| k_dict = {} | |
| q_dict = {} | |
| v_dict = {} | |
| for ntype, x in x_dict.items(): | |
| k_dict[ntype] = self.k_linears[ntype](x).view(-1, self.n_heads, self.d_k) | |
| q_dict[ntype] = self.q_linears[ntype](x).view(-1, self.n_heads, self.d_k) | |
| v_dict[ntype] = self.v_linears[ntype](x).view(-1, self.n_heads, self.d_k) | |
| # 2. 消息传递和注意力聚合 | |
| out_dict = {ntype: torch.zeros_like(v_dict[ntype]).view(-1, self.out_channels) | |
| for ntype in self.node_types} | |
| for edge_type, edge_index in edge_index_dict.items(): | |
| if edge_index.numel() == 0: # 空边 | |
| continue | |
| src_type, rel_type, dst_type = edge_type | |
| edge_key = f"{src_type}__{rel_type}__{dst_type}" | |
| # 获取源节点和目标节点的特征 | |
| src_idx, dst_idx = edge_index[0], edge_index[1] | |
| # K, Q, V | |
| k = k_dict[src_type][src_idx] # [num_edges, n_heads, d_k] | |
| q = q_dict[dst_type][dst_idx] # [num_edges, n_heads, d_k] | |
| v = v_dict[src_type][src_idx] # [num_edges, n_heads, d_k] | |
| # 边类型特定的注意力权重 | |
| edge_att_weight = self.edge_att[edge_key] # [n_heads, d_k] | |
| # 计算注意力分数 | |
| # q * k + edge_att_weight | |
| att_score = (q * k).sum(dim=-1) + (q * edge_att_weight.unsqueeze(0)).sum(dim=-1) | |
| att_score = att_score / math.sqrt(self.d_k) # [num_edges, n_heads] | |
| # 消息计算 | |
| msg = self.msg_linears[edge_key](x_dict[src_type][src_idx]) # [num_edges, out_channels] | |
| msg = msg.view(-1, self.n_heads, self.d_k) # [num_edges, n_heads, d_k] | |
| # 加权消息 | |
| att_score = F.softmax(att_score, dim=0) # 对每个目标节点的所有入边进行softmax | |
| msg = msg * att_score.unsqueeze(-1) # [num_edges, n_heads, d_k] | |
| # 聚合到目标节点 | |
| msg = msg.view(-1, self.out_channels) # [num_edges, out_channels] | |
| out_dict[dst_type] = out_dict[dst_type].index_add(0, dst_idx, msg) | |
| # 3. 残差连接和归一化 | |
| for ntype in self.node_types: | |
| # 自注意力 (节点自己到自己) | |
| if self.in_channels == self.out_channels: | |
| alpha = self.a_linears[ntype](out_dict[ntype]) | |
| out_dict[ntype] = alpha + x_dict[ntype] | |
| else: | |
| out_dict[ntype] = self.a_linears[ntype](out_dict[ntype]) | |
| # Layer Normalization | |
| if self.norms is not None: | |
| out_dict[ntype] = self.norms[ntype](out_dict[ntype]) | |
| # Dropout | |
| out_dict[ntype] = self.dropout_layer(out_dict[ntype]) | |
| return out_dict | |
| class HGT(nn.Module): | |
| """Heterogeneous Graph Transformer""" | |
| def __init__(self, | |
| hidden_channels: int = 64, | |
| out_channels: int = 64, | |
| num_layers: int = 3, | |
| n_heads: int = 8, | |
| dropout: float = 0.2, | |
| use_norm: bool = True, | |
| metadata: Optional[Tuple] = None): | |
| super().__init__() | |
| if metadata is None: | |
| raise ValueError("metadata is required for HGT") | |
| self.node_types = metadata[0] | |
| self.edge_types = metadata[1] | |
| self.hidden_channels = hidden_channels | |
| self.num_layers = num_layers | |
| # 输入投影层 (将不同维度的特征投影到统一维度) | |
| self.input_projections = nn.ModuleDict() | |
| for ntype in self.node_types: | |
| # 这里假设输入特征维度,实际使用时需要根据数据调整 | |
| self.input_projections[ntype] = Linear(-1, hidden_channels) | |
| # HGT层 | |
| self.layers = nn.ModuleList() | |
| for i in range(num_layers): | |
| layer = HGTLayer( | |
| in_channels=hidden_channels, | |
| out_channels=hidden_channels, | |
| node_types=self.node_types, | |
| edge_types=self.edge_types, | |
| n_heads=n_heads, | |
| dropout=dropout, | |
| use_norm=use_norm | |
| ) | |
| self.layers.append(layer) | |
| # 输出投影 | |
| if out_channels != hidden_channels: | |
| self.output_projections = nn.ModuleDict() | |
| for ntype in self.node_types: | |
| self.output_projections[ntype] = Linear(hidden_channels, out_channels) | |
| else: | |
| self.output_projections = None | |
| def forward(self, | |
| x_dict: Dict[str, torch.Tensor], | |
| edge_index_dict: Dict[Tuple[str, str, str], torch.Tensor], | |
| edge_attr_dict: Optional[Dict[Tuple[str, str, str], torch.Tensor]] = None): | |
| # 输入投影 | |
| for ntype in self.node_types: | |
| if ntype in x_dict: | |
| x_dict[ntype] = self.input_projections[ntype](x_dict[ntype]) | |
| # HGT层前向传播 | |
| for layer in self.layers: | |
| x_dict = layer(x_dict, edge_index_dict, edge_attr_dict) | |
| # 输出投影 | |
| if self.output_projections is not None: | |
| for ntype in self.node_types: | |
| if ntype in x_dict: | |
| x_dict[ntype] = self.output_projections[ntype](x_dict[ntype]) | |
| return x_dict | |
| class HGTLinkPredictor(nn.Module): | |
| """基于HGT的链接预测器""" | |
| def __init__(self, in_channels: int, hidden_channels: int = 64, num_layers: int = 2): | |
| super().__init__() | |
| layers = [] | |
| layers.append(Linear(2 * in_channels, hidden_channels)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| for _ in range(num_layers - 2): | |
| layers.append(Linear(hidden_channels, hidden_channels)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| layers.append(Linear(hidden_channels, 1)) | |
| self.predictor = nn.Sequential(*layers) | |
| def forward(self, x_src: torch.Tensor, x_dst: torch.Tensor, edge_label_index: torch.Tensor): | |
| src = x_src[edge_label_index[0]] | |
| dst = x_dst[edge_label_index[1]] | |
| x = torch.cat([src, dst], dim=-1) | |
| return self.predictor(x).squeeze() | |
| class HGTNodeClassifier(nn.Module): | |
| """基于HGT的节点分类器""" | |
| def __init__(self, in_channels: int, num_classes: int, hidden_channels: int = 64, num_layers: int = 2): | |
| super().__init__() | |
| layers = [] | |
| layers.append(Linear(in_channels, hidden_channels)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| for _ in range(num_layers - 2): | |
| layers.append(Linear(hidden_channels, hidden_channels)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| layers.append(Linear(hidden_channels, num_classes)) | |
| self.classifier = nn.Sequential(*layers) | |
| def forward(self, x: torch.Tensor): | |
| return self.classifier(x) | |
| # 在现有的 HGT.py 文件末尾继续添加: | |
| class HGTPatentValuePredictor(nn.Module): | |
| """基于HGT的专利价值评估器""" | |
| def __init__(self, in_channels: int, hidden_channels: int = 64, num_layers: int = 2): | |
| super().__init__() | |
| # 多维度特征融合层 | |
| self.tech_head = nn.Linear(in_channels, hidden_channels // 4) # 技术维度 | |
| self.market_head = nn.Linear(in_channels, hidden_channels // 4) # 市场维度 | |
| self.network_head = nn.Linear(in_channels, hidden_channels // 4) # 网络维度 | |
| self.time_head = nn.Linear(in_channels, hidden_channels // 4) # 时间维度 | |
| # 价值预测层 | |
| layers = [] | |
| layers.append(Linear(hidden_channels, hidden_channels // 2)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| for _ in range(num_layers - 2): | |
| layers.append(Linear(hidden_channels // 2, hidden_channels // 2)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| layers.append(Linear(hidden_channels // 2, 1)) | |
| layers.append(nn.Sigmoid()) # 输出 0-1 的价值分数 | |
| self.value_predictor = nn.Sequential(*layers) | |
| # 辅助任务头(提升泛化能力) | |
| self.grant_predictor = nn.Linear(hidden_channels, 1) # 授权概率 | |
| self.renewal_predictor = nn.Linear(hidden_channels, 1) # 续费概率 | |
| def forward(self, x: torch.Tensor, return_aux: bool = False): | |
| """ | |
| Args: | |
| x: 专利节点的嵌入表示 [num_patents, in_channels] | |
| return_aux: 是否返回辅助任务预测 | |
| """ | |
| # 多维度特征提取 | |
| tech_feat = F.relu(self.tech_head(x)) | |
| market_feat = F.relu(self.market_head(x)) | |
| network_feat = F.relu(self.network_head(x)) | |
| time_feat = F.relu(self.time_head(x)) | |
| # 特征融合 | |
| fused_feat = torch.cat([tech_feat, market_feat, network_feat, time_feat], dim=-1) | |
| # 价值预测 (0-100分) | |
| value_score = self.value_predictor(fused_feat) * 100 | |
| if return_aux: | |
| grant_prob = torch.sigmoid(self.grant_predictor(fused_feat)) | |
| renewal_prob = torch.sigmoid(self.renewal_predictor(fused_feat)) | |
| return value_score, grant_prob, renewal_prob | |
| return value_score | |
| class HGTCollaborationRecommender(nn.Module): | |
| """基于HGT的技术合作推荐器""" | |
| def __init__(self, in_channels: int, hidden_channels: int = 64, num_layers: int = 2): | |
| super().__init__() | |
| # 相似性计算模块 | |
| self.tech_similarity = nn.Sequential( | |
| Linear(2 * in_channels, hidden_channels), | |
| nn.ReLU(), | |
| nn.Linear(hidden_channels, 1) | |
| ) | |
| self.market_similarity = nn.Sequential( | |
| Linear(2 * in_channels, hidden_channels), | |
| nn.ReLU(), | |
| nn.Linear(hidden_channels, 1) | |
| ) | |
| # 互补性计算模块 | |
| self.complementarity = nn.Sequential( | |
| Linear(2 * in_channels, hidden_channels), | |
| nn.ReLU(), | |
| nn.Linear(hidden_channels, 1) | |
| ) | |
| # 成功概率预测器 | |
| layers = [] | |
| layers.append(Linear(2 * in_channels + 3, hidden_channels)) # +3 for similarity scores | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| for _ in range(num_layers - 2): | |
| layers.append(Linear(hidden_channels, hidden_channels)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.2)) | |
| layers.append(Linear(hidden_channels, 1)) | |
| layers.append(nn.Sigmoid()) | |
| self.success_predictor = nn.Sequential(*layers) | |
| def forward(self, x_company_a: torch.Tensor, x_company_b: torch.Tensor, | |
| edge_label_index: torch.Tensor): | |
| """ | |
| Args: | |
| x_company_a: 企业A的嵌入 [num_companies, in_channels] | |
| x_company_b: 企业B的嵌入 [num_companies, in_channels] | |
| edge_label_index: 待预测的企业对 [2, num_pairs] | |
| """ | |
| src = x_company_a[edge_label_index[0]] # [num_pairs, in_channels] | |
| dst = x_company_b[edge_label_index[1]] # [num_pairs, in_channels] | |
| # 计算相似性 | |
| combined = torch.cat([src, dst], dim=-1) | |
| tech_sim = torch.sigmoid(self.tech_similarity(combined)) | |
| market_sim = torch.sigmoid(self.market_similarity(combined)) | |
| complement = torch.sigmoid(self.complementarity(combined)) | |
| # 融合所有特征预测合作成功概率 | |
| final_features = torch.cat([combined, tech_sim, market_sim, complement], dim=-1) | |
| success_prob = self.success_predictor(final_features) | |
| return { | |
| 'success_probability': success_prob.squeeze(), | |
| 'tech_similarity': tech_sim.squeeze(), | |
| 'market_similarity': market_sim.squeeze(), | |
| 'complementarity': complement.squeeze() | |
| } | |