yzhouchen001 commited on
Commit
2c0063e
Β·
1 Parent(s): 7d8e998

cleaned up

Browse files
Files changed (39) hide show
  1. {mvp β†’ flare}/__init__.py +0 -0
  2. flare/data/__init__.py +1 -0
  3. {mvp β†’ flare}/data/data_module.py +1 -1
  4. {mvp β†’ flare}/data/datasets.py +2 -21
  5. {mvp β†’ flare}/data/transforms.py +1 -1
  6. {mvp β†’ flare}/data_preprocess.py +1 -1
  7. {mvp β†’ flare}/definitions.py +0 -0
  8. {mvp β†’ flare}/models/__init__.py +0 -0
  9. {mvp β†’ flare}/models/contrastive.py +360 -471
  10. {mvp β†’ flare}/models/encoders.py +0 -0
  11. {mvp β†’ flare}/models/mol_encoder.py +0 -0
  12. {mvp β†’ flare}/models/spec_encoder.py +7 -9
  13. {mvp β†’ flare}/params_binnedSpec.yaml +0 -0
  14. {mvp β†’ flare}/params_formSpec.yaml +42 -46
  15. {mvp β†’ flare}/params_jestr.yaml +0 -0
  16. {mvp β†’ flare}/params_tmp.yaml +0 -0
  17. flare/run.sh +3 -0
  18. {mvp β†’ flare}/subformula_assign/__init__.py +0 -0
  19. {mvp β†’ flare}/subformula_assign/assign_subformulae.py +0 -0
  20. {mvp β†’ flare}/subformula_assign/run.sh +0 -0
  21. {mvp β†’ flare}/subformula_assign/utils/__init__.py +0 -0
  22. {mvp β†’ flare}/subformula_assign/utils/chem_utils.py +0 -0
  23. {mvp β†’ flare}/subformula_assign/utils/parallel_utils.py +0 -0
  24. {mvp β†’ flare}/subformula_assign/utils/parse_utils.py +0 -0
  25. {mvp β†’ flare}/subformula_assign/utils/spectra_utils.py +0 -0
  26. {mvp β†’ flare}/test.py +5 -5
  27. {mvp β†’ flare}/train.py +6 -6
  28. {mvp β†’ flare}/tune.py +5 -5
  29. {mvp β†’ flare}/utils/__init__.py +0 -0
  30. {mvp β†’ flare}/utils/data.py +11 -30
  31. {mvp β†’ flare}/utils/debug.py +0 -0
  32. {mvp β†’ flare}/utils/eval.py +10 -89
  33. flare/utils/general.py +186 -0
  34. {mvp β†’ flare}/utils/loss.py +0 -0
  35. {mvp β†’ flare}/utils/models.py +4 -12
  36. {mvp β†’ flare}/utils/preprocessing.py +1 -1
  37. mvp/data/__init__.py +0 -3
  38. mvp/run.sh +0 -3
  39. mvp/utils/general.py +0 -87
{mvp β†’ flare}/__init__.py RENAMED
File without changes
flare/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
{mvp β†’ flare}/data/data_module.py RENAMED
@@ -1,6 +1,6 @@
1
  from torch.utils.data.dataloader import DataLoader
2
  from massspecgym.data.data_module import MassSpecDataModule
3
- from mvp.data.datasets import ContrastiveDataset
4
  from functools import partial
5
  from massspecgym.models.base import Stage
6
 
 
1
  from torch.utils.data.dataloader import DataLoader
2
  from massspecgym.data.data_module import MassSpecDataModule
3
+ from flare.data.datasets import ContrastiveDataset
4
  from functools import partial
5
  from massspecgym.models.base import Stage
6
 
{mvp β†’ flare}/data/datasets.py RENAMED
@@ -11,7 +11,7 @@ import dgl
11
  from collections import defaultdict
12
  from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
13
  from massspecgym.data.datasets import MassSpecDataset
14
- import mvp.utils.data as data_utils
15
  from torch.nn.utils.rnn import pad_sequence
16
  from massspecgym.models.base import Stage
17
  import pickle
@@ -254,7 +254,7 @@ class ContrastiveDataset(Dataset):
254
  return item
255
 
256
  @staticmethod
257
- def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, mask_peak_ratio: float = 0.0, aug_cands: bool = False) -> dict:
258
  mol_key = 'cand' if stage == Stage.TEST else 'mol'
259
  non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
260
  require_pad = False
@@ -314,25 +314,6 @@ class ContrastiveDataset(Dataset):
314
  n_peaks.append(len(item['NL_spec']))
315
  collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
316
  collated_batch['NL_n_peaks'] = n_peaks
317
-
318
-
319
- # mask peaks
320
- if mask_peak_ratio > 0.0 and stage == Stage.TRAIN:
321
- n_mask_peaks = [math.floor(n_peak* mask_peak_ratio) for n_peak in n_peaks]
322
- mask_peak_idx = [np.random.choice(n_peak, n_mask, replace=False) for n_peak, n_mask in zip(n_peaks, n_mask_peaks)]
323
- for i, peaks in enumerate(collated_batch[spectra_view]):
324
- peaks[mask_peak_idx[i]] = -5.0
325
-
326
- # batch candidates
327
- if aug_cands:
328
- candidates = \
329
- sum([item["aug_cands"] for item in batch], start=[])
330
- collated_batch['aug_cands'] = dgl.batch(candidates)
331
-
332
- if 'aug_cands_fp' in batch[0]:
333
- cand_fp = [item['aug_cands_fp'] for item in batch]
334
- collated_batch['aug_cands_fp'] = torch.flatten(torch.Tensor(cand_fp), end_dim=1)
335
-
336
  return collated_batch
337
 
338
 
 
11
  from collections import defaultdict
12
  from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
13
  from massspecgym.data.datasets import MassSpecDataset
14
+ import flare.utils.data as data_utils
15
  from torch.nn.utils.rnn import pad_sequence
16
  from massspecgym.models.base import Stage
17
  import pickle
 
254
  return item
255
 
256
  @staticmethod
257
+ def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None) -> dict:
258
  mol_key = 'cand' if stage == Stage.TEST else 'mol'
259
  non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
260
  require_pad = False
 
314
  n_peaks.append(len(item['NL_spec']))
315
  collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
316
  collated_batch['NL_n_peaks'] = n_peaks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  return collated_batch
318
 
319
 
{mvp β†’ flare}/data/transforms.py RENAMED
@@ -3,7 +3,7 @@ import torch
3
  import matchms
4
  from typing import Optional
5
  from rdkit.Chem import AllChem as Chem
6
- from mvp.definitions import CHEM_ELEMS_SMALL
7
  from massspecgym.data.transforms import MolTransform, SpecTransform, default_matchms_transforms
8
  from massspecgym.data.transforms import SpecBinner
9
  import dgllife.utils as chemutils
 
3
  import matchms
4
  from typing import Optional
5
  from rdkit.Chem import AllChem as Chem
6
+ from flare.definitions import CHEM_ELEMS_SMALL
7
  from massspecgym.data.transforms import MolTransform, SpecTransform, default_matchms_transforms
8
  from massspecgym.data.transforms import SpecBinner
9
  import dgllife.utils as chemutils
{mvp β†’ flare}/data_preprocess.py RENAMED
@@ -1,5 +1,5 @@
1
  import argparse
2
- from mvp.utils.preprocessing import generate_cons_spec_formulas, generate_cons_spec
3
  import os
4
  import pickle
5
  import pandas as pd
 
1
  import argparse
2
+ from flare.utils.preprocessing import generate_cons_spec_formulas, generate_cons_spec
3
  import os
4
  import pickle
5
  import pandas as pd
{mvp β†’ flare}/definitions.py RENAMED
File without changes
{mvp β†’ flare}/models/__init__.py RENAMED
File without changes
{mvp β†’ flare}/models/contrastive.py RENAMED
@@ -10,11 +10,11 @@ from massspecgym.models.base import Stage
10
  from massspecgym import utils
11
  from torch.nn.utils.rnn import pad_sequence
12
 
13
- from mvp.utils.loss import contrastive_loss, cand_spec_sim_loss, fp_loss, cons_spec_loss, filip_loss_with_mask
14
- import mvp.utils.models as model_utils
15
- from mvp.utils.general import pad_graph_nodes, filip_similarity_batch
16
 
17
- from mvp.models.encoders import CrossAttention
18
  import torch.nn.functional as F
19
 
20
  from torch_geometric.nn import global_mean_pool
@@ -32,62 +32,21 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
32
  if 'use_NL_spec' not in self.hparams:
33
  self.hparams.use_NL_spec = False
34
 
35
- # if 'loss_strategy' not in self.hparams:
36
- # self.hparams.loss_strategy = 'static'
37
- # self.hparams.contr_wt = 1.0
38
- # self.hparams.use_contr = True
39
 
40
  self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
41
  self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
42
-
43
- # setup loss strategy
44
- if self.hparams.model == 'contrastive':
45
- self._loss_setup()
46
- if self.hparams.pred_fp:
47
- self.fp_loss = fp_loss(self.hparams.fp_loss_type)
48
- self.fp_pred_model = model_utils.get_fp_pred_model(self.hparams)
49
- if self.hparams.use_cons_spec:
50
- self.cons_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
51
- self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
52
-
53
  self.spec_view = self.hparams.spectra_view
54
 
55
  # result storage for testing results
56
  self.result_dct = defaultdict(lambda: defaultdict(list))
57
-
58
-
59
- # def _loss_setup(self):
60
- # self.loss_wts = {}
61
- # self.loss_updates = {}
62
-
63
-
64
- # for p, loss in zip(['use_contr','pred_fp', 'use_cons_spec', 'aug_cands'], ['contr_wt','fp_wt','cons_spec_wt' ,'aug_cands_wt']):
65
- # if p not in self.hparams:
66
- # self.hparams[p] = False
67
- # if self.hparams[p]:
68
- # if self.hparams.loss_strategy == 'linear':
69
- # start_wt = self.hparams[loss+'_update']['start']
70
- # end_wt = self.hparams[loss+'_update']['end']
71
- # change = (end_wt - start_wt)/self.hparams.max_epochs
72
- # self.loss_updates[loss] = change
73
- # self.loss_wts[loss] = start_wt
74
- # elif self.hparams.loss_strategy == 'manual':
75
- # self.loss_updates[loss] = self.hparams[loss+'_update']
76
- # self.loss_wts[loss] = self.hparams[loss]
77
- # else:
78
- # self.loss_wts[loss] = self.hparams[loss]
79
 
80
  def forward(self, batch, stage):
81
  g = batch['cand'] if stage == Stage.TEST else batch['mol']
82
 
83
- if self.hparams.use_cons_spec and stage != Stage.TEST:
84
- spec = batch['cons_spec']
85
- n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
86
- spec_enc = self.cons_spec_enc_model(spec, n_peaks)
87
- else:
88
- spec = batch[self.spec_view]
89
- n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
90
- spec_enc = self.spec_enc_model(spec, n_peaks)
91
 
92
  fp = batch['fp'] if self.hparams.use_fp else None
93
  mol_enc = self.mol_enc_model(g, fp=fp)
@@ -98,26 +57,24 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
98
  loss = 0
99
  losses = {}
100
  contr_loss, _, _ = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp)
101
- # contr_loss = self.loss_wts['contr_wt'] *contr_loss
102
  losses['contr_loss'] = contr_loss.detach().item()
