OliXio commited on
Commit
5946936
·
verified ·
1 Parent(s): d65669c

Upload 5 files

Browse files

Using new loss function and bug fix of predict.py.

Files changed (5) hide show
  1. code/dataset.py +4 -1
  2. code/loss.py +77 -0
  3. code/modules.py +8 -5
  4. code/predict.py +130 -139
  5. code/train.py +105 -36
code/dataset.py CHANGED
@@ -44,6 +44,9 @@ class Dataset(torch.utils.data.Dataset):
44
  def __getitem__(self, idx):
45
  item = {}
46
  try:
 
 
 
47
  if 'nls' in self.data[idx]:
48
  nls = self.data[idx]['nls']
49
  else:
@@ -111,7 +114,7 @@ class PathDataset(torch.utils.data.Dataset):
111
  item = calc_feats(smi, ms, nls, self.cfg)
112
 
113
  except Exception as e:
114
- print('='*50, idx, str(e))
115
  return None
116
 
117
  return item
 
44
  def __getitem__(self, idx):
45
  item = {}
46
  try:
47
+ if 'ms_bins' in self.data[idx]:
48
+ return self.data[idx]
49
+
50
  if 'nls' in self.data[idx]:
51
  nls = self.data[idx]['nls']
52
  else:
 
114
  item = calc_feats(smi, ms, nls, self.cfg)
115
 
116
  except Exception as e:
117
+ #print('='*50, idx, str(e))
118
  return None
119
 
120
  return item
code/loss.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def infoNCE_loss1(mol_features, ms_features, temperature=0.1, norm=True):
6
+ # Normalize features
7
+ if norm:
8
+ mol_features = F.normalize(mol_features, p=2, dim=1)
9
+ ms_features = F.normalize(ms_features, p=2, dim=1)
10
+
11
+ # Compute similarity matrix
12
+ logits = torch.mm(mol_features, ms_features.T) / temperature
13
+
14
+ # Labels: positive pairs are on the diagonal
15
+ batch_size = mol_features.size(0)
16
+ labels = torch.arange(batch_size, device=mol_features.device)
17
+
18
+ # Cross entropy loss
19
+ loss_mol = F.cross_entropy(logits, labels)
20
+ loss_trans = F.cross_entropy(logits.T, labels)
21
+ loss = (loss_mol + loss_trans) / 2
22
+
23
+ return loss
24
+
25
+ def infoNCE_loss2(mol_features, ms_features, temperature=0.1, alpha=0.75, norm=True):
26
+ """
27
+ 使用更合适的temperature (0.07是CLIP中常用的值)
28
+ 添加更多的数值稳定性措施
29
+ """
30
+ if norm:
31
+ mol_features = F.normalize(mol_features, p=2, dim=1)
32
+ ms_features = F.normalize(ms_features, p=2, dim=1)
33
+
34
+ batch_size = mol_features.size(0)
35
+
36
+ # 计算相似度矩阵
37
+ logits_ab = torch.matmul(mol_features, ms_features.T) / temperature
38
+ logits_ba = torch.matmul(ms_features, mol_features.T) / temperature
39
+
40
+ # 创建标签
41
+ labels = torch.arange(batch_size, device=mol_features.device)
42
+
43
+ # 计算损失
44
+ loss_ab = F.cross_entropy(logits_ab, labels)
45
+ loss_ba = F.cross_entropy(logits_ba, labels)
46
+
47
+ return alpha * loss_ab + (1 - alpha) * loss_ba
48
+
49
+ # 在对比损失函数中增加困难负样本挖掘
50
+ def contrastive_loss_with_hard_negatives(features1, features2, margin=1.0, hard_negative_ratio=0.3):
51
+ """
52
+ 改进的对比损失函数,包含困难负样本挖掘
53
+ """
54
+ batch_size = features1.shape[0]
55
+
56
+ # 计算相似度矩阵
57
+ similarity = torch.matmul(features1, features2.t())
58
+
59
+ # 正样本对(对角线)
60
+ positive_similarity = torch.diag(similarity)
61
+
62
+ # 困难负样本挖掘:选择相似度较高的负样本
63
+ mask = ~torch.eye(batch_size, dtype=torch.bool)
64
+ negative_similarities = similarity[mask].view(batch_size, batch_size-1)
65
+
66
+ # 选择最困难的前k个负样本
67
+ k = int(batch_size * hard_negative_ratio)
68
+ hard_negatives, _ = torch.topk(negative_similarities, k=k, dim=1)
69
+
70
+ # 对比损失计算
71
+ loss = 0
72
+ for i in range(batch_size):
73
+ pos_loss = 1 - positive_similarity[i]
74
+ neg_loss = torch.mean(torch.clamp(hard_negatives[i] - margin, min=0))
75
+ loss += pos_loss + neg_loss
76
+
77
+ return loss / batch_size
code/modules.py CHANGED
@@ -8,9 +8,6 @@ import numpy as np
8
  from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
9
  from GNN import layers as gly
10
 
11
- loss_func_ms = nn.CrossEntropyLoss()
12
- loss_func = nn.CrossEntropyLoss()
13
-
14
  class MolGNNEncoder(nn.Module):
