Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Tue Jul 8 15:53:41 2025 | |
| @author: User | |
| """ | |
| import numpy as np | |
| import torch | |
| from rdkit import Chem | |
| from sklearn.preprocessing import MinMaxScaler | |
| from torch_geometric.nn import GATConv, global_mean_pool | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| from rdkit.Chem import Draw, BondType | |
| from PIL import Image | |
| import io | |
| import matplotlib | |
| # 设置 matplotlib 使用非交互式后端 | |
| matplotlib.use('Agg') | |
| # -------------------- 模型定义 -------------------- | |
| class EnhancedGAT(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_heads=8): | |
| super().__init__() | |
| self.conv1 = GATConv(input_dim, hidden_dim, heads=num_heads, edge_dim=1) | |
| self.bn1 = nn.BatchNorm1d(hidden_dim * num_heads) | |
| self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, edge_dim=1) | |
| self.bn2 = nn.BatchNorm1d(hidden_dim) | |
| self.fc = nn.Linear(hidden_dim, output_dim) | |
| self.dropout = nn.Dropout(0.5) | |
| def forward(self, data): | |
| x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr | |
| batch = data.batch | |
| x = self.conv1(x, edge_index, edge_attr=edge_attr) | |
| x = self.bn1(x) | |
| x = torch.relu(x) | |
| x = self.dropout(x) | |
| x = self.conv2(x, edge_index, edge_attr=edge_attr) | |
| x = self.bn2(x) | |
| x = torch.relu(x) | |
| x = global_mean_pool(x, batch) | |
| return self.fc(x) | |
| # -------------------- SMILES转图 -------------------- | |
| def smiles_to_graph(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| raise ValueError(f"Invalid SMILES: {smiles}") | |
| atom_features = [] | |
| for atom in mol.GetAtoms(): | |
| features = [ | |
| atom.GetAtomicNum(), | |
| atom.GetTotalNumHs(), | |
| atom.GetDegree(), | |
| int(atom.GetHybridization()), | |
| atom.GetIsAromatic(), | |
| atom.GetFormalCharge(), | |
| atom.IsInRing(), | |
| int(atom.GetChiralTag()), | |
| atom.GetTotalValence(), | |
| atom.GetMass()/100.0, | |
| atom.GetNumRadicalElectrons(), | |
| len(atom.GetNeighbors()) > 2 | |
| ] | |
| atom_features.append(features) | |
| scaler = MinMaxScaler() | |
| atom_features = scaler.fit_transform(atom_features).astype(np.float32) | |
| adj = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=np.float32) | |
| for bond in mol.GetBonds(): | |
| i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | |
| bond_val = { | |
| BondType.SINGLE: 1, | |
| BondType.DOUBLE: 2, | |
| BondType.TRIPLE: 3, | |
| BondType.AROMATIC: 1.5 | |
| }.get(bond.GetBondType(), 0) | |
| adj[i, j] = bond_val | |
| adj[j, i] = bond_val | |
| rows, cols = np.nonzero(adj) | |
| edge_values = adj[rows, cols] | |
| return atom_features, (rows, cols, edge_values), mol | |
| # -------------------- 原子重要性计算 -------------------- | |
| def calculate_atom_importance(edge_index, alpha, x, num_atoms): | |
| """改进版原子重要性计算(融合边注意力和原子特征)""" | |
| # 边注意力贡献部分 | |
| edge_based = np.zeros(num_atoms) | |
| edge_index_np = edge_index.cpu().t().numpy() | |
| for i, (src, dst) in enumerate(edge_index_np): | |
| edge_based[src] += alpha[i] | |
| edge_based[dst] += alpha[i] | |
| # 原子特征贡献部分(定义化学知识驱动的权重) | |
| feature_weights = torch.tensor([ | |
| 0.25, # 原子序数 (AtomicNum) | |
| 0.04, # 连接H数 | |
| 0.10, # 非氢连接度 | |
| 0.04, # 杂化状态 | |
| 0.15, # 芳香性 | |
| 0.20, # 形式电荷 | |
| 0.10, # 环内原子 | |
| 0.04, # 手性 | |
| 0.04, # 总价电子 | |
| 0.04, # 原子质量 | |
| 0.02, # 自由基电子 | |
| 0.02 # 高连接度 | |
| ], device=x.device, dtype=torch.float32) | |
| feature_based = torch.matmul(x, feature_weights).cpu().numpy() | |
| # 动态权重调整(边注意力占比60%,原子特征占比40%) | |
| combined = 0.6 * edge_based + 0.4 * feature_based | |
| # 跨分子归一化修正 | |
| atom_importance = (combined - combined.min()) / (combined.max() - combined.min() + 1e-8) | |
| return atom_importance | |
| # -------------------- 注意力可视化 -------------------- | |
| def visualize_single_molecule(model, data, device, model_name): | |
| model.eval() | |
| with torch.no_grad(): | |
| data = data.to(device) | |
| out = model(data) | |
| pred_label = out.argmax(dim=1).item() | |
| smiles = data.smiles[0] | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None, pred_label | |
| # 获取注意力权重 | |
| with torch.no_grad(): | |
| _, (edge_index, alpha) = model.conv1(data.x, data.edge_index, return_attention_weights=True) | |
| if isinstance(alpha, tuple): | |
| alpha = alpha[1] | |
| if alpha.dim() > 1: | |
| alpha = alpha.mean(dim=1) | |
| alpha_norm = alpha.cpu().numpy() | |
| atom_importance = calculate_atom_importance(edge_index, alpha_norm, data.x, mol.GetNumAtoms()) | |
| # 创建可视化图像 | |
| fig = plt.figure(figsize=(6, 6)) | |
| ax = fig.add_subplot(111) | |
| # 绘制分子结构 | |
| drawer = Draw.MolDraw2DCairo(400, 400) | |
| atom_colors = {} | |
| normalized_importance = atom_importance | |
| cmap = plt.cm.Blues | |
| norm = plt.Normalize(vmin=0, vmax=1) | |
| for i in range(mol.GetNumAtoms()): | |
| rgba = cmap(norm(normalized_importance[i])) | |
| atom_colors[i] = (rgba[0], rgba[1], rgba[2]) | |
| drawer.DrawMolecule( | |
| mol, | |
| highlightAtoms=list(range(mol.GetNumAtoms())), | |
| highlightAtomColors=atom_colors, | |
| highlightBonds=[] | |
| ) | |
| drawer.FinishDrawing() | |
| # 合成最终图像 | |
| img = Image.open(io.BytesIO(drawer.GetDrawingText())) | |
| ax.imshow(img) | |
| ax.axis('off') | |
| # 添加预测信息 | |
| #plt.text(0.5, 0.95, f"{model_name}\nPredicted: {pred_label}", | |
| # ha='center', va='top', | |
| # transform=fig.transFigure, | |
| # fontsize=10, | |
| # bbox=dict(facecolor='white', alpha=0.8)) | |
| # 添加颜色条 | |
| sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) | |
| sm.set_array([]) | |
| cbar = fig.colorbar(sm, ax=ax, | |
| fraction=0.03, | |
| pad=0.04, | |
| orientation='vertical') | |
| cbar.set_label('Atom Importance', | |
| fontsize=10, | |
| labelpad=5) | |
| cbar.ax.tick_params(labelsize=8) | |
| # 保存到缓冲区 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| plt.close(fig) | |
| buf.seek(0) | |
| return buf, pred_label | |