103
- # losses['cong_loss'] = cong_loss.detach().item()
104
- # losses['noncong_loss'] = noncong_loss.detach().item()
105
-
106
  loss+=contr_loss
107
- if self.hparams.pred_fp:
108
- fp_loss_val = self.loss_wts['fp_wt'] *self.fp_loss(output['fp'], batch['fp'])
109
- loss+= fp_loss_val
110
- losses['fp_loss'] = fp_loss_val.detach().item()
111
-
112
- if 'aug_cand_enc' in output:
113
- aug_cand_loss = self.loss_wts['aug_cand_wt'] * cand_spec_sim_loss(spec_enc, output['aug_cand_enc'])
114
- loss+= aug_cand_loss
115
- losses['aug_cand_loss'] = aug_cand_loss.detach().item()
116
 
117
- if 'ind_spec' in output:
118
- spec_loss = self.loss_wts['cons_spec_wt'] * self.cons_loss(spec_enc, output['ind_spec'])
119
- loss+=spec_loss
120
- losses['cons_spec_loss'] = spec_loss.detach().item()
121
 
122
  losses['loss'] = loss
123
 
@@ -158,62 +115,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
158
  on_epoch=True,
159
  # on_step=True
160
  )
161
-
162
- # contr loss
163
- if self.hparams.use_contr:
164
- self.log(
165
- f'{stage.to_pref()}contr_loss',
166
- outputs['contr_loss'],
167
- batch_size=len(batch['identifier']),
168
- sync_dist=True,
169
- prog_bar=False,
170
- on_epoch=True,
171
- # on_step=True
172
- )
173
-
174
- # noncongruent pairs
175
- self.log(
176
- f'{stage.to_pref()}noncong_loss',
177
- outputs['noncong_loss'],
178
- batch_size=len(batch['identifier']),
179
- sync_dist=True,
180
- prog_bar=False,
181
- on_epoch=True,
182
- # on_step=True
183
- )
184
-
185
- # congruent pairs
186
- self.log(
187
- f'{stage.to_pref()}cong_loss',
188
- outputs['cong_loss'],
189
- batch_size=len(batch['identifier']),
190
- sync_dist=True,
191
- prog_bar=False,
192
- on_epoch=True,
193
- # on_step=True
194
- )
195
-
196
-
197
- if self.hparams.pred_fp:
198
-
199
- self.log(
200
- f'{stage.to_pref()}_fp_loss',
201
- outputs['fp_loss'],
202
- batch_size=len(batch['identifier']),
203
- sync_dist=True,
204
- prog_bar=False,
205
- on_epoch=True,
206
- )
207
-
208
- if self.hparams.use_cons_spec:
209
- self.log(
210
- f'{stage.to_pref()}cons_loss',
211
- outputs['cons_spec_loss'],
212
- batch_size=len(batch['identifier']),
213
- sync_dist=True,
214
- prog_bar=False,
215
- on_epoch=True,
216
- )
217
 
218
  def test_step(self, batch, batch_idx):
219
  # Unpack inputs
@@ -275,172 +176,160 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
275
  {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss
276
  ]
277
  return monitors
278
-
279
- # def _update_loss_weights(self)-> None:
280
- # if self.hparams.loss_strategy == 'linear':
281
- # for loss in self.loss_wts:
282
- # self.loss_wts[loss] += self.loss_updates[loss]
283
- # elif self.hparams.loss_strategy == 'manual':
284
- # for loss in self.loss_wts:
285
- # if self.current_epoch in self.loss_updates[loss]:
286
- # self.loss_wts[loss] = self.loss_updates[loss][self.current_epoch]
287
-
288
- # def on_train_epoch_end(self) -> None:
289
- # self._update_loss_weights()
290
-
291
- class MultiViewContrastive(ContrastiveModel):
292
 
293
- def __init__(self,
294
- **kwargs):
295
 
296
- super().__init__(**kwargs)
297
 
298
- # build fingerprint encoder model
299
- if self.hparams.use_fp:
300
- self.fp_enc_model = model_utils.get_fp_enc_model(self.hparams)
301
 
302
- # build NL encoder model
303
- if self.hparams.use_NL_spec:
304
- self.NL_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
305
 
306
- def forward(self, batch, stage):
307
- g = batch['cand'] if stage == Stage.TEST else batch['mol']
308
 
309
- spec = batch[self.spec_view]
310
- n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
311
 
312
- spec_enc = self.spec_enc_model(spec, n_peaks)
313
- mol_enc = self.mol_enc_model(g)
314
- views = {'spec_enc': spec_enc, 'mol_enc': mol_enc}
315
 
316
- if self.hparams.use_fp:
317
- fp_enc = self.fp_enc_model(batch['fp'])
318
- views['fp_enc'] = fp_enc
319
-
320
- if self.hparams.use_cons_spec:
321
- spec = batch['cons_spec']
322
- n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
323
- spec_enc = self.cons_spec_enc_model(spec, n_peaks)
324
- views['cons_spec_enc'] = spec_enc
325
-
326
- if self.hparams.use_NL_spec:
327
- spec = batch['NL_spec']
328
- n_peaks = batch['NL_n_peaks'] if 'NL_n_peaks' in batch else None
329
- spec_enc = self.NL_enc_model(spec, n_peaks)
330
- views['NL_spec_enc'] = spec_enc
331
- return views
332
 
333
- def step(
334
- self, batch: dict, stage= Stage.NONE):
335
 
336
- # Compute spectra and mol encoding
337
- views = self.forward(batch, stage)
338
 
339
- if stage == Stage.TEST:
340
- return views
341
 
342
- # Calculate loss
343
- losses = self.compute_loss(batch, views)
344
 
345
- return losses
346
 
