File size: 6,651 Bytes
eaf47e0 1891ed9 eaf47e0 1891ed9 eaf47e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | # -*- coding: utf-8 -*-
"""
Created on Tue Jul 8 15:53:41 2025
@author: User
"""
import numpy as np
import torch
from rdkit import Chem
from sklearn.preprocessing import MinMaxScaler
from torch_geometric.nn import GATConv, global_mean_pool
import torch.nn as nn
import matplotlib.pyplot as plt
from rdkit.Chem import Draw, BondType
from PIL import Image
import io
import matplotlib
# 设置 matplotlib 使用非交互式后端
matplotlib.use('Agg')
# -------------------- 模型定义 --------------------
class EnhancedGAT(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_heads=8):
super().__init__()
self.conv1 = GATConv(input_dim, hidden_dim, heads=num_heads, edge_dim=1)
self.bn1 = nn.BatchNorm1d(hidden_dim * num_heads)
self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, edge_dim=1)
self.bn2 = nn.BatchNorm1d(hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, data):
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
batch = data.batch
x = self.conv1(x, edge_index, edge_attr=edge_attr)
x = self.bn1(x)
x = torch.relu(x)
x = self.dropout(x)
x = self.conv2(x, edge_index, edge_attr=edge_attr)
x = self.bn2(x)
x = torch.relu(x)
x = global_mean_pool(x, batch)
return self.fc(x)
# -------------------- SMILES转图 --------------------
def smiles_to_graph(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError(f"Invalid SMILES: {smiles}")
atom_features = []
for atom in mol.GetAtoms():
features = [
atom.GetAtomicNum(),
atom.GetTotalNumHs(),
atom.GetDegree(),
int(atom.GetHybridization()),
atom.GetIsAromatic(),
atom.GetFormalCharge(),
atom.IsInRing(),
int(atom.GetChiralTag()),
atom.GetTotalValence(),
atom.GetMass()/100.0,
atom.GetNumRadicalElectrons(),
len(atom.GetNeighbors()) > 2
]
atom_features.append(features)
scaler = MinMaxScaler()
atom_features = scaler.fit_transform(atom_features).astype(np.float32)
adj = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=np.float32)
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
bond_val = {
BondType.SINGLE: 1,
BondType.DOUBLE: 2,
BondType.TRIPLE: 3,
BondType.AROMATIC: 1.5
}.get(bond.GetBondType(), 0)
adj[i, j] = bond_val
adj[j, i] = bond_val
rows, cols = np.nonzero(adj)
edge_values = adj[rows, cols]
return atom_features, (rows, cols, edge_values), mol
# -------------------- 原子重要性计算 --------------------
def calculate_atom_importance(edge_index, alpha, x, num_atoms):
"""改进版原子重要性计算(融合边注意力和原子特征)"""
# 边注意力贡献部分
edge_based = np.zeros(num_atoms)
edge_index_np = edge_index.cpu().t().numpy()
for i, (src, dst) in enumerate(edge_index_np):
edge_based[src] += alpha[i]
edge_based[dst] += alpha[i]
# 原子特征贡献部分(定义化学知识驱动的权重)
feature_weights = torch.tensor([
0.25, # 原子序数 (AtomicNum)
0.04, # 连接H数
0.10, # 非氢连接度
0.04, # 杂化状态
0.15, # 芳香性
0.20, # 形式电荷
0.10, # 环内原子
0.04, # 手性
0.04, # 总价电子
0.04, # 原子质量
0.02, # 自由基电子
0.02 # 高连接度
], device=x.device, dtype=torch.float32)
feature_based = torch.matmul(x, feature_weights).cpu().numpy()
# 动态权重调整(边注意力占比60%,原子特征占比40%)
combined = 0.6 * edge_based + 0.4 * feature_based
# 跨分子归一化修正
atom_importance = (combined - combined.min()) / (combined.max() - combined.min() + 1e-8)
return atom_importance
# -------------------- 注意力可视化 --------------------
def visualize_single_molecule(model, data, device, model_name):
model.eval()
with torch.no_grad():
data = data.to(device)
out = model(data)
pred_label = out.argmax(dim=1).item()
smiles = data.smiles[0]
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None, pred_label
# 获取注意力权重
with torch.no_grad():
_, (edge_index, alpha) = model.conv1(data.x, data.edge_index, return_attention_weights=True)
if isinstance(alpha, tuple):
alpha = alpha[1]
if alpha.dim() > 1:
alpha = alpha.mean(dim=1)
alpha_norm = alpha.cpu().numpy()
atom_importance = calculate_atom_importance(edge_index, alpha_norm, data.x, mol.GetNumAtoms())
# 创建可视化图像
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111)
# 绘制分子结构
drawer = Draw.MolDraw2DCairo(400, 400)
atom_colors = {}
normalized_importance = atom_importance
cmap = plt.cm.Blues
norm = plt.Normalize(vmin=0, vmax=1)
for i in range(mol.GetNumAtoms()):
rgba = cmap(norm(normalized_importance[i]))
atom_colors[i] = (rgba[0], rgba[1], rgba[2])
drawer.DrawMolecule(
mol,
highlightAtoms=list(range(mol.GetNumAtoms())),
highlightAtomColors=atom_colors,
highlightBonds=[]
)
drawer.FinishDrawing()
# 合成最终图像
img = Image.open(io.BytesIO(drawer.GetDrawingText()))
ax.imshow(img)
ax.axis('off')
# 添加预测信息
#plt.text(0.5, 0.95, f"{model_name}\nPredicted: {pred_label}",
# ha='center', va='top',
# transform=fig.transFigure,
# fontsize=10,
# bbox=dict(facecolor='white', alpha=0.8))
# 添加颜色条
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax,
fraction=0.03,
pad=0.04,
orientation='vertical')
cbar.set_label('Atom Importance',
fontsize=10,
labelpad=5)
cbar.ax.tick_params(labelsize=8)
# 保存到缓冲区
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return buf, pred_label
|