15
  def __init__(self,
16
  outdim,
@@ -144,10 +141,16 @@ class FragSimiModel(nn.Module):
144
  ms_embeddings = self.ms_projection(ms_features)
145
  mol_embeddings = self.mol_projection(mol_features)
146
 
 
 
 
 
 
 
147
  # Calculating the Loss
148
  #logits = (mol_embeddings @ ms_embeddings.t())
149
  #logit_scale = self.logit_scale.exp()
150
- logits = mol_embeddings @ ms_embeddings.t()
151
 
152
  ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)
153
 
@@ -155,4 +158,4 @@ class FragSimiModel(nn.Module):
155
  mol_loss = loss_func(logits.t(), ground_truth)
156
  loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size)
157
 
158
- return loss.mean()
 
8
  from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
9
  from GNN import layers as gly
10
 
 
 
 
11
  class MolGNNEncoder(nn.Module):
12
  def __init__(self,
13
  outdim,
 
141
  ms_embeddings = self.ms_projection(ms_features)
142
  mol_embeddings = self.mol_projection(mol_features)
143
 
144
+ # Normalize the projected embeddings
145
+ mol_embeddings = F.normalize(mol_embeddings, p=2, dim=1)
146
+ ms_embeddings = F.normalize(ms_embeddings, p=2, dim=1)
147
+
148
+ return mol_embeddings, ms_embeddings
149
+
150
  # Calculating the Loss
151
  #logits = (mol_embeddings @ ms_embeddings.t())
152
  #logit_scale = self.logit_scale.exp()
153
+ '''logits = mol_embeddings @ ms_embeddings.t()
154
 
155
  ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)
156
 
 
158
  mol_loss = loss_func(logits.t(), ground_truth)
159
  loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size)
160
 
161
+ return loss.mean()'''
code/predict.py CHANGED
@@ -9,105 +9,55 @@ import utils
9
  import json
10
  import pandas as pd
11
  import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- MolFeatsCached = {}
14
-
15
- def calc_mol_embeddings0(model, smis, cfg):
16
- model.eval()
17
-
18
- valid_mol_embeddings = []
19
- with torch.no_grad():
20
- for smi in smis:
21
- try:
22
- mol_features = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
23
- mol_embeddings = model.mol_projection(mol_features.unsqueeze(0))
24
- valid_mol_embeddings.append(mol_embeddings.squeeze(0))
25
- except Exception as e:
26
- print(smi, e)
27
- continue
28
-
29
- return torch.stack(valid_mol_embeddings)
30
-
31
- def calc_mol_embeddings1(model, smis, cfg):
32
- model.eval()
33
- mol_embeddings = []
34
-
35
- with torch.no_grad():
36
- for smi in smis:
37
- try:
38
- if cfg.mol_encoder == 'fp':
39
- k = hash(smi + f'fp-{cfg.fptype}-{cfg.mol_embedding_dim}')
40
- if k in MolFeatsCached:
41
- feats = MolFeatsCached[k]
42
- else:
43
- feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
44
- MolFeatsCached[k] = feats
45
- me = model.mol_projection(feats.unsqueeze(0))
46
- mol_embeddings.append(me.squeeze(0))
47
- elif cfg.mol_encoder == 'gnn':
48
- k = hash(smi + 'gnn')
49
- if k in MolFeatsCached:
50
- gfeats = MolFeatsCached[k]
51
- else:
52
- gfeats = utils.mol_graph_featurizer(smi)
53
- MolFeatsCached[k] = gfeats
54
-
55
- bat = {'A': gfeats['A'].unsqueeze(0).to(cfg.device),
56
- 'V': gfeats['V'].unsqueeze(0).to(cfg.device),
57
- 'mol_size': gfeats['mol_size'].unsqueeze(0).to(cfg.device)}
58
-
59
- feats = model.mol_gnn_encoder(bat)
60
- me = model.mol_projection(feats)
61
- mol_embeddings.append(me.squeeze(0))
62
- except Exception as e:
63
- print(smi, e)
64
- continue
65
-
66
- return torch.stack(mol_embeddings)
67
 
68
  def calc_mol_embeddings(model, smis, cfg):
69
  model.eval()
70
  fp_featsl = []
71
  gnn_featsl = []
72
  fm_featsl = []
 
73
 
74
- for smi in smis:
 
75
  try:
76
  if 'gnn' in cfg.mol_encoder:
77
- k = hash(smi + 'gnn')
78
- if k in MolFeatsCached:
79
- gnn_feats = MolFeatsCached[k]
80
- if gnn_feats is None:
81
- continue
82
- else:
83
- gnn_feats = utils.mol_graph_featurizer(smi)
84
- MolFeatsCached[k] = gnn_feats
85
- if gnn_feats is None:
86
- continue
87
  gnn_featsl.append(gnn_feats)
88
  if 'fp' in cfg.mol_encoder:
89
- k = hash(smi + f'fp-{cfg.fptype}-{cfg.mol_embedding_dim}')
90
- if k in MolFeatsCached:
91
- fp_feats = MolFeatsCached[k]
92
- if fp_feats is None:
93
- continue
94
- else:
95
- fp_feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
96
- MolFeatsCached[k] = fp_feats
97
  fp_featsl.append(fp_feats)
98
  if 'fm' in cfg.mol_encoder:
99
- k = hash(smi + f'fm-{cfg.fptype}-{cfg.mol_embedding_dim}')
100
- if k in MolFeatsCached:
101
- fm_feats = MolFeatsCached[k]
102
- if fm_feats is None:
103
- continue
104
- else:
105
- fm_feats = utils.smi2fmvec(smi).to(cfg.device)
106
- MolFeatsCached[k] = fm_feats
107
  fm_featsl.append(fm_feats)
 
108
  except Exception as e:
109
  print(smi, e)
110
- MolFeatsCached[k] = None
111
  continue
112
 
113
  mol_feat_list = []
@@ -136,6 +86,8 @@ def calc_mol_embeddings(model, smis, cfg):
136
 
137
  mol_feat_list.append(model.mol_gnn_encoder(bat))
138
 
 
 
139
  if 'fp' in cfg.mol_encoder:
140
  mol_feat_list.append(torch.stack(fp_featsl).to(cfg.device))
141
 
@@ -150,62 +102,85 @@ def calc_mol_embeddings(model, smis, cfg):
150
  with torch.no_grad():
151
  mol_embeddings = model.mol_projection(mol_features)
152
 
153
- return mol_embeddings
154
 
155
- def find_matches(model, ms, smis, cfg, n=10):
 
 
156
  model.eval()
157
  with torch.no_grad():
158
  ms_features = utils.ms_binner(ms, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn).to(cfg.device)
159
  ms_features = ms_features.unsqueeze(0)
160
- ms_embeddings = model.ms_projection(ms_features).squeeze(0)
161
-
162
- #print(43, ms_features.shape, ms_embeddings.shape)
163
 
164
- mol_embeddings = calc_mol_embeddings(model, smis, cfg)
165
-
166
- mol_embeddings_n = F.normalize(mol_embeddings, p=2, dim=-1)
167
- ms_embeddings_n = F.normalize(ms_embeddings, p=2, dim=-1)
168
- dot_similarity = mol_embeddings_n @ ms_embeddings_n.t()
169
 
170
- if n == -1 or n > len(mol_embeddings):
171
- n = len(mol_embeddings)
172
-
173
- values, indices = torch.topk(dot_similarity.squeeze(0), n)
174
-
175
- matchsmis = [smis[idx] for idx in indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- return matchsmis, values.to('cpu').data.numpy()*100, indices.to('cpu').data.numpy()
178
 
179
- def calc(models, datal, cfg, saveout=True):
180
  dicall = {}
181
  coridxd = {}
182
 
183
  for idx, model in enumerate(models):
184
  for nn, data in enumerate(datal):
185
  print(f'Calculating {nn}-th MS...')
186
- #smipool = [d[1] for d in data['candidates'][:50]]
187
- smipool = [d[1] for d in data['candidates']]
188
 
189
  try:
190
- smis, scores, indices = find_matches(model, data['ms'], smipool, cfg, 50)
191
  except Exception as e:
192
  print(131, e)
193
  continue
194
 
195
  dic = {}
196
- for n, smi in enumerate(smis):
 
197
  if smi in dic:
198
- dic[smi]['score'] += scores[n]
199
- dic[smi]['iscor'] = data['candidates'][indices[n]][-1]
200
- dic[smi]['idx'] = data['candidates'][indices[n]][0]
201
  else:
202
- dic[smi] = {'score': scores[n], 'iscor': data['candidates'][indices[n]][-1], 'idx': data['candidates'][indices[n]][0]}
 
 
 
 
 
203
 
204
- ikey = data['ikey']
205
  if ikey in dicall:
206
  for k, v in dic.items():
207
  if k in dicall[ikey]:
208
  dicall[ikey][k]['score'] += v['score']
 
209
  else:
210
  dicall[ikey][k] = v
211
  else:
@@ -223,11 +198,14 @@ def calc(models, datal, cfg, saveout=True):
223
  n = len(scorel)
224
 
225
  values, indices = torch.topk(scoretsor, n)
226
-
227
- scorel = values
228
- smis = [smis[i] for i in indices]
229
- iscorl = [iscorl[i] for i in indices]
230
- indexl = [indexl[i] for i in indices]
 
 
 
231
 
232
  try:
233
  i = iscorl.index(True)
@@ -253,23 +231,42 @@ def calc(models, datal, cfg, saveout=True):
253
  if not k in dc:
254
  dc[k] = [0]
255
 
256
- '''if saveout:
257
- df0 = pd.DataFrame(dc)
258
- df0.to_csv('summary.csv', index=False)
259
 
260
- df = pd.DataFrame({
261
- 'MSFn': ikeysl,
262
- 'Item': iteml,
263
- 'Index': smisidl,
264
- 'Smiles': smis,
265
- 'Score': scoresl,
266
- 'IsCorrect': iscorl})
267
 
268
- df.to_csv('predicted.csv', index=False)'''
 
 
 
 
269
 
270
- return sumtop3, dc, dicall
 
 
 
 
 
 
 
 
 
 
271
 
272
- def test(modelfnl, datal, datafn=''):
 
 
 
 
 
 
 
 
 
 
 
 
273
  maxtop3 = 0
274
  maxoutt = ''
275
 
@@ -281,9 +278,8 @@ def test(modelfnl, datal, datafn=''):
281
 
282
  model = FragSimiModel(CFG).to(CFG.device)
283
  model.load_state_dict(d['state_dict'])
284
- model.to(CFG.device)
285
 
286
- sumtop3, dc, dicall = calc([model], datal, CFG, saveout=False)
287
 
288
  sumtop10 = 0
289
  for k in ['Hit %.3d' %(i+1) for i in range(10)]:
@@ -313,12 +309,12 @@ def test(modelfnl, datal, datafn=''):
313
  maxtop3 = sumtop3
314
  maxoutt = outt
315
 
316
- dicall['testdata'] = datafn
317
- dicall['testrlt'] = outt
318
- pickle.dump(dicall, open(fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}-tstrlt.pkl'), 'wb'))
319
 
320
  df = pd.DataFrame(tops)
321
- df.to_csv(fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}-tstrlt.csv'), index=False)
322
 
323
  return maxoutt, maxtop3
324
 
@@ -326,17 +322,12 @@ def main(datafn, fnl):
326
  outl = []
327
 
328
  datal = json.load(open(datafn))
329
- logfn = f'predict_results.csv'
330
-
331
- if not os.path.exists(logfn):
332
- open(logfn, 'w').write('Index,Results,Model,Data\n')
333
 
334
  n = 0
335
  for n, fn in enumerate(fnl):
336
- out, _ = test([fn], datal, datafn)
337
  print(out, os.path.basename(fn))
338
  outl.append(out)
339
- open(logfn, 'a').write(f'{n},"{out}",{fn},{datafn}\n')
340
 
341
  print(outl)
342
 
 
9
  import json
10
  import pandas as pd
11
  import pickle
12
+ from rdkit import Chem
13
+ from rdkit.Chem import inchi
14
+
15
+ def smiles_to_inchikey(smiles, nostereo=True):
16
+ try:
17
+ # 将SMILES转换为分子对象
18
+ mol = Chem.MolFromSmiles(smiles)
19
+ if mol is None:
20
+ return None
21
+
22
+ if nostereo:
23
+ options = "-SNon"
24
+ inchi_string = inchi.MolToInchi(mol, options=options)
25
+ else:
26
+ inchi_string = inchi.MolToInchi(mol)
27
+
28
+ if not inchi_string:
29
+ return None
30
+
31
+ inchikey = inchi.InchiToInchiKey(inchi_string)
32
 
33
+ return inchikey
34
+
35
+ except Exception as e:
36
+ print(f"转换失败: {e}")
37
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def calc_mol_embeddings(model, smis, cfg):
40
  model.eval()
41
  fp_featsl = []
42
  gnn_featsl = []
43
  fm_featsl = []
44
+ valid_smis = []
45
 
46
+ for smil in smis:
47
+ smi = smil[1]
48
  try:
49
  if 'gnn' in cfg.mol_encoder:
50
+ gnn_feats = utils.mol_graph_featurizer(smi)
 
 
 
 
 
 
 
 
 
51
  gnn_featsl.append(gnn_feats)
52
  if 'fp' in cfg.mol_encoder:
53
+ fp_feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
 
 
 
 
 
 
 
54
  fp_featsl.append(fp_feats)
55
  if 'fm' in cfg.mol_encoder:
56
+ fm_feats = utils.smi2fmvec(smi).to(cfg.device)
 
 
 
 
 
 
 
57
  fm_featsl.append(fm_feats)
58
+ valid_smis.append(smil)
59
  except Exception as e:
60
  print(smi, e)
 
61
  continue
62
 
63
  mol_feat_list = []
 
86
 
87
  mol_feat_list.append(model.mol_gnn_encoder(bat))
88
 
89
+ del bat
90
+
91
  if 'fp' in cfg.mol_encoder:
92
  mol_feat_list.append(torch.stack(fp_featsl).to(cfg.device))
93
 
 
102
  with torch.no_grad():
103
  mol_embeddings = model.mol_projection(mol_features)
104
 
105
+ del mol_features, mol_feat_list
106
 
107
+ return mol_embeddings, valid_smis
108
+
109
+ def find_matches(model, ms, smis, cfg, n=10, batch_size=64):
110
  model.eval()
111
  with torch.no_grad():
112
  ms_features = utils.ms_binner(ms, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn).to(cfg.device)
113
  ms_features = ms_features.unsqueeze(0)
114
+ ms_embeddings = model.ms_projection(ms_features)
115
+ ms_embeddings_n = F.normalize(ms_embeddings, p=2, dim=1)
 
116
 
117
+ # 分批计算相似度并维护top-k
118
+ all_similarities = []
119
+ all_valid_smis = []
 
 
120
 
121
+ # 收集所有分子embedding
122
+ all_embeddings = []
123
+ for i in tqdm(range(0, len(smis), batch_size)):
124
+ batch_smis = smis[i:i+batch_size]
125
+ batch_embeddings, valid_smis = calc_mol_embeddings(model, batch_smis, cfg)
126
+ all_embeddings.append(batch_embeddings)
127
+ all_valid_smis.extend(valid_smis)
128
+
129
+ del batch_embeddings
130
+
131
+ # 全局归一化
132
+ all_embeddings = torch.cat(all_embeddings, dim=0)
133
+ all_embeddings_n = F.normalize(all_embeddings, p=2, dim=1)
134
+
135
+ # 计算相似度
136
+ similarities = F.cosine_similarity(all_embeddings_n, ms_embeddings_n, dim=1)
137
+ #print('all_embeddings_n.shape', all_embeddings_n.shape, ms_embeddings.shape, len(all_valid_smis), similarities.shape)
138
+
139
+ if n == -1 or n > len(all_valid_smis):
140
+ n = len(all_valid_smis)
141
+
142
+ values, topk_indices = torch.topk(similarities, n)
143
+
144
+ topk_indices_list = topk_indices.cpu().tolist()
145
+ #print(len(topk_indices_list), len(all_valid_smis), len(similarities))
146
+ matchsmis = [all_valid_smis[idx] for idx in topk_indices_list]
147
 
148
+ return matchsmis, values.cpu().numpy()*100, topk_indices_list
149
 
150
+ def calc(models, datal, cfg):
151
  dicall = {}
152
  coridxd = {}
153
 
154
  for idx, model in enumerate(models):
155
  for nn, data in enumerate(datal):
156
  print(f'Calculating {nn}-th MS...')
 
 
157
 
158
  try:
159
+ smis, scores, indices = find_matches(model, data['ms'], data['candidates'], cfg, 50)
160
  except Exception as e:
161
  print(131, e)
162
  continue
163
 
164
  dic = {}
165
+ for n, smil in enumerate(smis):
166
+ smi = smil[1]
167
  if smi in dic:
168
+ dic[smi]['score'] = scores[n]
169
+ dic[smi]['iscor'] = smis[n][-1]
170
+ dic[smi]['idx'] = smis[n][0]
171
  else:
172
+ dic[smi] = {'score': scores[n], 'iscor': smis[n][-1], 'idx': smis[n][0]}
173
+
174
+ # 计算去除立体构型分子的inchikey,由于质谱很难区分立体构型,我们认为分子的不同立体构型都算正确匹配
175
+ ikey = smiles_to_inchikey(data['smiles'], True)
176
+ if ikey is None:
177
+ ikey = data['ikey']
178
 
 
179
  if ikey in dicall:
180
  for k, v in dic.items():
181
  if k in dicall[ikey]:
182
  dicall[ikey][k]['score'] += v['score']
183
+ dicall[ikey][k]['score'] /= 2
184
  else:
185
  dicall[ikey][k] = v
186
  else:
 
198
  n = len(scorel)
199
 
200
  values, indices = torch.topk(scoretsor, n)
201
+
202
+ # 修复:将张量转换为Python列表
203
+ indices_list = indices.cpu().tolist()
204
+
205
+ scorel = values.cpu().numpy()
206
+ smis = [smis[i] for i in indices_list]
207
+ iscorl = [iscorl[i] for i in indices_list]
208
+ indexl = [indexl[i] for i in indices_list]
209
 
210
  try:
211
  i = iscorl.index(True)
 
231
  if not k in dc:
232
  dc[k] = [0]
233
 
234
+ return sumtop3, dc, dicall
 
 
235
 
236
+ def calc_rank(dicall):
237
+ rankd = {}
 
 
 
 
 
238
 
239
+ for ikey, dic in dicall.items():
240
+ smis = [k for k in dic.keys()]
241
+ scorel = [d['score'] for d in dic.values()]
242
+ iscorl = [d['iscor'] for d in dic.values()]
243
+ indexl = [d['idx'] for d in dic.values()]
244
 
245
+ scoretsor = torch.tensor(scorel)
246
+ n = 100
247
+ if n > len(scorel):
248
+ n = len(scorel)
249
+
250
+ values, indices = torch.topk(scoretsor, n)
251
+
252
+ scorel = values
253
+ smis = [smis[i] for i in indices]
254
+ iscorl = [iscorl[i] for i in indices]
255
+ indexl = [indexl[i] for i in indices]
256
 
257
+ sl = []
258
+ for n, smi in enumerate(smis):
259
+ sl.append(f'{scorel[n]}:{smi}:{smiles_to_inchikey(smi)}')
260
+
261
+ try:
262
+ i = iscorl.index(True)
263
+ rankd[ikey] = {'Hit': i+1, 'Rank': sl}
264
+ except:
265
+ pass
266
+
267
+ return rankd
268
+
269
+ def predict(modelfnl, datal, datafn=''):
270
  maxtop3 = 0
271
  maxoutt = ''
272
 
 
278
 
279
  model = FragSimiModel(CFG).to(CFG.device)
280
  model.load_state_dict(d['state_dict'])
 
281
 
282
+ sumtop3, dc, dicall = calc([model], datal, CFG)
283
 
284
  sumtop10 = 0
285
  for k in ['Hit %.3d' %(i+1) for i in range(10)]:
 
309
  maxtop3 = sumtop3
310
  maxoutt = outt
311
 
312
+ basefn = fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}')
313
+ rank = calc_rank(dicall)
314
+ json.dump(rank, open(basefn + '-predict-rank.json', 'w'), indent=2)
315
 
316
  df = pd.DataFrame(tops)
317
+ df.to_csv(basefn + '-predict-summary.csv', index=False)
318
 
319
  return maxoutt, maxtop3
320
 
 
322
  outl = []
323
 
324
  datal = json.load(open(datafn))
 
 
 
 
325
 
326
  n = 0
327
  for n, fn in enumerate(fnl):
328
+ out, _ = predict([fn], datal, datafn)
329
  print(out, os.path.basename(fn))
330
  outl.append(out)
 
331
 
332
  print(outl)
333
 
code/train.py CHANGED
@@ -11,6 +11,9 @@ from dataset import *
11
  import torch.utils.data
12
  import copy, json, pickle
13
  import itertools as it
 
 
 
14
 
15
  def make_next_record_dir(basedir, prefix=''):
16
  path = '%s/%%s001/' %basedir
@@ -97,55 +100,91 @@ def build_loaders(inp, mode, cfg, num_workers):
97
  return dataloader
98
 
99
  def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
 
100
  loss_meter = AvgMeter()
101
  tqdm_object = tqdm(train_loader, total=len(train_loader))
 
102
 
103
  for batch in tqdm_object:
104
  for k, v in batch.items():
105
  batch[k] = v.to(CFG.device)
106
 
107
- loss = model(batch)
108
  optimizer.zero_grad()
 
 
 
 
 
109
  loss.backward()
110
  optimizer.step()
 
 
 
 
 
 
 
 
111
  if step == "batch":
112
  lr_scheduler.step()
113
 
114
  count = batch["ms_bins"].size(0)
115
  loss_meter.update(loss.item(), count)
116
 
117
- tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
118
- return loss_meter
 
 
 
 
 
 
 
119
 
120
  def valid_epoch(model, valid_loader):
 
121
  loss_meter = AvgMeter()
 
122
 
123
- tqdm_object = tqdm(valid_loader, total=len(valid_loader))
124
- for batch in tqdm_object:
125
- for k, v in batch.items():
126
- batch[k] = v.to(CFG.device)
 
127
 
128
- loss = model(batch)
129
 
130
- count = batch["ms_bins"].size(0)
131
- loss_meter.update(loss.item(), count)
132
 
133
- tqdm_object.set_postfix(valid_loss=loss_meter.avg)
 
 
 
134
 
135
- return loss_meter
136
 
137
- def main(data, cfg=CFG, savedir='data/train', encmodel=None, ratio=1):
 
 
 
 
 
 
 
 
138
  setup_seed(cfg.seed)
139
 
140
  train_set, valid_set = make_train_valid(data, valid_ratio=cfg.valid_ratio, seed=cfg.seed)
141
 
 
 
142
  n = len(train_set)
143
  if ratio < 1:
144
  train_set = random.sample(train_set, int(n*ratio))
145
  print(f'Ratio {ratio}, lenall {n}, newtrainset {len(train_set)}')
146
 
147
- train_loader = build_loaders(train_set, "train", cfg, 10)
148
- valid_loader = build_loaders(valid_set, "valid", cfg, 10)
149
 
150
  step = "epoch"
151
 
@@ -155,14 +194,6 @@ def main(data, cfg=CFG, savedir='data/train', encmodel=None, ratio=1):
155
 
156
  model = FragSimiModel(cfg).to(cfg.device)
157
 
158
- if not encmodel is None:
159
- model.mol_gnn_encoder.load_state_dict(encmodel.mol_gnn_encoder.state_dict())
160
- # fraze mol_gnn_encoder weights
161
- '''for name, param in model.named_parameters():
162
- if 'mol_gnn_encoder' in name:
163
- print(152, 'fraze mol_gnn_encoder weights')
164
- param.requires_grad = False'''
165
-
166
  print(model)
167
 
168
  optimizer = torch.optim.AdamW(
@@ -173,27 +204,59 @@ def main(data, cfg=CFG, savedir='data/train', encmodel=None, ratio=1):
173
  optimizer, mode="min", patience=cfg.patience, factor=cfg.factor
174
  )
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  for epoch in range(cfg.epochs):
177
  print(f"Epoch: {epoch + 1}/{cfg.epochs}")
178
- model.train()
179
- train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
180
- model.eval()
181
- with torch.no_grad():
182
- valid_loss = valid_epoch(model, valid_loader)
 
183
 
184
- if valid_loss.avg < best_loss:
185
  best_loss = valid_loss.avg
186
- best_model_fn = f"{savedir}/model-tloss{round(train_loss.avg, 3)}-vloss{round(valid_loss.avg, 3)}-epoch{epoch}.pth"
187
  best_model_fn_base = best_model_fn.replace('.pth', '')
188
  n = 1
189
  while os.path.exists(best_model_fn):
190
  best_model_fn = best_model_fn_base + f'-{n}.pth'
191
  n += 1
192
 
193
- checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': dict(CFG)}
194
  best_model_fns.append(best_model_fn)
195
- torch.save(checkpoint, best_model_fn)
196
- print("Saved Best Model!")
 
 
 
 
 
 
 
 
 
 
197
 
198
  best_model_fnl = []
199
  for fn in best_model_fns:
@@ -209,6 +272,8 @@ def main(data, cfg=CFG, savedir='data/train', encmodel=None, ratio=1):
209
  return best_model_fnl, best_loss
210
 
211
  if __name__ == "__main__":
 
 
212
  try:
213
  conffn = sys.argv[1]
214
  if conffn.endswith('.json'):
@@ -229,7 +294,10 @@ if __name__ == "__main__":
229
 
230
  os.system('mkdir -p %s' %savedir)
231
 
232
- mg = None
 
 
 
233
 
234
  print(CFG)
235
 
@@ -237,6 +305,7 @@ if __name__ == "__main__":
237
  data = [os.path.join(CFG.dataset_path, i) for i in os.listdir(CFG.dataset_path) if i.endswith('mgf')]
238
  elif os.path.isfile(CFG.dataset_path):
239
  if CFG.dataset_path.endswith('.pkl'):
 
240
  data = pickle.load(open(CFG.dataset_path, 'rb'))
241
  else:
242
  data = json.load(open(CFG.dataset_path))
@@ -244,8 +313,8 @@ if __name__ == "__main__":
244
  if not os.path.exists(pklfn):
245
  pickle.dump(data, open(pklfn, 'wb'))
246
 
247
- subdir = make_next_record_dir(savedir, f'train-')
248
  os.system(f'cp -a *py {subdir}; cp -a GNN {subdir}')
249
  CFG.save(f'{subdir}/config.json')
250
 
251
- modelfnl, _ = main(data, CFG, subdir, mg)
 
11
  import torch.utils.data
12
  import copy, json, pickle
13
  import itertools as it
14
+ import loss
15
+
16
+ loss_func = loss.infoNCE_loss2
17
 
18
  def make_next_record_dir(basedir, prefix=''):
19
  path = '%s/%%s001/' %basedir
 
100
  return dataloader
101
 
102
  def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
103
+ model.train()
104
  loss_meter = AvgMeter()
105
  tqdm_object = tqdm(train_loader, total=len(train_loader))
106
+ total_cos_sim = 0
107
 
108
  for batch in tqdm_object:
109
  for k, v in batch.items():
110
  batch[k] = v.to(CFG.device)
111
 
 
112
  optimizer.zero_grad()
113
+
114
+ mol_features, ms_features = model(batch)
115
+
116
+ loss = loss_func(mol_features, ms_features)
117
+
118
  loss.backward()
119
  optimizer.step()
120
+
121
+ with torch.no_grad():
122
+ cos_sim = F.cosine_similarity(
123
+ mol_features.detach(),
124
+ ms_features.detach()
125
+ ).mean().item()
126
+ total_cos_sim += cos_sim
127
+
128
  if step == "batch":
129
  lr_scheduler.step()
130
 
131
  count = batch["ms_bins"].size(0)
132
  loss_meter.update(loss.item(), count)
133
 
134
+ tqdm_object.set_postfix(train_loss=loss_meter.avg, train_cos_sim=round(cos_sim, 4), lr=get_lr(optimizer))
135
+
136
+ del mol_features, ms_features, loss, cos_sim
137
+
138
+ for k in list(batch.keys()):
139
+ del batch[k]
140
+ del batch
141
+
142
+ return loss_meter, total_cos_sim / len(train_loader)
143
 
144
  def valid_epoch(model, valid_loader):
145
+ model.eval()
146
  loss_meter = AvgMeter()
147
+ total_cos_sim = 0
148
 
149
+ with torch.no_grad():
150
+ tqdm_object = tqdm(valid_loader, total=len(valid_loader))
151
+ for batch in tqdm_object:
152
+ for k, v in batch.items():
153
+ batch[k] = v.to(CFG.device)
154
 
155
+ mol_features, ms_features = model(batch)
156
 
157
+ loss = loss_func(mol_features, ms_features)
 
158
 
159
+ count = batch["ms_bins"].size(0)
160
+ loss_meter.update(loss.item(), count)
161
+ cos_sim = F.cosine_similarity(mol_features.detach(), ms_features.detach()).mean().item()
162
+ total_cos_sim += cos_sim
163
 
164
+ tqdm_object.set_postfix(valid_loss=loss_meter.avg, valid_cos_sim=round(cos_sim, 4))
165
 
166
+ del mol_features, ms_features, loss, cos_sim
167
+
168
+ for k in list(batch.keys()):
169
+ del batch[k]
170
+ del batch
171
+
172
+ return loss_meter, total_cos_sim / len(valid_loader)
173
+
174
+ def main(data, cfg=CFG, savedir='data/train', model_path=None, ratio=1):
175
  setup_seed(cfg.seed)
176
 
177
  train_set, valid_set = make_train_valid(data, valid_ratio=cfg.valid_ratio, seed=cfg.seed)
178
 
179
+ log_file = f'{savedir}/trainlog.txt'
180
+
181
  n = len(train_set)
182
  if ratio < 1:
183
  train_set = random.sample(train_set, int(n*ratio))
184
  print(f'Ratio {ratio}, lenall {n}, newtrainset {len(train_set)}')
185
 
186
+ train_loader = build_loaders(train_set, "train", cfg, 1)
187
+ valid_loader = build_loaders(valid_set, "valid", cfg, 1)
188
 
189
  step = "epoch"
190
 
 
194
 
195
  model = FragSimiModel(cfg).to(cfg.device)
196
 
 
 
 
 
 
 
 
 
197
  print(model)
198
 
199
  optimizer = torch.optim.AdamW(
 
204
  optimizer, mode="min", patience=cfg.patience, factor=cfg.factor
205
  )
206
 
207
+ # Load pre-trained model if path is provided
208
+ if model_path and os.path.exists(model_path):
209
+ print(f"Loading model from {model_path}")
210
+ checkpoint = torch.load(model_path, map_location=cfg.device)
211
+ model.load_state_dict(checkpoint['state_dict'])
212
+
213
+ '''if 'optimizer' in checkpoint:
214
+ optimizer.load_state_dict(checkpoint['optimizer'])
215
+ print("Loaded optimizer state")'''
216
+
217
+ print(f"Resuming training")
218
+ del checkpoint
219
+
220
+ # write training log
221
+ with open(log_file, 'a', encoding='utf8') as f:
222
+ f.write(f'Start training:\n')
223
+ f.write(f'Data path: {cfg.dataset_path}, valid ratio: {cfg.valid_ratio}\n')
224
+ if model_path:
225
+ f.write(f'Resuming from: {model_path}\n')
226
+ print(model, file=f)
227
+ f.write(f'\n')
228
+
229
  for epoch in range(cfg.epochs):
230
  print(f"Epoch: {epoch + 1}/{cfg.epochs}")
231
+ train_loss, t_cos_sim = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
232
+ valid_loss, v_cos_sim = valid_epoch(model, valid_loader)
233
+
234
+ txt = f"Train Loss: {train_loss.avg:.4f} | Val Loss: {valid_loss.avg:.4f} | Train cos sim: {t_cos_sim:.4f} | Val cos sim: {v_cos_sim:.4f}"
235
+ print(txt)
236
+ open(log_file, 'a').write(f"Epoch {epoch + 1}/{cfg.epochs}: {txt}\n")
237
 
238
+ if True: #valid_loss.avg < best_loss:
239
  best_loss = valid_loss.avg
240
+ best_model_fn = f"{savedir}/model-tloss{round(train_loss.avg, 3)}-vloss{round(valid_loss.avg, 3)}-tcos{round(t_cos_sim, 3)}-vcos{round(v_cos_sim, 3)}-epoch{epoch}.pth"
241
  best_model_fn_base = best_model_fn.replace('.pth', '')
242
  n = 1
243
  while os.path.exists(best_model_fn):
244
  best_model_fn = best_model_fn_base + f'-{n}.pth'
245
  n += 1
246
 
 
247
  best_model_fns.append(best_model_fn)
248
+
249
+ torch.save({'epoch': epoch,
250
+ 'state_dict': model.state_dict(),
251
+ 'optimizer': optimizer.state_dict(),
252
+ 'config': dict(CFG),
253
+ 'train_loss': train_loss.avg,
254
+ 'valid_loss': valid_loss.avg,
255
+ 'train_cos_sim': t_cos_sim,
256
+ 'val_cos_sim': v_cos_sim
257
+ }, best_model_fn)
258
+
259
+ print("Saved new best model!")
260
 
261
  best_model_fnl = []
262
  for fn in best_model_fns:
 
272
  return best_model_fnl, best_loss
273
 
274
  if __name__ == "__main__":
275
+ import pickle
276
+ from tqdm import tqdm
277
  try:
278
  conffn = sys.argv[1]
279
  if conffn.endswith('.json'):
 
294
 
295
  os.system('mkdir -p %s' %savedir)
296
 
297
+ try:
298
+ prev_model_pth = sys.argv[3]
299
+ except:
300
+ prev_model_pth = None
301
 
302
  print(CFG)
303
 
 
305
  data = [os.path.join(CFG.dataset_path, i) for i in os.listdir(CFG.dataset_path) if i.endswith('mgf')]
306
  elif os.path.isfile(CFG.dataset_path):
307
  if CFG.dataset_path.endswith('.pkl'):
308
+ print(f'loading data from {CFG.dataset_path} ...')
309
  data = pickle.load(open(CFG.dataset_path, 'rb'))
310
  else:
311
  data = json.load(open(CFG.dataset_path))
 
313
  if not os.path.exists(pklfn):
314
  pickle.dump(data, open(pklfn, 'wb'))
315
 
316
+ subdir = make_next_record_dir(savedir, f'train-neg-')
317
  os.system(f'cp -a *py {subdir}; cp -a GNN {subdir}')
318
  CFG.save(f'{subdir}/config.json')
319
 
320
+ modelfnl, _ = main(data, CFG, subdir, prev_model_pth)