347
- def compute_loss(self, batch: dict, views: dict):
348
- loss = 0
349
- losses = {}
350
- for v1, v2 in self.hparams.contr_views:
351
- contr_loss, cong_loss, noncong_loss = contrastive_loss(views[v1], views[v2], self.hparams.contr_temp)
352
- loss+=contr_loss
353
-
354
- losses[f'{v1[:-4]}-{v2[:-4]}_contr_loss'] = contr_loss.detach().item()
355
- losses[f'{v1[:-4]}-{v2[:-4]}_cong_loss'] = cong_loss.detach().item()
356
- losses[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'] = noncong_loss.detach().item()
357
 
358
- losses['loss'] = loss
359
 
360
- return losses
361
 
362
- def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
363
- # total loss
364
- self.log(
365
- f'{stage.to_pref()}loss',
366
- outputs['loss'],
367
- batch_size=len(batch['identifier']),
368
- sync_dist=True,
369
- prog_bar=True,
370
- on_epoch=True,
371
- # on_step=True
372
- )
373
-
374
- for v1, v2 in self.hparams.contr_views:
375
- self.log(
376
- f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_contr_loss',
377
- outputs[f'{v1[:-4]}-{v2[:-4]}_contr_loss'],
378
- batch_size=len(batch['identifier']),
379
- sync_dist=True,
380
- on_epoch=True,
381
- )
382
- self.log(
383
- f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_cong_loss',
384
- outputs[f'{v1[:-4]}-{v2[:-4]}_cong_loss'],
385
- batch_size=len(batch['identifier']),
386
- sync_dist=True,
387
- on_epoch=True,
388
- )
389
- self.log(
390
- f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_noncong_loss',
391
- outputs[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'],
392
- batch_size=len(batch['identifier']),
393
- sync_dist=True,
394
- on_epoch=True,
395
- )
396
 
397
- def test_step(self, batch):
398
- # Unpack inputs
399
- identifiers = batch['identifier']
400
- cand_smiles = batch['cand_smiles']
401
- id_to_ct = defaultdict(int)
402
- for i in identifiers: id_to_ct[i]+=1
403
- batch_ptr = torch.tensor(list(id_to_ct.values()))
404
-
405
- outputs = self.step(batch, stage=Stage.TEST)
406
- scores = {}
407
- for v1, v2 in self.hparams.contr_views:
408
- # if 'cons_spec_enc' in (v1, v2):
409
- # continue
410
- v1_enc = outputs[v1]
411
- v2_enc = outputs[v2]
412
 
413
- s = nn.functional.cosine_similarity(v1_enc, v2_enc)
414
- scores[f'{v1[:-4]}-{v2[:-4]}_scores'] = torch.split(s, list(id_to_ct.values()))
415
 
416
- indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
417
 
418
- cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
419
- labels = utils.unbatch_list(batch['label'], indexes)
420
 
421
- return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
422
 
423
- def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
424
 
425
- # save scores
426
- for i, cands, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels']):
427
- self.result_dct[i]['candidates'].extend(cands)
428
- self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
429
 
430
- for v1, v2 in self.hparams.contr_views:
431
- for i, scores in zip(outputs['identifiers'], outputs['scores'][f'{v1[:-4]}-{v2[:-4]}_scores']):
432
- self.result_dct[i][f'{v1[:-4]}-{v2[:-4]}_scores'].extend(scores.cpu().tolist())
433
 
434
 
435
- def on_test_epoch_end(self) -> None:
436
 
437
- self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
438
 
439
- # Compute rank
440
- for v1, v2 in self.hparams.contr_views:
441
- self.df_test[f'{v1[:-4]}-{v2[:-4]}_rank'] = self.df_test.apply(lambda row: self._compute_rank(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['labels']), axis=1)
442
 
443
- self.df_test.to_pickle(self.df_test_path)
444
 
445
  class FilipContrastive(ContrastiveModel):
446
  def __init__(self,
@@ -492,7 +381,7 @@ class FilipContrastive(ContrastiveModel):
492
  # Calculate scores
493
  indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
494
 
495
- scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
496
  scores = torch.split(scores, list(id_to_ct.values()))
497
 
498
  cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
@@ -500,248 +389,248 @@ class FilipContrastive(ContrastiveModel):
500
 
501
  return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
502
 
503
- class MultiViewFineTuning(MultiViewContrastive):
504
- def __init__(self,
505
- **kwargs):
506
- super().__init__(**kwargs)
507
-
508
- # load preptrained spec, mol, fp encoders
509
- checkpoint = torch.load(self.hparams.partial_checkpoint)
510
- state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
511
- self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
512
-
513
- state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
514
- self.mol_enc_model.load_state_dict(state_dict)
515
-
516
- state_dict = state_dict = {k[len("fp_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("fp_enc_model")}
517
- self.fp_enc_model.load_state_dict(state_dict)
518
-
519
- self.encoding_views = ['spec_enc', 'mol_enc', 'fp_enc']
520
- self.loss_fn = nn.BCELoss()
521
-
522
- # freeze encoders
523
- for param in self.mol_enc_model.parameters():
524
- param.requires_grad = False
525
- for param in self.spec_enc_model.parameters():
526
- param.requires_grad = False
527
- for param in self.fp_enc_model.parameters():
528
- param.requires_grad = False
529
- for param in self.cons_spec_enc_model.parameters():
530
- param.requires_grad = False
531
-
532
- # n_views = 2
533
- # if self.hparams.use_fp:
534
- # n_views+=1
535
-
536
- # in_dim = self.hparams.final_embedding_dim*n_views
537
- in_dim = self.hparams.final_embedding_dim *2 + 2
538
-
539
- self.classifier_model = nn.Sequential(
540
- nn.Linear(in_dim, 512),
541
- nn.ReLU(),
542
- nn.BatchNorm1d(512),
543
- nn.Dropout(0.3),
544
- nn.Linear(512, 256),
545
- nn.ReLU(),
546
- nn.BatchNorm1d(256),
547
- nn.Dropout(0.3),
548
- nn.Linear(256, 1),
549
- nn.Sigmoid()
550
- )
551
- self.noise_std = 0.01
552
-
553
- def _add_noise(self, x):
554
- noise = torch.randn_like(x) * self.noise_std
555
- return x + noise
556
-
557
- def forward(self, batch, stage):
558
-
559
- matching_views = super().forward(batch, stage)
560
- # matching_enc = torch.concat((matching_views['spec_enc'], matching_views['mol_enc'], matching_views['fp_enc']), dim=-1)
561
- # enc1 = matching_views['spec_enc'] - matching_views['mol_enc']
562
- # enc2 = matching_views['spec_enc'] - matching_views['fp_enc']
563
- # matching_enc = torch.concat((enc1, enc2), dim=-1)
564
- view1 = matching_views['spec_enc']
565
- view2 = matching_views['mol_enc']
566
- view3 = matching_views['fp_enc']
567
-
568
- if stage == Stage.TRAIN:
569
- view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
570
-
571
- pairwise_diffs = torch.cat([
572
- torch.abs(view1 - view2),
573
- torch.abs(view1 - view3),
574
- ], dim=-1)
575
-
576
- pairwise_sims = torch.cat([
577
- (view1 * view2).sum(dim=-1, keepdim=True),
578
- (view1 * view3).sum(dim=-1, keepdim=True),
579
- ], dim=-1)
580
-
581
- matching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
582
- matching_scores = self.classifier_model(matching_enc)
583
-
584
- if stage == Stage.TEST:
585
- return dict(matching_scores = matching_scores)
586
 
587
- view1 = view1.repeat_interleave(self.hparams.aug_cands_size, dim=0)
588
- view2 = self.mol_enc_model(batch['aug_cands'])
589
- view3= self.fp_enc_model(batch['aug_cands_fp'])
590
- if stage == Stage.TRAIN:
591
- view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
592
 
593
- pairwise_diffs = torch.cat([
594
- torch.abs(view1 - view2),
595
- torch.abs(view1 - view3),
596
- ], dim=-1)
597
 
598
- pairwise_sims = torch.cat([
599
- (view1 * view2).sum(dim=-1, keepdim=True),
600
- (view1 * view3).sum(dim=-1, keepdim=True),
601
- ], dim=-1)
602
 
603
- nonmatching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
604
 
605
- nonmatching_scores = self.classifier_model(nonmatching_enc)
606
 
607
- return dict(matching_scores=matching_scores, nonmatching_scores=nonmatching_scores)
608
 
609
- def compute_loss(self, matching_scores, nonmatching_scores):
610
 
611
- matching_loss = self.loss_fn(matching_scores, torch.ones_like(matching_scores).to(matching_scores.device))
612
- nonmatching_loss = self.loss_fn(nonmatching_scores, torch.zeros_like(nonmatching_scores).to(nonmatching_scores.device))
613
 
614
- loss = matching_loss + (1/self.hparams.aug_cands_size)*nonmatching_loss
615
 
616
- return dict(loss=loss)
617
 
618
- def step(
619
- self, batch: dict, stage= Stage.NONE):
620
 
621
- output = self.forward(batch, stage)
622
 
623
- if stage == Stage.TEST:
624
- return output
625
 
626
- # Calculate loss
627
- losses = self.compute_loss(output['matching_scores'], output['nonmatching_scores'])
628
 
629
- return losses
630
 
631
- def test_step(self, batch):
632
- # Unpack inputs
633
- identifiers = batch['identifier']
634
- cand_smiles = batch['cand_smiles']
635
- id_to_ct = defaultdict(int)
636
- for i in identifiers: id_to_ct[i]+=1
637
- batch_ptr = torch.tensor(list(id_to_ct.values()))
638
 
639
- outputs = self.step(batch, stage=Stage.TEST)
640
- scores = outputs['matching_scores']
641
 
642
- indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
643
 
644
- cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
645
- labels = utils.unbatch_list(batch['label'], indexes)
646
 
647
- return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
648
 
649
- def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
650
- # total loss
651
- self.log(
652
- f'{stage.to_pref()}loss',
653
- outputs['loss'],
654
- batch_size=len(batch['identifier']),
655
- sync_dist=True,
656
- prog_bar=True,
657
- on_epoch=True,
658
- # on_step=True
659
- )
660
-
661
- def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
662
- ContrastiveModel.on_test_batch_end(self, outputs, batch, batch_idx, stage)
663
-
664
- def on_test_epoch_end(self):
665
- self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
666
- # self.df_test.to_csv(self.hparams.resutl)
667
- print(self.df_test_path)
668
- self.df_test.to_pickle(self.df_test_path)
669
- # ContrastiveModel.on_test_epoch_end(self)
670
-
671
- def get_checkpoint_monitors(self) -> T.List[dict]:
672
- monitors = [
673
- {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
674
- ]
675
- return monitors
676
- def configure_optimizers(self):
677
- return torch.optim.Adam(
678
- self.classifier_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
679
- )
680
-
681
- class IndSpecEncoder(ContrastiveModel):
682
- """ Trains a spectra encoder that maps to a pretrained spec encoder"""
683
- def __init__(
684
- self,
685
- **kwargs
686
- ):
687
- super().__init__(**kwargs)
688
-
689
- # initialize ind_spec_encoder and loss
690
- self.ind_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
691
- self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
692
-
693
- # load preptrained spec and mol encoders
694
- checkpoint = torch.load(self.hparams.partial_checkpoint)
695
- state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
696
- self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
697
-
698
- state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
699
- self.mol_enc_model.load_state_dict(state_dict)
700
-
701
- # freeze cons spec and mol encoders
702
- for param in self.mol_enc_model.parameters():
703
- param.requires_grad = False
704
- for param in self.spec_enc_model.parameters():
705
- param.requires_grad = False
706
-
707
- def forward(self, batch, stage):
708
-
709
- spec = batch[self.spec_view]
710
- n_peaks = batch['n_peaks']
711
- spec_enc = self.ind_spec_enc_model(spec, n_peaks)
712
-
713
- return spec_enc
714
 
715
- def compute_loss(self, spec_enc, cons_spec_enc):
716
- loss = self.cons_loss(spec_enc, cons_spec_enc)
717
- return dict(loss=loss)
718
 
719
- def step(self, batch: dict, stage=Stage.NONE):
720
- self.spec_enc_model.eval()
721
- self.mol_enc_model.eval()
722
 
723
- spec_enc = self.forward(batch, stage)
724
 
725
- if stage == Stage.TEST:
726
- mol_enc = self.mol_enc_model(batch['cand'])
727
- return dict(spec_enc=spec_enc, mol_enc=mol_enc)
728
 
729
- cons_spec_enc = self.spec_enc_model(batch['cons_spec'], batch['cons_n_peaks'])
730
 
731
- losses = self.compute_loss(spec_enc, cons_spec_enc)
732
 
733
- return losses
734
 
735
 
736
- def configure_optimizers(self):
737
- return torch.optim.Adam(
738
- self.ind_spec_enc_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
739
- )
740
- def get_checkpoint_monitors(self) -> T.List[dict]:
741
- monitors = [
742
- {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
743
- ]
744
- return monitors
745
 
746
  class CrossAttenContrastive(ContrastiveModel):
747
  def __init__(
 
10
  from massspecgym import utils
11
  from torch.nn.utils.rnn import pad_sequence
12
 
13
+ from flare.utils.loss import contrastive_loss, cand_spec_sim_loss, fp_loss, cons_spec_loss, filip_loss_with_mask
14
+ import flare.utils.models as model_utils
15
+ from flare.utils.general import pad_graph_nodes, filip_similarity_batch
16
 
17
+ from flare.models.encoders import CrossAttention
18
  import torch.nn.functional as F
19
 
20
  from torch_geometric.nn import global_mean_pool
 
32
  if 'use_NL_spec' not in self.hparams:
33
  self.hparams.use_NL_spec = False
34
 
 
 
 
 
35
 
36
  self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
37
  self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
38
+
 
 
 
 
 
 
 
 
 
 
39
  self.spec_view = self.hparams.spectra_view
40
 
41
  # result storage for testing results
42
  self.result_dct = defaultdict(lambda: defaultdict(list))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def forward(self, batch, stage):
45
  g = batch['cand'] if stage == Stage.TEST else batch['mol']
46
 
47
+ spec = batch[self.spec_view]
48
+ n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
49
+ spec_enc = self.spec_enc_model(spec, n_peaks)
 
 
 
 
 
50
 
51
  fp = batch['fp'] if self.hparams.use_fp else None
52
  mol_enc = self.mol_enc_model(g, fp=fp)
 
57
  loss = 0
58
  losses = {}
59
  contr_loss, _, _ = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp)
60
+
61
  losses['contr_loss'] = contr_loss.detach().item()
62
+
 
 
63
  loss+=contr_loss
64
+ # if self.hparams.pred_fp:
65
+ # fp_loss_val = self.loss_wts['fp_wt'] *self.fp_loss(output['fp'], batch['fp'])
66
+ # loss+= fp_loss_val
67
+ # losses['fp_loss'] = fp_loss_val.detach().item()
68
+
69
+ # if 'aug_cand_enc' in output:
70
+ # aug_cand_loss = self.loss_wts['aug_cand_wt'] * cand_spec_sim_loss(spec_enc, output['aug_cand_enc'])
71
+ # loss+= aug_cand_loss
72
+ # losses['aug_cand_loss'] = aug_cand_loss.detach().item()
73
 
74
+ # if 'ind_spec' in output:
75
+ # spec_loss = self.loss_wts['cons_spec_wt'] * self.cons_loss(spec_enc, output['ind_spec'])
76
+ # loss+=spec_loss
77
+ # losses['cons_spec_loss'] = spec_loss.detach().item()
78
 
79
  losses['loss'] = loss
80
 
 
115
  on_epoch=True,
116
  # on_step=True
117
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def test_step(self, batch, batch_idx):
120
  # Unpack inputs
 
176
  {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss
177
  ]
178
  return monitors
179
+
180
+ # class MultiViewContrastive(ContrastiveModel):
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ # def __init__(self,
183
+ # **kwargs):
184
 
185
+ # super().__init__(**kwargs)
186
 
187
+ # # build fingerprint encoder model
188
+ # if self.hparams.use_fp:
189
+ # self.fp_enc_model = model_utils.get_fp_enc_model(self.hparams)
190
 
191
+ # # build NL encoder model
192
+ # if self.hparams.use_NL_spec:
193
+ # self.NL_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
194
 
195
+ # def forward(self, batch, stage):
196
+ # g = batch['cand'] if stage == Stage.TEST else batch['mol']
197
 
198
+ # spec = batch[self.spec_view]
199
+ # n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
200
 
201
+ # spec_enc = self.spec_enc_model(spec, n_peaks)
202
+ # mol_enc = self.mol_enc_model(g)
203
+ # views = {'spec_enc': spec_enc, 'mol_enc': mol_enc}
204
 
205
+ # if self.hparams.use_fp:
206
+ # fp_enc = self.fp_enc_model(batch['fp'])
207
+ # views['fp_enc'] = fp_enc
208
+
209
+ # if self.hparams.use_cons_spec:
210
+ # spec = batch['cons_spec']
211
+ # n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
212
+ # spec_enc = self.cons_spec_enc_model(spec, n_peaks)
213
+ # views['cons_spec_enc'] = spec_enc
214
+
215
+ # if self.hparams.use_NL_spec:
216
+ # spec = batch['NL_spec']
217
+ # n_peaks = batch['NL_n_peaks'] if 'NL_n_peaks' in batch else None
218
+ # spec_enc = self.NL_enc_model(spec, n_peaks)
219
+ # views['NL_spec_enc'] = spec_enc
220
+ # return views
221
 
222
+ # def step(
223
+ # self, batch: dict, stage= Stage.NONE):
224
 
225
+ # # Compute spectra and mol encoding
226
+ # views = self.forward(batch, stage)
227
 
228
+ # if stage == Stage.TEST:
229
+ # return views
230
 
231
+ # # Calculate loss
232
+ # losses = self.compute_loss(batch, views)
233
 
234
+ # return losses
235
 
236
+ # def compute_loss(self, batch: dict, views: dict):
237
+ # loss = 0
238
+ # losses = {}
239
+ # for v1, v2 in self.hparams.contr_views:
240
+ # contr_loss, cong_loss, noncong_loss = contrastive_loss(views[v1], views[v2], self.hparams.contr_temp)
241
+ # loss+=contr_loss
242
+
243
+ # losses[f'{v1[:-4]}-{v2[:-4]}_contr_loss'] = contr_loss.detach().item()
244
+ # losses[f'{v1[:-4]}-{v2[:-4]}_cong_loss'] = cong_loss.detach().item()
245
+ # losses[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'] = noncong_loss.detach().item()
246
 
247
+ # losses['loss'] = loss
248
 
249
+ # return losses
250
 
251
+ # def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
252
+ # # total loss
253
+ # self.log(
254
+ # f'{stage.to_pref()}loss',
255
+ # outputs['loss'],
256
+ # batch_size=len(batch['identifier']),
257
+ # sync_dist=True,
258
+ # prog_bar=True,
259
+ # on_epoch=True,
260
+ # # on_step=True
261
+ # )
262
+
263
+ # for v1, v2 in self.hparams.contr_views:
264
+ # self.log(
265
+ # f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_contr_loss',
266
+ # outputs[f'{v1[:-4]}-{v2[:-4]}_contr_loss'],
267
+ # batch_size=len(batch['identifier']),
268
+ # sync_dist=True,
269
+ # on_epoch=True,
270
+ # )
271
+ # self.log(
272
+ # f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_cong_loss',
273
+ # outputs[f'{v1[:-4]}-{v2[:-4]}_cong_loss'],
274
+ # batch_size=len(batch['identifier']),
275
+ # sync_dist=True,
276
+ # on_epoch=True,
277
+ # )
278
+ # self.log(
279
+ # f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_noncong_loss',
280
+ # outputs[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'],
281
+ # batch_size=len(batch['identifier']),
282
+ # sync_dist=True,
283
+ # on_epoch=True,
284
+ # )
285
 
286
+ # def test_step(self, batch):
287
+ # # Unpack inputs
288
+ # identifiers = batch['identifier']
289
+ # cand_smiles = batch['cand_smiles']
290
+ # id_to_ct = defaultdict(int)
291
+ # for i in identifiers: id_to_ct[i]+=1
292
+ # batch_ptr = torch.tensor(list(id_to_ct.values()))
293
+
294
+ # outputs = self.step(batch, stage=Stage.TEST)
295
+ # scores = {}
296
+ # for v1, v2 in self.hparams.contr_views:
297
+ # # if 'cons_spec_enc' in (v1, v2):
298
+ # # continue
299
+ # v1_enc = outputs[v1]
300
+ # v2_enc = outputs[v2]
301
 
302
+ # s = nn.functional.cosine_similarity(v1_enc, v2_enc)
303
+ # scores[f'{v1[:-4]}-{v2[:-4]}_scores'] = torch.split(s, list(id_to_ct.values()))
304
 
305
+ # indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
306
 
307
+ # cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
308
+ # labels = utils.unbatch_list(batch['label'], indexes)
309
 
310
+ # return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
311
 
312
+ # def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
313
 
314
+ # # save scores
315
+ # for i, cands, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels']):
316
+ # self.result_dct[i]['candidates'].extend(cands)
317
+ # self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
318
 
319
+ # for v1, v2 in self.hparams.contr_views:
320
+ # for i, scores in zip(outputs['identifiers'], outputs['scores'][f'{v1[:-4]}-{v2[:-4]}_scores']):
321
+ # self.result_dct[i][f'{v1[:-4]}-{v2[:-4]}_scores'].extend(scores.cpu().tolist())
322
 
323
 
324
+ # def on_test_epoch_end(self) -> None:
325
 
326
+ # self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
327
 
328
+ # # Compute rank
329
+ # for v1, v2 in self.hparams.contr_views:
330
+ # self.df_test[f'{v1[:-4]}-{v2[:-4]}_rank'] = self.df_test.apply(lambda row: self._compute_rank(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['labels']), axis=1)
331
 
332
+ # self.df_test.to_pickle(self.df_test_path)
333
 
334
  class FilipContrastive(ContrastiveModel):
335
  def __init__(self,
 
381
  # Calculate scores
382
  indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
383
 
384
+ scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask, reduction='geom', temperature=0.05)
385
  scores = torch.split(scores, list(id_to_ct.values()))
386
 
387
  cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
 
389
 
390
  return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
391
 
392
+ # class MultiViewFineTuning(MultiViewContrastive):
393
+ # def __init__(self,
394
+ # **kwargs):
395
+ # super().__init__(**kwargs)
396
+
397
+ # # load preptrained spec, mol, fp encoders
398
+ # checkpoint = torch.load(self.hparams.partial_checkpoint)
399
+ # state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
400
+ # self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
401
+
402
+ # state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
403
+ # self.mol_enc_model.load_state_dict(state_dict)
404
+
405
+ # state_dict = state_dict = {k[len("fp_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("fp_enc_model")}
406
+ # self.fp_enc_model.load_state_dict(state_dict)
407
+
408
+ # self.encoding_views = ['spec_enc', 'mol_enc', 'fp_enc']
409
+ # self.loss_fn = nn.BCELoss()
410
+
411
+ # # freeze encoders
412
+ # for param in self.mol_enc_model.parameters():
413
+ # param.requires_grad = False
414
+ # for param in self.spec_enc_model.parameters():
415
+ # param.requires_grad = False
416
+ # for param in self.fp_enc_model.parameters():
417
+ # param.requires_grad = False
418
+ # for param in self.cons_spec_enc_model.parameters():
419
+ # param.requires_grad = False
420
+
421
+ # # n_views = 2
422
+ # # if self.hparams.use_fp:
423
+ # # n_views+=1
424
+
425
+ # # in_dim = self.hparams.final_embedding_dim*n_views
426
+ # in_dim = self.hparams.final_embedding_dim *2 + 2
427
+
428
+ # self.classifier_model = nn.Sequential(
429
+ # nn.Linear(in_dim, 512),
430
+ # nn.ReLU(),
431
+ # nn.BatchNorm1d(512),
432
+ # nn.Dropout(0.3),
433
+ # nn.Linear(512, 256),
434
+ # nn.ReLU(),
435
+ # nn.BatchNorm1d(256),
436
+ # nn.Dropout(0.3),
437
+ # nn.Linear(256, 1),
438
+ # nn.Sigmoid()
439
+ # )
440
+ # self.noise_std = 0.01
441
+
442
+ # def _add_noise(self, x):
443
+ # noise = torch.randn_like(x) * self.noise_std
444
+ # return x + noise
445
+
446
+ # def forward(self, batch, stage):
447
+
448
+ # matching_views = super().forward(batch, stage)
449
+ # # matching_enc = torch.concat((matching_views['spec_enc'], matching_views['mol_enc'], matching_views['fp_enc']), dim=-1)
450
+ # # enc1 = matching_views['spec_enc'] - matching_views['mol_enc']
451
+ # # enc2 = matching_views['spec_enc'] - matching_views['fp_enc']
452
+ # # matching_enc = torch.concat((enc1, enc2), dim=-1)
453
+ # view1 = matching_views['spec_enc']
454
+ # view2 = matching_views['mol_enc']
455
+ # view3 = matching_views['fp_enc']
456
+
457
+ # if stage == Stage.TRAIN:
458
+ # view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
459
+
460
+ # pairwise_diffs = torch.cat([
461
+ # torch.abs(view1 - view2),
462
+ # torch.abs(view1 - view3),
463
+ # ], dim=-1)
464
+
465
+ # pairwise_sims = torch.cat([
466
+ # (view1 * view2).sum(dim=-1, keepdim=True),
467
+ # (view1 * view3).sum(dim=-1, keepdim=True),
468
+ # ], dim=-1)
469
+
470
+ # matching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
471
+ # matching_scores = self.classifier_model(matching_enc)
472
+
473
+ # if stage == Stage.TEST:
474
+ # return dict(matching_scores = matching_scores)
475
 
476
+ # view1 = view1.repeat_interleave(self.hparams.aug_cands_size, dim=0)
477
+ # view2 = self.mol_enc_model(batch['aug_cands'])
478
+ # view3= self.fp_enc_model(batch['aug_cands_fp'])
479
+ # if stage == Stage.TRAIN:
480
+ # view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
481
 
482
+ # pairwise_diffs = torch.cat([
483
+ # torch.abs(view1 - view2),
484
+ # torch.abs(view1 - view3),
485
+ # ], dim=-1)
486
 
487
+ # pairwise_sims = torch.cat([
488
+ # (view1 * view2).sum(dim=-1, keepdim=True),
489
+ # (view1 * view3).sum(dim=-1, keepdim=True),
490
+ # ], dim=-1)
491
 
492
+ # nonmatching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
493
 
494
+ # nonmatching_scores = self.classifier_model(nonmatching_enc)
495
 
496
+ # return dict(matching_scores=matching_scores, nonmatching_scores=nonmatching_scores)
497
 
498
+ # def compute_loss(self, matching_scores, nonmatching_scores):
499
 
500
+ # matching_loss = self.loss_fn(matching_scores, torch.ones_like(matching_scores).to(matching_scores.device))
501
+ # nonmatching_loss = self.loss_fn(nonmatching_scores, torch.zeros_like(nonmatching_scores).to(nonmatching_scores.device))
502
 
503
+ # loss = matching_loss + (1/self.hparams.aug_cands_size)*nonmatching_loss
504
 
505
+ # return dict(loss=loss)
506
 
507
+ # def step(
508
+ # self, batch: dict, stage= Stage.NONE):
509
 
510
+ # output = self.forward(batch, stage)
511
 
512
+ # if stage == Stage.TEST:
513
+ # return output
514
 
515
+ # # Calculate loss
516
+ # losses = self.compute_loss(output['matching_scores'], output['nonmatching_scores'])
517
 
518
+ # return losses
519
 
520
+ # def test_step(self, batch):
521
+ # # Unpack inputs
522
+ # identifiers = batch['identifier']
523
+ # cand_smiles = batch['cand_smiles']
524
+ # id_to_ct = defaultdict(int)
525
+ # for i in identifiers: id_to_ct[i]+=1
526
+ # batch_ptr = torch.tensor(list(id_to_ct.values()))
527
 
528
+ # outputs = self.step(batch, stage=Stage.TEST)
529
+ # scores = outputs['matching_scores']
530
 
531
+ # indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
532
 
533
+ # cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
534
+ # labels = utils.unbatch_list(batch['label'], indexes)
535
 
536
+ # return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
537
 
538
+ # def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
539
+ # # total loss
540
+ # self.log(
541
+ # f'{stage.to_pref()}loss',
542
+ # outputs['loss'],
543
+ # batch_size=len(batch['identifier']),
544
+ # sync_dist=True,
545
+ # prog_bar=True,
546
+ # on_epoch=True,
547
+ # # on_step=True
548
+ # )
549
+
550
+ # def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
551
+ # ContrastiveModel.on_test_batch_end(self, outputs, batch, batch_idx, stage)
552
+
553
+ # def on_test_epoch_end(self):
554
+ # self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
555
+ # # self.df_test.to_csv(self.hparams.resutl)
556
+ # print(self.df_test_path)
557
+ # self.df_test.to_pickle(self.df_test_path)
558
+ # # ContrastiveModel.on_test_epoch_end(self)
559
+
560
+ # def get_checkpoint_monitors(self) -> T.List[dict]:
561
+ # monitors = [
562
+ # {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
563
+ # ]
564
+ # return monitors
565
+ # def configure_optimizers(self):
566
+ # return torch.optim.Adam(
567
+ # self.classifier_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
568
+ # )
569
+
570
+ # class IndSpecEncoder(ContrastiveModel):
571
+ # """ Trains a spectra encoder that maps to a pretrained spec encoder"""
572
+ # def __init__(
573
+ # self,
574
+ # **kwargs
575
+ # ):
576
+ # super().__init__(**kwargs)
577
+
578
+ # # initialize ind_spec_encoder and loss
579
+ # self.ind_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
580
+ # self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
581
+
582
+ # # load preptrained spec and mol encoders
583
+ # checkpoint = torch.load(self.hparams.partial_checkpoint)
584
+ # state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
585
+ # self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
586
+
587
+ # state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
588
+ # self.mol_enc_model.load_state_dict(state_dict)
589
+
590
+ # # freeze cons spec and mol encoders
591
+ # for param in self.mol_enc_model.parameters():
592
+ # param.requires_grad = False
593
+ # for param in self.spec_enc_model.parameters():
594
+ # param.requires_grad = False
595
+
596
+ # def forward(self, batch, stage):
597
+
598
+ # spec = batch[self.spec_view]
599
+ # n_peaks = batch['n_peaks']
600
+ # spec_enc = self.ind_spec_enc_model(spec, n_peaks)
601
+
602
+ # return spec_enc
603
 
604
+ # def compute_loss(self, spec_enc, cons_spec_enc):
605
+ # loss = self.cons_loss(spec_enc, cons_spec_enc)
606
+ # return dict(loss=loss)
607
 
608
+ # def step(self, batch: dict, stage=Stage.NONE):
609
+ # self.spec_enc_model.eval()
610
+ # self.mol_enc_model.eval()
611
 
612
+ # spec_enc = self.forward(batch, stage)
613
 
614
+ # if stage == Stage.TEST:
615
+ # mol_enc = self.mol_enc_model(batch['cand'])
616
+ # return dict(spec_enc=spec_enc, mol_enc=mol_enc)
617
 
618
+ # cons_spec_enc = self.spec_enc_model(batch['cons_spec'], batch['cons_n_peaks'])
619
 
620
+ # losses = self.compute_loss(spec_enc, cons_spec_enc)
621
 
622
+ # return losses
623
 
624
 
625
+ # def configure_optimizers(self):
626
+ # return torch.optim.Adam(
627
+ # self.ind_spec_enc_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
628
+ # )
629
+ # def get_checkpoint_monitors(self) -> T.List[dict]:
630
+ # monitors = [
631
+ # {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
632
+ # ]
633
+ # return monitors
634
 
635
  class CrossAttenContrastive(ContrastiveModel):
636
  def __init__(
{mvp β†’ flare}/models/encoders.py RENAMED
File without changes
{mvp β†’ flare}/models/mol_encoder.py RENAMED
File without changes
{mvp β†’ flare}/models/spec_encoder.py RENAMED
@@ -1,6 +1,6 @@
1
  import torch.nn as nn
2
  import torch
3
- from mvp.models.encoders import MLP
4
  from torch_geometric.nn import global_mean_pool
5
 
6
 
@@ -41,14 +41,14 @@ class SpecMzIntTokenTransformer(nn.Module):
41
  if args.model in ('crossAttenContrastive', 'filipContrastive'):
42
  self.returnEmb = True
43
  assert(args.use_cls == False)
 
 
44
 
45
  self.use_cls = args.use_cls
46
  if self.use_cls:
47
  self.cls_embed = torch.nn.Embedding(1,args.hidden_dims[-1])
48
  encoder_layer = nn.TransformerEncoderLayer(d_model=args.hidden_dims[-1], nhead=2, batch_first=True)
49
  self.tokenTransformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
50
-
51
- self.specEncoder = nn.Sequential(nn.Linear(args.hidden_dims[-1], args.final_embedding_dim), nn.Dropout(args.fc_dropout))
52
 
53
  def forward(self, spec, n_peaks=None):
54
  h = self.tokenEnc(spec)
@@ -61,11 +61,10 @@ class SpecMzIntTokenTransformer(nn.Module):
61
  pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1)
62
  h = self.tokenTransformer(h, src_key_padding_mask=pad)
63
  h = h[:,0,:]
64
- else:
65
 
 
66
  # mean
67
  h = self.tokenTransformer(h, src_key_padding_mask=pad)
68
-
69
  if self.returnEmb:
70
  # repad h
71
  h[pad] = -5
@@ -123,11 +122,10 @@ class SpecFormulaTransformer(nn.Module):
123
  self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1])
124
  encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=args.formula_attn_heads, batch_first=True)
125
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=args.formula_transformer_layers)
126
-
127
- if not out_dim:
128
- out_dim = args.final_embedding_dim
129
-
130
  if not self.returnEmb:
 
 
131
  self.fc = nn.Linear(args.formula_dims[-1], out_dim)
132
 
133
  def forward(self, spec, n_peaks):
 
1
  import torch.nn as nn
2
  import torch
3
+ from flare.models.encoders import MLP
4
  from torch_geometric.nn import global_mean_pool
5
 
6
 
 
41
  if args.model in ('crossAttenContrastive', 'filipContrastive'):
42
  self.returnEmb = True
43
  assert(args.use_cls == False)
44
+ else:
45
+ self.specEncoder = nn.Sequential(nn.Linear(args.hidden_dims[-1], args.final_embedding_dim), nn.Dropout(args.fc_dropout))
46
 
47
  self.use_cls = args.use_cls
48
  if self.use_cls:
49
  self.cls_embed = torch.nn.Embedding(1,args.hidden_dims[-1])
50
  encoder_layer = nn.TransformerEncoderLayer(d_model=args.hidden_dims[-1], nhead=2, batch_first=True)
51
  self.tokenTransformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
 
 
52
 
53
  def forward(self, spec, n_peaks=None):
54
  h = self.tokenEnc(spec)
 
61
  pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1)
62
  h = self.tokenTransformer(h, src_key_padding_mask=pad)
63
  h = h[:,0,:]
 
64
 
65
+ else:
66
  # mean
67
  h = self.tokenTransformer(h, src_key_padding_mask=pad)
 
68
  if self.returnEmb:
69
  # repad h
70
  h[pad] = -5
 
122
  self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1])
123
  encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=args.formula_attn_heads, batch_first=True)
