junxue commited on
Commit
9e4f268
·
1 Parent(s): 166bcef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +562 -0
app.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import gradio as gr
3
+ from rdkit import Chem
4
+ import torch
5
+ import os
6
+ import pandas as pd
7
+ import hashlib
8
+ from torch_geometric.loader import DataLoader
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torchmetrics
13
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
14
+ from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GAE, GATv2Conv, GraphSAGE, GENConv, GMMConv, \
15
+ GravNetConv, MessagePassing, global_max_pool, global_add_pool, GAT, GINConv, GINEConv, GraphNorm, SAGEConv, RGATConv
16
+ from torch.nn.functional import sigmoid
17
+ from torch import nn
18
+ import numpy as np
19
+ import torch.nn.functional as F
20
+ from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, MessagePassing
21
+ from torch_geometric.utils import add_self_loops
22
+ from tqdm import tqdm
23
+ from torch.nn import Conv1d
24
+ from typing import Optional, Callable, Union, List, Tuple
25
+ from torch_geometric.data import Data, in_memory_dataset, Dataset, InMemoryDataset
26
+ from torch_geometric.loader import DataLoader
27
+ import numpy as np
28
+ import os
29
+ import torch
30
+ from torch_geometric.data import Dataset, Data
31
+ from torch_geometric.utils import to_networkx, to_dense_adj
32
+ import networkx as nx
33
+ import pandas as pd
34
+ from rdkit import Chem
35
+ from rdkit.Chem.rdchem import HybridizationType
36
+ from rdkit.Chem.rdchem import BondType as BT
37
+ from rdkit.Chem import AllChem
38
+ from sklearn.preprocessing import OneHotEncoder
39
+ import warnings
40
+
41
+
42
+ CHIRALITY_LIST = [
43
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
44
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
45
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
46
+ Chem.rdchem.ChiralType.CHI_OTHER
47
+ ]
48
+ BOND_LIST = [
49
+ BT.SINGLE,
50
+ BT.DOUBLE,
51
+ BT.TRIPLE,
52
+ BT.AROMATIC
53
+ ]
54
+ BONDDIR_LIST = [
55
+ Chem.rdchem.BondDir.NONE,
56
+ Chem.rdchem.BondDir.ENDUPRIGHT,
57
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
58
+ ]
59
+ hybridization_list = ['OTHER', 'S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'UNSPECIFIED']
60
+ hybridization_encoder = OneHotEncoder()
61
+ hybridization_encoder.fit(torch.range(0, len(hybridization_list) - 1).unsqueeze(-1))
62
+
63
+ atom_list = ['H', 'C', 'O', 'S', 'N', 'P', 'F', 'Cl', 'Br', 'I', 'Si']
64
+ atom_encoder = OneHotEncoder()
65
+ atom_encoder.fit(torch.range(0, len(atom_list) - 1).unsqueeze(-1))
66
+
67
+ chirarity_encoder = OneHotEncoder()
68
+ chirarity_encoder.fit(torch.range(0, len(CHIRALITY_LIST) - 1).unsqueeze(-1))
69
+
70
+ def get_data_list(mol_list):
71
+ data_list = []
72
+ # mol = Chem.MolFromInchi(inchi, sanitize=False, removeHs=False)
73
+ # mol = Chem.AddHs(mol)
74
+ for mol in mol_list:
75
+ weights = []
76
+ type_idx = []
77
+ chirality_idx = []
78
+ atomic_number = []
79
+ degrees = []
80
+ total_degrees = []
81
+ formal_charges = []
82
+ hybridization_types = []
83
+ explicit_valences = []
84
+ implicit_valences = []
85
+ total_valences = []
86
+ atom_map_nums = []
87
+ isotopes = []
88
+ radical_electrons = []
89
+ inrings = []
90
+ atom_is_aromatic = []
91
+
92
+ for atom in mol.GetAtoms():
93
+ atom_is_aromatic.append(atom.GetIsAromatic())
94
+
95
+ type_idx.append(atom_list.index(atom.GetSymbol()))
96
+ chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
97
+ atomic_number.append(atom.GetAtomicNum())
98
+ degrees.append(atom.GetDegree())
99
+ weights.append(atom.GetMass())
100
+ total_degrees.append(atom.GetTotalDegree())
101
+ formal_charges.append(atom.GetFormalCharge())
102
+ hybridization_types.append(hybridization_list.index(str(atom.GetHybridization())))
103
+ explicit_valences.append(atom.GetExplicitValence())
104
+ implicit_valences.append(atom.GetImplicitValence())
105
+ total_valences.append(atom.GetTotalValence())
106
+ atom_map_nums.append(atom.GetAtomMapNum())
107
+ isotopes.append(atom.GetIsotope())
108
+ radical_electrons.append(atom.GetNumRadicalElectrons())
109
+ inrings.append(int(atom.IsInRing()))
110
+
111
+ x1 = torch.tensor(type_idx, dtype=torch.float32).view(-1, 1)
112
+ x2 = torch.tensor(chirality_idx, dtype=torch.float32).view(-1, 1)
113
+ x3 = torch.tensor(weights, dtype=torch.float32).view(-1, 1)
114
+ x4 = torch.tensor(degrees, dtype=torch.float32).view(-1, 1)
115
+ x5 = torch.tensor(total_degrees, dtype=torch.float32).view(-1, 1)
116
+ x6 = torch.tensor(formal_charges, dtype=torch.float32).view(-1, 1)
117
+ x7 = torch.tensor(hybridization_types, dtype=torch.float32).view(-1, 1)
118
+ x8 = torch.tensor(explicit_valences, dtype=torch.float32).view(-1, 1)
119
+ x9 = torch.tensor(implicit_valences, dtype=torch.float32).view(-1, 1)
120
+ x10 = torch.tensor(total_valences, dtype=torch.float32).view(-1, 1)
121
+ x11 = torch.tensor(atom_map_nums, dtype=torch.float32).view(-1, 1)
122
+ x12 = torch.tensor(isotopes, dtype=torch.float32).view(-1, 1)
123
+ x13 = torch.tensor(radical_electrons, dtype=torch.float32).view(-1, 1)
124
+ x14 = torch.tensor(inrings, dtype=torch.float32).view(-1, 1)
125
+ # x15 = torch.tensor(atom_is_aromatic, dtype=torch.float32).view(-1, 1)
126
+
127
+ # x = [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14]
128
+
129
+ x = torch.cat([torch.tensor(atom_encoder.transform(x1).toarray(), dtype=torch.float32),
130
+ torch.tensor(chirarity_encoder.transform(x2).toarray(), dtype=torch.float32),
131
+ x3,
132
+ x4,
133
+ x5,
134
+ x6,
135
+ torch.tensor(hybridization_encoder.transform(x7).toarray(), dtype=torch.float32),
136
+ x8,
137
+ x9,
138
+ x10,
139
+ x11,
140
+ x12,
141
+ x13,
142
+ x14, ], dim=-1)
143
+
144
+ row, col, edge_feat = [], [], []
145
+ for bond in mol.GetBonds():
146
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
147
+ row += [start, end]
148
+ col += [end, start]
149
+ edge_feat.append([
150
+ BOND_LIST.index(bond.GetBondType()),
151
+ BONDDIR_LIST.index(bond.GetBondDir()),
152
+ float(int(bond.IsInRing())),
153
+ float(int(bond.GetIsAromatic())),
154
+ float(int(bond.GetIsConjugated()))
155
+ ])
156
+ edge_feat.append([
157
+ BOND_LIST.index(bond.GetBondType()),
158
+ BONDDIR_LIST.index(bond.GetBondDir()),
159
+ float(int(bond.IsInRing())),
160
+ float(int(bond.GetIsAromatic())),
161
+ float(int(bond.GetIsConjugated()))
162
+ ])
163
+ edge_index = torch.tensor([row, col], dtype=torch.long)
164
+ edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.float32)
165
+ fingerprint = torch.tensor(AllChem.GetMorganFingerprintAsBitVect(mol, 2), dtype=torch.float32)
166
+ data = Data(x=x,
167
+ edge_index=edge_index,
168
+ edge_attr=edge_attr,
169
+ fingerprint=fingerprint,)
170
+ data_list.append(data)
171
+ return data_list
172
+
173
+ class GraphTransformerBlock(nn.Module):
174
+ def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs):
175
+ super(GraphTransformerBlock, self).__init__(**kwargs)
176
+ self.edge_dim = edge_dim
177
+ self.in_channels = in_channels
178
+ self.out_channels = out_channels
179
+
180
+ self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim)
181
+ self.linear = nn.Linear(heads * out_channels, out_channels)
182
+ self.layerNorm = nn.LayerNorm(out_channels)
183
+ self.dropout = dropout
184
+
185
+ def forward(self, x, edge_index, edge_attr):
186
+
187
+ x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr)
188
+ x_gat = self.linear(x_gat)
189
+ x_gat = self.layerNorm(x + x_gat)
190
+
191
+ return F.dropout(x_gat, self.dropout, training=self.training)
192
+
193
+
194
+ class GraphTransformerBlock2(nn.Module):
195
+ def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs):
196
+ super(GraphTransformerBlock2, self).__init__(**kwargs)
197
+ self.edge_dim = edge_dim
198
+ self.in_channels = in_channels
199
+ self.out_channels = out_channels
200
+
201
+ self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim)
202
+ self.linear1 = nn.Linear(heads * out_channels, out_channels)
203
+ self.layerNorm1 = nn.LayerNorm(out_channels)
204
+ self.linear2 = nn.Linear(out_channels, out_channels)
205
+ self.layerNorm2 = nn.LayerNorm(out_channels)
206
+ self.dropout = dropout
207
+
208
+ def forward(self, x, edge_index, edge_attr):
209
+ x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr)
210
+ x_gat = self.linear1(x_gat)
211
+ x_gat = self.layerNorm1(x + x_gat)
212
+ linear_ = self.linear2(x_gat)
213
+ linear_ = self.layerNorm2(linear_ + x_gat)
214
+
215
+ return F.dropout(linear_, self.dropout, training=self.training)
216
+
217
+ class Trainer(object):
218
+ def __init__(self, model, lr, device):
219
+ self.model = model
220
+ from torch import optim
221
+ self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
222
+ torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=10,
223
+ verbose=False, threshold=0.0001, threshold_mode='rel',
224
+ cooldown=0, min_lr=0, eps=1e-08)
225
+
226
+ self.device = device
227
+
228
+ def train(self, data_loader):
229
+ criterion = torch.nn.L1Loss()
230
+ for i, data in enumerate(data_loader):
231
+ data.to(self.device)
232
+ y_hat = self.model(data)
233
+ loss = criterion(y_hat, data.y)
234
+ self.optimizer.zero_grad()
235
+ loss.backward()
236
+ self.optimizer.step()
237
+ return 0
238
+
239
+
240
+ class Tester(object):
241
+ def __init__(self, model, device):
242
+ self.model = model
243
+ self.device = device
244
+
245
+ def test_regressor(self, data_loader):
246
+ y_true = []
247
+ y_pred = []
248
+ with torch.no_grad():
249
+ for data in data_loader:
250
+ data.to(self.device, non_blocking=True)
251
+ y_hat = self.model(data)
252
+ # total_loss += torch.abs(y_hat - data.y).sum()
253
+ # mre_total = torch.div(torch.abs(y_hat - data.y), data.y).sum()
254
+ y_true.append(data.y)
255
+ y_pred.append(y_hat)
256
+
257
+ y_true = torch.concat(y_true)
258
+ y_pred = torch.concat(y_pred)
259
+
260
+ mae = torch.abs(y_true - y_pred).mean()
261
+ # mre = torch.div(torch.abs(y_true - y_pred), y_true).mean()
262
+ # medAE = torch.median(torch.abs(y_true - y_pred))
263
+ # medRE = torch.median(torch.div(torch.abs(y_true - y_pred), y_true))
264
+ #
265
+ # score = torchmetrics.R2Score().to(self.device)
266
+ # r2 = score(y_pred, y_true)
267
+ # return mae.item(), medAE.item(), mre.item(), medRE.item(), r2.item()
268
+ return mae.item()
269
+
270
+
271
+ class MyNet(nn.Module):
272
+ def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'):
273
+ super(MyNet, self).__init__()
274
+ self.emb_dim = emb_dim
275
+ self.feat_dim = feat_dim
276
+ self.drop_ratio = drop_ratio
277
+
278
+ self.in_linear = nn.Linear(34, emb_dim)
279
+
280
+ self.conv1 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
281
+ self.conv2 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
282
+ self.conv3 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
283
+ self.conv4 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
284
+ self.conv5 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
285
+ self.conv6 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
286
+ self.conv7 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
287
+ self.conv8 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
288
+ self.conv9 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
289
+
290
+ if pool == 'mean':
291
+ self.pool = global_mean_pool
292
+ elif pool == 'max':
293
+ self.pool = global_max_pool
294
+ elif pool == 'add':
295
+ self.pool = global_add_pool
296
+
297
+ self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
298
+
299
+ self.out_lin = nn.Sequential(
300
+ nn.Linear(self.feat_dim, self.feat_dim // 8),
301
+ nn.ReLU(inplace=True),
302
+ nn.Linear(self.feat_dim // 8, self.feat_dim // 64),
303
+ nn.ReLU(inplace=True),
304
+ nn.Linear(self.feat_dim // 64, 1),
305
+ )
306
+
307
+ self.conv1d1 = OneDimConvBlock()
308
+ self.conv1d2 = OneDimConvBlock()
309
+ self.conv1d3 = OneDimConvBlock()
310
+ self.conv1d4 = OneDimConvBlock()
311
+ self.conv1d5 = OneDimConvBlock()
312
+ self.conv1d6 = OneDimConvBlock()
313
+ self.conv1d7 = OneDimConvBlock()
314
+ self.conv1d8 = OneDimConvBlock()
315
+ self.conv1d9 = OneDimConvBlock()
316
+ self.conv1d10 = OneDimConvBlock()
317
+ self.conv1d11 = OneDimConvBlock()
318
+ self.conv1d12 = OneDimConvBlock()
319
+
320
+ self.preconcat1 = nn.Linear(2048, 1024)
321
+ self.preconcat2 = nn.Linear(1024, self.feat_dim)
322
+
323
+ self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim)
324
+ self.after_cat_drop = nn.Dropout(self.drop_ratio)
325
+
326
+ def forward(self, data):
327
+ x = data.x
328
+ edge_index = data.edge_index
329
+ edge_attr = data.edge_attr
330
+ batch = data.batch
331
+ fringerprint = data.fingerprint.reshape(-1, 2048)
332
+
333
+ h = self.in_linear(x)
334
+
335
+ h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True)
336
+ h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True)
337
+ h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True)
338
+ h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True)
339
+ h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True)
340
+ h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True)
341
+ h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True)
342
+ h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True)
343
+ h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True)
344
+
345
+ fringerprint = self.conv1d1(fringerprint)
346
+ fringerprint = self.conv1d2(fringerprint)
347
+ fringerprint = self.conv1d3(fringerprint)
348
+ fringerprint = self.conv1d4(fringerprint)
349
+ fringerprint = self.conv1d5(fringerprint)
350
+ fringerprint = self.conv1d6(fringerprint)
351
+ fringerprint = self.conv1d7(fringerprint)
352
+ fringerprint = self.conv1d8(fringerprint)
353
+ fringerprint = self.conv1d9(fringerprint)
354
+ fringerprint = self.conv1d10(fringerprint)
355
+ fringerprint = self.conv1d11(fringerprint)
356
+ fringerprint = self.conv1d12(fringerprint)
357
+ fringerprint = self.preconcat1(fringerprint)
358
+ fringerprint = self.preconcat2(fringerprint)
359
+
360
+ h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
361
+ h = self.pool(h, batch)
362
+ h = self.feat_lin(h)
363
+
364
+ concat = torch.concat([h, fringerprint], dim=-1)
365
+ concat = self.afterconcat1(concat)
366
+ concat = self.after_cat_drop(concat)
367
+
368
+ out = self.out_lin(concat)
369
+
370
+ return out.squeeze()
371
+
372
+
373
+ class OneDimConvBlock(nn.Module):
374
+ def __init__(self, in_channel=2048, out_channel=2048):
375
+ super().__init__()
376
+ self.attention_conv = OneDimAttention(in_channel, in_channel)
377
+ self.batchnorm1 = torch.nn.BatchNorm1d(in_channel)
378
+ self.batchnorm2 = torch.nn.BatchNorm1d(in_channel)
379
+ self.linear1 = nn.Linear(in_channel, in_channel)
380
+ self.linear2 = nn.Linear(in_channel, out_channel)
381
+ self.ffn = nn.Sequential(
382
+ nn.Linear(in_channel, in_channel),
383
+ nn.ReLU(),
384
+ nn.Linear(in_channel, in_channel),
385
+ nn.ReLU()
386
+ )
387
+
388
+ def forward(self, x):
389
+ h = self.attention_conv(x, x, x)
390
+ h = self.batchnorm1(x + h)
391
+
392
+ h_new = self.ffn(h)
393
+ h_new = self.batchnorm2(h + h_new)
394
+ return F.dropout1d(self.linear2(h_new), training=self.training)
395
+
396
+
397
+ class OneDimAttention(nn.Module):
398
+ def __init__(self, in_size, out_size):
399
+ super().__init__()
400
+ self.in_size = torch.tensor(in_size)
401
+ self.out_size = out_size
402
+ self.linear = nn.Linear(in_size, out_size)
403
+
404
+ def forward(self, q, k, v):
405
+ attention = torch.mul(q, k) / torch.sqrt(self.in_size)
406
+ attention = self.linear(attention)
407
+ return torch.mul(F.softmax(attention, dim=-1), v)
408
+
409
+
410
+ class MyNetTest(nn.Module):
411
+ def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'):
412
+ super(MyNetTest, self).__init__()
413
+ self.emb_dim = emb_dim
414
+ self.feat_dim = feat_dim
415
+ self.drop_ratio = drop_ratio
416
+
417
+ self.in_linear = nn.Linear(34, emb_dim)
418
+
419
+ self.conv1 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
420
+ self.conv2 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
421
+ self.conv3 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
422
+ self.conv4 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
423
+ self.conv5 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
424
+ self.conv6 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
425
+ self.conv7 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
426
+ self.conv8 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
427
+ self.conv9 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
428
+
429
+ if pool == 'mean':
430
+ self.pool = global_mean_pool
431
+ elif pool == 'max':
432
+ self.pool = global_max_pool
433
+ elif pool == 'add':
434
+ self.pool = global_add_pool
435
+
436
+ self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
437
+
438
+ self.out_lin = nn.Sequential(
439
+ nn.Linear(self.feat_dim, self.feat_dim // 8),
440
+ nn.ReLU(inplace=True),
441
+ nn.Linear(self.feat_dim // 8, self.feat_dim // 64),
442
+ nn.ReLU(inplace=True),
443
+ nn.Linear(self.feat_dim // 64, 1),
444
+ )
445
+
446
+ self.conv1d1 = OneDimConvBlock()
447
+ self.conv1d2 = OneDimConvBlock()
448
+ self.conv1d3 = OneDimConvBlock()
449
+ self.conv1d4 = OneDimConvBlock()
450
+ self.conv1d5 = OneDimConvBlock()
451
+ self.conv1d6 = OneDimConvBlock()
452
+ self.conv1d7 = OneDimConvBlock()
453
+ self.conv1d8 = OneDimConvBlock()
454
+ self.conv1d9 = OneDimConvBlock()
455
+ self.conv1d10 = OneDimConvBlock()
456
+ self.conv1d11 = OneDimConvBlock()
457
+ self.conv1d12 = OneDimConvBlock()
458
+
459
+ self.preconcat1 = nn.Linear(2048, 1024)
460
+ self.preconcat2 = nn.Linear(1024, self.feat_dim)
461
+
462
+ self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim)
463
+ self.after_cat_drop = nn.Dropout(self.drop_ratio)
464
+
465
+ def forward(self, data):
466
+ x = data.x
467
+ edge_index = data.edge_index
468
+ edge_attr = data.edge_attr
469
+ batch = data.batch
470
+ fringerprint = data.fingerprint.reshape(-1, 2048)
471
+
472
+ h = self.in_linear(x)
473
+
474
+ h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True)
475
+ h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True)
476
+ h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True)
477
+ h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True)
478
+ h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True)
479
+ h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True)
480
+ h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True)
481
+ h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True)
482
+ h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True)
483
+
484
+ fringerprint = self.conv1d1(fringerprint)
485
+ fringerprint = self.conv1d2(fringerprint)
486
+ fringerprint = self.conv1d3(fringerprint)
487
+ fringerprint = self.conv1d4(fringerprint)
488
+ fringerprint = self.conv1d5(fringerprint)
489
+ fringerprint = self.conv1d6(fringerprint)
490
+ fringerprint = self.conv1d7(fringerprint)
491
+ fringerprint = self.conv1d8(fringerprint)
492
+ fringerprint = self.conv1d9(fringerprint)
493
+ fringerprint = self.conv1d10(fringerprint)
494
+ fringerprint = self.conv1d11(fringerprint)
495
+ fringerprint = self.conv1d12(fringerprint)
496
+ fringerprint = self.preconcat1(fringerprint)
497
+ fringerprint = self.preconcat2(fringerprint)
498
+
499
+ h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
500
+ h = self.pool(h, batch)
501
+ h = self.feat_lin(h)
502
+
503
+ concat = torch.concat([h, fringerprint], dim=-1)
504
+ concat = self.afterconcat1(concat)
505
+ concat = self.after_cat_drop(concat)
506
+
507
+ out = self.out_lin(concat)
508
+
509
+ return out.squeeze()
510
+
511
+
512
+ model = MyNet(emb_dim=512, feat_dim=512)
513
+ state = torch.load('./best_state_download_dict.pth')
514
+ model.load_state_dict(state)
515
+ model.eval()
516
+ try:
517
+ os.mkdir('./save_df/')
518
+ except:
519
+ pass
520
+
521
+ def get_rt_from_mol(mol):
522
+ data_list = get_data_list([mol])
523
+ loader = DataLoader(data_list,batch_size=1)
524
+ for batch in loader:
525
+ break
526
+ return model(batch).item()
527
+
528
+ def pred_file_btyes(file_bytes,progress=gr.Progress()):
529
+ progress(0,desc='Starting')
530
+ file_name = os.path.join(
531
+ './save_df/',
532
+ (hashlib.md5(str(file_bytes).encode('utf-8')).hexdigest()+'.csv')
533
+ )
534
+ if os.path.exists(file_name):
535
+ print('该文件已经存在')
536
+ return file_name
537
+ with open('temp.sdf','bw') as f:
538
+ f.write(file_bytes)
539
+ sup = Chem.SDMolSupplier('temp.sdf')
540
+ df = pd.DataFrame(columns=['InChI','Predicted RT'])
541
+ for mol in progress.tqdm(sup):
542
+ try:
543
+ inchi = Chem.MolToInchi(mol)
544
+ rt = get_rt_from_mol(mol)
545
+ df.loc[len(df)] = [inchi,rt]
546
+ except:
547
+ pass
548
+
549
+ df.to_csv(file_name)
550
+ return file_name
551
+
552
+ demo = gr.Interface(
553
+ pred_file_btyes,
554
+ gr.File(type='binary'),
555
+ gr.File(type='filepath'),
556
+ title='RT-Transformer Rentention Time Predictor',
557
+ description='Input SDF Molecule File,Predicted RT output with a CSV File',
558
+ )
559
+
560
+
561
+ if __name__ == "__main__":
562
+ demo.launch()