# -*- coding: utf-8 -*- """ Created on Tue Jul 8 15:53:41 2025 @author: User """ import numpy as np import torch from rdkit import Chem from sklearn.preprocessing import MinMaxScaler from torch_geometric.nn import GATConv, global_mean_pool import torch.nn as nn import matplotlib.pyplot as plt from rdkit.Chem import Draw, BondType from PIL import Image import io import matplotlib # 设置 matplotlib 使用非交互式后端 matplotlib.use('Agg') # -------------------- 模型定义 -------------------- class EnhancedGAT(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_heads=8): super().__init__() self.conv1 = GATConv(input_dim, hidden_dim, heads=num_heads, edge_dim=1) self.bn1 = nn.BatchNorm1d(hidden_dim * num_heads) self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, edge_dim=1) self.bn2 = nn.BatchNorm1d(hidden_dim) self.fc = nn.Linear(hidden_dim, output_dim) self.dropout = nn.Dropout(0.5) def forward(self, data): x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr batch = data.batch x = self.conv1(x, edge_index, edge_attr=edge_attr) x = self.bn1(x) x = torch.relu(x) x = self.dropout(x) x = self.conv2(x, edge_index, edge_attr=edge_attr) x = self.bn2(x) x = torch.relu(x) x = global_mean_pool(x, batch) return self.fc(x) # -------------------- SMILES转图 -------------------- def smiles_to_graph(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError(f"Invalid SMILES: {smiles}") atom_features = [] for atom in mol.GetAtoms(): features = [ atom.GetAtomicNum(), atom.GetTotalNumHs(), atom.GetDegree(), int(atom.GetHybridization()), atom.GetIsAromatic(), atom.GetFormalCharge(), atom.IsInRing(), int(atom.GetChiralTag()), atom.GetTotalValence(), atom.GetMass()/100.0, atom.GetNumRadicalElectrons(), len(atom.GetNeighbors()) > 2 ] atom_features.append(features) scaler = MinMaxScaler() atom_features = scaler.fit_transform(atom_features).astype(np.float32) adj = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=np.float32) for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() bond_val = { BondType.SINGLE: 1, BondType.DOUBLE: 2, BondType.TRIPLE: 3, BondType.AROMATIC: 1.5 }.get(bond.GetBondType(), 0) adj[i, j] = bond_val adj[j, i] = bond_val rows, cols = np.nonzero(adj) edge_values = adj[rows, cols] return atom_features, (rows, cols, edge_values), mol # -------------------- 原子重要性计算 -------------------- def calculate_atom_importance(edge_index, alpha, x, num_atoms): """改进版原子重要性计算(融合边注意力和原子特征)""" # 边注意力贡献部分 edge_based = np.zeros(num_atoms) edge_index_np = edge_index.cpu().t().numpy() for i, (src, dst) in enumerate(edge_index_np): edge_based[src] += alpha[i] edge_based[dst] += alpha[i] # 原子特征贡献部分(定义化学知识驱动的权重) feature_weights = torch.tensor([ 0.25, # 原子序数 (AtomicNum) 0.04, # 连接H数 0.10, # 非氢连接度 0.04, # 杂化状态 0.15, # 芳香性 0.20, # 形式电荷 0.10, # 环内原子 0.04, # 手性 0.04, # 总价电子 0.04, # 原子质量 0.02, # 自由基电子 0.02 # 高连接度 ], device=x.device, dtype=torch.float32) feature_based = torch.matmul(x, feature_weights).cpu().numpy() # 动态权重调整(边注意力占比60%,原子特征占比40%) combined = 0.6 * edge_based + 0.4 * feature_based # 跨分子归一化修正 atom_importance = (combined - combined.min()) / (combined.max() - combined.min() + 1e-8) return atom_importance # -------------------- 注意力可视化 -------------------- def visualize_single_molecule(model, data, device, model_name): model.eval() with torch.no_grad(): data = data.to(device) out = model(data) pred_label = out.argmax(dim=1).item() smiles = data.smiles[0] mol = Chem.MolFromSmiles(smiles) if mol is None: return None, pred_label # 获取注意力权重 with torch.no_grad(): _, (edge_index, alpha) = model.conv1(data.x, data.edge_index, return_attention_weights=True) if isinstance(alpha, tuple): alpha = alpha[1] if alpha.dim() > 1: alpha = alpha.mean(dim=1) alpha_norm = alpha.cpu().numpy() atom_importance = calculate_atom_importance(edge_index, alpha_norm, data.x, mol.GetNumAtoms()) # 创建可视化图像 fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111) # 绘制分子结构 drawer = Draw.MolDraw2DCairo(400, 400) atom_colors = {} normalized_importance = atom_importance cmap = plt.cm.Blues norm = plt.Normalize(vmin=0, vmax=1) for i in range(mol.GetNumAtoms()): rgba = cmap(norm(normalized_importance[i])) atom_colors[i] = (rgba[0], rgba[1], rgba[2]) drawer.DrawMolecule( mol, highlightAtoms=list(range(mol.GetNumAtoms())), highlightAtomColors=atom_colors, highlightBonds=[] ) drawer.FinishDrawing() # 合成最终图像 img = Image.open(io.BytesIO(drawer.GetDrawingText())) ax.imshow(img) ax.axis('off') # 添加预测信息 #plt.text(0.5, 0.95, f"{model_name}\nPredicted: {pred_label}", # ha='center', va='top', # transform=fig.transFigure, # fontsize=10, # bbox=dict(facecolor='white', alpha=0.8)) # 添加颜色条 sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = fig.colorbar(sm, ax=ax, fraction=0.03, pad=0.04, orientation='vertical') cbar.set_label('Atom Importance', fontsize=10, labelpad=5) cbar.ax.tick_params(labelsize=8) # 保存到缓冲区 buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf, pred_label