124
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=args.formula_transformer_layers)
125
+
 
 
 
126
  if not self.returnEmb:
127
+ if not out_dim:
128
+ out_dim = args.final_embedding_dim
129
  self.fc = nn.Linear(args.formula_dims[-1], out_dim)
130
 
131
  def forward(self, spec, n_peaks):
{mvp β†’ flare}/params_binnedSpec.yaml RENAMED
File without changes
{mvp β†’ flare}/params_formSpec.yaml RENAMED
@@ -1,11 +1,11 @@
1
  # Experiment setup
2
  job_key: ''
3
- run_name: 'optimized_flare'
4
  run_details: ""
5
  project_name: ''
6
  wandb_entity_name: 'mass-spec-ml'
7
  no_wandb: True
8
- seed: 0
9
  debug: False
10
  checkpoint_pth:
11
 
@@ -19,27 +19,23 @@ val_check_interval: 1.0
19
  # Data paths
20
  candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
  dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
- subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
  split_pth:
24
  fp_dir_pth:
25
- cons_spec_dir_pth:
26
- NL_spec_dir_pth: ""
27
  partial_checkpoint: ""
28
 
29
  # General hyperparameters
30
- batch_size: 32 #64
31
- lr: 7.092216555765765e-05 #2.881339661302105e-05 # 5.0e-05
32
  weight_decay: 1.8376229667330708e-05
