TransKP / model.py
KangjieXu's picture
Upload 8 files
f8095e3 verified
import torch
import torch.nn as nn
from transformers import EsmModel
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data, Batch
from rdkit import Chem
from rdkit.Chem import AllChem
# --- Helper Functions for Graph Creation ---
#将一个原子对象(atom)转换为一个 one-hot 编码向量,表示该原子的类型。返回 15维 向量,表示原子的类型,1 表示该原子类型,0 表示其他。
def get_atom_features(atom):
# Returns a one-hot encoded vector for the atom type.
possible_atoms = ['C', 'O', 'N', 'S', 'F', 'Cl', 'Br', 'I', 'P', 'Co', 'Fe', 'Cu', 'Zn', 'Mg', 'Mn', 'Cr', 'Ni']
features = [0] * (len(possible_atoms) + 1)
try:
idx = possible_atoms.index(atom.GetSymbol())
features[idx] = 1
except ValueError:
features[-1] = 1 # For 'other' atoms
return features
#将一个化学键对象(bond)转换为一个 one-hot 编码向量,表示该键的类型。返回 4维 布尔向量,表示化学键的类型,True 表示该键类型,False 表示非该键类型。
def get_bond_features(bond):
# Returns a one-hot encoded vector for the bond type.
bond_type = bond.GetBondType()
return [
bond_type == Chem.rdchem.BondType.SINGLE,
bond_type == Chem.rdchem.BondType.DOUBLE,
bond_type == Chem.rdchem.BondType.TRIPLE,
bond_type == Chem.rdchem.BondType.AROMATIC
]
#将一个 SMILES 字符串 转换为 PyTorch Geometric 图数据对象(Data)。解析 SMILES 字符串生成分子结构。添加氢原子并生成 3D 构象。为分子中的每个原子生成 one-hot 特征。为分子中的每个化学键生成类型特征。将原子和化学键信息封装为 PyG 的图数据对象 Data。
def smiles_to_pyg_graph(smiles_string):
"""
Converts a SMILES string into a PyTorch Geometric Data object.
Returns None if the SMILES string is invalid.
"""
try:
mol = Chem.MolFromSmiles(smiles_string)
if mol is None: return None
mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol, AllChem.ETKDG())
atom_features_list = [get_atom_features(atom) for atom in mol.GetAtoms()]
x = torch.tensor(atom_features_list, dtype=torch.float)
if mol.GetNumBonds() > 0:
edge_indices, edge_attrs = [], []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_indices.append((i, j))
edge_indices.append((j, i))
bond_features = get_bond_features(bond)
edge_attrs.append(bond_features)
edge_attrs.append(bond_features)
edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_attrs, dtype=torch.float)
else:
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, 4), dtype=torch.float)
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
except Exception:
return None
# --- Model Components ---
#SubstrateGNN: 这是一个继承自 torch.nn.Module 的类,用于定义图神经网络模型,具体使用的是 Graph Attention Network (GATv2),专门用于处理底物的 SMILES 字符串。
class SubstrateGNN(nn.Module):
"""
Graph Attention Network (GATv2) for processing substrate SMILES strings.
"""
def __init__(self, input_dim, hidden_dim, output_dim, heads=4, dropout=0.1):
super(SubstrateGNN, self).__init__()
#conv1: 第一个图卷积层,输入维度是 input_dim,输出维度是 hidden_dim * heads。
self.conv1 = GATv2Conv(input_dim, hidden_dim, heads=heads, dropout=dropout, concat=True)
#conv2: 第二个图卷积层,输入维度是 hidden_dim * heads,输出维度是 hidden_dim * heads。
self.conv2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout, concat=True)
#conv3: 第三个图卷积层,输入维度是 hidden_dim * heads,输出维度是 output_dim。
self.conv3 = GATv2Conv(hidden_dim * heads, output_dim, heads=1, dropout=dropout, concat=False)
#定义了一个 ELU 激活函数和一个 dropout 层:ELU: 激活函数用于对每一层的输出进行非线性变换。Dropout: 在训练中随机丢弃一些神经元输出,以防止过拟合。
self.elu = nn.ELU()
self.dropout = nn.Dropout(p=dropout)
def forward(self, data):
#它接收图数据 data,并从中提取节点特征 x 和边索引 edge_index。
x, edge_index = data.x, data.edge_index
#通过第一个图卷积层 conv1 更新节点特征,并应用 ELU 激活函数和 Dropout:
x = self.dropout(self.elu(self.conv1(x, edge_index)))
#功能:通过第二个图卷积层 conv2 更新节点特征,并应用 ELU 激活函数和 Dropout:
x = self.dropout(self.elu(self.conv2(x, edge_index)))
#功能:通过第三个图卷积层 conv3 更新节点特征:
x = self.conv3(x, edge_index)
# 这里使用 全局平均池化(global mean pooling)操作来将节点级的特征转换为图级别的嵌入。
if hasattr(data, 'batch') and data.batch is not None:
from torch_geometric.nn import global_mean_pool
graph_embedding = global_mean_pool(x, data.batch)
else:
graph_embedding = x.mean(dim=0, keepdim=True)
return graph_embedding
#定义了一个 FusionBlock 类,它包含了跨模态(cross-modal)融合的过程,结合了自注意力(self-attention)和交叉注意力(cross-attention)机制。
class FusionBlock(nn.Module):
"""
A single block for cross-modal fusion, combining self-attention and cross-attention.
"""
def __init__(self, d_model, num_heads, dim_feedforward, dropout=0.1):
super(FusionBlock, self).__init__()
#使用 PyTorch 的 MultiheadAttention 来实现蛋白质模态中的自注意力机制。
self.self_attn_protein = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
#另一个 MultiheadAttention,用于交叉模态的注意力机制,使蛋白质能够关注底物的特征。
self.cross_attn_prot_to_sub = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
#前馈网络(Feed-Forward Network),由两层全连接层(Linear),中间有一个 ReLU 激活和 Dropout 层。
self.ffn_protein = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Dropout(dropout)
)
#归一化层
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
#定义了类的前向传播过程。protein_emb 是蛋白质嵌入,substrate_emb 是底物嵌入,protein_mask 是用于掩蔽的可选参数(通常用于处理填充的部分)。
def forward(self, protein_emb, substrate_emb, protein_mask=None):
# 对蛋白质嵌入进行 自注意力(Self-Attention) 计算,并进行 层归一化(Layer Normalization)。这里的自注意力操作通过 _sa_block 实现,蛋白质嵌入会计算自己内部位置之间的相关性,并将加权结果加回原始的 protein_emb,形成 残差连接。经过残差连接后,使用层归一化处理,保证每一层的输出稳定。
protein_emb = self.norm1(protein_emb + self._sa_block(protein_emb, protein_mask))
# 执行 交叉注意力(Cross-Attention),使蛋白质嵌入(protein_emb)能够关注底物嵌入(substrate_emb)。_ca_block 用于计算蛋白质和底物之间的交叉注意力。结果经过 残差连接 加回到原始的 protein_emb 后,再应用 层归一化 进行稳定化处理。
protein_emb = self.norm2(protein_emb + self._ca_block(protein_emb, substrate_emb))
#使用 前馈神经网络(Feed-Forward Network) 进一步处理蛋白质嵌入(protein_emb)。前馈网络对嵌入进行非线性转换,以增强模型的表示能力。前馈网络的输出通过 残差连接 与原始的 protein_emb 相加,然后经过 层归一化,确保训练过程的稳定性。
protein_emb = self.norm3(protein_emb + self.ffn_protein(protein_emb))
return protein_emb
#@定义了自注意力计算的过程,接受输入 x(蛋白质嵌入)和 key_padding_mask(可选的掩码)。
def _sa_block(self, x, key_padding_mask):
x, _ = self.self_attn_protein(x, x, x, key_padding_mask=key_padding_mask)
return x
#定义了交叉注意力的计算过程,接受输入 query(蛋白质嵌入)和 key_value(底物嵌入)。
def _ca_block(self, query, key_value):
x, _ = self.cross_attn_prot_to_sub(query, key_value, key_value)
return x
#定义了一个继承自 nn.Module 的 DeepFusionKcatPredictor 类,旨在结合蛋白质序列和底物图结构的特征进行 kcat 预测。
class DeepFusionKcatPredictor(nn.Module):
"""
The main model that integrates ESM-2 for protein sequences and a GNN for substrates,
then fuses their representations to predict kcat values.
"""
def __init__(self, esm_model_name, gnn_input_dim, gnn_hidden_dim, gnn_heads, d_model,
num_fusion_blocks, num_attn_heads, dim_feedforward, dropout=0.1):
super(DeepFusionKcatPredictor, self).__init__()
#加载一个预训练的 ESM-2 模型,该模型基于 Transformer 架构,专门用于处理蛋白质序列(类似自然语言中的 BERT)。
self.esm_model = EsmModel.from_pretrained(esm_model_name)
#使用线性层将 ESM 的原始输出维度 hidden_size 映射到模型统一使用的维度 d_model。
self.protein_projection = nn.Linear(self.esm_model.config.hidden_size, d_model)
#SubstrateGNN 是一个用户自定义的图神经网络模块,用于从 SMILES 构造的图中提取底物的特征。
self.gnn = SubstrateGNN(input_dim=gnn_input_dim, hidden_dim=gnn_hidden_dim, output_dim=d_model, heads=gnn_heads)
#FusionBlock: 每个融合模块的设计目标是使蛋白质嵌入与底物嵌入进行特征交互。
self.fusion_blocks = nn.ModuleList([
FusionBlock(d_model, num_attn_heads, dim_feedforward, dropout) for _ in range(num_fusion_blocks)
])
#用于将融合后的蛋白质表示(池化后)映射为一个实数 kcat 值。
self.output_regressor = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_model // 2, 1)
)
def forward(self, input_ids, attention_mask, smiles_list):
batch_size = input_ids.shape[0]
device = input_ids.device
# Create a placeholder for final predictions.
# This is crucial for handling batches where some SMILES might be invalid.
# Initialize with torch.float32, as this is the expected final output type.
#初始化一个全零的张量 final_predictions,用于存储最终的 kcat 预测结果。
final_predictions = torch.zeros(batch_size, device=device, dtype=torch.float32)
#将 SMILES 字符串转换为图结构(通过 smiles_to_pyg_graph 函数)。valid_indices 存储那些成功生成图结构的样本的索引。
graphs = [smiles_to_pyg_graph(s) for s in smiles_list]
valid_indices = [i for i, g in enumerate(graphs) if g is not None]
#仅处理那些有效的底物图。如果有有效的图,将它们组合成一个批次 graph_batch,并将其移动到当前设备(如 GPU)。
if valid_indices:
valid_graphs = [graphs[i] for i in valid_indices]
graph_batch = Batch.from_data_list(valid_graphs).to(device)
#底物图编码,详细步骤在上面
substrate_embedding = self.gnn(graph_batch) # Shape: [num_valid_graphs, d_model]
substrate_embedding = substrate_embedding.unsqueeze(1) # Shape: [num_valid_graphs, 1, d_model]
#蛋白质序列编码
valid_input_ids = input_ids[valid_indices]
valid_attention_mask = attention_mask[valid_indices]
#调用预训练的 ESM 模型对蛋白质序列进行编码。
esm_outputs = self.esm_model(input_ids=valid_input_ids, attention_mask=valid_attention_mask)
protein_embedding = esm_outputs.last_hidden_state # Shape: [num_valid, seq_len, esm_hidden_size]
#使用线性层将 ESM 的原始输出维度 hidden_size 映射到模型统一使用的维度 d_model。
protein_embedding = self.protein_projection(protein_embedding) # Shape: [num_valid, seq_len, d_model]
# Fusion blocks
fused_output = protein_embedding
# Create key padding mask for attention: True for padded tokens
key_padding_mask = (valid_attention_mask == 0)
#迭代调用每个融合模块 FusionBlock,进行蛋白质特征与底物特征的融合。
for block in self.fusion_blocks:
fused_output = block(fused_output, substrate_embedding, protein_mask=key_padding_mask)
#序列维度池化(Global average pooling)
masked_fused_output = fused_output * valid_attention_mask.unsqueeze(-1)
summed_output = masked_fused_output.sum(dim=1)
non_pad_count = valid_attention_mask.sum(dim=1, keepdim=True)
pooled_output = summed_output / non_pad_count.clamp(min=1e-9)
#将平均池化后的蛋白质表示送入回归网络,得到 [batch] 维度的 kcat 值。
predicted_kcat = self.output_regressor(pooled_output).squeeze(-1)
# [FIX] Cast predicted_kcat to float32 before assigning.
# This aligns the source (Half/float16) and destination (Float/float32) dtypes
# when running under torch.amp.autocast.
final_predictions[valid_indices] = predicted_kcat.to(torch.float32)
return final_predictions