Tingxie commited on
Commit
2d211da
·
1 Parent(s): bf8e921

Upload 2 files

Browse files
Files changed (2) hide show
  1. infer.py +105 -0
  2. model.py +283 -0
infer.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Sep 15 16:22:05 2022
4
+
5
+ @author: ZNDX002
6
+ """
7
+ from model import ModelCLR
8
+ import yaml
9
+ import os
10
+ import torch
11
+ import numpy as np
12
+ import re
13
+ from torch_geometric.data import Data, Batch
14
+ from dataloader.dataset_wrapper import MolToGraph
15
+ from rdkit import Chem
16
+
17
+ class ModelInference(object):
18
+ def __init__(self, config_path, pretrain_model_path, device):
19
+ assert (config_path is not None, "config_path is None")
20
+ assert (pretrain_model_path is not None, "pretrain_model_path is None")
21
+
22
+ if device is None:
23
+ self.device = torch.device(
24
+ "cuda" if torch.cuda.is_available() else "cpu")
25
+ else:
26
+ self.device = torch.device(device)
27
+
28
+ self.config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
29
+ self.model = ModelCLR(**self.config["model_config"]).to(self.device)
30
+ state_dict = torch.load(pretrain_model_path)
31
+ self.model.load_state_dict(state_dict)
32
+ self.model.eval()
33
+
34
+
35
+ def smiles_encode(self, smiles_str):
36
+ with torch.no_grad():
37
+ if isinstance(smiles_str, str):
38
+ #single smiles
39
+ v_d = MolToGraph(smiles_str)
40
+ v_d = v_d.to(self.device)
41
+ smiles_tensor = self.model.smiles_encoder(v_d)
42
+ smiles_tensor=self.model.smi_esa(smiles_tensor,v_d.batch)
43
+ smiles_tensor = self.model.smi_proj(smiles_tensor)
44
+ smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
45
+ return smiles_tensor
46
+ else:
47
+ #smiles list
48
+ graphs=[]
49
+ for smi in smiles_str:
50
+ v_d = MolToGraph(smi)
51
+ graphs.append(v_d)
52
+ v_ds = Batch.from_data_list(graphs)
53
+ v_ds = v_ds.to(self.device)
54
+ smiles_tensor = self.model.smiles_encoder(v_ds)
55
+ smiles_tensor=self.model.smi_esa(smiles_tensor,v_ds.batch)
56
+ smiles_tensor = self.model.smi_proj(smiles_tensor)
57
+ smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
58
+ return smiles_tensor
59
+
60
+ def ms2_encode(self, ms2_list):
61
+ with torch.no_grad():
62
+ if not isinstance(ms2_list, list):
63
+ #single ms2
64
+ spec_mz = ms2_list.mz
65
+ spec_intens = ms2_list.intensities
66
+ num_peak = len(spec_mz)
67
+ spec_mz = np.around(spec_mz, decimals=4)
68
+ spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
69
+ spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
70
+ spec_mz= torch.tensor(spec_mz).float().unsqueeze(0)
71
+ spec_intens= torch.tensor(spec_intens).float().unsqueeze(0)
72
+ num_peak = torch.LongTensor(num_peak).unsqueeze(0)
73
+ spec_tensor,spec_mask = self.model.ms_encoder(spec_mz,spec_intens,num_peak)
74
+ spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
75
+ spec_tensor = self.model.spec_proj(spec_tensor)
76
+ spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
77
+ return spec_tensor
78
+ else:
79
+ # batch ms2
80
+ spec_mzs = [spec.mz for spec in ms2_list]
81
+ spec_intens = [spec.intensities for spec in ms2_list]
82
+ num_peaks = [len(i) for i in spec_mzs]
83
+ spec_mzs = [np.around(spec_mz, decimals=4) for spec_mz in spec_mzs]
84
+ num_peaks = torch.LongTensor(num_peaks)
85
+ mzs = [torch.from_numpy(spec_mz).float() for spec_mz in spec_mzs]
86
+ intens = [torch.from_numpy(spec_intens).float() for spec_intens in spec_intens]
87
+ mzs_tensors = torch.nn.utils.rnn.pad_sequence(
88
+ mzs, batch_first=True, padding_value=0
89
+ )
90
+ intens_tensors = torch.nn.utils.rnn.pad_sequence(
91
+ intens, batch_first=True, padding_value=0
92
+ )
93
+ mzs_tensors=mzs_tensors.to(self.device)
94
+ intens_tensors=intens_tensors.to(self.device)
95
+ num_peaks=num_peaks.to(self.device)
96
+
97
+ spec_tensor,spec_mask = self.model.ms_encoder(mzs_tensors,intens_tensors,num_peaks)
98
+ spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
99
+ spec_tensor = self.model.spec_proj(spec_tensor)
100
+ spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
101
+ return spec_tensor
102
+
103
+ def get_cos_distance(self, input_1, input_2):
104
+ with torch.no_grad():
105
+ return input_1 @ input_2.t()
model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import math
5
+ import numpy as np
6
+ from torch_geometric.nn import MessagePassing
7
+ from torch_geometric.utils import add_self_loops
8
+ from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
9
+ import nn_utils as nn_utils
10
+ num_atom_type = 119 # including the extra mask tokens
11
+ num_chirality_tag = 4
12
+ num_hybrid_type = 8
13
+ num_valence_tag = 6
14
+ num_degree_tag = 5
15
+
16
+ num_bond_type = 5 # including aromatic and self-loop edge
17
+ num_bond_direction = 3
18
+ num_bond_configuration = 6
19
+ class GINEConv(MessagePassing):
20
+ def __init__(self, emb_dim):
21
+ super(GINEConv, self).__init__()
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(emb_dim, 2*emb_dim),
24
+ nn.ReLU(),
25
+ nn.Linear(2*emb_dim, emb_dim)
26
+ )
27
+ self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
28
+ self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
29
+ #self.edge_embedding3 = nn.Embedding(num_bond_configuration, emb_dim)
30
+ nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
31
+ nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
32
+ #nn.init.xavier_uniform_(self.edge_embedding3.weight.data)
33
+
34
+ def forward(self, x, edge_index, edge_attr):
35
+ # add self loops in the edge space
36
+ edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
37
+
38
+ # add features corresponding to self-loop edges.
39
+ self_loop_attr = torch.zeros(x.size(0), 2)
40
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
41
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
42
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
43
+
44
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
45
+
46
+ return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
47
+
48
+ def message(self, x_j, edge_attr):
49
+ return x_j + edge_attr
50
+
51
+ def update(self, aggr_out):
52
+ return self.mlp(aggr_out)
53
+
54
+
55
+ class SmilesModel(nn.Module):
56
+ """
57
+ Args:
58
+ num_layer (int): the number of GNN layers
59
+ emb_dim (int): dimensionality of embeddings
60
+ max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
61
+ drop_ratio (float): dropout rate
62
+ gnn_type: gin, gcn, graphsage, gat
63
+ Output:
64
+ node representations
65
+ """
66
+ def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'):
67
+ super(SmilesModel, self).__init__()
68
+ self.num_layer = num_layer
69
+ self.emb_dim = emb_dim
70
+ self.feat_dim = feat_dim
71
+ self.drop_ratio = drop_ratio
72
+
73
+ self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
74
+ self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
75
+ self.x_embedding3 = nn.Embedding(num_hybrid_type, emb_dim)
76
+ self.x_embedding4 = nn.Embedding(num_valence_tag, emb_dim)
77
+ self.x_embedding5 = nn.Embedding(num_degree_tag, emb_dim)
78
+
79
+ nn.init.xavier_uniform_(self.x_embedding1.weight.data)
80
+ nn.init.xavier_uniform_(self.x_embedding2.weight.data)
81
+ nn.init.xavier_uniform_(self.x_embedding3.weight.data)
82
+ nn.init.xavier_uniform_(self.x_embedding4.weight.data)
83
+ nn.init.xavier_uniform_(self.x_embedding5.weight.data)
84
+
85
+ # List of MLPs
86
+ self.gnns = nn.ModuleList()
87
+ for layer in range(num_layer):
88
+ self.gnns.append(GINEConv(emb_dim))
89
+
90
+ # List of batchnorms
91
+ self.batch_norms = nn.ModuleList()
92
+ for layer in range(num_layer):
93
+ self.batch_norms.append(nn.BatchNorm1d(emb_dim))
94
+
95
+ if pool == 'mean':
96
+ self.pool = global_mean_pool
97
+ elif pool == 'max':
98
+ self.pool = global_max_pool
99
+ elif pool == 'add':
100
+ self.pool = global_add_pool
101
+
102
+ self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
103
+
104
+ self.out_lin = nn.Sequential(
105
+ nn.Linear(self.feat_dim, self.feat_dim),
106
+ nn.ReLU(inplace=True),
107
+ nn.Linear(self.feat_dim, self.feat_dim//2)
108
+ )
109
+
110
+ def forward(self, data):
111
+ x = data.x
112
+ edge_index = data.edge_index
113
+ edge_attr = data.edge_attr
114
+
115
+ h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1]) + self.x_embedding3(x[:,2]) + self.x_embedding4(x[:,3]) + self.x_embedding5(x[:,4])
116
+
117
+ for layer in range(self.num_layer):
118
+ h = self.gnns[layer](h, edge_index, edge_attr)
119
+ h = self.batch_norms[layer](h)
120
+ if layer == self.num_layer - 1:
121
+ h = F.dropout(h, self.drop_ratio, training=self.training)
122
+ else:
123
+ h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
124
+
125
+ '''h = self.pool(h, data.batch)
126
+ h = self.feat_lin(h)
127
+ out = self.out_lin(h)'''
128
+
129
+ return h
130
+
131
+ class FourierEmbedder(nn.Module):
132
+ """Embed a set of mz float values using frequencies"""
133
+
134
+ def __init__(self, spec_embed_dim, logmin=-2.5, logmax=3.3):
135
+ super().__init__()
136
+ self.d = spec_embed_dim
137
+ self.logmin = logmin
138
+ self.logmax = logmax
139
+
140
+ lambda_min = np.power(10, -logmin)
141
+ lambda_max = np.power(10, logmax)
142
+ index = torch.arange(np.ceil(self.d / 2))
143
+ exp = torch.pow(lambda_max / lambda_min, (2 * index) / (self.d - 2))
144
+ freqs = 2 * np.pi * (lambda_min * exp) ** (-1)
145
+
146
+ self.freqs = nn.Parameter(freqs, requires_grad=False)
147
+
148
+ # Turn off requires grad for freqs
149
+ self.freqs.requires_grad = False
150
+
151
+ def forward(self, mz: torch.FloatTensor):
152
+ """forward
153
+
154
+ Args:
155
+ mz: FloatTensor of shape (batch_size, mz values)
156
+
157
+ Returns:
158
+ FloatTensor of shape (batch_size, peak len, mz )
159
+ """
160
+ freq_input = torch.einsum("bi,j->bij", mz, self.freqs)
161
+ embedded = torch.cat([torch.sin(freq_input), torch.cos(freq_input)], -1)
162
+ embedded = embedded[:, :, : self.d]
163
+ return embedded
164
+
165
+ class MSModel(nn.Module):
166
+ def __init__(self, spec_embed_dim,dropout,layers):
167
+ super(MSModel,self).__init__()
168
+ self.mz_embedder = FourierEmbedder(spec_embed_dim)
169
+ self.input_compress = nn.Linear(spec_embed_dim+1, spec_embed_dim)
170
+ peak_attn_layer = nn_utils.TransformerEncoderLayer(
171
+ d_model=spec_embed_dim,
172
+ nhead=8,
173
+ dim_feedforward=spec_embed_dim * 4,
174
+ dropout=dropout,
175
+ additive_attn=False,
176
+ pairwise_featurization=False)
177
+ self.peak_attn_layers = nn_utils.get_clones(peak_attn_layer,layers)
178
+ self.pooling_layer = nn.AdaptiveAvgPool1d(1)
179
+ self.output_layer = nn.Linear(spec_embed_dim, spec_embed_dim)
180
+
181
+ def forward(self,mzs,intens,num_peaks):
182
+ embedded_mz = self.mz_embedder(mzs)
183
+ cat_vec = [embedded_mz, intens[:, :, None]]
184
+ peak_tensor = torch.cat(cat_vec, -1)
185
+ peak_tensor = self.input_compress(peak_tensor)
186
+ peak_dim = peak_tensor.shape[1]
187
+ peaks_aranged = torch.arange(peak_dim).to(mzs.device)
188
+
189
+ # batch x num peaks
190
+ attn_mask = ~(peaks_aranged[None, :] < num_peaks[:, None])
191
+
192
+ # Transpose to peaks x batch x features
193
+ peak_tensor = peak_tensor.transpose(0, 1)
194
+ for peak_attn_layer in self.peak_attn_layers:
195
+ peak_tensor, pairwise_features = peak_attn_layer(
196
+ peak_tensor,
197
+ src_key_padding_mask=attn_mask,
198
+ )
199
+
200
+ peak_tensor = peak_tensor.transpose(0, 1)
201
+
202
+ # Get only the class token
203
+ #h0 = peak_tensor[:, 0, :]
204
+
205
+ #output = self.output_layer(h0)
206
+
207
+ '''pooled_embeddings = self.pooling_layer(peak_tensor.permute(0, 2, 1)).squeeze(dim=-1)
208
+ output = self.output_layer(pooled_embeddings)'''
209
+ return peak_tensor,attn_mask
210
+
211
+ class ESA_SMILES(nn.Module):
212
+ def __init__(self, feature_dim, out_dim):
213
+ super().__init__()
214
+ self.ln_f = nn.LayerNorm(feature_dim)
215
+ self.linear = nn.Linear(feature_dim, out_dim)
216
+ self.linear1 = nn.Linear(out_dim, out_dim)
217
+
218
+ def forward(self, hidden_states,data_batch):
219
+ B = data_batch.max().item() + 1 # batch_num
220
+ node_counts = torch.bincount(data_batch) # node_num
221
+ N = node_counts.max().item() # max_node_num
222
+ C = hidden_states.shape[1] # feat_dim
223
+ result = torch.zeros((B, N, C)).to(hidden_states.device)
224
+ for i in range(B):
225
+ indices = torch.where(data_batch == i)[0]
226
+ result[i, :len(indices), :] = hidden_states[indices]
227
+ attention_mask = (result != 0).any(dim=-1).float()
228
+ logits = self.ln_f(result) # (B, N, C)
229
+ cap_embes = self.linear(logits) # Q
230
+ features_in = self.linear1(cap_embes) # M
231
+ mask = attention_mask.unsqueeze(-1) # (B, N, 1)
232
+ features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C)
233
+ features_k_softmax = nn.Softmax(dim=1)(features_in)
234
+ attn = features_k_softmax.masked_fill(mask == 0, 0)
235
+ smi_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
236
+ return smi_feature
237
+
238
+ class ESA_SPEC(nn.Module):
239
+ def __init__(self, feature_dim, out_dim):
240
+ super().__init__()
241
+ self.ln_f = nn.LayerNorm(feature_dim)
242
+ self.linear = nn.Linear(feature_dim, out_dim)
243
+ self.linear1 = nn.Linear(out_dim, out_dim)
244
+
245
+ def forward(self, hidden_states,attention_mask):
246
+ logits = self.ln_f(hidden_states) # (B, N, C)
247
+ cap_embes = self.linear(logits) # Q
248
+ features_in = self.linear1(cap_embes) # M
249
+ mask = attention_mask.unsqueeze(-1) # (B, N, 1)
250
+ features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C)
251
+ features_k_softmax = nn.Softmax(dim=1)(features_in)
252
+ attn = features_k_softmax.masked_fill(mask == 0, 0)
253
+ spec_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
254
+ return spec_feature
255
+
256
+ class ModelCLR(nn.Module):
257
+ def __init__(self, num_layer, emb_dim, feat_dim, drop_ratio, pool,spec_embed_dim,dropout,layers,embed_dim):
258
+ super().__init__()
259
+
260
+ self.Smiles_model = SmilesModel(num_layer, emb_dim, feat_dim, drop_ratio, pool)
261
+ self.MS_model = MSModel(spec_embed_dim,dropout,layers)
262
+ self.smi_esa = ESA_SMILES(emb_dim, embed_dim)
263
+ self.spec_esa = ESA_SPEC(spec_embed_dim, embed_dim)
264
+ self.smi_proj = nn.Linear(embed_dim, embed_dim)
265
+ self.spec_proj = nn.Linear(embed_dim, embed_dim)
266
+
267
+ def smiles_encoder(self, xis):
268
+ x = self.Smiles_model(xis)
269
+ return x
270
+
271
+ def ms_encoder(self, mzs,intens,num_peaks):
272
+ out_emb = self.MS_model(mzs,intens,num_peaks)
273
+ return out_emb
274
+
275
+ def forward(self, xis, mzs,intens,num_peaks):
276
+ zis = self.smiles_encoder(xis)
277
+ zls,attn_mask = self.ms_encoder(mzs,intens,num_peaks)
278
+ zis_feat=self.smi_esa(zis,xis.batch)
279
+ zls_feat=self.spec_esa(zls,attn_mask)
280
+ zis_feat=self.smi_proj(zis_feat)
281
+ zls_feat=self.spec_proj(zls_feat)
282
+ return zis_feat, zls_feat
283
+