33
- contr_temp: 0.043339030104611806 # 0.022772534845886608 # 0.05
34
- early_stopping_patience: 300
35
- loss_strategy: 'static'
36
  num_workers: 50
37
 
38
 
39
  ############################## Data transforms ##############################
40
  # - Spectra
41
  spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
42
- formula_source: 'default' # magma_1, magma_all, sirius, default
43
  # 1. Binner
44
  max_mz: 1000
45
  bin_width: 1
@@ -48,7 +44,6 @@ mask_peak_ratio: 0.00
48
  # 2. SpecFormula
49
  element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
50
  add_intensities: True
51
- mask_precursor: False
52
 
53
  # - Molecule
54
  molecule_view: "MolGraph"
@@ -58,34 +53,34 @@ bond_feature: 'full'
58
 
59
  ############################## Views ##############################
60
  # contrastive
61
- use_contr: False
62
- contr_wt: 1
63
- contr_wt_update: {}
64
 
65
  # consensus spectra
66
- use_cons_spec: False
67
- cons_spec_wt: 3
68
- cons_spec_wt_update: {}
69
- cons_loss_type: 'l2' # cosine, l2
70
 
71
  # fp prediction/usage
72
- pred_fp: False
73
- use_fp: False
74
- fp_loss_type: 'cosine' #cosine, bce
75
- fp_wt: 3
76
- fp_wt_update: {}
77
- fp_size: 1024
78
- fp_radius: 5
79
- fp_dropout: 0.4
80
 
81
  # candidates
82
- aug_cands: False
83
- aug_cands_wt: 0.1
84
- aug_cands_update: {}
85
- aug_cands_size: 3
86
 
87
  # neutral loss
88
- use_NL: False
89
 
90
 
91
  ############################## Task and model ##############################
@@ -93,33 +88,34 @@ task: 'retrieval'
93
  spec_enc: Transformer_Formula # Transformer_MzInt #Transformer_Formula
94
  mol_enc: "GNN"
95
  model: filipContrastive # "MultiviewContrastive"
96
- contr_views: [['spec_enc', 'mol_enc']] #[['spec_enc', 'mol_enc'], ['spec_enc', 'NL_spec_enc'], ['mol_enc', 'NL_spec_enc']] #[['spec_enc', 'mol_enc'], ['mol_enc', 'cons_spec_enc'], ['cons_spec_enc', 'spec_enc'], ['fp_enc', 'mol_enc'], ['fp_enc', 'spec_enc'], ['fp_enc', 'cons_spec_enc']]
97
  log_only_loss_at_stages: []
