ZZZCCCYYY commited on
Commit
1f35e05
·
verified ·
1 Parent(s): e7faed8

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +206 -0
model_utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 8 15:53:41 2025
4
+
5
+ @author: User
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ from rdkit import Chem
11
+ from sklearn.preprocessing import MinMaxScaler
12
+ from torch_geometric.nn import GATConv, global_mean_pool
13
+ import torch.nn as nn
14
+ import matplotlib.pyplot as plt
15
+ from rdkit.Chem import Draw, BondType
16
+ from PIL import Image
17
+ import io
18
+ import matplotlib
19
+
20
+ # 设置 matplotlib 使用非交互式后端
21
+ matplotlib.use('Agg')
22
+
23
+ # -------------------- 模型定义 --------------------
24
+ class EnhancedGAT(nn.Module):
25
+ def __init__(self, input_dim, hidden_dim, output_dim, num_heads=8):
26
+ super().__init__()
27
+ self.conv1 = GATConv(input_dim, hidden_dim, heads=num_heads, edge_dim=1)
28
+ self.bn1 = nn.BatchNorm1d(hidden_dim * num_heads)
29
+ self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, edge_dim=1)
30
+ self.bn2 = nn.BatchNorm1d(hidden_dim)
31
+ self.fc = nn.Linear(hidden_dim, output_dim)
32
+ self.dropout = nn.Dropout(0.5)
33
+
34
+ def forward(self, data):
35
+ x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
36
+ batch = data.batch
37
+
38
+ x = self.conv1(x, edge_index, edge_attr=edge_attr)
39
+ x = self.bn1(x)
40
+ x = torch.relu(x)
41
+ x = self.dropout(x)
42
+
43
+ x = self.conv2(x, edge_index, edge_attr=edge_attr)
44
+ x = self.bn2(x)
45
+ x = torch.relu(x)
46
+
47
+ x = global_mean_pool(x, batch)
48
+ return self.fc(x)
49
+
50
+ # -------------------- SMILES转图 --------------------
51
+ def smiles_to_graph(smiles):
52
+ mol = Chem.MolFromSmiles(smiles)
53
+ if mol is None:
54
+ raise ValueError(f"Invalid SMILES: {smiles}")
55
+
56
+ atom_features = []
57
+ for atom in mol.GetAtoms():
58
+ features = [
59
+ atom.GetAtomicNum(),
60
+ atom.GetTotalNumHs(),
61
+ atom.GetDegree(),
62
+ int(atom.GetHybridization()),
63
+ atom.GetIsAromatic(),
64
+ atom.GetFormalCharge(),
65
+ atom.IsInRing(),
66
+ int(atom.GetChiralTag()),
67
+ atom.GetTotalValence(),
68
+ atom.GetMass()/100.0,
69
+ atom.GetNumRadicalElectrons(),
70
+ len(atom.GetNeighbors()) > 2
71
+ ]
72
+ atom_features.append(features)
73
+
74
+ scaler = MinMaxScaler()
75
+ atom_features = scaler.fit_transform(atom_features).astype(np.float32)
76
+
77
+ adj = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=np.float32)
78
+ for bond in mol.GetBonds():
79
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
80
+ bond_val = {
81
+ BondType.SINGLE: 1,
82
+ BondType.DOUBLE: 2,
83
+ BondType.TRIPLE: 3,
84
+ BondType.AROMATIC: 1.5
85
+ }.get(bond.GetBondType(), 0)
86
+ adj[i, j] = bond_val
87
+ adj[j, i] = bond_val
88
+
89
+ rows, cols = np.nonzero(adj)
90
+ edge_values = adj[rows, cols]
91
+ return atom_features, (rows, cols, edge_values), mol
92
+
93
+ # -------------------- 原子重要性计算 --------------------
94
+ def calculate_atom_importance(edge_index, alpha, x, num_atoms):
95
+ """改进版原子重要性计算(融合边注意力和原子特征)"""
96
+ # 边注意力贡献部分
97
+ edge_based = np.zeros(num_atoms)
98
+ edge_index_np = edge_index.cpu().t().numpy()
99
+
100
+ for i, (src, dst) in enumerate(edge_index_np):
101
+ edge_based[src] += alpha[i]
102
+ edge_based[dst] += alpha[i]
103
+
104
+ # 原子特征贡献部分(定义化学知识驱动的权重)
105
+ feature_weights = torch.tensor([
106
+ 0.25, # 原子序数 (AtomicNum)
107
+ 0.04, # 连接H数
108
+ 0.10, # 非氢连接度
109
+ 0.04, # 杂化状态
110
+ 0.15, # 芳香性
111
+ 0.20, # 形式电荷
112
+ 0.10, # 环内原子
113
+ 0.04, # 手性
114
+ 0.04, # 总价电子
115
+ 0.04, # 原子质量
116
+ 0.02, # 自由基电子
117
+ 0.02 # 高连接度
118
+ ], device=x.device, dtype=torch.float32)
119
+
120
+ feature_based = torch.matmul(x, feature_weights).cpu().numpy()
121
+
122
+ # 动态权重调整(边注意力占比60%,原子特征占比40%)
123
+ combined = 0.6 * edge_based + 0.4 * feature_based
124
+
125
+ # 跨分子归一化修正
126
+ atom_importance = (combined - combined.min()) / (combined.max() - combined.min() + 1e-8)
127
+ return atom_importance
128
+
129
+ # -------------------- 注意力可视化 --------------------
130
+ def visualize_single_molecule(model, data, device, model_name):
131
+ model.eval()
132
+ with torch.no_grad():
133
+ data = data.to(device)
134
+ out = model(data)
135
+ pred_label = out.argmax(dim=1).item()
136
+
137
+ smiles = data.smiles[0]
138
+ mol = Chem.MolFromSmiles(smiles)
139
+ if mol is None:
140
+ return None, pred_label
141
+
142
+ # 获取注意力权重
143
+ with torch.no_grad():
144
+ _, (edge_index, alpha) = model.conv1(data.x, data.edge_index, return_attention_weights=True)
145
+ if isinstance(alpha, tuple):
146
+ alpha = alpha[1]
147
+ if alpha.dim() > 1:
148
+ alpha = alpha.mean(dim=1)
149
+ alpha_norm = alpha.cpu().numpy()
150
+
151
+ atom_importance = calculate_atom_importance(edge_index, alpha_norm, data.x, mol.GetNumAtoms())
152
+
153
+ # 创建可视化图像
154
+ fig = plt.figure(figsize=(6, 6))
155
+ ax = fig.add_subplot(111)
156
+
157
+ # 绘制分子结构
158
+ drawer = Draw.MolDraw2DCairo(400, 400)
159
+ atom_colors = {}
160
+ normalized_importance = atom_importance
161
+ cmap = plt.cm.Blues
162
+ norm = plt.Normalize(vmin=0, vmax=1)
163
+
164
+ for i in range(mol.GetNumAtoms()):
165
+ rgba = cmap(norm(normalized_importance[i]))
166
+ atom_colors[i] = (rgba[0], rgba[1], rgba[2])
167
+
168
+ drawer.DrawMolecule(
169
+ mol,
170
+ highlightAtoms=list(range(mol.GetNumAtoms())),
171
+ highlightAtomColors=atom_colors,
172
+ highlightBonds=[]
173
+ )
174
+ drawer.FinishDrawing()
175
+
176
+ # 合成最终图像
177
+ img = Image.open(io.BytesIO(drawer.GetDrawingText()))
178
+ ax.imshow(img)
179
+ ax.axis('off')
180
+
181
+ # 添加预测信息
182
+ #plt.text(0.5, 0.95, f"{model_name}\nPredicted: {pred_label}",
183
+ # ha='center', va='top',
184
+ # transform=fig.transFigure,
185
+ # fontsize=10,
186
+ # bbox=dict(facecolor='white', alpha=0.8))
187
+
188
+ # 添加颜色条
189
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
190
+ sm.set_array([])
191
+ cbar = fig.colorbar(sm, ax=ax,
192
+ fraction=0.03,
193
+ pad=0.04,
194
+ orientation='vertical')
195
+ cbar.set_label('Atom Importance',
196
+ fontsize=10,
197
+ labelpad=5)
198
+ cbar.ax.tick_params(labelsize=8)
199
+
200
+ # 保存到缓冲区
201
+ buf = io.BytesIO()
202
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
203
+ plt.close(fig)
204
+ buf.seek(0)
205
+
206
+ return buf, pred_label