testGAT / model_utils.py
QQ2S3R's picture
Update model_utils.py
1891ed9 verified
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 8 15:53:41 2025
@author: User
"""
import numpy as np
import torch
from rdkit import Chem
from sklearn.preprocessing import MinMaxScaler
from torch_geometric.nn import GATConv, global_mean_pool
import torch.nn as nn
import matplotlib.pyplot as plt
from rdkit.Chem import Draw, BondType
from PIL import Image
import io
import matplotlib
# 设置 matplotlib 使用非交互式后端
matplotlib.use('Agg')
# -------------------- 模型定义 --------------------
class EnhancedGAT(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_heads=8):
super().__init__()
self.conv1 = GATConv(input_dim, hidden_dim, heads=num_heads, edge_dim=1)
self.bn1 = nn.BatchNorm1d(hidden_dim * num_heads)
self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, edge_dim=1)
self.bn2 = nn.BatchNorm1d(hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, data):
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
batch = data.batch
x = self.conv1(x, edge_index, edge_attr=edge_attr)
x = self.bn1(x)
x = torch.relu(x)
x = self.dropout(x)
x = self.conv2(x, edge_index, edge_attr=edge_attr)
x = self.bn2(x)
x = torch.relu(x)
x = global_mean_pool(x, batch)
return self.fc(x)
# -------------------- SMILES转图 --------------------
def smiles_to_graph(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError(f"Invalid SMILES: {smiles}")
atom_features = []
for atom in mol.GetAtoms():
features = [
atom.GetAtomicNum(),
atom.GetTotalNumHs(),
atom.GetDegree(),
int(atom.GetHybridization()),
atom.GetIsAromatic(),
atom.GetFormalCharge(),
atom.IsInRing(),
int(atom.GetChiralTag()),
atom.GetTotalValence(),
atom.GetMass()/100.0,
atom.GetNumRadicalElectrons(),
len(atom.GetNeighbors()) > 2
]
atom_features.append(features)
scaler = MinMaxScaler()
atom_features = scaler.fit_transform(atom_features).astype(np.float32)
adj = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=np.float32)
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
bond_val = {
BondType.SINGLE: 1,
BondType.DOUBLE: 2,
BondType.TRIPLE: 3,
BondType.AROMATIC: 1.5
}.get(bond.GetBondType(), 0)
adj[i, j] = bond_val
adj[j, i] = bond_val
rows, cols = np.nonzero(adj)
edge_values = adj[rows, cols]
return atom_features, (rows, cols, edge_values), mol
# -------------------- 原子重要性计算 --------------------
def calculate_atom_importance(edge_index, alpha, x, num_atoms):
"""改进版原子重要性计算(融合边注意力和原子特征)"""
# 边注意力贡献部分
edge_based = np.zeros(num_atoms)
edge_index_np = edge_index.cpu().t().numpy()
for i, (src, dst) in enumerate(edge_index_np):
edge_based[src] += alpha[i]
edge_based[dst] += alpha[i]
# 原子特征贡献部分(定义化学知识驱动的权重)
feature_weights = torch.tensor([
0.25, # 原子序数 (AtomicNum)
0.04, # 连接H数
0.10, # 非氢连接度
0.04, # 杂化状态
0.15, # 芳香性
0.20, # 形式电荷
0.10, # 环内原子
0.04, # 手性
0.04, # 总价电子
0.04, # 原子质量
0.02, # 自由基电子
0.02 # 高连接度
], device=x.device, dtype=torch.float32)
feature_based = torch.matmul(x, feature_weights).cpu().numpy()
# 动态权重调整(边注意力占比60%,原子特征占比40%)
combined = 0.6 * edge_based + 0.4 * feature_based
# 跨分子归一化修正
atom_importance = (combined - combined.min()) / (combined.max() - combined.min() + 1e-8)
return atom_importance
# -------------------- 注意力可视化 --------------------
def visualize_single_molecule(model, data, device, model_name):
model.eval()
with torch.no_grad():
data = data.to(device)
out = model(data)
pred_label = out.argmax(dim=1).item()
smiles = data.smiles[0]
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None, pred_label
# 获取注意力权重
with torch.no_grad():
_, (edge_index, alpha) = model.conv1(data.x, data.edge_index, return_attention_weights=True)
if isinstance(alpha, tuple):
alpha = alpha[1]
if alpha.dim() > 1:
alpha = alpha.mean(dim=1)
alpha_norm = alpha.cpu().numpy()
atom_importance = calculate_atom_importance(edge_index, alpha_norm, data.x, mol.GetNumAtoms())
# 创建可视化图像
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111)
# 绘制分子结构
drawer = Draw.MolDraw2DCairo(400, 400)
atom_colors = {}
normalized_importance = atom_importance
cmap = plt.cm.Blues
norm = plt.Normalize(vmin=0, vmax=1)
for i in range(mol.GetNumAtoms()):
rgba = cmap(norm(normalized_importance[i]))
atom_colors[i] = (rgba[0], rgba[1], rgba[2])
drawer.DrawMolecule(
mol,
highlightAtoms=list(range(mol.GetNumAtoms())),
highlightAtomColors=atom_colors,
highlightBonds=[]
)
drawer.FinishDrawing()
# 合成最终图像
img = Image.open(io.BytesIO(drawer.GetDrawingText()))
ax.imshow(img)
ax.axis('off')
# 添加预测信息
#plt.text(0.5, 0.95, f"{model_name}\nPredicted: {pred_label}",
# ha='center', va='top',
# transform=fig.transFigure,
# fontsize=10,
# bbox=dict(facecolor='white', alpha=0.8))
# 添加颜色条
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax,
fraction=0.03,
pad=0.04,
orientation='vertical')
cbar.set_label('Atom Importance',
fontsize=10,
labelpad=5)
cbar.ax.tick_params(labelsize=8)
# 保存到缓冲区
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return buf, pred_label