98
  df_test_path: ""
99
 
100
- # - Spectra encoder
101
- final_embedding_dim: 512
102
- fc_dropout: 0.4
103
-
104
- # - Spectra Token encoder
105
- hidden_dims: [64, 128]
106
-
107
 
108
  # - Formula-based spec encoders
109
  formula_dropout: 0.2
110
- formula_dims: [256,512,256] #[512, 256, 512] #[64, 128, 256]
111
  cross_attn_heads: 2
112
  use_cls: False
113
  peak_dropout: 0.2
114
  formula_attn_heads: 4 # 2
115
- formula_transformer_layers: 1 #2
116
 
117
  # -- GAT params
118
  attn_heads: [12,12,12]
119
 
120
  # - Molecule encoder (GNN)
121
- gnn_channels: [128, 256, 256] #[64,128,512]
122
  gnn_type: "gcn"
123
- num_gnn_layers: 3
124
- gnn_hidden_dim: 512
125
- gnn_dropout: 0.157104273477570 #0.23234950970370824 #0.3
 
 
 
 
 
 
 
 
 
1
  # Experiment setup
2
  job_key: ''
3
+ run_name: 'flare_sirius_labels_42'
4
  run_details: ""
5
  project_name: ''
6
  wandb_entity_name: 'mass-spec-ml'
7
  no_wandb: True
8
+ seed: 42
9
  debug: False
10
  checkpoint_pth:
11
 
 
19
  # Data paths
20
  candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
  dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
+ subformula_dir_pth: /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
  split_pth:
24
  fp_dir_pth:
 
 
25
  partial_checkpoint: ""
26
 
27
  # General hyperparameters
28
+ batch_size: 64 #64
29
+ lr: 2.881339661302105e-05 # 5.0e-05
30
  weight_decay: 1.8376229667330708e-05
31
+ contr_temp: 0.022772534845886608 # 0.022772534845886608 # 0.05
 
 
32
  num_workers: 50
33
 
34
 
35
  ############################## Data transforms ##############################
36
  # - Spectra
37
  spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
38
+ formula_source: 'sirius' # magma_1, magma_all, sirius, default
39
  # 1. Binner
40
  max_mz: 1000
41
  bin_width: 1
 
44
  # 2. SpecFormula
45
  element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
46
  add_intensities: True
 
47
 
48
  # - Molecule
49
  molecule_view: "MolGraph"
 
53
 
54
  ############################## Views ##############################
55
  # contrastive
56
+ # use_contr: False
57
+ # contr_wt: 1
58
+ # contr_wt_update: {}
59
 
60
  # consensus spectra
61
+ # use_cons_spec: False
62
+ # cons_spec_wt: 3
63
+ # cons_spec_wt_update: {}
64
+ # cons_loss_type: 'l2' # cosine, l2
65
 
66
  # fp prediction/usage
67
+ # pred_fp: False
68
+ # use_fp: False
69
+ # fp_loss_type: 'cosine' #cosine, bce
70
+ # fp_wt: 3
71
+ # fp_wt_update: {}
72
+ # fp_size: 1024
73
+ # fp_radius: 5
74
+ # fp_dropout: 0.4
75
 
76
  # candidates
77
+ # aug_cands: False
78
+ # aug_cands_wt: 0.1
79
+ # aug_cands_update: {}
80
+ # aug_cands_size: 3
81
 
82
  # neutral loss
83
+ # use_NL: False
84
 
85
 
86
  ############################## Task and model ##############################
 
88
  spec_enc: Transformer_Formula # Transformer_MzInt #Transformer_Formula
89
  mol_enc: "GNN"
90
  model: filipContrastive # "MultiviewContrastive"
91
+ contr_views: [['spec_enc', 'mol_enc']]
92
  log_only_loss_at_stages: []
93
  df_test_path: ""
94
 
 
 
 
 
 
 
 
95
 
96
  # - Formula-based spec encoders
97
  formula_dropout: 0.2
98
+ formula_dims: [512,256,512] #[512, 256, 512] #[64, 128, 256]
99
  cross_attn_heads: 2
100
  use_cls: False
101
  peak_dropout: 0.2
102
  formula_attn_heads: 4 # 2
103
+ formula_transformer_layers: 2 #2
104
 
105
  # -- GAT params
106
  attn_heads: [12,12,12]
107
 
108
  # - Molecule encoder (GNN)
109
+ gnn_channels: [128, 256, 512] #[64,128,512]
110
  gnn_type: "gcn"
111
+ # num_gnn_layers: 3
112
+ # gnn_hidden_dim: 512
113
+ gnn_dropout: 0.23234950970370824 #0.3
114
+
115
+
116
+ # - Spectra encoder (cross attention model)
117
+ # final_embedding_dim: 512
118
+ # fc_dropout: 0.4
119
+
120
+ # - Spectra Token encoder (mz-int token model)
121
+ # hidden_dims: [64, 256]
{mvp β†’ flare}/params_jestr.yaml RENAMED
File without changes
{mvp β†’ flare}/params_tmp.yaml RENAMED
File without changes
flare/run.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # python train.py
2
+ python test.py --param_pth ../hparams.yaml
3
+ # python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
{mvp β†’ flare}/subformula_assign/__init__.py RENAMED
File without changes
{mvp β†’ flare}/subformula_assign/assign_subformulae.py RENAMED
File without changes
{mvp β†’ flare}/subformula_assign/run.sh RENAMED
File without changes
{mvp β†’ flare}/subformula_assign/utils/__init__.py RENAMED
File without changes
{mvp β†’ flare}/subformula_assign/utils/chem_utils.py RENAMED
File without changes
{mvp β†’ flare}/subformula_assign/utils/parallel_utils.py RENAMED
File without changes
{mvp β†’ flare}/subformula_assign/utils/parse_utils.py RENAMED
File without changes
{mvp β†’ flare}/subformula_assign/utils/spectra_utils.py RENAMED
File without changes
{mvp β†’ flare}/test.py RENAMED
@@ -10,12 +10,12 @@ from pytorch_lightning import Trainer
10
  from massspecgym.models.base import Stage
11
  import os
12
 
13
- from mvp.data.data_module import TestDataModule
14
- from mvp.data.datasets import ContrastiveDataset
15
- from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
16
- from mvp.utils.models import get_model
17
 
18
- from mvp.definitions import TEST_RESULTS_DIR
19
  import yaml
20
  from functools import partial
21
  # Suppress RDKit warnings and errors
 
10
  from massspecgym.models.base import Stage
11
  import os
12
 
13
+ from flare.data.data_module import TestDataModule
14
+ from flare.data.datasets import ContrastiveDataset
15
+ from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
16
+ from flare.utils.models import get_model
17
 
18
+ from flare.definitions import TEST_RESULTS_DIR
19
  import yaml
20
  from functools import partial
21
  # Suppress RDKit warnings and errors
{mvp β†’ flare}/train.py RENAMED
@@ -11,15 +11,15 @@ from pytorch_lightning import Trainer
11
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
12
 
13
 
14
- from mvp.data.data_module import ContrastiveDataModule
15
 
16
- from mvp.definitions import TEST_RESULTS_DIR
17
  import yaml
18
- from mvp.data.datasets import ContrastiveDataset
19
  from functools import partial
20
 
21
- from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
22
- from mvp.utils.models import get_model
23
  # Suppress RDKit warnings and errors
24
  lg = RDLogger.logger()
25
  lg.setLevel(RDLogger.CRITICAL)
@@ -43,7 +43,7 @@ def main(params):
43
  dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
44
 
45
  # Init data module
46
- collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], mask_peak_ratio=params['mask_peak_ratio'], aug_cands=params['aug_cands'])
47
  data_module = ContrastiveDataModule(
48
  dataset=dataset,
49
  collate_fn=collate_fn,
 
11
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
12
 
13
 
14
+ from flare.data.data_module import ContrastiveDataModule
15
 
16
+ from flare.definitions import TEST_RESULTS_DIR
17
  import yaml
18
+ from flare.data.datasets import ContrastiveDataset
19
  from functools import partial
20
 
21
+ from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
22
+ from flare.utils.models import get_model
23
  # Suppress RDKit warnings and errors
24
  lg = RDLogger.logger()
25
  lg.setLevel(RDLogger.CRITICAL)
 
43
  dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
44
 
45
  # Init data module
46
+ collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'])
47
  data_module = ContrastiveDataModule(
48
  dataset=dataset,
49
  collate_fn=collate_fn,
{mvp β†’ flare}/tune.py RENAMED
@@ -15,11 +15,11 @@ from pytorch_lightning import Trainer
15
  from optuna.integration import PyTorchLightningPruningCallback
16
  from pytorch_lightning.callbacks import Callback
17
 
18
- from mvp.data.data_module import ContrastiveDataModule
19
- from mvp.data.datasets import ContrastiveDataset
20
- from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
21
- from mvp.utils.models import get_model
22
- from mvp.definitions import TEST_RESULTS_DIR
23
  from functools import partial
24
  from rdkit import RDLogger
25
  from massspecgym.models.base import Stage
 
15
  from optuna.integration import PyTorchLightningPruningCallback
16
  from pytorch_lightning.callbacks import Callback
17
 
18
+ from flare.data.data_module import ContrastiveDataModule
19
+ from flare.data.datasets import ContrastiveDataset
20
+ from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
21
+ from flare.utils.models import get_model
22
+ from flare.definitions import TEST_RESULTS_DIR
23
  from functools import partial
24
  from rdkit import RDLogger
25
  from massspecgym.models.base import Stage
{mvp β†’ flare}/utils/__init__.py RENAMED
File without changes
{mvp β†’ flare}/utils/data.py RENAMED
@@ -2,12 +2,12 @@ import os
2
  import json
3
  import numpy as np
4
 
5
- from mvp.data.transforms import SpecBinner, SpecBinnerLog, SpecFormulaFeaturizer, SpecFormulaMzFeaturizer, SpecMzIntTokenizer
6
  from massspecgym.data.transforms import SpecTransform, MolTransform
7
- from mvp.data.transforms import MolToGraph
8
- import mvp.data.datasets as jestr_datasets
9
  import typing as T
10
- from mvp.definitions import MSGYM_FORMULA_VECTOR_NORM, MSGYM_STANDARD_MH, PRECURSOR_INTENSITY
11
  import matchms
12
  import tqdm
13
 
@@ -42,9 +42,6 @@ class Subformula_Loader:
42
  '''MIST subformula format:https://github.com/samgoldman97/mist/blob/main_v2/src/mist/utils/spectra_utils.py
43
  '''
44
  try:
45
- # file = os.path.join(self.dir_path, spec_id+".json")
46
- # with open(file) as f:
47
- # data = json.load(f)
48
  mzs = np.array(data['output_tbl']['mz'])
49
  formulas = np.array(data['output_tbl']['formula'])
50
  intensities = np.array(data['output_tbl']['ms2_inten'])
@@ -271,12 +268,12 @@ def get_test_ms_dataset(spectra_view: T.Union[str, T.List[str]],
271
  dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'use_magma': params['formula_source'].startswith('magma'), 'formula_source':params['formula_source']})
272
  use_formulas = True
273
 
274
- if params['use_cons_spec']:
275
- dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
276
- if 'use_NL_spec' in params and params['use_NL_spec']:
277
- dataset_params.update({'NL_spec_dir_pth': params['NL_spec_dir_pth']})
278
- if params['pred_fp'] or params['use_fp']:
279
- dataset_params.update({'fp_dir_pth': '', 'fp_size': params['fp_size'], 'fp_radius': params['fp_radius']})
280
 
281
  return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, **dataset_params)
