| |
| """ |
| Created on Tue Jul 8 15:53:41 2025 |
| |
| @author: User |
| """ |
|
|
| import numpy as np |
| import torch |
| from rdkit import Chem |
| from sklearn.preprocessing import MinMaxScaler |
| from torch_geometric.nn import GATConv, global_mean_pool |
| import torch.nn as nn |
| import matplotlib.pyplot as plt |
| from rdkit.Chem import Draw, BondType |
| from PIL import Image |
| import io |
| import matplotlib |
|
|
| |
| matplotlib.use('Agg') |
|
|
| |
| class EnhancedGAT(nn.Module): |
| def __init__(self, input_dim, hidden_dim, output_dim, num_heads=8): |
| super().__init__() |
| self.conv1 = GATConv(input_dim, hidden_dim, heads=num_heads, edge_dim=1) |
| self.bn1 = nn.BatchNorm1d(hidden_dim * num_heads) |
| self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, edge_dim=1) |
| self.bn2 = nn.BatchNorm1d(hidden_dim) |
| self.fc = nn.Linear(hidden_dim, output_dim) |
| self.dropout = nn.Dropout(0.5) |
| |
| def forward(self, data): |
| x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr |
| batch = data.batch |
|
|
| x = self.conv1(x, edge_index, edge_attr=edge_attr) |
| x = self.bn1(x) |
| x = torch.relu(x) |
| x = self.dropout(x) |
|
|
| x = self.conv2(x, edge_index, edge_attr=edge_attr) |
| x = self.bn2(x) |
| x = torch.relu(x) |
|
|
| x = global_mean_pool(x, batch) |
| return self.fc(x) |
|
|
| |
| def smiles_to_graph(smiles): |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| raise ValueError(f"Invalid SMILES: {smiles}") |
| |
| atom_features = [] |
| for atom in mol.GetAtoms(): |
| features = [ |
| atom.GetAtomicNum(), |
| atom.GetTotalNumHs(), |
| atom.GetDegree(), |
| int(atom.GetHybridization()), |
| atom.GetIsAromatic(), |
| atom.GetFormalCharge(), |
| atom.IsInRing(), |
| int(atom.GetChiralTag()), |
| atom.GetTotalValence(), |
| atom.GetMass()/100.0, |
| atom.GetNumRadicalElectrons(), |
| len(atom.GetNeighbors()) > 2 |
| ] |
| atom_features.append(features) |
|
|
| scaler = MinMaxScaler() |
| atom_features = scaler.fit_transform(atom_features).astype(np.float32) |
|
|
| adj = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=np.float32) |
| for bond in mol.GetBonds(): |
| i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
| bond_val = { |
| BondType.SINGLE: 1, |
| BondType.DOUBLE: 2, |
| BondType.TRIPLE: 3, |
| BondType.AROMATIC: 1.5 |
| }.get(bond.GetBondType(), 0) |
| adj[i, j] = bond_val |
| adj[j, i] = bond_val |
|
|
| rows, cols = np.nonzero(adj) |
| edge_values = adj[rows, cols] |
| return atom_features, (rows, cols, edge_values), mol |
|
|
| |
| def calculate_atom_importance(edge_index, alpha, x, num_atoms): |
| """改进版原子重要性计算(融合边注意力和原子特征)""" |
| |
| edge_based = np.zeros(num_atoms) |
| edge_index_np = edge_index.cpu().t().numpy() |
| |
| for i, (src, dst) in enumerate(edge_index_np): |
| edge_based[src] += alpha[i] |
| edge_based[dst] += alpha[i] |
| |
| |
| feature_weights = torch.tensor([ |
| 0.25, |
| 0.04, |
| 0.10, |
| 0.04, |
| 0.15, |
| 0.20, |
| 0.10, |
| 0.04, |
| 0.04, |
| 0.04, |
| 0.02, |
| 0.02 |
| ], device=x.device, dtype=torch.float32) |
| |
| feature_based = torch.matmul(x, feature_weights).cpu().numpy() |
| |
| |
| combined = 0.6 * edge_based + 0.4 * feature_based |
| |
| |
| atom_importance = (combined - combined.min()) / (combined.max() - combined.min() + 1e-8) |
| return atom_importance |
|
|
| |
| def visualize_single_molecule(model, data, device, model_name): |
| model.eval() |
| with torch.no_grad(): |
| data = data.to(device) |
| out = model(data) |
| pred_label = out.argmax(dim=1).item() |
|
|
| smiles = data.smiles[0] |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| return None, pred_label |
|
|
| |
| with torch.no_grad(): |
| _, (edge_index, alpha) = model.conv1(data.x, data.edge_index, return_attention_weights=True) |
| if isinstance(alpha, tuple): |
| alpha = alpha[1] |
| if alpha.dim() > 1: |
| alpha = alpha.mean(dim=1) |
| alpha_norm = alpha.cpu().numpy() |
| |
| atom_importance = calculate_atom_importance(edge_index, alpha_norm, data.x, mol.GetNumAtoms()) |
| |
| |
| fig = plt.figure(figsize=(6, 6)) |
| ax = fig.add_subplot(111) |
| |
| |
| drawer = Draw.MolDraw2DCairo(400, 400) |
| atom_colors = {} |
| normalized_importance = atom_importance |
| cmap = plt.cm.Blues |
| norm = plt.Normalize(vmin=0, vmax=1) |
| |
| for i in range(mol.GetNumAtoms()): |
| rgba = cmap(norm(normalized_importance[i])) |
| atom_colors[i] = (rgba[0], rgba[1], rgba[2]) |
| |
| drawer.DrawMolecule( |
| mol, |
| highlightAtoms=list(range(mol.GetNumAtoms())), |
| highlightAtomColors=atom_colors, |
| highlightBonds=[] |
| ) |
| drawer.FinishDrawing() |
| |
| |
| img = Image.open(io.BytesIO(drawer.GetDrawingText())) |
| ax.imshow(img) |
| ax.axis('off') |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) |
| sm.set_array([]) |
| cbar = fig.colorbar(sm, ax=ax, |
| fraction=0.03, |
| pad=0.04, |
| orientation='vertical') |
| cbar.set_label('Atom Importance', |
| fontsize=10, |
| labelpad=5) |
| cbar.ax.tick_params(labelsize=8) |
| |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
| plt.close(fig) |
| buf.seek(0) |
| |
| return buf, pred_label |
|
|