OliXio commited on
Commit
d5233a9
·
verified ·
1 Parent(s): bcfb88f

Upload 13 files

Browse files
code/GNN/__init__.py ADDED
File without changes
code/GNN/featurizer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rdkit import Chem
3
+ from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
4
+ from utils import *
5
+ import torch
6
+ import copy
7
+ from . import subgraphfp as subfp
8
+
9
+ PERIODIC_TABLE = Chem.GetPeriodicTable()
10
+ POSSIBLE_ATOMS = ['H', 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br','I', 'B']
11
+ HYBRIDS = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
12
+ Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
13
+ CHIRALS = [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
14
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER]
15
+ BOND_TYPES = [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC ]
16
+
17
+ def one_of_k_encoding(x, allowable_set):
18
+ if x not in allowable_set:
19
+ raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
20
+ return list(map(lambda s: x == s, allowable_set))
21
+
22
+ def one_of_k_encoding_unk(x, allowable_set):
23
+ """Maps inputs not in the allowable set to the last element."""
24
+ if x not in allowable_set:
25
+ x = allowable_set[-1]
26
+
27
+ return list(map(lambda s: x == s, allowable_set))
28
+
29
+ def calc_atom_features_onehot(atom, feature):
30
+ '''
31
+ Method that computes atom level features from rdkit atom object
32
+ '''
33
+ atom_features = one_of_k_encoding_unk(atom.GetSymbol(), POSSIBLE_ATOMS)
34
+ atom_features += one_of_k_encoding_unk(atom.GetExplicitValence(), list(range(7)))
35
+ atom_features += one_of_k_encoding_unk(atom.GetImplicitValence(), list(range(7)))
36
+ atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(), list(range(5)))
37
+ atom_features += one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), list(range(5)))
38
+ atom_features += one_of_k_encoding_unk(atom.GetTotalDegree(), list(range(7)))
39
+ atom_features += one_of_k_encoding_unk(atom.GetFormalCharge(), list(range(-2, 3)))
40
+ atom_features += one_of_k_encoding_unk(atom.GetHybridization(), HYBRIDS)
41
+ atom_features += one_of_k_encoding_unk(atom.GetIsAromatic(), [False, True])
42
+ atom_features += one_of_k_encoding_unk(atom.IsInRing(), [False, True])
43
+ atom_features += one_of_k_encoding_unk(atom.GetChiralTag(), CHIRALS)
44
+ atom_features += one_of_k_encoding_unk(atom.HasProp('_CIPCode'), ['R', 'S'])
45
+ atom_features += [PERIODIC_TABLE.GetRvdw(atom.GetSymbol())]
46
+ atom_features += [atom.HasProp('_ChiralityPossible')]
47
+ atom_features += [atom.GetAtomicNum()]
48
+ atom_features += [atom.GetMass() * 0.01]
49
+ atom_features += [atom.GetDegree()]
50
+ atom_features += [int(i) for i in list('{0:06b}'.format(feature))]
51
+
52
+ return atom_features
53
+
54
+ def calc_adjacent_tensor(bonds, atom_num, with_ring_conj=False):
55
+ '''
56
+ Method that constructs a AdjecentTensor with many AdjecentMatrics
57
+ :param bonds: bonds of a rdkit mol
58
+ :param atom_num: the atom number of the rdkit mol
59
+ :param with_ring_conj: should the AdjecentTensor contains bond in ring and
60
+ is conjugated info
61
+ :return: AdjecentTensor A shaped [N, F, N], where N is atom number and F is bond types
62
+ '''
63
+ bond_types = len(BOND_TYPES)
64
+ if with_ring_conj:
65
+ bond_types += 2
66
+
67
+ A = np.zeros([atom_num, bond_types, atom_num])
68
+
69
+ for bond in bonds:
70
+ b, e = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
71
+ try:
72
+ bond_type = BOND_TYPES.index(bond.GetBondType())
73
+ A[b, bond_type, e] = 1
74
+ A[e, bond_type, b] = 1
75
+ if with_ring_conj:
76
+ if bond.IsInRing():
77
+ A[b, bond_types-2, e] = 1
78
+ A[e, bond_types-2, b] = 1
79
+ if bond.GetIsConjugated():
80
+ A[b, bond_types-1, e] = 1
81
+ A[e, bond_types-1, b] = 1
82
+ except:
83
+ pass
84
+ return A
85
+
86
+ def calc_data_from_smile(smiles, addh=False, with_ring_conj=False, with_atom_feats=True, with_submol_fp=True, radius=2):
87
+ '''
88
+ Method that constructs the data of a molecular.
89
+ :param smiles: SMILES representation of a molecule
90
+ :param addh: should we add all the Hs of the mol
91
+ :param with_ring_conj: should the AdjecentTensor contains bond in ring and
92
+ is conjugated info
93
+ :return: V, A, global_state, mol_size, subgraph_size
94
+ '''
95
+ mol = Chem.MolFromSmiles(smiles, sanitize=True)
96
+ #mol.UpdatePropertyCache(strict=False)
97
+
98
+ if addh:
99
+ mol = Chem.AddHs(mol)
100
+ #else:
101
+ # mol = Chem.RemoveHs(mol, sanitize=False)
102
+
103
+ mol_size = torch.IntTensor([mol.GetNumAtoms()])
104
+
105
+ V = []
106
+
107
+ if with_atom_feats:
108
+ features = rdDesc.GetFeatureInvariants(mol)
109
+
110
+ submoldict = {}
111
+ if with_submol_fp:
112
+ atoms, submols = subfp.get_atom_submol_radn(mol, radius, sanitize=True)
113
+ submoldict = dict(zip([a.GetIdx() for a in atoms], submols))
114
+
115
+ for i in range(mol.GetNumAtoms()):
116
+ atom_i = mol.GetAtomWithIdx(i)
117
+ if with_atom_feats:
118
+ atom_i_features = calc_atom_features_onehot(atom_i, features[i])
119
+ else:
120
+ atom_i_features = []
121
+
122
+ if with_submol_fp:
123
+ submol = submoldict[i]
124
+ #print(Chem.MolToSmiles(submol))
125
+ submolfp = subfp.gen_fps_from_mol(submol)
126
+ atom_i_features.extend(submolfp)
127
+
128
+ V.append(atom_i_features)
129
+
130
+ V = torch.FloatTensor(V)
131
+
132
+ if len(V.shape) != 2:
133
+ return None
134
+
135
+ A = calc_adjacent_tensor(mol.GetBonds(), mol.GetNumAtoms(), with_ring_conj)
136
+ A = torch.FloatTensor(A)
137
+
138
+ return {'V': V, 'A': A, 'mol_size': mol_size}
code/GNN/layers.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import utils
7
+ import pickle
8
+
9
+ DEVICE = torch.cuda.is_available() and torch.device('cuda') or torch.device('cpu')
10
+
11
+ class GraphCNNLayer(nn.Module):
12
+ def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True):
13
+ super(GraphCNNLayer, self).__init__()
14
+ self.n_feats = n_feats
15
+ self.adj_chans = adj_chans
16
+ self.n_filters = n_filters
17
+ self.has_bias = bias
18
+
19
+ # [C*L, F], C = n_feats, L = adj_chans, F = n_filters; this is for the edge feats
20
+ self.weight_e = nn.Parameter(torch.FloatTensor(adj_chans*n_feats, n_filters))
21
+ # [C, F], this is for 𝐈𝐕in𝐖0
22
+ self.weight_i = nn.Parameter(torch.FloatTensor(n_feats, self.n_filters))
23
+
24
+ if bias:
25
+ self.bias = nn.Parameter(torch.FloatTensor(n_filters))
26
+ else:
27
+ self.register_parameter('bias', None)
28
+
29
+ self.reset_parameters()
30
+
31
+ def reset_parameters(self):
32
+ nn.init.xavier_uniform_(self.weight_e)
33
+ nn.init.xavier_uniform_(self.weight_i)
34
+
35
+ if self.bias is not None:
36
+ self.bias.data.fill_(0.01)
37
+
38
+ def forward(self, V, A):
39
+ '''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
40
+ b, N, C = V.shape
41
+ b, N, L, _ = A.shape
42
+
43
+ # formula: 𝐕out = 𝐈𝐕in𝐖0 + GConv(𝐕in, 𝐹) + 𝐛; 𝐈𝐕in = 𝐕in, so 𝐈𝐕in𝐖0 = 𝐕in𝐖0
44
+
45
+ # A [b, N, L, N] -> [b, N*L, N]
46
+ A_reshape = A.view(-1, N*L, N)
47
+ # [b, N*L, N] * [b, N, C] -> [b, N*L, C]
48
+ n = torch.bmm(A_reshape, V)
49
+ # [b, N*L, C] -> [b, N, L*C]
50
+ n = n.view(-1, N, L*self.n_feats)
51
+
52
+ # n [b, N, L*C], W [C*L, F], V [b, N, C], W_I [C, F]
53
+ # -> [b, N, F] + [b, N, F] + b
54
+ output = torch.matmul(n, self.weight_e) + torch.matmul(V, self.weight_i)
55
+
56
+ if self.has_bias:
57
+ output += self.bias
58
+
59
+ # output: [b, N, F]
60
+ return output
61
+
62
+ def __repr__(self):
63
+ return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias}) -> [b, N, {self.n_filters}]'
64
+
65
+ class GraphResidualCNNLayer(nn.Module):
66
+ def __init__(self, n_feats, adj_chans=4, bias=True):
67
+ super(GraphResidualCNNLayer, self).__init__()
68
+ self.n_feats = n_feats
69
+ self.adj_chans = adj_chans
70
+ self.has_bias = bias
71
+
72
+ # [C*L, F], C = n_feats, L = adj_chans
73
+ self.weight_layers = nn.ModuleList([nn.Linear(n_feats, n_feats) for _ in range(adj_chans)])
74
+
75
+ if bias:
76
+ self.bias = nn.Parameter(torch.FloatTensor(n_feats))
77
+ else:
78
+ self.register_parameter('bias', None)
79
+
80
+ self.reset_parameters()
81
+
82
+ def reset_parameters(self):
83
+ if self.bias is not None:
84
+ self.bias.data.fill_(0.01)
85
+
86
+ def forward(self, V, A):
87
+ '''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
88
+ b, N, C = V.shape
89
+ b, N, L, _ = A.shape
90
+
91
+ for i in range(self.adj_chans):
92
+ # [b, N, C] -> [b, N, C]
93
+ hs = F.relu(self.weight_layers[i](V))
94
+ # [b, N, N]
95
+ a = A[:, :, i, :]
96
+ a = a.view(-1, N, N)
97
+ # [b, N, N] * [b, N, C] -> [b, N, C]
98
+ V = V + torch.bmm(a, hs)
99
+
100
+ if self.has_bias:
101
+ V += self.bias
102
+
103
+ # output: [b, N, C]
104
+ return V
105
+
106
+ def __repr__(self):
107
+ return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_feats}]'
108
+
109
+ class GraphAttentionLayer(nn.Module):
110
+ def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True, dropout=0., alpha=0.2):
111
+ super(GraphAttentionLayer, self).__init__()
112
+ self.n_feats = n_feats
113
+ self.adj_chans = adj_chans
114
+ self.n_filters = n_filters
115
+ self.has_bias = bias
116
+ self.dropout = dropout
117
+ self.alpha = alpha
118
+
119
+ # [C*L, F], C = n_feats, L = adj_chans, F = n_filters; this is for the edge feats
120
+ self.weight_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_feats, n_filters)) for _ in range(adj_chans)])
121
+ self.a1_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
122
+ self.a2_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
123
+
124
+ if bias:
125
+ self.bias = nn.Parameter(torch.FloatTensor(n_filters))
126
+ else:
127
+ self.register_parameter('bias', None)
128
+
129
+ self.reset_parameters()
130
+
131
+ def reset_parameters(self):
132
+ for w in self.weight_list:
133
+ nn.init.xavier_uniform_(w)
134
+ for w in self.a1_list:
135
+ nn.init.xavier_uniform_(w)
136
+ for w in self.a2_list:
137
+ nn.init.xavier_uniform_(w)
138
+ if self.bias is not None:
139
+ self.bias.data.fill_(0.01)
140
+
141
+ def forward(self, V, A):
142
+ '''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
143
+ b, N, C = V.shape
144
+ b, N, L, _ = A.shape
145
+
146
+ output = None
147
+
148
+ # formula: 𝐕out = 𝐈𝐕in𝐖0 + GConv(𝐕in, 𝐹) + 𝐛; 𝐈𝐕in = 𝐕in, so 𝐈𝐕in𝐖0 = 𝐕in𝐖0
149
+ for i in range(self.adj_chans):
150
+ # [b, N, 1, N] -> [b, N, N]
151
+ adj = A[:, :, i, :].view(-1, N, N)
152
+
153
+ # [b, N, C] * [C, F] -> [b, N, F]
154
+ h = torch.matmul(V, self.weight_list[i])
155
+ # [b, N, F] * [F, 1] -> [b, N, 1]
156
+ f_1 = torch.matmul(h, self.a1_list[i])
157
+ # [b, N, F] * [F, 1] -> [b, N, 1]
158
+ f_2 = torch.matmul(h, self.a2_list[i])
159
+
160
+ # leaky_relu([b, N, 1] + [b, 1, N]) -> [b, N, N]
161
+ e = F.leaky_relu(f_1 + f_2.transpose(1, 2), self.alpha)
162
+
163
+ zero_vec = -9e15 * torch.ones_like(e)
164
+ # [b, N, N]
165
+ att = torch.where(adj > 0, e, zero_vec)
166
+ att = F.softmax(att, dim=1)
167
+ att = F.dropout(att, self.dropout, training=self.training)
168
+ # [b, N, N] * [b, N, F] -> [b, N, F]
169
+ if output is None:
170
+ output = torch.matmul(att, h)
171
+ else:
172
+ output += torch.matmul(att, h)
173
+
174
+ if self.has_bias:
175
+ output += self.bias
176
+
177
+ # output: [b, N, F]
178
+ return output
179
+
180
+ def __repr__(self):
181
+ return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias},dropout={self.dropout},alpha={self.alpha}) -> [b, N, {self.n_filters}]'
182
+
183
+ class GraphNodeCatGlobalFeatures(nn.Module):
184
+ def __init__(self, global_feats, out_feats, mols=1, bias=True):
185
+ super(GraphNodeCatGlobalFeatures, self).__init__()
186
+ self.global_feats = global_feats
187
+ self.out_feats = out_feats
188
+ self.mols = mols
189
+ self.has_bias = bias
190
+
191
+ self.weights = nn.ParameterList([nn.Parameter(torch.FloatTensor(int(global_feats/mols), out_feats)) for _ in range(mols)])
192
+
193
+ self.biass = []
194
+ if bias:
195
+ self.biass = nn.ParameterList([nn.Parameter(torch.FloatTensor(out_feats)) for _ in range(mols)])
196
+ else:
197
+ self.register_parameter('bias', None)
198
+
199
+ self.reset_parameters()
200
+
201
+ def reset_parameters(self):
202
+ for weight in self.weights:
203
+ nn.init.xavier_uniform_(weight)
204
+ for bias in self.biass:
205
+ bias.data.fill_(0.01)
206
+
207
+ def forward(self, V, global_state, graph_size, subgraph_size=None):
208
+ # V: [b, N, Ov], global_state: [b, F], subgraph_size: [b, mols]
209
+ b, N, Ov = V.shape
210
+ O = self.out_feats
211
+ if self.mols == 1:
212
+ subgraph_size = graph_size.view(-1, 1)
213
+ global_state = torch.mm(global_state, self.weights[0])
214
+ else:
215
+ # global_state: [b, F] view -> [b*mols, F/mols]
216
+ global_state_view = global_state.view(b*self.mols, -1)
217
+
218
+ # split global_state into that of individual mols
219
+ idxmols = []
220
+ for i in range(self.mols):
221
+ idxmols.append(torch.IntTensor(list(range(i, b*self.mols, self.mols))).to(self.weights[0].device))
222
+
223
+ global_states = []
224
+ for i, idx in enumerate(idxmols):
225
+ # selected global_state of mols from global_state_view [b*mols, F/mols]. Out shape is [b, F/mols]
226
+ gs = global_state_view.index_select(dim=0, index=idx)
227
+ # gs: [b, F/mols] * weight: [F/mols, O] -> [b, O]; F = global_feats, O = out_feats
228
+ gs = torch.mm(gs, self.weights[i])
229
+
230
+ if self.has_bias:
231
+ gs += self.biass[i]
232
+
233
+ global_states.append(F.relu(gs))
234
+
235
+ # convert global_states back to global_state
236
+ # [[b, O] ... ] -> [b, mols*O]
237
+ global_state = torch.cat(global_states, dim=1)
238
+
239
+ # [b, mols*O] || [b, O] -> [b, (mols+1)*O]
240
+ global_state_new = torch.cat([global_state, torch.zeros(b, O).to(self.weights[0].device)], dim=-1)
241
+ # [b*(mols+1), O]
242
+ global_state_new = global_state_new.view(-1, O)
243
+
244
+ repeats = []
245
+ for sz in subgraph_size:
246
+ repeats.extend(sz.tolist() + [N-sz.sum()])
247
+ repeats = torch.tensor(repeats).to(self.weights[0].device)
248
+
249
+ # repeat form [b*(mols+1), O] -> [b*N, O], the content like [m1_feats, m2_feats, ... mn_feats, pads, ...]
250
+ global_state_new = global_state_new.repeat_interleave(repeats, dim=0)
251
+
252
+ # V view: [b*N, Ov], global_state_new: [b*N, O]
253
+ output = torch.cat([V.contiguous().view(-1, Ov), global_state_new], dim=1)
254
+
255
+ # output: [b, N, Ov+O]
256
+ return output.view(-1, N, Ov+O), global_state
257
+
258
+ def __repr__(self):
259
+ return f'{self.__class__.__name__}(global_feats={self.global_feats},out_feats={self.out_feats},bias={self.has_bias}) -> [b, N, {self.global_feats+self.out_feats}], [b, out_feats]'
260
+
261
+ class MultiHeadGlobalAttention(nn.Module):
262
+ '''Input [b, N, C] -> output [b, n_head*C] if concat or else [b, n_head]'''
263
+ def __init__(self, n_feats, n_head=5, alpha=0.2, concat=True, bias=True):
264
+ super(MultiHeadGlobalAttention, self).__init__()
265
+
266
+ self.n_feats = n_feats
267
+ self.n_head = n_head
268
+ self.alpha = alpha
269
+ self.concat = concat
270
+ self.has_bias = bias
271
+
272
+ self.weight = nn.Parameter(torch.FloatTensor(n_feats, n_head*n_feats))
273
+ self.tune_weight = nn.Parameter(torch.FloatTensor(1, n_head, n_feats))
274
+
275
+ if bias:
276
+ self.bias = nn.Parameter(torch.FloatTensor(n_head*n_feats))
277
+ else:
278
+ self.register_parameter('bias', None)
279
+
280
+ self.reset_parameters()
281
+
282
+ def reset_parameters(self):
283
+ nn.init.xavier_uniform_(self.weight)
284
+ nn.init.xavier_uniform_(self.tune_weight)
285
+ if self.bias is not None:
286
+ self.bias.data.fill_(0.01)
287
+
288
+ def forward(self, V, graph_size):
289
+ # Gather V of mols in a batch, after this, the pad was removed.
290
+ #print(248, V.shape, graph_size)
291
+ if V.shape[0] == 1:
292
+ Vg = torch.squeeze(V)
293
+ graph_size = [graph_size]
294
+ else:
295
+ Vg = torch.cat([torch.split(v.view(-1, v.shape[-1]), graph_size[i])[0] for i,v in enumerate(torch.split(V, 1))], dim=0)
296
+
297
+ Vg = torch.matmul(Vg, self.weight)
298
+ if self.has_bias:
299
+ Vg += self.bias
300
+ Vg = Vg.view(-1, self.n_head, self.n_feats)
301
+
302
+ alpha = torch.mul(self.tune_weight, Vg)
303
+ alpha = torch.sum(alpha, dim=-1)
304
+ alpha = F.leaky_relu(alpha, self.alpha) # original code is "alpha = tf.nn.leaky_relu(alpha, alpha=0.2)"
305
+ alpha = utils.segment_softmax(alpha, graph_size)
306
+
307
+ #alpha_collect = torch.mean(alpha, dim=-1) # origin code like this. alpha_collect not used?
308
+ alpha = alpha.view(-1, self.n_head, 1)
309
+ V = torch.mul(Vg, alpha)
310
+
311
+ if self.concat:
312
+ V = utils.segment_sum(V, graph_size)
313
+ V = V.view(-1, self.n_head*self.n_feats)
314
+ else:
315
+ V = torch.mean(V, dim=1)
316
+ V = utils.segment_sum(V, graph_size)
317
+
318
+ return V
319
+
320
+ def __repr__(self):
321
+ if self.concat:
322
+ outc = self.n_head*self.n_feats
323
+ else:
324
+ outc = self.n_head
325
+ return f'{self.__class__.__name__}(n_feats={self.n_feats},n_head={self.n_head},alpha={self.alpha},concat={self.concat},bias={self.has_bias}) -> [b, {outc}]'
326
+
327
+ class GraphEmbedPoolingLayer(nn.Module):
328
+ def __init__(self, n_feats, n_filters=1, mask=None, bias=True):
329
+ super(GraphEmbedPoolingLayer, self).__init__()
330
+ self.n_feats = n_feats
331
+ self.n_filters = n_filters
332
+ self.mask = mask
333
+ self.has_bias = bias
334
+
335
+ self.emb = nn.Linear(n_feats, n_filters, bias=bias)
336
+
337
+ def forward(self, V, A):
338
+ # [b, N, F]
339
+ factors = self.emb(V)
340
+
341
+ if self.mask is not None:
342
+ factors = torch.mul(factors, self.mask)
343
+
344
+ factors = F.softmax(factors, dim=1)
345
+ # [b, N, F] trans -> [b, F, N] * [b, N, C] -> [b, F, C]
346
+ result = torch.matmul(factors.transpose(1, 2).contiguous(), V)
347
+
348
+ if self.n_filters == 1:
349
+ return result.view(-1, self.n_feats), A
350
+
351
+ result_A = A.view(A.shape[0], -1, A.shape[-1])
352
+ result_A = torch.matmul(result_A, factors)
353
+ result_A = result_A.view(A.shape[0], A.shape[-1], -1)
354
+ result_A = torch.matmul(factors.transpose(1, 2).contiguous(), result_A)
355
+ result_A = result_A.view(A.shape[0], self.n_filters, A.shape[2], self.n_filters)
356
+
357
+ return result, result_A
358
+
359
+ def __repr__(self):
360
+ return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mask={self.mask},bias={self.has_bias}) -> [b, {self.n_filters}, {self.n_feats}], [b, {self.n_filters}, L, {self.n_filters}]'
361
+
362
+ class GConvBlockWithGF(nn.Module):
363
+ def __init__( self,
364
+ n_feats,
365
+ n_filters,
366
+ global_feats,
367
+ global_out_feats,
368
+ mols=1,
369
+ adj_chans=4,
370
+ bias=True,
371
+ usegat=False):
372
+
373
+ super(GConvBlockWithGF, self).__init__()
374
+
375
+ self.n_feats = n_feats
376
+ self.n_filters = n_filters
377
+ self.global_out_feats = global_out_feats
378
+ self.global_feats = global_feats
379
+ self.mols = mols
380
+ self.adj_chans = adj_chans
381
+ self.has_bias = bias
382
+ self.usegat = usegat
383
+
384
+ self.broadcast_global_state = GraphNodeCatGlobalFeatures(global_feats, global_out_feats, mols, bias)
385
+ if usegat:
386
+ self.graph_conv = GraphAttentionLayer(n_feats+global_out_feats, adj_chans, n_filters)
387
+ else:
388
+ self.graph_conv = GraphCNNLayer(n_feats+global_out_feats, adj_chans, n_filters, bias)
389
+
390
+ self.bn_global = nn.BatchNorm1d(global_out_feats*mols)
391
+ self.bn_graph = nn.BatchNorm1d(n_filters)
392
+
393
+ def forward(self, V, A, global_state, graph_size, subgraph_size):
394
+ ######## transfer global_state #########
395
+ # V shape from [b, N, C] to [b, N, C+F], F is n_filters
396
+ V, global_state = self.broadcast_global_state(V, global_state, graph_size, subgraph_size)
397
+
398
+ ######## Graph Convolution #########
399
+ # V shape from [b, N, C+F] to [b, N, F1], F1 is n_filters
400
+ V = self.graph_conv(V, A)
401
+ V = self.bn_graph(V.transpose(1, 2).contiguous())
402
+ V = F.relu(V.transpose(1, 2))
403
+
404
+ global_state = F.relu(self.bn_global(global_state))
405
+
406
+ return V, global_state
407
+
408
+ def __repr__(self):
409
+ return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},global_feats={self.global_feats},global_out_feats={self.global_out_feats},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias},usegat={self.usegat}) -> [b, N, {self.n_filters}], [b, {self.global_out_feats*self.mols}]'
410
+
411
+ class GConvBlockNoGF(nn.Module):
412
+ def __init__( self,
413
+ n_feats,
414
+ n_filters,
415
+ mols=1,
416
+ adj_chans=4,
417
+ bias=True):
418
+
419
+ super(GConvBlockNoGF, self).__init__()
420
+
421
+ self.n_feats = n_feats
422
+ self.n_filters = n_filters
423
+ self.mols = mols
424
+ self.adj_chans = adj_chans
425
+ self.has_bias = bias
426
+
427
+ #self.graph_conv = GraphCNNLayer(n_feats+n_filters, adj_chans, n_filters, bias)
428
+ self.graph_conv = GraphCNNLayer(n_feats, adj_chans, n_filters, bias)
429
+
430
+ #self.bn_global = nn.BatchNorm1d(n_filters*mols)
431
+ self.bn_graph = nn.BatchNorm1d(n_filters)
432
+
433
+ def forward(self, V, A):
434
+ ######## Graph Convolution #########
435
+ # V shape from [b, N, C+F] to [b, N, F1], F1 is n_filters
436
+ V = self.graph_conv(V, A)
437
+ V = self.bn_graph(V.transpose(1, 2).contiguous())
438
+ V = F.relu(V.transpose(1, 2))
439
+
440
+ return V
441
+
442
+ def __repr__(self):
443
+ return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_filters}]'
code/GNN/subgraphfp.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import Draw
3
+ from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
4
+ from collections import defaultdict
5
+ import numpy as np
6
+ import os, pickle, hashlib
7
+
8
+ AllChem.SetPreferCoordGen(True)
9
+
10
+ FINGERPRINT_DICT = defaultdict(lambda : len(FINGERPRINT_DICT))
11
+
12
+ ELEMENTS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al',
13
+ 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn',
14
+ 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb',
15
+ 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In',
16
+ 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm',
17
+ 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta',
18
+ 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At',
19
+ 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk',
20
+ 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt',
21
+ 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og']
22
+
23
+ for e in ELEMENTS:
24
+ FINGERPRINT_DICT[e]
25
+
26
+ if os.path.exists('rdkit_fingerprint_list_r1.pkl'):
27
+ l = pickle.load(open('rdkit_fingerprint_list_r1.pkl', 'rb'))
28
+
29
+ for smi in l:
30
+ FINGERPRINT_DICT[smi]
31
+
32
+ print('Len fingerprint_list: %s' %len(FINGERPRINT_DICT)) + len(ELEMENTS)
33
+
34
+ def mol_with_atom_index(mol):
35
+ atoms = mol.GetNumAtoms()
36
+ for idx in range(atoms):
37
+ mol.GetAtomWithIdx(idx).SetProp('molAtomMapNumber', str(mol.GetAtomWithIdx(idx).GetIdx()))
38
+ return mol
39
+
40
+ def prepare_mol_for_drawing(mol):
41
+ try:
42
+ mol_draw = Draw.rdMolDraw2D.PrepareMolForDrawing(mol)
43
+ except Chem.KekulizeException:
44
+ mol_draw = Draw.rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=False)
45
+ Chem.SanitizeMol(mol_draw, Chem.SANITIZE_ALL ^ Chem.SANITIZE_KEKULIZE)
46
+ return mol_draw
47
+
48
+ def get_atom_submol_radn(mol, radius, sanitize=True):
49
+ atoms = []
50
+ submols = []
51
+ #smis = []
52
+ for atom in mol.GetAtoms():
53
+ atoms.append(atom)
54
+ r = radius
55
+ while r > 0:
56
+ try:
57
+ env = Chem.FindAtomEnvironmentOfRadiusN(mol, r, atom.GetIdx())
58
+ amap={}
59
+ submol = Chem.PathToSubmol(mol, env, atomMap=amap)
60
+ if sanitize:
61
+ Chem.SanitizeMol(submol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^Chem.SanitizeFlags.SANITIZE_KEKULIZE)
62
+ #smis.append(Chem.MolToSmiles(submol))
63
+ submols.append(submol)
64
+ break
65
+ except Exception as e:
66
+ print(64, e)
67
+ r -= 1
68
+
69
+ return atoms, submols #, smis
70
+
71
+ def gen_fps_from_mol(mol, nbits=256, use_morgan=True, use_macc=False, use_rdkit=False):
72
+ # morgan
73
+ fp = []
74
+ if use_morgan:
75
+ fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
76
+ fp1 = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
77
+ fp = fp1.tolist()
78
+ if use_macc:
79
+ # MACCSkeys
80
+ fp_vec = MACCSkeys.GenMACCSKeys(mol)
81
+ fp1 = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
82
+ fp.extend(fp1.tolist())
83
+ if use_rdkit:
84
+ fp_vec = Chem.RDKFingerprint(mol)
85
+ fp1 = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
86
+ fp.extend(fp1.tolist())
87
+
88
+ return fp
89
+
90
+ def gen_subgraph_fps_from_str(s, wordsdict={}):
91
+ if s in wordsdict:
92
+ return [wordsdict[s]]
93
+ else:
94
+ return [len(wordsdict)]
95
+
96
+ def gen_subgraph_fps_from_mol(mol, wordsdict={}):
97
+ try:
98
+ k = Chem.MolToSmiles(mol)
99
+ return gen_subgraph_fps_from_str(k, wordsdict)
100
+ except Exception as e:
101
+ print(e)
102
+ return [len(wordsdict)]
103
+
104
+ def calc_subgraph_fps_from_mol(mol, radius=2, nbits=128, use_macc=True, fptype=1, wordsdict={}):
105
+ #atoms, submols, smis = get_atom_submol_radn(mol, radius, True)
106
+ atoms, submols = get_atom_submol_radn(mol, radius, True)
107
+ feats = []
108
+ for idx, submol in enumerate(submols):
109
+ if fptype == 1:
110
+ feat = gen_fps_from_mol(submol, nbits, use_macc)
111
+ feats.append(feat)
112
+ elif fptype == 2:
113
+ feat = gen_subgraph_fps_from_mol(submol, wordsdict)
114
+ feats.append(feat)
115
+
116
+ return np.array(feats)
117
+
118
+ if __name__ == '__main__':
119
+ smi = 'C=C(S)C(N)(O)C'
120
+ smi = 'CC1CCN(CC1N(C)C2=NC=NC3=C2C=CN3)C(=O)CC#N'
121
+
122
+ mol = Chem.MolFromSmiles(smi, sanitize=False)
123
+
124
+ print(calc_subgraph_fps_from_mol(mol, 3))
125
+
126
+ mol = mol_with_atom_index(mol)
127
+ submols = get_atom_submol_radn(mol, 3)
128
+ submols = [prepare_mol_for_drawing(m) for m in submols]
129
+ hl = []
130
+ for idx, m in enumerate(submols):
131
+ for a in m.GetAtoms():
132
+ if int(a.GetProp('molAtomMapNumber')) == idx:
133
+ hl.append([a.GetIdx()])
134
+ break
135
+
136
+ draw = Draw.MolsToGridImage([mol] + submols, highlightAtomLists=[[]] + hl, molsPerRow=5)
137
+ draw.show()
138
+
code/GNN/utils.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def gather(x, indices):
4
+ indices = indices.view(-1, indices.shape[-1]).tolist()
5
+ out = torch.cat([x[i] for i in indices])
6
+
7
+ return out
8
+
9
+ def gather_nd(x, indices):
10
+ newshape = indices.shape[:-1] + x.shape[indices.shape[-1]:]
11
+ indices = indices.view(-1, indices.shape[-1]).tolist()
12
+ out = torch.cat([x[tuple(i)] for i in indices])
13
+
14
+ return out.reshape(newshape)
15
+
16
+ def gen_node_indices(size_list):
17
+ '''generate node index for extraction of nodes of each graph from batched data'''
18
+ node_num = []
19
+ node_range = []
20
+ size_list = [int(i) for i in size_list]
21
+ for i, n in enumerate(size_list):
22
+ node_num.extend([i]*n)
23
+ node_range.extend(list(range(n)))
24
+
25
+ node_num = torch.tensor(node_num)
26
+ node_range = torch.tensor(node_range)
27
+ indices = torch.stack([node_num, node_range], axis=1)
28
+ return indices, node_num, node_range
29
+
30
+ def segment_max(x, size_list):
31
+ size_list = [int(i) for i in size_list]
32
+ return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
33
+
34
+ def segment_sum(x, size_list):
35
+ size_list = [int(i) for i in size_list]
36
+ return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
37
+
38
+ def segment_softmax(gate, size_list):
39
+ segmax = segment_max(gate, size_list)
40
+ # expand segmax shape to alpha shape
41
+ segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
42
+ subtract = gate - segmax_expand
43
+ exp = torch.exp(subtract)
44
+ segsum = segment_sum(exp, size_list)
45
+ # expand segmax shape to alpha shape
46
+ segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
47
+ attention = exp / (segsum_expand + 1e-16)
48
+
49
+ return attention
50
+
51
+ def pad_V(V, max_n):
52
+ N, C = V.shape
53
+ if max_n > N:
54
+ zeros = torch.zeros(max_n-N, C)
55
+ V = torch.cat([V, zeros], dim=0)
56
+ return V
57
+
58
+ def pad_A(A, max_n):
59
+ N, L, _ = A.shape
60
+ if max_n > N:
61
+ zeros = torch.zeros(N, L, max_n-N)
62
+ A = torch.cat([A, zeros], dim=-1)
63
+ zeros = torch.zeros(max_n-N, L, max_n)
64
+ A = torch.cat([A, zeros], dim=0)
65
+
66
+ return A
67
+
68
+ def pad_prot(P, max_n):
69
+ N, = P.shape
70
+ if max_n > N:
71
+ zeros = torch.zeros(max_n-N)
72
+ P = torch.cat([P, zeros], dim=0)
73
+
74
+ return P.type(torch.IntTensor)
75
+
76
+ def create_batch(input, pad=False, device=torch.device('cpu')):
77
+ vl = []
78
+ al = []
79
+ gsl = []
80
+ msl = []
81
+ ssl = []
82
+ lbl = []
83
+ idxs = []
84
+ smis = []
85
+
86
+ for d in input:
87
+ vl.append(d['V'])
88
+ al.append(d['A'])
89
+ gsl.append(d['G'])
90
+ msl.append(d['mol_size'])
91
+ ssl.append(d['subgraph_size'])
92
+ lbl.append(d['label'])
93
+ idxs.append(d['index'])
94
+ smis.append(d['smiles'])
95
+
96
+ if gsl[0] is not None:
97
+ gsl = torch.stack(gsl, dim=0).to(device)
98
+
99
+ if pad:
100
+ max_n = max(map(lambda x:x.shape[0], vl))
101
+ vl1 = []
102
+ for v in vl:
103
+ vl1.append(pad_V(v, max_n))
104
+ al1 = []
105
+ for a in al:
106
+ al1.append(pad_A(a, max_n))
107
+
108
+ return {'V': torch.stack(vl1, dim=0).to(device),
109
+ 'A': torch.stack(al1, dim=0).to(device),
110
+ 'G': gsl,
111
+ 'mol_size': torch.cat(msl, dim=0).to(device),
112
+ 'subgraph_size': torch.stack(ssl, dim=0).to(device),
113
+ 'label': torch.stack(lbl, dim=0).to(device),
114
+ 'index': idxs,
115
+ 'smiles': smis}
116
+
117
+ return {'V': torch.stack(vl, dim=0).to(device),
118
+ 'A': torch.stack(al, dim=0).to(device),
119
+ 'G': gsl,
120
+ 'mol_size': torch.cat(msl, dim=0).to(device),
121
+ 'subgraph_size': torch.stack(ssl, dim=0).to(device),
122
+ 'label': torch.stack(lbl, dim=0).to(device),
123
+ 'index': idxs,
124
+ 'smiles': smis}
125
+
126
+ def create_mol_protein_batch(input, pad=False, device=torch.device('cpu'), pr=True):
127
+ vl = []
128
+ al = []
129
+ gsl = []
130
+ msl = []
131
+ ssl = []
132
+ prot = []
133
+ seq = []
134
+ lbl = []
135
+ idxs = []
136
+ smis = []
137
+ fpl = []
138
+
139
+ for d in input:
140
+ vl.append(d['V'])
141
+ al.append(d['A'])
142
+ gsl.append(d['G'])
143
+ msl.append(d['mol_size'])
144
+ ssl.append(d['subgraph_size'])
145
+ prot.append(d['protein_seq'])
146
+ seq.append(d['protein'])
147
+ lbl.append(d['label'])
148
+ idxs.append(d['index'])
149
+ smis.append(d['smiles'])
150
+ if 'fp' in d:
151
+ fpl.append(d['fp'])
152
+
153
+ if gsl[0] is not None:
154
+ if pad:
155
+ gsl = torch.stack(gsl, dim=0).to(device)
156
+ else:
157
+ gsl = [torch.unsqueeze(g, 0) for g in gsl]
158
+
159
+ if pad:
160
+ max_n = max(map(lambda x:x.shape[0], vl))
161
+ vl1 = []
162
+ if pr:
163
+ print('\tPadding V to max_n:', max_n)
164
+ for v in vl:
165
+ vl1.append(pad_V(v, max_n))
166
+
167
+ al1 = []
168
+ if pr:
169
+ print('\tPadding A to max_n:', max_n)
170
+ for a in al:
171
+ al1.append(pad_A(a, max_n))
172
+
173
+ max_prot = max(map(lambda x:x.shape[0], prot))
174
+ prot1 = []
175
+ if pr:
176
+ print('\tPadding protein_seq to max_n:', max_prot)
177
+ for p in prot:
178
+ prot1.append(pad_prot(p, max_prot))
179
+
180
+ fpt = None
181
+ if fpl:
182
+ fpt = torch.stack(fpl, dim=0).to(device)
183
+
184
+ return {'V': torch.stack(vl1, dim=0).to(device),
185
+ 'A': torch.stack(al1, dim=0).to(device),
186
+ 'G': gsl,
187
+ 'fp': fpt,
188
+ 'mol_size': torch.cat(msl, dim=0).to(device),
189
+ 'subgraph_size': torch.stack(ssl, dim=0).to(device),
190
+ 'protein_seq': torch.stack(prot1, dim=0).to(device),
191
+ 'label': torch.stack(lbl, dim=0).view(-1).to(device),
192
+ 'index': idxs,
193
+ 'smiles': smis,
194
+ 'protein': seq}
195
+
196
+ return {'V': [torch.unsqueeze(v, 0) for v in vl],
197
+ 'A': [torch.unsqueeze(a, 0) for a in al],
198
+ 'G': gsl,
199
+ 'fp': fpt,
200
+ 'mol_size': torch.cat(msl, dim=0).to(device),
201
+ 'subgraph_size': [torch.unsqueeze(s, 0) for s in ssl],
202
+ 'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
203
+ 'label': torch.stack(lbl, dim=0).view(-1).to(device),
204
+ 'index': idxs,
205
+ 'smiles': smis,
206
+ 'protein': seq}
207
+
208
+ def create_mol_protein_fp_batch(input, pad=False, device=torch.device('cpu'), pr=True):
209
+ fp = []
210
+ prot = []
211
+ lbl = []
212
+ idxs = []
213
+ smis = []
214
+
215
+ for d in input:
216
+ fp.append(d['fp'])
217
+ prot.append(d['protein_seq'])
218
+ lbl.append(d['label'])
219
+ idxs.append(d['index'])
220
+ smis.append(d['smiles'])
221
+
222
+ if pad:
223
+ max_prot = max(map(lambda x:x.shape[0], prot))
224
+ prot1 = []
225
+ if pr:
226
+ print('\tPadding protein_seq to max_n:', max_prot)
227
+ for p in prot:
228
+ prot1.append(pad_prot(p, max_prot))
229
+
230
+ return {'fp': torch.stack(fp, dim=0).to(device),
231
+ 'protein_seq': torch.stack(prot1, dim=0).to(device),
232
+ 'label': torch.stack(lbl, dim=0).view(-1).to(device),
233
+ 'index': idxs,
234
+ 'smiles': smis}
235
+
236
+ return {'fp': [torch.unsqueeze(f, 0) for f in fp],
237
+ 'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
238
+ 'label': torch.stack(lbl, dim=0).view(-1).to(device),
239
+ 'index': idxs,
240
+ 'smiles': smis}
code/cliplayers.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+ class LayerNorm(nn.LayerNorm):
157
+ """Subclass torch's LayerNorm to handle fp16."""
158
+
159
+ def forward(self, x: torch.Tensor):
160
+ orig_type = x.dtype
161
+ ret = super().forward(x.type(torch.float32))
162
+ return ret.type(orig_type)
163
+
164
+ class QuickGELU(nn.Module):
165
+ def forward(self, x: torch.Tensor):
166
+ return x * torch.sigmoid(1.702 * x)
167
+
168
+ class ResidualAttentionBlock(nn.Module):
169
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
170
+ super().__init__()
171
+
172
+ self.attn = nn.MultiheadAttention(d_model, n_head)
173
+ self.ln_1 = LayerNorm(d_model)
174
+ self.mlp = nn.Sequential(OrderedDict([
175
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
176
+ ("gelu", QuickGELU()),
177
+ ("c_proj", nn.Linear(d_model * 4, d_model))
178
+ ]))
179
+ self.ln_2 = LayerNorm(d_model)
180
+ self.attn_mask = attn_mask
181
+
182
+ def attention(self, x: torch.Tensor):
183
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
184
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
185
+
186
+ def forward(self, x: torch.Tensor):
187
+ x = x + self.attention(self.ln_1(x))
188
+ x = x + self.mlp(self.ln_2(x))
189
+ return x
190
+
191
+ class Transformer(nn.Module):
192
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193
+ super().__init__()
194
+ self.width = width
195
+ self.layers = layers
196
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197
+
198
+ def forward(self, x: torch.Tensor):
199
+ return self.resblocks(x)
200
+
201
+
202
+ class VisionTransformer(nn.Module):
203
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204
+ super().__init__()
205
+ self.input_resolution = input_resolution
206
+ self.output_dim = output_dim
207
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208
+
209
+ scale = width ** -0.5
210
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
211
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212
+ self.ln_pre = LayerNorm(width)
213
+
214
+ self.transformer = Transformer(width, layers, heads)
215
+
216
+ self.ln_post = LayerNorm(width)
217
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218
+
219
+ def forward(self, x: torch.Tensor):
220
+ x = self.conv1(x) # shape = [*, width, grid, grid]
221
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224
+ x = x + self.positional_embedding.to(x.dtype)
225
+ x = self.ln_pre(x)
226
+
227
+ x = x.permute(1, 0, 2) # NLD -> LND
228
+ x = self.transformer(x)
229
+ x = x.permute(1, 0, 2) # LND -> NLD
230
+
231
+ x = self.ln_post(x[:, 0, :])
232
+
233
+ if self.proj is not None:
234
+ x = x @ self.proj
235
+
236
+ return x
237
+
238
+
239
+ class CLIP(nn.Module):
240
+ def __init__(self,
241
+ embed_dim: int,
242
+ # vision
243
+ image_resolution: int,
244
+ vision_layers: Union[Tuple[int, int, int, int], int],
245
+ vision_width: int,
246
+ vision_patch_size: int,
247
+ # text
248
+ context_length: int,
249
+ vocab_size: int,
250
+ transformer_width: int,
251
+ transformer_heads: int,
252
+ transformer_layers: int
253
+ ):
254
+ super().__init__()
255
+
256
+ self.context_length = context_length
257
+
258
+ if isinstance(vision_layers, (tuple, list)):
259
+ vision_heads = vision_width * 32 // 64
260
+ self.visual = ModifiedResNet(
261
+ layers=vision_layers,
262
+ output_dim=embed_dim,
263
+ heads=vision_heads,
264
+ input_resolution=image_resolution,
265
+ width=vision_width
266
+ )
267
+ else:
268
+ vision_heads = vision_width // 64
269
+ self.visual = VisionTransformer(
270
+ input_resolution=image_resolution,
271
+ patch_size=vision_patch_size,
272
+ width=vision_width,
273
+ layers=vision_layers,
274
+ heads=vision_heads,
275
+ output_dim=embed_dim
276
+ )
277
+
278
+ self.transformer = Transformer(
279
+ width=transformer_width,
280
+ layers=transformer_layers,
281
+ heads=transformer_heads,
282
+ attn_mask=self.build_attention_mask()
283
+ )
284
+
285
+ self.vocab_size = vocab_size
286
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288
+ self.ln_final = LayerNorm(transformer_width)
289
+
290
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292
+
293
+ self.initialize_parameters()
294
+
295
+ def initialize_parameters(self):
296
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
297
+ nn.init.normal_(self.positional_embedding, std=0.01)
298
+
299
+ if isinstance(self.visual, ModifiedResNet):
300
+ if self.visual.attnpool is not None:
301
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
302
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306
+
307
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308
+ for name, param in resnet_block.named_parameters():
309
+ if name.endswith("bn3.weight"):
310
+ nn.init.zeros_(param)
311
+
312
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313
+ attn_std = self.transformer.width ** -0.5
314
+ fc_std = (2 * self.transformer.width) ** -0.5
315
+ for block in self.transformer.resblocks:
316
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320
+
321
+ if self.text_projection is not None:
322
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323
+
324
+ def build_attention_mask(self):
325
+ # lazily create causal attention mask, with full attention between the vision tokens
326
+ # pytorch uses additive attention mask; fill with -inf
327
+ mask = torch.empty(self.context_length, self.context_length)
328
+ mask.fill_(float("-inf"))
329
+ mask.triu_(1) # zero out the lower diagonal
330
+ return mask
331
+
332
+ @property
333
+ def dtype(self):
334
+ return self.visual.conv1.weight.dtype
335
+
336
+ def encode_image(self, image):
337
+ return self.visual(image.type(self.dtype))
338
+
339
+ def encode_text(self, text):
340
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341
+
342
+ x = x + self.positional_embedding.type(self.dtype)
343
+ x = x.permute(1, 0, 2) # NLD -> LND
344
+ x = self.transformer(x)
345
+ x = x.permute(1, 0, 2) # LND -> NLD
346
+ x = self.ln_final(x).type(self.dtype)
347
+
348
+ # x.shape = [batch_size, n_ctx, transformer.width]
349
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
350
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351
+
352
+ return x
353
+
354
+ def forward(self, image, text):
355
+ image_features = self.encode_image(image)
356
+ text_features = self.encode_text(text)
357
+
358
+ # normalized features
359
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
360
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
361
+
362
+ # cosine similarity as logits
363
+ logit_scale = self.logit_scale.exp()
364
+ logits_per_image = logit_scale * image_features @ text_features.t()
365
+ logits_per_text = logits_per_image.t()
366
+
367
+ # shape = [global_batch_size, global_batch_size]
368
+ return logits_per_image, logits_per_text
369
+
370
+
371
+ def convert_weights(model: nn.Module):
372
+ """Convert applicable model parameters to fp16"""
373
+
374
+ def _convert_weights_to_fp16(l):
375
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376
+ l.weight.data = l.weight.data.half()
377
+ if l.bias is not None:
378
+ l.bias.data = l.bias.data.half()
379
+
380
+ if isinstance(l, nn.MultiheadAttention):
381
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382
+ tensor = getattr(l, attr)
383
+ if tensor is not None:
384
+ tensor.data = tensor.data.half()
385
+
386
+ for name in ["text_projection", "proj"]:
387
+ if hasattr(l, name):
388
+ attr = getattr(l, name)
389
+ if attr is not None:
390
+ attr.data = attr.data.half()
391
+
392
+ model.apply(_convert_weights_to_fp16)
393
+
394
+
395
+ def build_model(state_dict: dict):
396
+ vit = "visual.proj" in state_dict
397
+
398
+ if vit:
399
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
400
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403
+ image_resolution = vision_patch_size * grid_size
404
+ else:
405
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406
+ vision_layers = tuple(counts)
407
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409
+ vision_patch_size = None
410
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411
+ image_resolution = output_width * 32
412
+
413
+ embed_dim = state_dict["text_projection"].shape[1]
414
+ context_length = state_dict["positional_embedding"].shape[0]
415
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
416
+ transformer_width = state_dict["ln_final.weight"].shape[0]
417
+ transformer_heads = transformer_width // 64
418
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
419
+
420
+ model = CLIP(
421
+ embed_dim,
422
+ image_resolution, vision_layers, vision_width, vision_patch_size,
423
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424
+ )
425
+
426
+ for key in ["input_resolution", "context_length", "vocab_size"]:
427
+ if key in state_dict:
428
+ del state_dict[key]
429
+
430
+ convert_weights(model)
431
+ model.load_state_dict(state_dict)
432
+ return model.eval()
code/config.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json, math, os
2
+
3
+ d = {
4
+ 'debug': True,
5
+ 'dataset_path': 'data/path_to_your_dataset.json',
6
+ 'fptype': 'morgan',
7
+ 'valid_ratio': 0.1,
8
+ 'batch_size': 128,
9
+ 'lr': 1e-3,
10
+ 'weight_decay': 1e-3,
11
+ 'patience': 2,
12
+ 'factor': 0.5,
13
+ 'add_nl': True,
14
+ 'binary_intn': False,
15
+ 'max_mz': 2000,
16
+ 'min_mz': 20,
17
+ 'energy': 'Energy1',
18
+ 'epochs': 50,
19
+ 'bin_size': 0.05,
20
+ 'ms_embedding_dim': 300,
21
+ 'projection_dim': 256,
22
+ 'ms_projection_layers': 1,
23
+ 'mol_embedding_dim': 2048,
24
+ 'mol_projection_layers': 1,
25
+ 'tsfm_in_ms': True,
26
+ 'tsfm_in_mol': False,
27
+ 'tsfm_layers': 6,
28
+ 'tsfm_heads': 8,
29
+ 'lstm_layers': 2,
30
+ 'lstm_in_ms': False,
31
+ 'lstm_in_mol': False,
32
+ 'dropout': 0.1,
33
+ 'nmodels': 1,
34
+ 'mol_encoder': 'fp', # fp, gnn or gnn+fp
35
+ 'molgnn_n_filters_list': [256, 256, 256],
36
+ 'molgnn_nhead': 4,
37
+ 'molgnn_readout_layers': 2,
38
+ 'seed': 1234,
39
+ 'dev_name': 'cuda',
40
+ 'keep_best_models_num': 3
41
+ }
42
+
43
+ class ConfigDict(dict):
44
+ '''
45
+ Makes a dictionary behave like an object,with attribute-style access.
46
+ '''
47
+ def __getattr__(self, name):
48
+ try:
49
+ return self[name]
50
+ except:
51
+ raise AttributeError(name)
52
+
53
+ def __setattr__(self, name, value):
54
+ self[name] = value
55
+
56
+ def save(self, fn, onlyprint=False):
57
+ if onlyprint:
58
+ print(self)
59
+ else:
60
+ json.dump(self, open(fn, 'w'), indent=2)
61
+
62
+ def load_dict(self, dic):
63
+ for k, v in dic.items():
64
+ self[k] = v
65
+ self.calc_ms_embedding_dim()
66
+
67
+ def load(self, fn):
68
+ try:
69
+ if type(fn) is dict:
70
+ d = fn
71
+ elif type(fn) is str:
72
+ if os.path.exists(fn):
73
+ d = json.load(open(fn, 'r'))
74
+ else:
75
+ d = json.loads(fn)
76
+ self.load_dict(d)
77
+ except Exception as e:
78
+ print(e)
79
+
80
+ def calc_ms_embedding_dim(self):
81
+ if 'bin_size' in self:
82
+ self['ms_embedding_dim'] = math.ceil((self['max_mz'] - self['min_mz']) / self['bin_size'])
83
+ if 'ms_embedding_dim' in self and 'add_nl' in self and self['add_nl']:
84
+ self['ms_embedding_dim'] += math.ceil((200) / self['bin_size'])
85
+
86
+ @property
87
+ def device(self):
88
+ try:
89
+ return torch.device(self['dev_name'])
90
+ except:
91
+ return torch.device('cpu')
92
+
93
+
94
+ CFG = ConfigDict()
95
+ CFG.load_dict(d)
code/dataset.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ import torch
3
+ import utils
4
+
5
+ def calc_feats(smi, ms, nls, cfg):
6
+ item = {}
7
+ item['ms_bins'] = utils.ms_binner(ms, nls,
8
+ min_mz=cfg.min_mz,
9
+ max_mz=cfg.max_mz,
10
+ bin_size=cfg.bin_size,
11
+ add_nl=cfg.add_nl,
12
+ binary_intn=cfg.binary_intn)
13
+
14
+ fmcalced = False
15
+ if 'fp' in cfg.mol_encoder:
16
+ if not 'fm' in cfg.mol_encoder:
17
+ item['mol_fps'] = utils.mol_fp_encoder(smi,
18
+ tp=cfg.fptype,
19
+ nbits=cfg.mol_embedding_dim)
20
+ else:
21
+ item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
22
+ tp=cfg.fptype,
23
+ nbits=cfg.mol_embedding_dim)
24
+ fmcalced = True
25
+ if 'gnn' in cfg.mol_encoder:
26
+ f = utils.mol_graph_featurizer(smi)
27
+ if not f:
28
+ return None
29
+ item.update(f)
30
+ if 'fm' in cfg.mol_encoder and not fmcalced:
31
+ item['mol_fmvec'] = utils.smi2fmvec(smi)
32
+
33
+ return item
34
+
35
+ class Dataset(torch.utils.data.Dataset):
36
+ def __init__(self, inp, cfg):
37
+ if type(inp) is str:
38
+ self.data = json.load(open(inp))
39
+ else:
40
+ self.data = inp
41
+
42
+ self.cfg = cfg
43
+
44
+ def __getitem__(self, idx):
45
+ item = {}
46
+ try:
47
+ if 'nls' in self.data[idx]:
48
+ nls = self.data[idx]['nls']
49
+ else:
50
+ nls = []
51
+
52
+ ms = self.data[idx]['ms']
53
+ smi = self.data[idx]['smiles']
54
+
55
+ item = calc_feats(smi, ms, nls, self.cfg)
56
+
57
+ except Exception as e:
58
+ print('='*50, idx, str(e))
59
+ return None
60
+
61
+ return item
62
+
63
+ def __len__(self):
64
+ return len(self.data)
65
+
66
+ class DatasetGNNFP(torch.utils.data.Dataset):
67
+ def __init__(self, inp, cfg):
68
+ if type(inp) is str:
69
+ self.data = json.load(open(inp))
70
+ else:
71
+ self.data = inp
72
+
73
+ self.cfg = cfg
74
+
75
+ def __getitem__(self, idx):
76
+ try:
77
+ smi = self.data[idx]['smiles']
78
+ item = {}
79
+ item['mol_fps'] = utils.mol_fp_encoder(smi,
80
+ tp=self.cfg.fptype,
81
+ nbits=self.cfg.mol_embedding_dim)
82
+ item.update(utils.mol_graph_featurizer(smi))
83
+ except Exception as e:
84
+ print('='*50, idx, str(e))
85
+ return None
86
+
87
+ return item
88
+
89
+ def __len__(self):
90
+ return len(self.data)
91
+
92
+ class PathDataset(torch.utils.data.Dataset):
93
+ def __init__(self, pathlist, cfg):
94
+ self.fns = pathlist
95
+ self.cfg = cfg
96
+ self.data = {}
97
+
98
+ def __getitem__(self, idx):
99
+ try:
100
+ item = {}
101
+ nls = []
102
+ if not idx in self.data:
103
+ out = self.proc_data(self.fns[idx], self.cfg.energy)
104
+ if out is None:
105
+ return None
106
+ self.data[idx] = out
107
+
108
+ ms = self.data[idx]['ms']
109
+ smi = self.data[idx]['smiles']
110
+
111
+ item = calc_feats(smi, ms, nls, self.cfg)
112
+
113
+ except Exception as e:
114
+ print('='*50, idx, str(e))
115
+ return None
116
+
117
+ return item
118
+
119
+ def proc_data(self, fn, energy='Energy1'):
120
+ tl = open(fn).readlines()
121
+ l = []
122
+ try:
123
+ flag = False
124
+ for i in tl:
125
+ if energy in i:
126
+ smi = i.split(';')[-2]
127
+ flag = True
128
+ continue
129
+ if 'END IONS' in i:
130
+ if flag:
131
+ break
132
+ if flag:
133
+ mz, intn = i.split(' ')
134
+ l.append((float(mz), float(intn)))
135
+ except:
136
+ return None
137
+
138
+ out = {'ms': l, 'smiles': smi}
139
+ return out
140
+
141
+ def __len__(self):
142
+ return len(self.fns)
code/modules.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from config import CFG
5
+ import utils
6
+ import math
7
+ import numpy as np
8
+ from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
9
+ from GNN import layers as gly
10
+
11
+ loss_func_ms = nn.CrossEntropyLoss()
12
+ loss_func = nn.CrossEntropyLoss()
13
+
14
+ class MolGNNEncoder(nn.Module):
15
+ def __init__(self,
16
+ outdim,
17
+ n_feats=74, #330, # 74+256 morgan 256
18
+ n_filters_list=[256, 256, 256],
19
+ n_head=4,
20
+ mols=1,
21
+ adj_chans=6,
22
+ readout_layers=2,
23
+ bias=True):
24
+
25
+ super().__init__()
26
+
27
+ n_filters_list = [i for i in n_filters_list if i is not None]
28
+ lys = []
29
+
30
+ for i, nf in enumerate(n_filters_list):
31
+ if i == 0:
32
+ nf1 = n_feats
33
+ else:
34
+ nf1 = prevnf
35
+
36
+ prevnf = nf
37
+
38
+ ly = gly.GConvBlockNoGF(nf1, nf, mols, adj_chans, bias)
39
+ lys.append(ly)
40
+
41
+ self.block_layers = nn.ModuleList(lys)
42
+ self.attention_layer = gly.MultiHeadGlobalAttention(nf, n_head=n_head, concat=True, bias=bias)
43
+ self.readout_layers = nn.ModuleList([nn.Linear(nf*n_head, outdim, bias=bias)] + [nn.Linear(outdim, outdim) for _ in range(readout_layers-1)])
44
+ self.gelu = QuickGELU()
45
+
46
+ def forward(self, batch):
47
+ V = batch['V']
48
+ A = batch['A']
49
+ mol_size = batch['mol_size']
50
+
51
+ for ly in self.block_layers:
52
+ V = ly(V, A)
53
+
54
+ X = self.attention_layer(V, mol_size)
55
+
56
+ for ly in self.readout_layers:
57
+ X = self.gelu(ly(X))
58
+
59
+ return X
60
+
61
+ class ProjectionHead(nn.Module):
62
+ def __init__(self,
63
+ embedding_dim,
64
+ projection_dim,
65
+ cfg,
66
+ transformer=True,
67
+ lstm=False):
68
+
69
+ super().__init__()
70
+
71
+ self.projection = nn.Linear(embedding_dim, projection_dim)
72
+ self.gelu = nn.GELU() #QuickGELU()
73
+ self.transformer = None
74
+ if transformer:
75
+ self.transformer = MSTsfmEncoder(projection_dim, cfg.tsfm_layers, cfg.tsfm_heads)
76
+ self.lstm = None
77
+ if lstm:
78
+ self.lstm = nn.LSTM(input_size=projection_dim, hidden_size=projection_dim, num_layers=cfg.lstm_layers, batch_first=True)
79
+ self.dropout = nn.Dropout(cfg.dropout)
80
+
81
+ def forward(self, x):
82
+ projected = self.projection(x)
83
+ if self.transformer is None:
84
+ x = self.gelu(projected)
85
+ else:
86
+ x = self.transformer(projected)
87
+ if not self.lstm is None:
88
+ x, (_, _) = self.lstm(x)
89
+ x = self.dropout(x)
90
+
91
+ return x
92
+
93
+ # New name in paper is CMSSPModel
94
+ class FragSimiModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ cfg
98
+ ):
99
+ super().__init__()
100
+
101
+ self.cfg = cfg
102
+ self.mol_gnn_encoder = None
103
+ mol_embedding_dim = cfg.mol_embedding_dim
104
+
105
+ if 'gnn' in self.cfg.mol_encoder:
106
+ self.mol_gnn_encoder = MolGNNEncoder(outdim=cfg.mol_embedding_dim,
107
+ n_filters_list=cfg.molgnn_n_filters_list,
108
+ n_head=cfg.molgnn_nhead,
109
+ readout_layers=cfg.molgnn_readout_layers)
110
+ if 'fp' in self.cfg.mol_encoder:
111
+ mol_embedding_dim = 2*cfg.mol_embedding_dim
112
+
113
+ if 'fm' in self.cfg.mol_encoder:
114
+ mol_embedding_dim += 10
115
+
116
+ self.ms_projection = ProjectionHead(cfg.ms_embedding_dim,
117
+ cfg.projection_dim,
118
+ cfg,
119
+ cfg.tsfm_in_ms,
120
+ cfg.lstm_in_ms)
121
+
122
+ self.mol_projection = ProjectionHead(mol_embedding_dim,
123
+ cfg.projection_dim,
124
+ cfg,
125
+ cfg.tsfm_in_mol,
126
+ cfg.lstm_in_mol)
127
+
128
+ def forward(self, batch):
129
+ ms_features = batch["ms_bins"]
130
+ mol_feat_list = []
131
+ if 'gnn' in self.cfg.mol_encoder:
132
+ mol_feat_list.append(self.mol_gnn_encoder(batch))
133
+ if 'fp' in self.cfg.mol_encoder:
134
+ mol_feat_list.append(batch["mol_fps"])
135
+ if 'fm' in self.cfg.mol_encoder:
136
+ mol_feat_list.append(batch["mol_fmvec"])
137
+
138
+ if len(mol_feat_list) > 1:
139
+ mol_features = torch.cat(mol_feat_list, dim=1)
140
+ else:
141
+ mol_features = mol_feat_list[0]
142
+
143
+ # Getting ms and mol Embeddings (with same dimension)
144
+ ms_embeddings = self.ms_projection(ms_features)
145
+ mol_embeddings = self.mol_projection(mol_features)
146
+
147
+ # Calculating the Loss
148
+ #logits = (mol_embeddings @ ms_embeddings.t())
149
+ #logit_scale = self.logit_scale.exp()
150
+ logits = mol_embeddings @ ms_embeddings.t()
151
+
152
+ ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)
153
+
154
+ ms_loss = loss_func(logits, ground_truth)
155
+ mol_loss = loss_func(logits.t(), ground_truth)
156
+ loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size)
157
+
158
+ return loss.mean()
code/predict.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import *
2
+ import os, sys
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import torch
6
+ from torch import nn
7
+ from config import CFG
8
+ import utils
9
+ import json
10
+ import pandas as pd
11
+ import pickle
12
+
13
+ MolFeatsCached = {}
14
+
15
+ def calc_mol_embeddings0(model, smis, cfg):
16
+ model.eval()
17
+
18
+ valid_mol_embeddings = []
19
+ with torch.no_grad():
20
+ for smi in smis:
21
+ try:
22
+ mol_features = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
23
+ mol_embeddings = model.mol_projection(mol_features.unsqueeze(0))
24
+ valid_mol_embeddings.append(mol_embeddings.squeeze(0))
25
+ except Exception as e:
26
+ print(smi, e)
27
+ continue
28
+
29
+ return torch.stack(valid_mol_embeddings)
30
+
31
+ def calc_mol_embeddings1(model, smis, cfg):
32
+ model.eval()
33
+ mol_embeddings = []
34
+
35
+ with torch.no_grad():
36
+ for smi in smis:
37
+ try:
38
+ if cfg.mol_encoder == 'fp':
39
+ k = hash(smi + f'fp-{cfg.fptype}-{cfg.mol_embedding_dim}')
40
+ if k in MolFeatsCached:
41
+ feats = MolFeatsCached[k]
42
+ else:
43
+ feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
44
+ MolFeatsCached[k] = feats
45
+ me = model.mol_projection(feats.unsqueeze(0))
46
+ mol_embeddings.append(me.squeeze(0))
47
+ elif cfg.mol_encoder == 'gnn':
48
+ k = hash(smi + 'gnn')
49
+ if k in MolFeatsCached:
50
+ gfeats = MolFeatsCached[k]
51
+ else:
52
+ gfeats = utils.mol_graph_featurizer(smi)
53
+ MolFeatsCached[k] = gfeats
54
+
55
+ bat = {'A': gfeats['A'].unsqueeze(0).to(cfg.device),
56
+ 'V': gfeats['V'].unsqueeze(0).to(cfg.device),
57
+ 'mol_size': gfeats['mol_size'].unsqueeze(0).to(cfg.device)}
58
+
59
+ feats = model.mol_gnn_encoder(bat)
60
+ me = model.mol_projection(feats)
61
+ mol_embeddings.append(me.squeeze(0))
62
+ except Exception as e:
63
+ print(smi, e)
64
+ continue
65
+
66
+ return torch.stack(mol_embeddings)
67
+
68
+ def calc_mol_embeddings(model, smis, cfg):
69
+ model.eval()
70
+ fp_featsl = []
71
+ gnn_featsl = []
72
+ fm_featsl = []
73
+
74
+ for smi in smis:
75
+ try:
76
+ if 'gnn' in cfg.mol_encoder:
77
+ k = hash(smi + 'gnn')
78
+ if k in MolFeatsCached:
79
+ gnn_feats = MolFeatsCached[k]
80
+ if gnn_feats is None:
81
+ continue
82
+ else:
83
+ gnn_feats = utils.mol_graph_featurizer(smi)
84
+ MolFeatsCached[k] = gnn_feats
85
+ if gnn_feats is None:
86
+ continue
87
+ gnn_featsl.append(gnn_feats)
88
+ if 'fp' in cfg.mol_encoder:
89
+ k = hash(smi + f'fp-{cfg.fptype}-{cfg.mol_embedding_dim}')
90
+ if k in MolFeatsCached:
91
+ fp_feats = MolFeatsCached[k]
92
+ if fp_feats is None:
93
+ continue
94
+ else:
95
+ fp_feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
96
+ MolFeatsCached[k] = fp_feats
97
+ fp_featsl.append(fp_feats)
98
+ if 'fm' in cfg.mol_encoder:
99
+ k = hash(smi + f'fm-{cfg.fptype}-{cfg.mol_embedding_dim}')
100
+ if k in MolFeatsCached:
101
+ fm_feats = MolFeatsCached[k]
102
+ if fm_feats is None:
103
+ continue
104
+ else:
105
+ fm_feats = utils.smi2fmvec(smi).to(cfg.device)
106
+ MolFeatsCached[k] = fm_feats
107
+ fm_featsl.append(fm_feats)
108
+ except Exception as e:
109
+ print(smi, e)
110
+ MolFeatsCached[k] = None
111
+ continue
112
+
113
+ mol_feat_list = []
114
+ if 'gnn' in cfg.mol_encoder:
115
+ vl, al, msl = [], [], []
116
+ bat = {}
117
+ for b in gnn_featsl:
118
+ if 'V' in b:
119
+ vl.append(b['V'])
120
+ if 'A' in b:
121
+ al.append(b['A'])
122
+ if 'mol_size' in b:
123
+ msl.append(b['mol_size'])
124
+
125
+ vl1, al1 = [], []
126
+ if vl and al and msl:
127
+ max_n = max(map(lambda x:x.shape[0], vl))
128
+ for v in vl:
129
+ vl1.append(utils.pad_V(v, max_n))
130
+ for a in al:
131
+ al1.append(utils.pad_A(a, max_n))
132
+
133
+ bat['V'] = torch.stack(vl1).to(cfg.device)
134
+ bat['A'] = torch.stack(al1).to(cfg.device)
135
+ bat['mol_size'] = torch.cat(msl, dim=0).to(cfg.device)
136
+
137
+ mol_feat_list.append(model.mol_gnn_encoder(bat))
138
+
139
+ if 'fp' in cfg.mol_encoder:
140
+ mol_feat_list.append(torch.stack(fp_featsl).to(cfg.device))
141
+
142
+ if 'fm' in cfg.mol_encoder:
143
+ mol_feat_list.append(torch.stack(fm_featsl).to(cfg.device))
144
+
145
+ if len(mol_feat_list) > 1:
146
+ mol_features = torch.cat(mol_feat_list, dim=1).to(cfg.device)
147
+ else:
148
+ mol_features = mol_feat_list[0].to(cfg.device)
149
+
150
+ with torch.no_grad():
151
+ mol_embeddings = model.mol_projection(mol_features)
152
+
153
+ return mol_embeddings
154
+
155
+ def find_matches(model, ms, smis, cfg, n=10):
156
+ model.eval()
157
+ with torch.no_grad():
158
+ ms_features = utils.ms_binner(ms, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn).to(cfg.device)
159
+ ms_features = ms_features.unsqueeze(0)
160
+ ms_embeddings = model.ms_projection(ms_features).squeeze(0)
161
+
162
+ #print(43, ms_features.shape, ms_embeddings.shape)
163
+
164
+ mol_embeddings = calc_mol_embeddings(model, smis, cfg)
165
+
166
+ mol_embeddings_n = F.normalize(mol_embeddings, p=2, dim=-1)
167
+ ms_embeddings_n = F.normalize(ms_embeddings, p=2, dim=-1)
168
+ dot_similarity = mol_embeddings_n @ ms_embeddings_n.t()
169
+
170
+ if n == -1 or n > len(mol_embeddings):
171
+ n = len(mol_embeddings)
172
+
173
+ values, indices = torch.topk(dot_similarity.squeeze(0), n)
174
+
175
+ matchsmis = [smis[idx] for idx in indices]
176
+
177
+ return matchsmis, values.to('cpu').data.numpy()*100, indices.to('cpu').data.numpy()
178
+
179
+ def calc(models, datal, cfg, saveout=True):
180
+ dicall = {}
181
+ coridxd = {}
182
+
183
+ for idx, model in enumerate(models):
184
+ for nn, data in enumerate(datal):
185
+ print(f'Calculating {nn}-th MS...')
186
+ #smipool = [d[1] for d in data['candidates'][:50]]
187
+ smipool = [d[1] for d in data['candidates']]
188
+
189
+ try:
190
+ smis, scores, indices = find_matches(model, data['ms'], smipool, cfg, 50)
191
+ except Exception as e:
192
+ print(131, e)
193
+ continue
194
+
195
+ dic = {}
196
+ for n, smi in enumerate(smis):
197
+ if smi in dic:
198
+ dic[smi]['score'] += scores[n]
199
+ dic[smi]['iscor'] = data['candidates'][indices[n]][-1]
200
+ dic[smi]['idx'] = data['candidates'][indices[n]][0]
201
+ else:
202
+ dic[smi] = {'score': scores[n], 'iscor': data['candidates'][indices[n]][-1], 'idx': data['candidates'][indices[n]][0]}
203
+
204
+ ikey = data['ikey']
205
+ if ikey in dicall:
206
+ for k, v in dic.items():
207
+ if k in dicall[ikey]:
208
+ dicall[ikey][k]['score'] += v['score']
209
+ else:
210
+ dicall[ikey][k] = v
211
+ else:
212
+ dicall[ikey] = dic
213
+
214
+ for ikey, dic in dicall.items():
215
+ smis = [k for k in dic.keys()]
216
+ scorel = [d['score'] for d in dic.values()]
217
+ iscorl = [d['iscor'] for d in dic.values()]
218
+ indexl = [d['idx'] for d in dic.values()]
219
+
220
+ scoretsor = torch.tensor(scorel)
221
+ n = 100
222
+ if n > len(scorel):
223
+ n = len(scorel)
224
+
225
+ values, indices = torch.topk(scoretsor, n)
226
+
227
+ scorel = values
228
+ smis = [smis[i] for i in indices]
229
+ iscorl = [iscorl[i] for i in indices]
230
+ indexl = [indexl[i] for i in indices]
231
+
232
+ try:
233
+ i = iscorl.index(True)
234
+ k = 'Hit %.3d' %(i+1)
235
+ if k in coridxd:
236
+ coridxd[k] += 1
237
+ else:
238
+ coridxd[k] = 1
239
+ except:
240
+ pass
241
+
242
+ ks = sorted(list(coridxd.keys()))
243
+ dc = {}
244
+ sumtop3 = 0
245
+
246
+ for k in ks:
247
+ dc[k] = [coridxd[k]]
248
+ if k in ['Hit 001', 'Hit 002', 'Hit 003']:
249
+ sumtop3 += coridxd[k]
250
+
251
+ for i in range(100):
252
+ k = 'Hit %.3d' %(i+1)
253
+ if not k in dc:
254
+ dc[k] = [0]
255
+
256
+ '''if saveout:
257
+ df0 = pd.DataFrame(dc)
258
+ df0.to_csv('summary.csv', index=False)
259
+
260
+ df = pd.DataFrame({
261
+ 'MSFn': ikeysl,
262
+ 'Item': iteml,
263
+ 'Index': smisidl,
264
+ 'Smiles': smis,
265
+ 'Score': scoresl,
266
+ 'IsCorrect': iscorl})
267
+
268
+ df.to_csv('predicted.csv', index=False)'''
269
+
270
+ return sumtop3, dc, dicall
271
+
272
+ def test(modelfnl, datal, datafn=''):
273
+ maxtop3 = 0
274
+ maxoutt = ''
275
+
276
+ for fn in modelfnl:
277
+ d = torch.load(fn)
278
+ CFG.load(d['config'])
279
+ print(d['config'])
280
+ CFG.save('', True)
281
+
282
+ model = FragSimiModel(CFG).to(CFG.device)
283
+ model.load_state_dict(d['state_dict'])
284
+ model.to(CFG.device)
285
+
286
+ sumtop3, dc, dicall = calc([model], datal, CFG, saveout=False)
287
+
288
+ sumtop10 = 0
289
+ for k in ['Hit %.3d' %(i+1) for i in range(10)]:
290
+ if k in dc:
291
+ sumtop10 += dc[k][0]
292
+
293
+ sumtop50 = 0
294
+ for k in ['Hit %.3d' %(i+1) for i in range(50)]:
295
+ if k in dc:
296
+ sumtop50 += dc[k][0]
297
+
298
+ tops = {}
299
+ for i in range(100):
300
+ k = 'Hit %.3d' %(i+1)
301
+ key = k.replace('Hit', 'Top')
302
+ if not key in tops:
303
+ tops[key] = [0]
304
+ if k in dc:
305
+ for n in range(i+1):
306
+ kk = 'Hit %.3d' %(n+1)
307
+ if kk in dc:
308
+ tops[key][0] += dc[kk][0]
309
+
310
+ outt = f'Top1: {dc.setdefault("Hit 001", [0])[0]}, top3: {sumtop3}, top10: {sumtop10}, top50: {sumtop50} of {len(datal)}'
311
+
312
+ if sumtop3 > maxtop3:
313
+ maxtop3 = sumtop3
314
+ maxoutt = outt
315
+
316
+ dicall['testdata'] = datafn
317
+ dicall['testrlt'] = outt
318
+ pickle.dump(dicall, open(fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}-tstrlt.pkl'), 'wb'))
319
+
320
+ df = pd.DataFrame(tops)
321
+ df.to_csv(fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}-tstrlt.csv'), index=False)
322
+
323
+ return maxoutt, maxtop3
324
+
325
+ def main(datafn, fnl):
326
+ outl = []
327
+
328
+ datal = json.load(open(datafn))
329
+ logfn = f'predict_results.csv'
330
+
331
+ if not os.path.exists(logfn):
332
+ open(logfn, 'w').write('Index,Results,Model,Data\n')
333
+
334
+ n = 0
335
+ for n, fn in enumerate(fnl):
336
+ out, _ = test([fn], datal, datafn)
337
+ print(out, os.path.basename(fn))
338
+ outl.append(out)
339
+ open(logfn, 'a').write(f'{n},"{out}",{fn},{datafn}\n')
340
+
341
+ print(outl)
342
+
343
+ if __name__ == '__main__':
344
+ import time
345
+ t0 = time.time()
346
+ main(sys.argv[1], sys.argv[2:])
347
+ print(300, time.time()-t0)
code/separate_posneg.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+
4
+ if __name__ == '__main__':
5
+ import sys
6
+ fn = sys.argv[1]
7
+ d = json.load(open(fn))
8
+
9
+ lpos = []
10
+ lneg = []
11
+
12
+ for n, it in enumerate(d):
13
+ print(f'processing {n+1}th...')
14
+
15
+ try:
16
+ if it['Ion_Mode'].strip().lower() == 'negative':
17
+ lneg.append(it)
18
+ else:
19
+ lpos.append(it)
20
+ except:
21
+ if it['species'].strip().endswith('-'):
22
+ lneg.append(it)
23
+ else:
24
+ lpos.append(it)
25
+
26
+ print(f'Len lpos = {len(lpos)}, len lneg = {len(lneg)}, sum = {len(lpos)+len(lneg)}')
27
+
28
+ json.dump(lpos, open(fn.replace('.json', '-pos.json'), 'w'), indent=2)
29
+ json.dump(lneg, open(fn.replace('.json', '-neg.json'), 'w'), indent=2)
code/train.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from modules import *
3
+ import os, sys
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import random
7
+ import torch
8
+ from torch import nn
9
+ from config import CFG
10
+ from dataset import *
11
+ import torch.utils.data
12
+ import copy, json, pickle
13
+ import itertools as it
14
+
15
+ def make_next_record_dir(basedir, prefix=''):
16
+ path = '%s/%%s001/' %basedir
17
+ n = 2
18
+ while os.path.exists(path %prefix):
19
+ path = '%s/%%s%.3d/' %(basedir, n)
20
+ n += 1
21
+
22
+ pth = path %prefix
23
+ os.makedirs(pth)
24
+ return pth
25
+
26
+ def setup_seed(seed):
27
+ torch.manual_seed(seed)
28
+ torch.cuda.manual_seed(seed)
29
+ np.random.seed(seed)
30
+ random.seed(seed)
31
+ torch.backends.cudnn.deterministic = True
32
+
33
+ def my_collate(batch):
34
+ batch = list(filter(lambda x:(x is not None), batch))
35
+ msbinl, molfpl, molfml, vl, al, msl = [], [], [], [], [], []
36
+ bat = {}
37
+
38
+ for b in batch:
39
+ if 'ms_bins' in b:
40
+ msbinl.append(b['ms_bins'])
41
+ if 'mol_fps' in b:
42
+ molfpl.append(b['mol_fps'])
43
+ if 'mol_fmvec' in b:
44
+ molfml.append(b['mol_fmvec'])
45
+ if 'V' in b:
46
+ vl.append(b['V'])
47
+ if 'A' in b:
48
+ al.append(b['A'])
49
+ if 'mol_size' in b:
50
+ msl.append(b['mol_size'])
51
+
52
+ if msbinl:
53
+ bat['ms_bins'] = torch.stack(msbinl)
54
+ if molfpl:
55
+ bat['mol_fps'] = torch.stack(molfpl)
56
+ if molfml:
57
+ bat['mol_fmvec'] = torch.stack(molfml)
58
+ if vl and al and msl:
59
+ max_n = max(map(lambda x:x.shape[0], vl))
60
+ vl1, al1 = [], []
61
+ for v in vl:
62
+ vl1.append(pad_V(v, max_n))
63
+ for a in al:
64
+ al1.append(pad_A(a, max_n))
65
+
66
+ bat['V'] = torch.stack(vl1)
67
+ bat['A'] = torch.stack(al1)
68
+ bat['mol_size'] = torch.cat(msl, dim=0)
69
+
70
+ #return torch.utils.data.dataloader.default_collate(batch)
71
+ return bat
72
+
73
+ def make_train_valid(data, valid_ratio, seed=1234):
74
+ idxs = np.arange(len(data))
75
+ np.random.seed(seed)
76
+ np.random.shuffle(idxs)
77
+
78
+ lenval = int(valid_ratio*len(data))
79
+
80
+ valid_set = [ data[i] for i in idxs[:lenval] ]
81
+ train_set = [ data[i] for i in idxs[lenval:] ]
82
+
83
+ return train_set, valid_set
84
+
85
+ def build_loaders(inp, mode, cfg, num_workers):
86
+ if type(inp[0]) is dict:
87
+ dataset = Dataset(inp, cfg)
88
+ else:
89
+ dataset = PathDataset(inp, cfg)
90
+ dataloader = torch.utils.data.DataLoader(
91
+ dataset,
92
+ batch_size=cfg.batch_size,
93
+ num_workers=num_workers,
94
+ shuffle=True if mode == "train" else False,
95
+ collate_fn=my_collate
96
+ )
97
+ return dataloader
98
+
99
+ def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
100
+ loss_meter = AvgMeter()
101
+ tqdm_object = tqdm(train_loader, total=len(train_loader))
102
+
103
+ for batch in tqdm_object:
104
+ for k, v in batch.items():
105
+ batch[k] = v.to(CFG.device)
106
+
107
+ loss = model(batch)
108
+ optimizer.zero_grad()
109
+ loss.backward()
110
+ optimizer.step()
111
+ if step == "batch":
112
+ lr_scheduler.step()
113
+
114
+ count = batch["ms_bins"].size(0)
115
+ loss_meter.update(loss.item(), count)
116
+
117
+ tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
118
+ return loss_meter
119
+
120
+ def valid_epoch(model, valid_loader):
121
+ loss_meter = AvgMeter()
122
+
123
+ tqdm_object = tqdm(valid_loader, total=len(valid_loader))
124
+ for batch in tqdm_object:
125
+ for k, v in batch.items():
126
+ batch[k] = v.to(CFG.device)
127
+
128
+ loss = model(batch)
129
+
130
+ count = batch["ms_bins"].size(0)
131
+ loss_meter.update(loss.item(), count)
132
+
133
+ tqdm_object.set_postfix(valid_loss=loss_meter.avg)
134
+
135
+ return loss_meter
136
+
137
+ def main(data, cfg=CFG, savedir='data/train', encmodel=None, ratio=1):
138
+ setup_seed(cfg.seed)
139
+
140
+ train_set, valid_set = make_train_valid(data, valid_ratio=cfg.valid_ratio, seed=cfg.seed)
141
+
142
+ n = len(train_set)
143
+ if ratio < 1:
144
+ train_set = random.sample(train_set, int(n*ratio))
145
+ print(f'Ratio {ratio}, lenall {n}, newtrainset {len(train_set)}')
146
+
147
+ train_loader = build_loaders(train_set, "train", cfg, 10)
148
+ valid_loader = build_loaders(valid_set, "valid", cfg, 10)
149
+
150
+ step = "epoch"
151
+
152
+ best_loss = float('inf')
153
+ best_model_fn = ''
154
+ best_model_fns = []
155
+
156
+ model = FragSimiModel(cfg).to(cfg.device)
157
+
158
+ if not encmodel is None:
159
+ model.mol_gnn_encoder.load_state_dict(encmodel.mol_gnn_encoder.state_dict())
160
+ # fraze mol_gnn_encoder weights
161
+ '''for name, param in model.named_parameters():
162
+ if 'mol_gnn_encoder' in name:
163
+ print(152, 'fraze mol_gnn_encoder weights')
164
+ param.requires_grad = False'''
165
+
166
+ print(model)
167
+
168
+ optimizer = torch.optim.AdamW(
169
+ model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
170
+ )
171
+
172
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
173
+ optimizer, mode="min", patience=cfg.patience, factor=cfg.factor
174
+ )
175
+
176
+ for epoch in range(cfg.epochs):
177
+ print(f"Epoch: {epoch + 1}/{cfg.epochs}")
178
+ model.train()
179
+ train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
180
+ model.eval()
181
+ with torch.no_grad():
182
+ valid_loss = valid_epoch(model, valid_loader)
183
+
184
+ if valid_loss.avg < best_loss:
185
+ best_loss = valid_loss.avg
186
+ best_model_fn = f"{savedir}/model-tloss{round(train_loss.avg, 3)}-vloss{round(valid_loss.avg, 3)}-epoch{epoch}.pth"
187
+ best_model_fn_base = best_model_fn.replace('.pth', '')
188
+ n = 1
189
+ while os.path.exists(best_model_fn):
190
+ best_model_fn = best_model_fn_base + f'-{n}.pth'
191
+ n += 1
192
+
193
+ checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': dict(CFG)}
194
+ best_model_fns.append(best_model_fn)
195
+ torch.save(checkpoint, best_model_fn)
196
+ print("Saved Best Model!")
197
+
198
+ best_model_fnl = []
199
+ for fn in best_model_fns:
200
+ if os.path.exists(fn):
201
+ best_model_fnl.append(fn)
202
+
203
+ for fn in best_model_fnl[:-cfg.keep_best_models_num]:
204
+ os.remove(fn)
205
+
206
+ best_model_fnl = best_model_fnl[-cfg.keep_best_models_num:]
207
+
208
+ print(best_model_fnl, best_loss)
209
+ return best_model_fnl, best_loss
210
+
211
+ if __name__ == "__main__":
212
+ try:
213
+ conffn = sys.argv[1]
214
+ if conffn.endswith('.json'):
215
+ CFG.load(sys.argv[1])
216
+ elif conffn.endswith('.pth'):
217
+ dpath = CFG.dataset_path
218
+ d = torch.load(conffn)
219
+ CFG.load(d['config'])
220
+ CFG.dataset_path = dpath
221
+ print('Use config from', conffn)
222
+ except:
223
+ pass
224
+
225
+ try:
226
+ savedir = sys.argv[2]
227
+ except:
228
+ savedir = 'data/'
229
+
230
+ os.system('mkdir -p %s' %savedir)
231
+
232
+ mg = None
233
+
234
+ print(CFG)
235
+
236
+ if os.path.isdir(CFG.dataset_path):
237
+ data = [os.path.join(CFG.dataset_path, i) for i in os.listdir(CFG.dataset_path) if i.endswith('mgf')]
238
+ elif os.path.isfile(CFG.dataset_path):
239
+ if CFG.dataset_path.endswith('.pkl'):
240
+ data = pickle.load(open(CFG.dataset_path, 'rb'))
241
+ else:
242
+ data = json.load(open(CFG.dataset_path))
243
+ pklfn = CFG.dataset_path.replace('.json', '.pkl')
244
+ if not os.path.exists(pklfn):
245
+ pickle.dump(data, open(pklfn, 'wb'))
246
+
247
+ subdir = make_next_record_dir(savedir, f'train-')
248
+ os.system(f'cp -a *py {subdir}; cp -a GNN {subdir}')
249
+ CFG.save(f'{subdir}/config.json')
250
+
251
+ modelfnl, _ = main(data, CFG, subdir, mg)
code/utils.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import AllChem, MACCSkeys
3
+ from rdkit.Chem.rdmolops import FastFindRings
4
+ from rdkit.Chem.rdMolDescriptors import CalcMolFormula
5
+ import torch
6
+ import numpy as np
7
+ import scipy
8
+ import scipy.sparse as ss
9
+ import scipy.sparse.linalg
10
+ import math
11
+ import json
12
+ import itertools as it
13
+ import re
14
+ from GNN import featurizer as ft
15
+
16
+ import rdkit.RDLogger as rkl
17
+ logger = rkl.logger()
18
+ logger.setLevel(rkl.ERROR)
19
+
20
+ import rdkit.rdBase as rkrb
21
+ rkrb.DisableLog('rdApp.error')
22
+
23
+ # 50w metabolites fpbit relative aboundance > 5%
24
+ FPBitIdx = [1, 5, 13, 41, 69, 80, 84, 94, 114, 117, 118, 119, 125, 133, 145,
25
+ 147, 191, 192, 197, 202, 222, 227, 231, 249, 283, 294, 310, 314,
26
+ 322, 333, 352, 361, 378, 387, 389, 392, 401, 406, 441, 478, 486,
27
+ 489, 519, 521, 524, 555, 561, 591, 598, 599, 610, 622, 650, 656,
28
+ 667, 675, 677, 679, 680, 694, 695, 715, 718, 722, 729, 736, 739,
29
+ 745, 750, 760, 775, 781, 787, 794, 798, 802, 807, 811, 823, 835,
30
+ 841, 849, 869, 872, 874, 875, 881, 890, 896, 926, 935, 980, 991,
31
+ 1004, 1009, 1017, 1019, 1027, 1028, 1035, 1037, 1039, 1057, 1060,
32
+ 1066, 1070, 1077, 1088, 1097, 1114, 1126, 1136, 1142, 1143, 1145,
33
+ 1152, 1154, 1160, 1162, 1171, 1181, 1195, 1199, 1202, 1218, 1234,
34
+ 1236, 1243, 1257, 1267, 1274, 1279, 1283, 1292, 1294, 1309, 1313,
35
+ 1323, 1325, 1349, 1356, 1357, 1366, 1380, 1381, 1385, 1386, 1391,
36
+ 1399, 1436, 1440, 1441, 1444, 1452, 1454, 1457, 1475, 1476, 1477,
37
+ 1480, 1487, 1516, 1536, 1544, 1558, 1564, 1573, 1599, 1602, 1604,
38
+ 1607, 1619, 1648, 1670, 1683, 1693, 1716, 1722, 1737, 1738, 1745,
39
+ 1747, 1750, 1754, 1755, 1764, 1781, 1803, 1808, 1810, 1816, 1838,
40
+ 1844, 1847, 1855, 1860, 1866, 1873, 1905, 1911, 1917, 1921, 1923,
41
+ 1928, 1933, 1950, 1951, 1970, 1977, 1980, 1984, 1991, 2002, 2033, 2034, 2038]
42
+
43
+ class ConfigDict(dict):
44
+ '''
45
+ Makes a dictionary behave like an object,with attribute-style access.
46
+ '''
47
+ def __getattr__(self, name):
48
+ try:
49
+ return self[name]
50
+ except:
51
+ raise AttributeError(name)
52
+
53
+ def __setattr__(self, name, value):
54
+ self[name] = value
55
+
56
+ def save(self, fn):
57
+ json.dump(self, open(fn, 'w'), indent=2)
58
+
59
+ def load_dict(self, dic):
60
+ for k, v in dic.items():
61
+ self[k] = v
62
+
63
+ def load(self, fn):
64
+ try:
65
+ d = json.load(open(fn, 'r'))
66
+ self.load_dict(d)
67
+ except Exception as e:
68
+ print(e)
69
+
70
+ def conv_out_dim(length_in, kernel, stride, padding, dilation):
71
+ length_out = (length_in + 2 * padding - dilation * (kernel - 1) - 1)// stride + 1
72
+ return length_out
73
+
74
+ def filter_ms(ms, thr=0.05, max_mz=2000):
75
+ mz = []
76
+ intn = []
77
+ maxi = 0
78
+ for m, i in ms:
79
+ if m < max_mz and i > maxi:
80
+ maxi = i
81
+
82
+ for m, i in ms:
83
+ if m < max_mz and i/maxi > thr:
84
+ mz.append(m)
85
+ intn.append(round(i/maxi*100, 2))
86
+
87
+ return mz, intn
88
+
89
+ def calc_nls(ms, thr=0.05, max_mz=2000):
90
+ mz, intn = filter_ms(ms, thr=0.05, max_mz=2000)
91
+
92
+ nlmass = []
93
+ nlintn = []
94
+ for a, b in it.combinations(mz[::-1], 2):
95
+ nl = a - b
96
+ if 0 < nl < 200:
97
+ nlmass.append(round(nl, 5))
98
+ idxa = mz.index(a)
99
+ idxb = mz.index(b)
100
+ nlintn.append(round((intn[idxa]+intn[idxb])/2., 5))
101
+
102
+ nls = sorted(list(zip(nlmass, nlintn)))
103
+ return nls
104
+
105
+ def ms_binner(ms, nls=[], min_mz=20, max_mz=2000, bin_size=0.05, add_nl=False, binary_intn=False):
106
+ """
107
+ Convert the given spectrum to a binned sparse SciPy vector.
108
+
109
+ Parameters
110
+ ----------
111
+ spectrum_mz : np.ndarray
112
+ The peak m/z values of the spectrum to be converted to a vector.
113
+ spectrum_intensity : np.ndarray
114
+ The peak intensities of the spectrum to be converted to a vector.
115
+ min_mz : float
116
+ The minimum m/z to include in the vector.
117
+ bin_size : float
118
+ The bin size in m/z used to divide the m/z range.
119
+ num_bins : int
120
+ The number of elements of which the vector consists.
121
+
122
+ Returns
123
+ -------
124
+ ss.csr_matrix
125
+ The binned spectrum vector.
126
+ """
127
+ if add_nl and not nls:
128
+ nls = calc_nls(ms, max_mz=max_mz)
129
+
130
+ nltensor = None
131
+ mz, intn = filter_ms(ms)
132
+
133
+ if add_nl:
134
+ nlmass = []
135
+ nlintn = []
136
+
137
+ if not nls:
138
+ nls = calc_nls(ms, max_mz=max_mz)
139
+
140
+ for m, i in nls:
141
+ if m < 200:
142
+ if binary_intn:
143
+ i = 1
144
+ nlmass.append(m)
145
+ nlintn.append(i)
146
+
147
+ nlmass = np.array(nlmass)
148
+ nlintn = np.array(nlintn)
149
+ if len(nlintn) > 0:
150
+ nlintn = nlintn/nlintn.max()
151
+ num_nlbins = math.ceil((200) / bin_size)
152
+ #print('num_nlbins', num_nlbins)
153
+ nlbins = (nlmass / bin_size).astype(np.int32)
154
+
155
+ if len(nlmass) > 0:
156
+ vecnl = ss.csr_matrix(
157
+ (nlintn,
158
+ (np.repeat(0, len(nlintn)), nlbins)),
159
+ shape=(1, num_nlbins),
160
+ dtype=np.float32)
161
+
162
+ vecnl = (vecnl / scipy.sparse.linalg.norm(vecnl)*100)
163
+ nltensor = torch.FloatTensor(vecnl.todense()).view(-1)
164
+ else:
165
+ nltensor = torch.zeros(num_nlbins)
166
+
167
+ mz = np.array(mz)
168
+ keepidx = (mz <= max_mz)
169
+ mz = mz[keepidx]
170
+ intn = np.array(intn)
171
+ intn = intn[keepidx]
172
+
173
+ if binary_intn:
174
+ intn[intn > 0] = 1.0
175
+ elif len(intn) > 0:
176
+ intn = intn/intn.max()
177
+
178
+ num_bins = math.ceil((max_mz - min_mz) / bin_size)
179
+ #print('num_bins', num_bins)
180
+ bins = ((mz - min_mz) / bin_size).astype(np.int32)
181
+
182
+ #print(num_bins, intn, bins)
183
+
184
+ if len(mz) > 0:
185
+ vec = ss.csr_matrix(
186
+ (intn,
187
+ (np.repeat(0, len(intn)), bins)),
188
+ shape=(1, num_bins),
189
+ dtype=np.float32)
190
+
191
+ if not binary_intn:
192
+ vec = (vec / scipy.sparse.linalg.norm(vec)*100)
193
+
194
+ mstensor = torch.FloatTensor(vec.todense()).view(-1)
195
+ else:
196
+ mstensor = torch.zeros(num_bins)
197
+
198
+ if not nltensor is None:
199
+ return torch.cat([nltensor, mstensor], dim=0)
200
+
201
+ return mstensor
202
+
203
+ def formula2vec(formula, elements=['C', 'H', 'O', 'N', 'P', 'S', 'P', 'F', 'Cl', 'Br']):
204
+ formula_p = re.findall(r'([A-Z][a-z]*)(\d*)', formula)
205
+ vec = np.zeros(len(elements))
206
+ for i in range(len(formula_p)):
207
+ ele = formula_p[i][0]
208
+ num = formula_p[i][1]
209
+ if num == '':
210
+ num = 1
211
+ else:
212
+ num = int(num)
213
+ if ele in elements:
214
+ vec[elements.index(ele)] += num
215
+ return np.array(vec)
216
+
217
+ def mol_fp_encoder0(smiles, tp='rdkit', nbits=2048):
218
+ mol = Chem.MolFromSmiles(smiles)
219
+ if mol is None:
220
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
221
+ if not mol is None:
222
+ mol.UpdatePropertyCache()
223
+ FastFindRings(mol)
224
+
225
+ if mol is None:
226
+ return None, None
227
+
228
+ if tp == 'morgan':
229
+ fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
230
+ fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
231
+ fp = fp.tolist()
232
+ elif tp == 'morgan1':
233
+ fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
234
+ fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
235
+ fp = fp[FPBitIdx].tolist()
236
+ elif tp == 'macc':
237
+ # MACCSkeys
238
+ fp_vec = MACCSkeys.GenMACCSKeys(mol)
239
+ fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
240
+ fp = fp.tolist()
241
+ elif tp == 'rdkit':
242
+ fp_vec = Chem.RDKFingerprint(mol, nBitsPerHash=1)
243
+ fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
244
+ fp = fp.tolist()
245
+
246
+ return torch.FloatTensor(fp), mol
247
+
248
+ def mol_fp_encoder(smiles, tp='rdkit', nbits=2048):
249
+ fpenc, _ = mol_fp_encoder0(smiles, tp, nbits)
250
+ return fpenc
251
+
252
+ def mol_fp_fm_encoder(smiles, tp='rdkit', nbits=2048):
253
+ fmenc = None
254
+ fpenc, mol = mol_fp_encoder0(smiles, tp, nbits)
255
+ if not mol is None:
256
+ fm = CalcMolFormula(mol)
257
+ fmenc = torch.FloatTensor(formula2vec(fm))
258
+ return fpenc, fmenc
259
+
260
+ def smi2fmvec(smiles):
261
+ mol = Chem.MolFromSmiles(smiles)
262
+ if mol is None:
263
+ return None
264
+ fm = CalcMolFormula(mol)
265
+ fmenc = torch.FloatTensor(formula2vec(fm))
266
+
267
+ return fmenc
268
+
269
+ def mol_graph_featurizer(smiles):
270
+ # mol_graph = {V, A, mol_size}
271
+ '''mol_graph = ft.calc_data_from_smile(smiles,
272
+ addh=True,
273
+ with_ring_conj=True,
274
+ with_atom_feats=True,
275
+ with_submol_fp=True,
276
+ radius=2)
277
+ '''
278
+ mol_graph = ft.calc_data_from_smile(smiles,
279
+ addh=False,
280
+ with_ring_conj=True,
281
+ with_atom_feats=True,
282
+ with_submol_fp=False,
283
+ radius=2)
284
+ return mol_graph
285
+
286
+ def pad_V(V, max_n):
287
+ N, C = V.shape
288
+ if max_n > N:
289
+ zeros = torch.zeros(max_n-N, C)
290
+ V = torch.cat([V, zeros], dim=0)
291
+ return V
292
+
293
+ def pad_A(A, max_n):
294
+ N, L, _ = A.shape
295
+ if max_n > N:
296
+ zeros = torch.zeros(N, L, max_n-N)
297
+ A = torch.cat([A, zeros], dim=-1)
298
+ zeros = torch.zeros(max_n-N, L, max_n)
299
+ A = torch.cat([A, zeros], dim=0)
300
+ return A
301
+
302
+ class AvgMeter:
303
+ def __init__(self, name="Metric"):
304
+ self.name = name
305
+ self.reset()
306
+
307
+ def reset(self):
308
+ self.avg, self.sum, self.count = [0] * 3
309
+
310
+ def update(self, val, count=1):
311
+ self.count += count
312
+ self.sum += val * count
313
+ self.avg = self.sum / self.count
314
+
315
+ def __repr__(self):
316
+ text = f"{self.name}: {self.avg:.4f}"
317
+ return text
318
+
319
+ def get_lr(optimizer):
320
+ for param_group in optimizer.param_groups:
321
+ return param_group["lr"]
322
+
323
+ def segment_max(x, size_list):
324
+ size_list = [int(i) for i in size_list]
325
+ return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
326
+
327
+ def segment_sum(x, size_list):
328
+ size_list = [int(i) for i in size_list]
329
+ return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
330
+
331
+ def segment_softmax(gate, size_list):
332
+ segmax = segment_max(gate, size_list)
333
+ # expand segmax shape to alpha shape
334
+ segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
335
+ subtract = gate - segmax_expand
336
+ exp = torch.exp(subtract)
337
+ segsum = segment_sum(exp, size_list)
338
+ # expand segmax shape to alpha shape
339
+ segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
340
+ attention = exp / (segsum_expand + 1e-16)
341
+
342
+ return attention
343
+
344
+ def pad_ms_list(ms_list, thr=0.05, min_mz=20, max_mz=2000):
345
+ thr = thr*100
346
+ mslst = []
347
+ for ms in ms_list:
348
+ ms = np.array(ms)
349
+ ms[:,1] = ms[:,1]/ms[:,1].max()*100
350
+
351
+ if thr > 0:
352
+ ms = ms[(ms[:,1] >= thr)]
353
+
354
+ ms = ms[(ms[:,0] >= min_mz)]
355
+ ms = ms[(ms[:,0] <= max_mz)]
356
+
357
+ mslst.append(ms)
358
+
359
+ size_list = [ms.shape[0] for ms in mslst]
360
+ maxlen = max(size_list)
361
+
362
+ l = []
363
+ for ms in mslst:
364
+ extn = maxlen-len(ms)
365
+ if extn > 0:
366
+ l.append(np.concatenate([ms, [[0,0]]*extn], axis=0))
367
+ else:
368
+ l.append(ms)
369
+
370
+ return torch.FloatTensor(np.stack(l)), torch.IntTensor(size_list)