282
 
@@ -294,24 +291,8 @@ def get_ms_dataset(spectra_view: str,
294
  dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'formula_source': params['formula_source']})
295
  use_formulas = True
296
 
297
- if params['pred_fp'] or params['use_fp']:
298
- dataset_params.update({'fp_dir_pth': params['fp_dir_pth']})
299
-
300
- if params['aug_cands']:
301
- dataset_params.update({'aug_cands_dir_pth': params['aug_cands_dir_pth'],
302
- 'use_formulas':use_formulas,
303
- "aug_cands_size": params['aug_cands_size']})
304
-
305
- if params['use_cons_spec']:
306
- dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
307
-
308
- if 'use_NL_spec' in params and params['use_NL_spec']:
309
- dataset_params.update({'NL_spec_dir_pth': params['NL_spec_dir_pth']})
310
-
311
  # select dataset
312
- if params['aug_cands']:
313
- return jestr_datasets.MassSpecDataset_Candidates(**dataset_params)
314
- elif use_formulas:
315
  return jestr_datasets.MassSpecDataset_PeakFormulas(**dataset_params)
316
 
317
  return jestr_datasets.JESTR1_MassSpecDataset(**dataset_params)
 
2
  import json
3
  import numpy as np
4
 
5
+ from flare.data.transforms import SpecBinner, SpecBinnerLog, SpecFormulaFeaturizer, SpecFormulaMzFeaturizer, SpecMzIntTokenizer
6
  from massspecgym.data.transforms import SpecTransform, MolTransform
7
+ from flare.data.transforms import MolToGraph
8
+ import flare.data.datasets as jestr_datasets
9
  import typing as T
10
+ from flare.definitions import MSGYM_FORMULA_VECTOR_NORM, MSGYM_STANDARD_MH, PRECURSOR_INTENSITY
11
  import matchms
12
  import tqdm
13
 
 
42
  '''MIST subformula format:https://github.com/samgoldman97/mist/blob/main_v2/src/mist/utils/spectra_utils.py
43
  '''
44
  try:
 
 
 
45
  mzs = np.array(data['output_tbl']['mz'])
46
  formulas = np.array(data['output_tbl']['formula'])
47
  intensities = np.array(data['output_tbl']['ms2_inten'])
 
268
  dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'use_magma': params['formula_source'].startswith('magma'), 'formula_source':params['formula_source']})
269
  use_formulas = True
270
 
271
+ # if params['use_cons_spec']:
272
+ # dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
273
+ # if 'use_NL_spec' in params and params['use_NL_spec']:
274
+ # dataset_params.update({'NL_spec_dir_pth': params['NL_spec_dir_pth']})
275
+ # if params['pred_fp'] or params['use_fp']:
276
+ # dataset_params.update({'fp_dir_pth': '', 'fp_size': params['fp_size'], 'fp_radius': params['fp_radius']})
277
 
278
  return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, **dataset_params)
279
 
 
291
  dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'formula_source': params['formula_source']})
292
  use_formulas = True
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # select dataset
295
+ if use_formulas:
 
 
296
  return jestr_datasets.MassSpecDataset_PeakFormulas(**dataset_params)
297
 
298
  return jestr_datasets.JESTR1_MassSpecDataset(**dataset_params)
{mvp β†’ flare}/utils/debug.py RENAMED
File without changes
{mvp β†’ flare}/utils/eval.py RENAMED
@@ -1,8 +1,8 @@
1
- from MassSpecGym.massspecgym.utils import MyopicMCES
2
  import numpy as np
3
  import tqdm
4
  from multiprocessing import Pool
5
-
6
  import os
7
  import pandas as pd
8
 
@@ -51,29 +51,6 @@ class Compute_Myopic_MCES_timeout:
51
 
52
  return results
53
 
54
-
55
- def get_result_files(exp_dir, spec_type, views_type):
56
- files = os.listdir(exp_dir)
57
- mass_result = ''
58
- form_result = ''
59
-
60
- for f in files:
61
- try:
62
- _, s, views = f.split('_')
63
- except:
64
- continue
65
-
66
- if s == spec_type and views == views_type:
67
- print(exp_dir / f)
68
-
69
- files = os.listdir(exp_dir / f)
70
- for fr in files:
71
- if 'mass_result' in fr:
72
- mass_result = exp_dir / f / fr
73
- elif 'result' in fr:
74
- form_result = exp_dir / f/ fr
75
-
76
- return mass_result, form_result
77
 
78
  # get target
79
  def get_target(candidates, labels):
@@ -85,73 +62,17 @@ def get_top_cand(candidates, scores):
85
 
86
  # split into hit rates
87
  def convert_rank_to_hit_rates(row, rank_col ,top_k=[1,5,20]):
88
- top_k_hits ={}
89
  rank = row[rank_col]
90
  for k in top_k:
91
  if rank <= k:
92
- top_k_hits[f'{rank_col}-hit_rate@{k}'] = 1
93
  else:
94
- top_k_hits[f'{rank_col}-hit_rate@{k}'] = 0
95
- return pd.Series(top_k_hits)
96
-
97
- #################### Rank aggregation #######################
98
- from collections import defaultdict
99
- import numpy as np
100
- from scipy.stats import rankdata
101
-
102
- def borda_count(candidates, score_lists, target):
103
- scores = defaultdict(int)
104
- N = len(candidates)
105
- for score_list in score_lists:
106
- ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
107
- for rank, (mol, _) in enumerate(ranked_list, start=1):
108
- scores[mol] += N - rank + 1
109
- ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
110
- return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
111
-
112
- def average_rank(candidates, score_lists, target):
113
- rank_sums = defaultdict(list)
114
- for score_list in score_lists:
115
- ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
116
- for rank, (mol, _) in enumerate(ranked_list, start=1):
117
- rank_sums[mol].append(rank)
118
- avg_ranks = {mol: np.mean(ranks) for mol, ranks in rank_sums.items()}
119
- ranked_candidates = [mol for mol, _ in sorted(avg_ranks.items(), key=lambda x: x[1])]
120
- return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
121
-
122
- def reciprocal_rank_aggregation(candidates, score_lists, target):
123
- scores = defaultdict(float)
124
- for score_list in score_lists:
125
- ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
126
- for rank, (mol, _) in enumerate(ranked_list, start=1):
127
- scores[mol] += 1 / rank
128
- ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
129
- return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
130
-
131
- def weighted_voting(candidates, score_lists, weights, target):
132
- scores = defaultdict(float)
133
- for weight, score_list in zip(weights, score_lists):
134
- ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
135
- for rank, (mol, _) in enumerate(ranked_list, start=1):
136
- scores[mol] += weight / rank
137
- ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
138
- return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
139
 
140
- def median_rank(candidates, score_lists, target):
141
- rank_sums = defaultdict(list)
142
- for score_list in score_lists:
143
- ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
144
- for rank, (mol, _) in enumerate(ranked_list, start=1):
145
- rank_sums[mol].append(rank)
146
- median_ranks = {mol: np.median(ranks) for mol, ranks in rank_sums.items()}
147
- ranked_candidates = [mol for mol, _ in sorted(median_ranks.items(), key=lambda x: x[1])]
148
- return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
149
 
150
- def score_based_aggregation(candidates, score_lists, target):
151
- scores = defaultdict(list)
152
- for score_list in score_lists:
153
- for mol, score in zip(candidates, score_list):
154
- scores[mol].append(score)
155
- avg_scores = {mol: np.mean(vals) for mol, vals in scores.items()}
156
- ranked_candidates = [mol for mol, _ in sorted(avg_scores.items(), key=lambda x: x[1], reverse=True)]
157
- return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
 
1
+ from massspecgym.utils import MyopicMCES
2
  import numpy as np
3
  import tqdm
4
  from multiprocessing import Pool
5
+ from scipy.stats import bootstrap
6
  import os
7
  import pandas as pd
8
 
 
51
 
52
  return results
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # get target
56
  def get_target(candidates, labels):
 
62
 
63
  # split into hit rates
64
  def convert_rank_to_hit_rates(row, rank_col ,top_k=[1,5,20]):
65
+ top_k_hits = []
66
  rank = row[rank_col]
67
  for k in top_k:
68
  if rank <= k:
69
+ top_k_hits.append(1)
70
  else:
