Tingxie commited on
Commit
5146751
·
verified ·
1 Parent(s): 5a8c2f5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +283 -283
model.py CHANGED
@@ -1,283 +1,283 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
- import torch
4
- import math
5
- import numpy as np
6
- from torch_geometric.nn import MessagePassing
7
- from torch_geometric.utils import add_self_loops
8
- from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
9
- import nn_utils as nn_utils
10
- num_atom_type = 119 # including the extra mask tokens
11
- num_chirality_tag = 4
12
- num_hybrid_type = 8
13
- num_valence_tag = 6
14
- num_degree_tag = 5
15
-
16
- num_bond_type = 5 # including aromatic and self-loop edge
17
- num_bond_direction = 3
18
- num_bond_configuration = 6
19
- class GINEConv(MessagePassing):
20
- def __init__(self, emb_dim):
21
- super(GINEConv, self).__init__()
22
- self.mlp = nn.Sequential(
23
- nn.Linear(emb_dim, 2*emb_dim),
24
- nn.ReLU(),
25
- nn.Linear(2*emb_dim, emb_dim)
26
- )
27
- self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
28
- self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
29
- #self.edge_embedding3 = nn.Embedding(num_bond_configuration, emb_dim)
30
- nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
31
- nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
32
- #nn.init.xavier_uniform_(self.edge_embedding3.weight.data)
33
-
34
- def forward(self, x, edge_index, edge_attr):
35
- # add self loops in the edge space
36
- edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
37
-
38
- # add features corresponding to self-loop edges.
39
- self_loop_attr = torch.zeros(x.size(0), 2)
40
- self_loop_attr[:,0] = 4 #bond type for self-loop edge
41
- self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
42
- edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
43
-
44
- edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
45
-
46
- return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
47
-
48
- def message(self, x_j, edge_attr):
49
- return x_j + edge_attr
50
-
51
- def update(self, aggr_out):
52
- return self.mlp(aggr_out)
53
-
54
-
55
- class SmilesModel(nn.Module):
56
- """
57
- Args:
58
- num_layer (int): the number of GNN layers
59
- emb_dim (int): dimensionality of embeddings
60
- max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
61
- drop_ratio (float): dropout rate
62
- gnn_type: gin, gcn, graphsage, gat
63
- Output:
64
- node representations
65
- """
66
- def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'):
67
- super(SmilesModel, self).__init__()
68
- self.num_layer = num_layer
69
- self.emb_dim = emb_dim
70
- self.feat_dim = feat_dim
71
- self.drop_ratio = drop_ratio
72
-
73
- self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
74
- self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
75
- self.x_embedding3 = nn.Embedding(num_hybrid_type, emb_dim)
76
- self.x_embedding4 = nn.Embedding(num_valence_tag, emb_dim)
77
- self.x_embedding5 = nn.Embedding(num_degree_tag, emb_dim)
78
-
79
- nn.init.xavier_uniform_(self.x_embedding1.weight.data)
80
- nn.init.xavier_uniform_(self.x_embedding2.weight.data)
81
- nn.init.xavier_uniform_(self.x_embedding3.weight.data)
82
- nn.init.xavier_uniform_(self.x_embedding4.weight.data)
83
- nn.init.xavier_uniform_(self.x_embedding5.weight.data)
84
-
85
- # List of MLPs
86
- self.gnns = nn.ModuleList()
87
- for layer in range(num_layer):
88
- self.gnns.append(GINEConv(emb_dim))
89
-
90
- # List of batchnorms
91
- self.batch_norms = nn.ModuleList()
92
- for layer in range(num_layer):
93
- self.batch_norms.append(nn.BatchNorm1d(emb_dim))
94
-
95
- if pool == 'mean':
96
- self.pool = global_mean_pool
97
- elif pool == 'max':
98
- self.pool = global_max_pool
99
- elif pool == 'add':
100
- self.pool = global_add_pool
101
-
102
- self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
103
-
104
- self.out_lin = nn.Sequential(
105
- nn.Linear(self.feat_dim, self.feat_dim),
106
- nn.ReLU(inplace=True),
107
- nn.Linear(self.feat_dim, self.feat_dim//2)
108
- )
109
-
110
- def forward(self, data):
111
- x = data.x
112
- edge_index = data.edge_index
113
- edge_attr = data.edge_attr
114
-
115
- h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1]) + self.x_embedding3(x[:,2]) + self.x_embedding4(x[:,3]) + self.x_embedding5(x[:,4])
116
-
117
- for layer in range(self.num_layer):
118
- h = self.gnns[layer](h, edge_index, edge_attr)
119
- h = self.batch_norms[layer](h)
120
- if layer == self.num_layer - 1:
121
- h = F.dropout(h, self.drop_ratio, training=self.training)
122
- else:
123
- h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
124
-
125
- '''h = self.pool(h, data.batch)
126
- h = self.feat_lin(h)
127
- out = self.out_lin(h)'''
128
-
129
- return h
130
-
131
- class FourierEmbedder(nn.Module):
132
- """Embed a set of mz float values using frequencies"""
133
-
134
- def __init__(self, spec_embed_dim, logmin=-2.5, logmax=3.3):
135
- super().__init__()
136
- self.d = spec_embed_dim
137
- self.logmin = logmin
138
- self.logmax = logmax
139
-
140
- lambda_min = np.power(10, -logmin)
141
- lambda_max = np.power(10, logmax)
142
- index = torch.arange(np.ceil(self.d / 2))
143
- exp = torch.pow(lambda_max / lambda_min, (2 * index) / (self.d - 2))
144
- freqs = 2 * np.pi * (lambda_min * exp) ** (-1)
145
-
146
- self.freqs = nn.Parameter(freqs, requires_grad=False)
147
-
148
- # Turn off requires grad for freqs
149
- self.freqs.requires_grad = False
150
-
151
- def forward(self, mz: torch.FloatTensor):
152
- """forward
153
-
154
- Args:
155
- mz: FloatTensor of shape (batch_size, mz values)
156
-
157
- Returns:
158
- FloatTensor of shape (batch_size, peak len, mz )
159
- """
160
- freq_input = torch.einsum("bi,j->bij", mz, self.freqs)
161
- embedded = torch.cat([torch.sin(freq_input), torch.cos(freq_input)], -1)
162
- embedded = embedded[:, :, : self.d]
163
- return embedded
164
-
165
- class MSModel(nn.Module):
166
- def __init__(self, spec_embed_dim,dropout,layers):
167
- super(MSModel,self).__init__()
168
- self.mz_embedder = FourierEmbedder(spec_embed_dim)
169
- self.input_compress = nn.Linear(spec_embed_dim+1, spec_embed_dim)
170
- peak_attn_layer = nn_utils.TransformerEncoderLayer(
171
- d_model=spec_embed_dim,
172
- nhead=8,
173
- dim_feedforward=spec_embed_dim * 4,
174
- dropout=dropout,
175
- additive_attn=False,
176
- pairwise_featurization=False)
177
- self.peak_attn_layers = nn_utils.get_clones(peak_attn_layer,layers)
178
- self.pooling_layer = nn.AdaptiveAvgPool1d(1)
179
- self.output_layer = nn.Linear(spec_embed_dim, spec_embed_dim)
180
-
181
- def forward(self,mzs,intens,num_peaks):
182
- embedded_mz = self.mz_embedder(mzs)
183
- cat_vec = [embedded_mz, intens[:, :, None]]
184
- peak_tensor = torch.cat(cat_vec, -1)
185
- peak_tensor = self.input_compress(peak_tensor)
186
- peak_dim = peak_tensor.shape[1]
187
- peaks_aranged = torch.arange(peak_dim).to(mzs.device)
188
-
189
- # batch x num peaks
190
- attn_mask = ~(peaks_aranged[None, :] < num_peaks[:, None])
191
-
192
- # Transpose to peaks x batch x features
193
- peak_tensor = peak_tensor.transpose(0, 1)
194
- for peak_attn_layer in self.peak_attn_layers:
195
- peak_tensor, pairwise_features = peak_attn_layer(
196
- peak_tensor,
197
- src_key_padding_mask=attn_mask,
198
- )
199
-
200
- peak_tensor = peak_tensor.transpose(0, 1)
201
-
202
- # Get only the class token
203
- #h0 = peak_tensor[:, 0, :]
204
-
205
- #output = self.output_layer(h0)
206
-
207
- '''pooled_embeddings = self.pooling_layer(peak_tensor.permute(0, 2, 1)).squeeze(dim=-1)
208
- output = self.output_layer(pooled_embeddings)'''
209
- return peak_tensor,attn_mask
210
-
211
- class ESA_SMILES(nn.Module):
212
- def __init__(self, feature_dim, out_dim):
213
- super().__init__()
214
- self.ln_f = nn.LayerNorm(feature_dim)
215
- self.linear = nn.Linear(feature_dim, out_dim)
216
- self.linear1 = nn.Linear(out_dim, out_dim)
217
-
218
- def forward(self, hidden_states,data_batch):
219
- B = data_batch.max().item() + 1 # batch_num
220
- node_counts = torch.bincount(data_batch) # node_num
221
- N = node_counts.max().item() # max_node_num
222
- C = hidden_states.shape[1] # feat_dim
223
- result = torch.zeros((B, N, C)).to(hidden_states.device)
224
- for i in range(B):
225
- indices = torch.where(data_batch == i)[0]
226
- result[i, :len(indices), :] = hidden_states[indices]
227
- attention_mask = (result != 0).any(dim=-1).float()
228
- logits = self.ln_f(result) # (B, N, C)
229
- cap_embes = self.linear(logits) # Q
230
- features_in = self.linear1(cap_embes) # M
231
- mask = attention_mask.unsqueeze(-1) # (B, N, 1)
232
- features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C)
233
- features_k_softmax = nn.Softmax(dim=1)(features_in)
234
- attn = features_k_softmax.masked_fill(mask == 0, 0)
235
- smi_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
236
- return smi_feature
237
-
238
- class ESA_SPEC(nn.Module):
239
- def __init__(self, feature_dim, out_dim):
240
- super().__init__()
241
- self.ln_f = nn.LayerNorm(feature_dim)
242
- self.linear = nn.Linear(feature_dim, out_dim)
243
- self.linear1 = nn.Linear(out_dim, out_dim)
244
-
245
- def forward(self, hidden_states,attention_mask):
246
- logits = self.ln_f(hidden_states) # (B, N, C)
247
- cap_embes = self.linear(logits) # Q
248
- features_in = self.linear1(cap_embes) # M
249
- mask = attention_mask.unsqueeze(-1) # (B, N, 1)
250
- features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C)
251
- features_k_softmax = nn.Softmax(dim=1)(features_in)
252
- attn = features_k_softmax.masked_fill(mask == 0, 0)
253
- spec_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
254
- return spec_feature
255
-
256
- class ModelCLR(nn.Module):
257
- def __init__(self, num_layer, emb_dim, feat_dim, drop_ratio, pool,spec_embed_dim,dropout,layers,embed_dim):
258
- super().__init__()
259
-
260
- self.Smiles_model = SmilesModel(num_layer, emb_dim, feat_dim, drop_ratio, pool)
261
- self.MS_model = MSModel(spec_embed_dim,dropout,layers)
262
- self.smi_esa = ESA_SMILES(emb_dim, embed_dim)
263
- self.spec_esa = ESA_SPEC(spec_embed_dim, embed_dim)
264
- self.smi_proj = nn.Linear(embed_dim, embed_dim)
265
- self.spec_proj = nn.Linear(embed_dim, embed_dim)
266
-
267
- def smiles_encoder(self, xis):
268
- x = self.Smiles_model(xis)
269
- return x
270
-
271
- def ms_encoder(self, mzs,intens,num_peaks):
272
- out_emb = self.MS_model(mzs,intens,num_peaks)
273
- return out_emb
274
-
275
- def forward(self, xis, mzs,intens,num_peaks):
276
- zis = self.smiles_encoder(xis)
277
- zls,attn_mask = self.ms_encoder(mzs,intens,num_peaks)
278
- zis_feat=self.smi_esa(zis,xis.batch)
279
- zls_feat=self.spec_esa(zls,attn_mask)
280
- zis_feat=self.smi_proj(zis_feat)
281
- zls_feat=self.spec_proj(zls_feat)
282
- return zis_feat, zls_feat
283
-
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import math
5
+ import numpy as np
6
+ from torch_geometric.nn import MessagePassing
7
+ from torch_geometric.utils import add_self_loops
8
+ from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
9
+ import nn_utils as nn_utils
10
+ num_atom_type = 119 # including the extra mask tokens
11
+ num_chirality_tag = 4
12
+ num_hybrid_type = 8
13
+ num_valence_tag = 8
14
+ num_degree_tag = 5
15
+
16
+ num_bond_type = 5 # including aromatic and self-loop edge
17
+ num_bond_direction = 3
18
+ num_bond_configuration = 6
19
+ class GINEConv(MessagePassing):
20
+ def __init__(self, emb_dim):
21
+ super(GINEConv, self).__init__()
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(emb_dim, 2*emb_dim),
24
+ nn.ReLU(),
25
+ nn.Linear(2*emb_dim, emb_dim)
26
+ )
27
+ self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
28
+ self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
29
+ #self.edge_embedding3 = nn.Embedding(num_bond_configuration, emb_dim)
30
+ nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
31
+ nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
32
+ #nn.init.xavier_uniform_(self.edge_embedding3.weight.data)
33
+
34
+ def forward(self, x, edge_index, edge_attr):
35
+ # add self loops in the edge space
36
+ edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
37
+
38
+ # add features corresponding to self-loop edges.
39
+ self_loop_attr = torch.zeros(x.size(0), 2)
40
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
41
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
42
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
43
+
44
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
45
+
46
+ return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
47
+
48
+ def message(self, x_j, edge_attr):
49
+ return x_j + edge_attr
50
+
51
+ def update(self, aggr_out):
52
+ return self.mlp(aggr_out)
53
+
54
+
55
+ class SmilesModel(nn.Module):
56
+ """
57
+ Args:
58
+ num_layer (int): the number of GNN layers
59
+ emb_dim (int): dimensionality of embeddings
60
+ max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
61
+ drop_ratio (float): dropout rate
62
+ gnn_type: gin, gcn, graphsage, gat
63
+ Output:
64
+ node representations
65
+ """
66
+ def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'):
67
+ super(SmilesModel, self).__init__()
68
+ self.num_layer = num_layer
69
+ self.emb_dim = emb_dim
70
+ self.feat_dim = feat_dim
71
+ self.drop_ratio = drop_ratio
72
+
73
+ self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
74
+ self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
75
+ self.x_embedding3 = nn.Embedding(num_hybrid_type, emb_dim)
76
+ self.x_embedding4 = nn.Embedding(num_valence_tag, emb_dim)
77
+ self.x_embedding5 = nn.Embedding(num_degree_tag, emb_dim)
78
+
79
+ nn.init.xavier_uniform_(self.x_embedding1.weight.data)
80
+ nn.init.xavier_uniform_(self.x_embedding2.weight.data)
81
+ nn.init.xavier_uniform_(self.x_embedding3.weight.data)
82
+ nn.init.xavier_uniform_(self.x_embedding4.weight.data)
83
+ nn.init.xavier_uniform_(self.x_embedding5.weight.data)
84
+
85
+ # List of MLPs
86
+ self.gnns = nn.ModuleList()
87
+ for layer in range(num_layer):
88
+ self.gnns.append(GINEConv(emb_dim))
89
+
90
+ # List of batchnorms
91
+ self.batch_norms = nn.ModuleList()
92
+ for layer in range(num_layer):
93
+ self.batch_norms.append(nn.BatchNorm1d(emb_dim))
94
+
95
+ if pool == 'mean':
96
+ self.pool = global_mean_pool
97
+ elif pool == 'max':
98
+ self.pool = global_max_pool
99
+ elif pool == 'add':
100
+ self.pool = global_add_pool
101
+
102
+ self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
103
+
104
+ self.out_lin = nn.Sequential(
105
+ nn.Linear(self.feat_dim, self.feat_dim),
106
+ nn.ReLU(inplace=True),
107
+ nn.Linear(self.feat_dim, self.feat_dim//2)
108
+ )
109
+
110
+ def forward(self, data):
111
+ x = data.x
112
+ edge_index = data.edge_index
113
+ edge_attr = data.edge_attr
114
+
115
+ h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1]) + self.x_embedding3(x[:,2]) + self.x_embedding4(x[:,3]) + self.x_embedding5(x[:,4])
116
+
117
+ for layer in range(self.num_layer):
118
+ h = self.gnns[layer](h, edge_index, edge_attr)
119
+ h = self.batch_norms[layer](h)
120
+ if layer == self.num_layer - 1:
121
+ h = F.dropout(h, self.drop_ratio, training=self.training)
122
+ else:
123
+ h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
124
+
125
+ '''h = self.pool(h, data.batch)
126
+ h = self.feat_lin(h)
127
+ out = self.out_lin(h)'''
128
+
129
+ return h
130
+
131
+ class FourierEmbedder(nn.Module):
132
+ """Embed a set of mz float values using frequencies"""
133
+
134
+ def __init__(self, spec_embed_dim, logmin=2.5, logmax=3.3):
135
+ super().__init__()
136
+ self.d = spec_embed_dim
137
+ self.logmin = logmin
138
+ self.logmax = logmax
139
+
140
+ lambda_min = np.power(10, -logmin)
141
+ lambda_max = np.power(10, logmax)
142
+ index = torch.arange(np.ceil(self.d / 2))
143
+ exp = torch.pow(lambda_max / lambda_min, (2 * index) / (self.d - 2))
144
+ freqs = 2 * np.pi * (lambda_min * exp) ** (-1)
145
+
146
+ self.freqs = nn.Parameter(freqs, requires_grad=False)
147
+
148
+ # Turn off requires grad for freqs
149
+ self.freqs.requires_grad = False
150
+
151
+ def forward(self, mz: torch.FloatTensor):
152
+ """forward
153
+
154
+ Args:
155
+ mz: FloatTensor of shape (batch_size, mz values)
156
+
157
+ Returns:
158
+ FloatTensor of shape (batch_size, peak len, mz )
159
+ """
160
+ freq_input = torch.einsum("bi,j->bij", mz, self.freqs)
161
+ embedded = torch.cat([torch.sin(freq_input), torch.cos(freq_input)], -1)
162
+ embedded = embedded[:, :, : self.d]
163
+ return embedded
164
+
165
+ class MSModel(nn.Module):
166
+ def __init__(self, spec_embed_dim,dropout,layers):
167
+ super(MSModel,self).__init__()
168
+ self.mz_embedder = FourierEmbedder(spec_embed_dim)
169
+ self.input_compress = nn.Linear(spec_embed_dim+1, spec_embed_dim)
170
+ peak_attn_layer = nn_utils.TransformerEncoderLayer(
171
+ d_model=spec_embed_dim,
172
+ nhead=8,
173
+ dim_feedforward=spec_embed_dim * 4,
174
+ dropout=dropout,
175
+ additive_attn=False,
176
+ pairwise_featurization=False)
177
+ self.peak_attn_layers = nn_utils.get_clones(peak_attn_layer,layers)
178
+ self.pooling_layer = nn.AdaptiveAvgPool1d(1)
179
+ self.output_layer = nn.Linear(spec_embed_dim, spec_embed_dim)
180
+
181
+ def forward(self,mzs,intens,num_peaks):
182
+ embedded_mz = self.mz_embedder(mzs)
183
+ cat_vec = [embedded_mz, intens[:, :, None]]
184
+ peak_tensor = torch.cat(cat_vec, -1)
185
+ peak_tensor = self.input_compress(peak_tensor)
186
+ peak_dim = peak_tensor.shape[1]
187
+ peaks_aranged = torch.arange(peak_dim).to(mzs.device)
188
+
189
+ # batch x num peaks
190
+ attn_mask = ~(peaks_aranged[None, :] < num_peaks[:, None])
191
+
192
+ # Transpose to peaks x batch x features
193
+ peak_tensor = peak_tensor.transpose(0, 1)
194
+ for peak_attn_layer in self.peak_attn_layers:
195
+ peak_tensor, attn_weights, pairwise_features = peak_attn_layer(
196
+ peak_tensor,
197
+ src_key_padding_mask=attn_mask,
198
+ )
199
+
200
+ peak_tensor = peak_tensor.transpose(0, 1)
201
+
202
+ # Get only the class token
203
+ #h0 = peak_tensor[:, 0, :]
204
+
205
+ #output = self.output_layer(h0)
206
+
207
+ '''pooled_embeddings = self.pooling_layer(peak_tensor.permute(0, 2, 1)).squeeze(dim=-1)
208
+ output = self.output_layer(pooled_embeddings)'''
209
+ return peak_tensor,attn_mask
210
+
211
+ class ESA_SMILES(nn.Module):
212
+ def __init__(self, feature_dim, out_dim):
213
+ super().__init__()
214
+ self.ln_f = nn.LayerNorm(feature_dim)
215
+ self.linear = nn.Linear(feature_dim, out_dim)
216
+ self.linear1 = nn.Linear(out_dim, out_dim)
217
+
218
+ def forward(self, hidden_states,data_batch):
219
+ B = data_batch.max().item() + 1 # batch_num
220
+ node_counts = torch.bincount(data_batch) # node_num
221
+ N = node_counts.max().item() # max_node_num
222
+ C = hidden_states.shape[1] # feat_dim
223
+ result = torch.zeros((B, N, C)).to(hidden_states.device)
224
+ for i in range(B):
225
+ indices = torch.where(data_batch == i)[0]
226
+ result[i, :len(indices), :] = hidden_states[indices]
227
+ attention_mask = (result != 0).any(dim=-1).float()
228
+ logits = self.ln_f(result) # (B, N, C)
229
+ cap_embes = self.linear(logits) # Q
230
+ features_in = self.linear1(cap_embes) # M
231
+ mask = attention_mask.unsqueeze(-1) # (B, N, 1)
232
+ features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C)
233
+ features_k_softmax = nn.Softmax(dim=1)(features_in)
234
+ attn = features_k_softmax.masked_fill(mask == 0, 0)
235
+ smi_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
236
+ return smi_feature
237
+
238
+ class ESA_SPEC(nn.Module):
239
+ def __init__(self, feature_dim, out_dim):
240
+ super().__init__()
241
+ self.ln_f = nn.LayerNorm(feature_dim)
242
+ self.linear = nn.Linear(feature_dim, out_dim)
243
+ self.linear1 = nn.Linear(out_dim, out_dim)
244
+
245
+ def forward(self, hidden_states,attention_mask):
246
+ logits = self.ln_f(hidden_states) # (B, N, C)
247
+ cap_embes = self.linear(logits) # Q
248
+ features_in = self.linear1(cap_embes) # M
249
+ mask = attention_mask.unsqueeze(-1) # (B, N, 1)
250
+ features_in = features_in.masked_fill(mask == 1, -1e4) # (B, N, C)
251
+ features_k_softmax = nn.Softmax(dim=1)(features_in)
252
+ attn = features_k_softmax.masked_fill(mask == 1, 0)
253
+ spec_feature = torch.sum(attn * cap_embes, dim=1) # (B, C)
254
+ return spec_feature
255
+
256
+ class ModelCLR(nn.Module):
257
+ def __init__(self, num_layer, emb_dim, feat_dim, drop_ratio, pool,spec_embed_dim,dropout,layers,embed_dim):
258
+ super().__init__()
259
+
260
+ self.Smiles_model = SmilesModel(num_layer, emb_dim, feat_dim, drop_ratio, pool)
261
+ self.MS_model = MSModel(spec_embed_dim,dropout,layers)
262
+ self.smi_esa = ESA_SMILES(emb_dim, embed_dim)
263
+ self.spec_esa = ESA_SPEC(spec_embed_dim, embed_dim)
264
+ self.smi_proj = nn.Linear(embed_dim, embed_dim)
265
+ self.spec_proj = nn.Linear(embed_dim, embed_dim)
266
+
267
+ def smiles_encoder(self, xis):
268
+ x = self.Smiles_model(xis)
269
+ return x
270
+
271
+ def ms_encoder(self, mzs,intens,num_peaks):
272
+ out_emb = self.MS_model(mzs,intens,num_peaks)
273
+ return out_emb
274
+
275
+ def forward(self, xis, mzs,intens,num_peaks):
276
+ zis = self.smiles_encoder(xis)
277
+ zls,attn_mask = self.ms_encoder(mzs,intens,num_peaks)
278
+ zis_feat=self.smi_esa(zis,xis.batch)
279
+ zls_feat=self.spec_esa(zls,attn_mask)
280
+ zis_feat=self.smi_proj(zis_feat)
281
+ zls_feat=self.spec_proj(zls_feat)
282
+ return zis_feat, zls_feat
283
+