71
+ top_k_hits.append(0)
72
+ return top_k_hits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000, seed=0):
76
+ res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples, random_state=seed)
77
+ ci = res.confidence_interval
78
+ return f'{ci.low:.2f}-{ci.high:.2f}'
 
 
 
 
flare/utils/general.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ def pad_graph_nodes(mol_enc, g_n_nodes):
6
+ """
7
+ Args:
8
+ mol_enc: 2D tensor of shape (sum_nodes, D)
9
+ Node embeddings for each molecule.
10
+ g_n_nodes: list[int] Number of nodes per graph (len = B)
11
+
12
+ Returns:
13
+ padded: (B, max_nodes, D) tensor
14
+ mask: (B, max_nodes) bool tensor, True for valid nodes
15
+ """
16
+
17
+ # Already concatenated: shape (sum_nodes, D)
18
+ B = len(g_n_nodes)
19
+ D = mol_enc.shape[1]
20
+ max_nodes = max(g_n_nodes)
21
+ padded = mol_enc.new_zeros((B, max_nodes, D))
22
+ mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
23
+
24
+ idx = 0
25
+ for i, n in enumerate(g_n_nodes):
26
+ padded[i, :n] = mol_enc[idx:idx+n]
27
+ mask[i, :n] = True
28
+ idx += n
29
+ return padded, mask
30
+
31
+ import torch
32
+ import torch.nn.functional as F
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ def filip_similarity_batch(
38
+ image_tokens,
39
+ text_tokens,
40
+ mask_image,
41
+ mask_text,
42
+ reduction="mean", # "mean", "topk", "softmax", or "geom"
43
+ k=5,
44
+ temperature=0.05,
45
+ eps=1e-6
46
+ ):
47
+ """
48
+ Compute FILIP similarity for batches of image and text token embeddings.
49
+
50
+ Args:
51
+ image_tokens: (B, N_img, D) float tensor
52
+ text_tokens: (B, N_text, D) float tensor
53
+ mask_image: (B, N_img) bool tensor
54
+ mask_text: (B, N_text) bool tensor
55
+ reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom"
56
+ k: int, used if reduction == "topk"
57
+ temperature: float, used if reduction == "softmax"
58
+ eps: float, small constant for numerical stability
59
+
60
+ Returns:
61
+ similarities: (B,) float tensor of similarity scores
62
+ """
63
+ B, N_img, D = image_tokens.shape
64
+ N_text = text_tokens.shape[1]
65
+
66
+ # Normalize tokens
67
+ image_norm = F.normalize(image_tokens, p=2, dim=-1)
68
+ text_norm = F.normalize(text_tokens, p=2, dim=-1)
69
+
70
+ # Compute cosine similarity matrices
71
+ sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
72
+
73
+ # Expand masks
74
+ mask_image_exp = mask_image.unsqueeze(2)
75
+ mask_text_exp = mask_text.unsqueeze(1)
76
+ valid_mask = mask_image_exp & mask_text_exp
77
+
78
+ # Mask invalid positions
79
+ sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf'))
80
+
81
+ # Max per image/text token
82
+ max_sim_img, _ = sim_matrix_masked.max(dim=2)
83
+ max_sim_text, _ = sim_matrix_masked.max(dim=1)
84
+
85
+ # Replace -inf with zeros
86
+ max_sim_img[max_sim_img == float('-inf')] = 0
87
+ max_sim_text[max_sim_text == float('-inf')] = 0
88
+
89
+ # Helper: aggregate with chosen strategy
90
+ def aggregate(max_sim, mask):
91
+ count = mask.sum(dim=1).clamp(min=1).float()
92
+
93
+ if reduction == "mean":
94
+ return (max_sim * mask).sum(dim=1) / count
95
+
96
+ elif reduction == "topk":
97
+ k_eff = min(k, max_sim.size(1))
98
+ # Mask invalid tokens to large negative before topk
99
+ masked_vals = max_sim.masked_fill(~mask, float('-inf'))
100
+ topk_vals, _ = torch.topk(masked_vals, k_eff, dim=1)
101
+ topk_vals[topk_vals == float('-inf')] = 0
102
+ return topk_vals.sum(dim=1) / k_eff
103
+
104
+ elif reduction == "softmax":
105
+ masked_vals = max_sim.masked_fill(~mask, float('-inf'))
106
+ weights = torch.softmax(masked_vals / temperature, dim=1)
107
+ weights = weights * mask
108
+ weights = weights / weights.sum(dim=1, keepdim=True).clamp(min=eps)
109
+ return (weights * max_sim).sum(dim=1)
110
+
111
+ elif reduction == "geom":
112
+ # Use log-sum-exp trick for geometric mean stability
113
+ masked_vals = (max_sim * mask).clamp(min=eps)
114
+ log_vals = torch.log(masked_vals)
115
+ geom_mean = torch.exp((log_vals.sum(dim=1)) / count)
116
+ return geom_mean
117
+
118
+ else:
119
+ raise ValueError(f"Unknown reduction type: {reduction}")
120
+
121
+ # Aggregate both sides
122
+ avg_img = aggregate(max_sim_img, mask_image)
123
+ avg_text = aggregate(max_sim_text, mask_text)
124
+
125
+ # Final similarity
126
+ similarity = (avg_img + avg_text) / 2
127
+ return similarity
128
+
129
+
130
+
131
+ # def filip_similarity_batch(image_tokens, text_tokens, mask_image, mask_text):
132
+ # """
133
+ # Compute FILIP similarity for batches of image and text token embeddings.
134
+
135
+ # Args:
136
+ # image_tokens: (B, N_img, D) float tensor
137
+ # text_tokens: (B, N_text, D) float tensor
138
+ # mask_image: (B, N_img) bool tensor
139
+ # mask_text: (B, N_text) bool tensor
140
+
141
+ # Returns:
142
+ # similarities: (B,) float tensor of similarity scores
143
+ # """
144
+ # B, N_img, D = image_tokens.shape
145
+ # N_text = text_tokens.shape[1]
146
+
147
+ # # Normalize tokens
148
+ # image_norm = F.normalize(image_tokens, p=2, dim=-1) # (B, N_img, D)
149
+ # text_norm = F.normalize(text_tokens, p=2, dim=-1) # (B, N_text, D)
150
+
151
+ # # Compute batched cosine similarity matrices
152
+ # # Result shape: (B, N_img, N_text)
153
+ # sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
154
+
155
+ # # Expand masks for broadcasting
156
+ # mask_image_exp = mask_image.unsqueeze(2) # (B, N_img, 1)
157
+ # mask_text_exp = mask_text.unsqueeze(1) # (B, 1, N_text)
158
+ # valid_mask = mask_image_exp & mask_text_exp # (B, N_img, N_text)
159
+
160
+ # # Mask invalid positions by setting them to -inf
161
+ # sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf'))
162
+
163
+ # # Max over text tokens per image token: (B, N_img)
164
+ # max_sim_img, _ = sim_matrix_masked.max(dim=2)
165
+
166
+ # # Max over image tokens per text token: (B, N_text)
167
+ # max_sim_text, _ = sim_matrix_masked.max(dim=1)
168
+
169
+ # # Replace -inf (no valid tokens) with zeros to avoid NaNs
170
+ # max_sim_img[max_sim_img == float('-inf')] = 0
171
+ # max_sim_text[max_sim_text == float('-inf')] = 0
172
+
173
+ # # Sum over valid tokens and divide by number of valid tokens (avoid division by zero)
174
+ # sum_img = (max_sim_img * mask_image).sum(dim=1)
175
+ # count_img = mask_image.sum(dim=1).clamp(min=1).float()
176
+
177
+ # sum_text = (max_sim_text * mask_text).sum(dim=1)
178
+ # count_text = mask_text.sum(dim=1).clamp(min=1).float()
179
+
180
+ # avg_img = sum_img / count_img
181
+ # avg_text = sum_text / count_text
182
+
183
+ # # Final similarity per batch element
184
+ # similarity = (avg_img + avg_text) / 2
185
+
186
+ # return similarity
{mvp β†’ flare}/utils/loss.py RENAMED
File without changes
{mvp β†’ flare}/utils/models.py RENAMED
@@ -1,7 +1,7 @@
1
- from mvp.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
2
- from mvp.models.mol_encoder import MolEnc
3
- from mvp.models.encoders import MLP
4
- from mvp.models.contrastive import ContrastiveModel, CrossAttenContrastive, IndSpecEncoder, MultiViewContrastive, MultiViewFineTuning, FilipContrastive
5
 
6
  def get_spec_encoder(spec_enc:str, args):
7
  return {"MLP_BIN": SpecEncMLP_BIN,
@@ -26,14 +26,6 @@ def get_model(model:str,
26
  model= ContrastiveModel(**params)
27
  elif model =='crossAttenContrastive':
28
  model = CrossAttenContrastive(**params)
29
- elif model == 'IndSpecEncoder':
30
- params['pred_fp'] = False
31
- params['use_cons_spec'] = False
32
- model = IndSpecEncoder(**params)
33
- elif model == "MultiviewContrastive":
34
- model = MultiViewContrastive(**params)
35
- elif model == "MultiViewFineTuning":
36
- model = MultiViewFineTuning(**params)
37
  elif model == "filipContrastive":
38
  model = FilipContrastive(**params)
39
  else:
 
1
+ from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
2
+ from flare.models.mol_encoder import MolEnc
3
+ from flare.models.encoders import MLP
4
+ from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive
5
 
6
  def get_spec_encoder(spec_enc:str, args):
7
  return {"MLP_BIN": SpecEncMLP_BIN,
 
26
  model= ContrastiveModel(**params)
27
  elif model =='crossAttenContrastive':
28
  model = CrossAttenContrastive(**params)
 
 
 
 
 
 
 
 
29
  elif model == "filipContrastive":
30
  model = FilipContrastive(**params)
31
  else:
{mvp β†’ flare}/utils/preprocessing.py RENAMED
@@ -1,7 +1,7 @@
1
  import pandas as pd
2
  import pickle
3
  import numpy as np
4
- import mvp.utils.data as data_utils
5
  import collections
6
  import os
7
  import requests
 
1
  import pandas as pd
2
  import pickle
3
  import numpy as np
4
+ import flare.utils.data as data_utils
5
  import collections
6
  import os
7
  import requests
mvp/data/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- import sys
2
- sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
3
- from massspecgym.data import *
 
 
 
 
mvp/run.sh DELETED
@@ -1,3 +0,0 @@
1
- python train.py
2
- python test.py
3
- python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
 
 
 
 
mvp/utils/general.py DELETED
@@ -1,87 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
-
5
- def pad_graph_nodes(mol_enc, g_n_nodes):
6
- """
7
- Args:
8
- mol_enc: 2D tensor of shape (sum_nodes, D)
9
- Node embeddings for each molecule.
10
- g_n_nodes: list[int] Number of nodes per graph (len = B)
11
-
12
- Returns:
13
- padded: (B, max_nodes, D) tensor
14
- mask: (B, max_nodes) bool tensor, True for valid nodes
15
- """
16
-
17
- # Already concatenated: shape (sum_nodes, D)
18
- B = len(g_n_nodes)
19
- D = mol_enc.shape[1]
20
- max_nodes = max(g_n_nodes)
21
- padded = mol_enc.new_zeros((B, max_nodes, D))
22
- mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
23
-
24
- idx = 0
25
- for i, n in enumerate(g_n_nodes):
26
- padded[i, :n] = mol_enc[idx:idx+n]
27
- mask[i, :n] = True
28
- idx += n
29
- return padded, mask
30
-
31
-
32
- def filip_similarity_batch(image_tokens, text_tokens, mask_image, mask_text):
33
- """
34
- Compute FILIP similarity for batches of image and text token embeddings.
35
-
36
- Args:
37
- image_tokens: (B, N_img, D) float tensor
38
- text_tokens: (B, N_text, D) float tensor
39
- mask_image: (B, N_img) bool tensor
40
- mask_text: (B, N_text) bool tensor
41
-
42
- Returns:
43
- similarities: (B,) float tensor of similarity scores
44
- """
45
- B, N_img, D = image_tokens.shape
46
- N_text = text_tokens.shape[1]
47
-
48
- # Normalize tokens
49
- image_norm = F.normalize(image_tokens, p=2, dim=-1) # (B, N_img, D)
50
- text_norm = F.normalize(text_tokens, p=2, dim=-1) # (B, N_text, D)
51
-
52
- # Compute batched cosine similarity matrices
53
- # Result shape: (B, N_img, N_text)
54
- sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
55
-
56
- # Expand masks for broadcasting
57
- mask_image_exp = mask_image.unsqueeze(2) # (B, N_img, 1)
58
- mask_text_exp = mask_text.unsqueeze(1) # (B, 1, N_text)
59
- valid_mask = mask_image_exp & mask_text_exp # (B, N_img, N_text)
60
-
61
- # Mask invalid positions by setting them to -inf
62
- sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf'))
63
-
64
- # Max over text tokens per image token: (B, N_img)
65
- max_sim_img, _ = sim_matrix_masked.max(dim=2)
66
-
67
- # Max over image tokens per text token: (B, N_text)
68
- max_sim_text, _ = sim_matrix_masked.max(dim=1)
69
-
70
- # Replace -inf (no valid tokens) with zeros to avoid NaNs
71
- max_sim_img[max_sim_img == float('-inf')] = 0
72
- max_sim_text[max_sim_text == float('-inf')] = 0
73
-
74
- # Sum over valid tokens and divide by number of valid tokens (avoid division by zero)
75
- sum_img = (max_sim_img * mask_image).sum(dim=1)
76
- count_img = mask_image.sum(dim=1).clamp(min=1).float()
77
-
78
- sum_text = (max_sim_text * mask_text).sum(dim=1)
79
- count_text = mask_text.sum(dim=1).clamp(min=1).float()
80
-
81
- avg_img = sum_img / count_img
82
- avg_text = sum_text / count_text
83
-
84
- # Final similarity per batch element
85
- similarity = (avg_img + avg_text) / 2
86
-
87
- return similarity