yzhouchen001 commited on
Commit
42f26af
·
1 Parent(s): 514233d
mvp/__init__.py ADDED
File without changes
mvp/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
3
+ from massspecgym.data import *
mvp/data/data_module.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data.dataloader import DataLoader
2
+ from massspecgym.data.data_module import MassSpecDataModule
3
+ from mvp.data.datasets import ContrastiveDataset
4
+ from functools import partial
5
+ from massspecgym.models.base import Stage
6
+
7
+ class TestDataModule(MassSpecDataModule):
8
+ def __init__(
9
+ self,
10
+ collate_fn,
11
+ **kwargs
12
+ ):
13
+ super().__init__(**kwargs)
14
+ self.collate_fn = collate_fn
15
+
16
+ def prepare_data(self):
17
+ pass
18
+
19
+ def setup(self, stage=None):
20
+ if stage == "test":
21
+ self.test_dataset = self.dataset
22
+ else:
23
+ raise Exception("Data module supports test set only")
24
+
25
+ def test_dataloader(self):
26
+ return DataLoader(
27
+ self.test_dataset,
28
+ batch_size=self.batch_size,
29
+ shuffle=False,
30
+ num_workers=self.num_workers,
31
+ persistent_workers=self.persistent_workers,
32
+ drop_last=False,
33
+ collate_fn=self.collate_fn,
34
+ )
35
+
36
+ def train_dataloader(self):
37
+ return None
38
+
39
+ def val_dataset(self):
40
+ return None
41
+
42
+ class ContrastiveDataModule(MassSpecDataModule):
43
+ def __init__(
44
+ self,
45
+ collate_fn,
46
+ **kwargs
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.collate_fn = collate_fn
50
+ self.regularization_flag = False
51
+
52
+ def train_dataloader(self):
53
+ self.train_contrastive_dataset = ContrastiveDataset(self.train_dataset)
54
+
55
+ return DataLoader(self.train_contrastive_dataset,
56
+ batch_size=self.batch_size,
57
+ shuffle=True,
58
+ num_workers=self.num_workers,
59
+ persistent_workers=self.persistent_workers,
60
+ drop_last=False,
61
+ collate_fn=partial(self.collate_fn, stage=Stage.TRAIN),
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ self.val_contrastive_dataset = ContrastiveDataset(self.val_dataset)
66
+
67
+ return DataLoader(self.val_contrastive_dataset,
68
+ batch_size=self.batch_size,
69
+ shuffle=False,
70
+ num_workers=self.num_workers,
71
+ persistent_workers=self.persistent_workers,
72
+ drop_last=False,
73
+ collate_fn=partial(self.collate_fn, stage=Stage.VAL))
74
+
75
+ def test_dataloader(self):
76
+ return DataLoader(
77
+ self.test_dataset,
78
+ batch_size=self.batch_size,
79
+ shuffle=False,
80
+ num_workers=self.num_workers,
81
+ persistent_workers=self.persistent_workers,
82
+ drop_last=False,
83
+ collate_fn=self.dataset.collate_fn,
84
+ )
mvp/data/datasets.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import typing as T
4
+ import numpy as np
5
+ import torch
6
+ import massspecgym.utils as utils
7
+ from pathlib import Path
8
+ from torch.utils.data.dataset import Dataset
9
+ from torch.utils.data.dataloader import default_collate
10
+ import dgl
11
+ from collections import defaultdict
12
+ from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
13
+ from massspecgym.data.datasets import MassSpecDataset
14
+ import mvp.utils.data as data_utils
15
+ from torch.nn.utils.rnn import pad_sequence
16
+ from massspecgym.models.base import Stage
17
+ import pickle
18
+ import math
19
+ import itertools
20
+ from rdkit.Chem import AllChem
21
+ from rdkit import Chem
22
+ class JESTR1_MassSpecDataset(MassSpecDataset):
23
+ def __init__(
24
+ self,
25
+ spectra_view: str,
26
+ fp_dir_pth: str = None,
27
+ cons_spec_dir_pth: str = None,
28
+ NL_spec_dir_pth: str = None,
29
+ **kwargs
30
+ ):
31
+ super().__init__(**kwargs)
32
+
33
+ self.use_fp = False
34
+ self.use_cons_spec = False
35
+ self.use_NL_spec = False
36
+ self.spectra_view = spectra_view
37
+
38
+ # load fingerprints
39
+ self._load_fp(fp_dir_pth)
40
+
41
+ # load consensus
42
+ self._load_cons_spec(cons_spec_dir_pth)
43
+
44
+ # load NL specs
45
+ self._load_NL_spec(NL_spec_dir_pth)
46
+
47
+ def _load_fp(self, fp_dir_pth):
48
+ if fp_dir_pth is not None:
49
+ self.use_fp = True
50
+ if fp_dir_pth:
51
+ with open(fp_dir_pth, 'rb') as f:
52
+ self.smiles_to_fp = pickle.load(f)
53
+ else:
54
+ self.smiles_to_fp = {}
55
+
56
+ def _load_cons_spec(self, cons_spec_dir_pth):
57
+ if cons_spec_dir_pth is not None:
58
+ self.use_cons_spec = True
59
+ with open(cons_spec_dir_pth, 'rb') as f:
60
+ cons_specs = pickle.load(f)
61
+
62
+ # Convert spectra to matchms spectra
63
+ matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
64
+ spectra = cons_specs.apply(matchMS_preparer.prepare,axis=1)
65
+
66
+ self.cons_specs = dict(zip(cons_specs['smiles'].tolist(), spectra))
67
+
68
+ def _load_NL_spec(self, NL_spec_dir_pth):
69
+ if NL_spec_dir_pth is not None:
70
+ self.use_NL_spec = True
71
+ with open(NL_spec_dir_pth, 'rb') as f:
72
+ NL_specs = pickle.load(f)
73
+
74
+ # Convert spectra to matchms spectra
75
+ matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
76
+ self.NL_specs = NL_specs.apply(matchMS_preparer.prepare,axis=1)
77
+
78
+
79
+ def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
80
+
81
+ spec = self.spectra[i]
82
+ metadata = self.metadata.iloc[i]
83
+ mol = metadata["smiles"]
84
+
85
+ # Apply all transformations to the spectrum
86
+ item = {}
87
+ if transform_spec and self.spec_transform:
88
+ if isinstance(self.spec_transform, dict):
89
+ for key, transform in self.spec_transform.items():
90
+ item[key] = transform(spec) if transform is not None else spec
91
+ else:
92
+ item["spec"] = self.spec_transform(spec)
93
+ else:
94
+ item["spec"] = spec
95
+
96
+ if self.return_mol_freq:
97
+ item["mol_freq"] = metadata["mol_freq"]
98
+
99
+ if self.return_identifier:
100
+ item["identifier"] = metadata["identifier"]
101
+
102
+ if self.use_fp and self.smiles_to_fp:
103
+ item['fp'] = torch.Tensor(self.smiles_to_fp[mol].ToList())
104
+
105
+ if self.use_cons_spec:
106
+ item['cons_spec'] = self.spec_transform[self.spectra_view](self.cons_specs[mol])
107
+
108
+ if self.use_NL_spec:
109
+ item['NL_spec'] = self.spec_transform[self.spectra_view](self.NL_specs[i])
110
+
111
+ # Apply all transformations to the molecule
112
+ if transform_mol and self.mol_transform:
113
+ if isinstance(self.mol_transform, dict):
114
+ for key, transform in self.mol_transform.items():
115
+ item[key] = transform(mol) if transform is not None else mol
116
+ else:
117
+ item["mol"] = self.mol_transform(mol)
118
+ else:
119
+ item["mol"] = mol
120
+ return item
121
+
122
+ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
123
+ def __init__(
124
+ self,
125
+ spectra_view: str,
126
+ spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]],
127
+ mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
128
+ pth: T.Optional[Path],
129
+ subformula_dir_pth: str,
130
+ fp_dir_pth: str = None,
131
+ NL_spec_dir_pth: str = None,
132
+ cons_spec_dir_pth: str = None,
133
+ return_mol_freq: bool = False,
134
+ return_identifier: bool = True,
135
+ dtype: T.Type = torch.float32
136
+ ):
137
+ """
138
+ Args:
139
+ """
140
+ self.pth = pth
141
+ self.spec_transform = spec_transform
142
+ self.mol_transform = mol_transform
143
+ self.return_mol_freq = return_mol_freq
144
+ self.pred_fp = False
145
+ self.use_fp = False
146
+ self.use_cons_spec = False
147
+ self.use_NL_spec = False
148
+ self.spectra_view = spectra_view
149
+
150
+ if isinstance(self.pth, str):
151
+ self.pth = Path(self.pth)
152
+
153
+ self.spectra_view = spectra_view
154
+ print("Data path: ", self.pth)
155
+ self.metadata = pd.read_csv(self.pth, sep="\t")
156
+
157
+ # Used for training on consensus spectra
158
+ # with open(self.pth, 'rb') as f:
159
+ # self.metadata = pickle.load(f)
160
+ # self.metadata['identifier'] = self.metadata['smiles'].tolist()
161
+
162
+ # load subformulas
163
+ all_spec_ids = self.metadata['identifier'].tolist()
164
+ subformulaLoader = data_utils.Subformula_Loader(spectra_view=spectra_view, dir_path=subformula_dir_pth)
165
+ id_to_spec = subformulaLoader(all_spec_ids)
166
+
167
+ # create subformula spectra if no subformula is available
168
+ tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
169
+ tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)]
170
+ tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
171
+ id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist())))
172
+
173
+
174
+ # load fingerprints
175
+ self._load_fp(fp_dir_pth)
176
+
177
+ # load consensus spectra
178
+ self._load_cons_spec(cons_spec_dir_pth)
179
+
180
+ # load NL specs
181
+ self._load_NL_spec(NL_spec_dir_pth)
182
+
183
+ self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)]
184
+ formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
185
+ self.metadata = self.metadata.merge(formula_df, on='identifier')
186
+
187
+ # create matchms spectra
188
+ matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
189
+ self.spectra = self.metadata.apply(matchMS_preparer.prepare,axis=1)
190
+
191
+ if self.return_mol_freq:
192
+ if "inchikey" not in self.metadata.columns:
193
+ self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
194
+ self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
195
+
196
+ self.return_identifier = return_identifier
197
+ self.dtype = dtype
198
+
199
+ def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
200
+ item = super().__getitem__(i, transform_spec, transform_mol = False)
201
+ mol = item['mol'] #smiles
202
+
203
+ # transform mol
204
+ if transform_mol:
205
+ if isinstance(self.mol_transform, dict):
206
+ for key, transform in self.mol_transform.items():
207
+ item[key] = transform(mol) if transform is not None else mol
208
+ else:
209
+ item["mol"] = self.mol_transform(mol)
210
+
211
+ return item
212
+
213
+ class ContrastiveDataset(Dataset):
214
+ def __init__(
215
+ self,
216
+ spec_mol_data,
217
+ ):
218
+ super().__init__()
219
+
220
+ indices = spec_mol_data.indices
221
+ self.spec_mol_data = spec_mol_data
222
+ self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby('smiles').indices
223
+ self.smiles_to_spec_couter = defaultdict(int)
224
+ self.smiles_list = list(self.smiles_to_specmol_ids.keys())
225
+
226
+ def __len__(self) -> int:
227
+ return len(self.smiles_list)
228
+
229
+ def __getitem__(self, i:int) -> dict:
230
+ mol = self.smiles_list[i]
231
+
232
+ # select spectrum (iterate through list of spectra)
233
+ specmol_ids = self.smiles_to_specmol_ids[mol]
234
+ counter = self.smiles_to_spec_couter[mol]
235
+ specmol_id = specmol_ids[counter % len(specmol_ids)]
236
+
237
+ item = self.spec_mol_data.__getitem__(specmol_id)
238
+ self.smiles_to_spec_couter[mol] = counter+1
239
+ # item['smiles'] = mol
240
+ # item['spec_id'] = specmol_id
241
+ return item
242
+
243
+ @staticmethod
244
+ def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, mask_peak_ratio: float = 0.0, aug_cands: bool = False) -> dict:
245
+ mol_key = 'cand' if stage == Stage.TEST else 'mol'
246
+ non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
247
+ require_pad = False
248
+ if 'Formula' in spectra_view or 'Tokens' in spectra_view:
249
+ require_pad = True
250
+ padding_value=-5 if spec_enc in ('Transformer_Formula', 'Formula_BinnedSpec', 'Transformer_MzInt') else 0
251
+ non_standard_collate.append(spectra_view)
252
+ else:
253
+ non_standard_collate.remove('cons_spec')
254
+ non_standard_collate.remove('NL_spec')
255
+
256
+ collated_batch = {}
257
+ # standard collate
258
+ for k in batch[0].keys():
259
+ if k not in non_standard_collate:
260
+ collated_batch[k] = default_collate([item[k] for item in batch])
261
+
262
+ # batch graphs
263
+ batch_mol = []
264
+ batch_mol_nodes= []
265
+
266
+ for item in batch:
267
+ batch_mol.append(item[mol_key])
268
+ batch_mol_nodes.append(item[mol_key].num_nodes())
269
+
270
+ collated_batch[mol_key] = dgl.batch(batch_mol)
271
+ collated_batch['mol_n_nodes'] = batch_mol_nodes
272
+
273
+ # pad peaks/formulas
274
+ if require_pad:
275
+ peaks = []
276
+ n_peaks = []
277
+ for item in batch:
278
+ peaks.append(item[spectra_view])
279
+ n_peaks.append(len(item[spectra_view]))
280
+ collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
281
+ collated_batch['n_peaks'] = n_peaks
282
+
283
+ if 'cons_spec' in batch[0]:
284
+ peaks = []
285
+ n_peaks = []
286
+ for item in batch:
287
+ peaks.append(item['cons_spec'])
288
+ n_peaks.append(len(item['cons_spec']))
289
+ collated_batch['cons_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
290
+ collated_batch['cons_n_peaks'] = n_peaks
291
+
292
+ if 'NL_spec' in batch[0]:
293
+ peaks = []
294
+ n_peaks = []
295
+ for item in batch:
296
+ peaks.append(item['NL_spec'])
297
+ n_peaks.append(len(item['NL_spec']))
298
+ collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
299
+ collated_batch['NL_n_peaks'] = n_peaks
300
+
301
+
302
+ # mask peaks
303
+ if mask_peak_ratio > 0.0 and stage == Stage.TRAIN:
304
+ n_mask_peaks = [math.floor(n_peak* mask_peak_ratio) for n_peak in n_peaks]
305
+ mask_peak_idx = [np.random.choice(n_peak, n_mask, replace=False) for n_peak, n_mask in zip(n_peaks, n_mask_peaks)]
306
+ for i, peaks in enumerate(collated_batch[spectra_view]):
307
+ peaks[mask_peak_idx[i]] = -5.0
308
+
309
+ # batch candidates
310
+ if aug_cands:
311
+ candidates = \
312
+ sum([item["aug_cands"] for item in batch], start=[])
313
+ collated_batch['aug_cands'] = dgl.batch(candidates)
314
+
315
+ if 'aug_cands_fp' in batch[0]:
316
+ cand_fp = [item['aug_cands_fp'] for item in batch]
317
+ collated_batch['aug_cands_fp'] = torch.flatten(torch.Tensor(cand_fp), end_dim=1)
318
+
319
+ return collated_batch
320
+
321
+
322
+
323
+ class ExpandedRetrievalDataset:
324
+ '''Used for testing only
325
+ Assumes 'fold' column defines the split'''
326
+ def __init__(self,
327
+ use_formulas: bool = True,
328
+ mol_label_transform: MolTransform = MolToInChIKey(),
329
+ candidates_pth: T.Optional[T.Union[Path, str]] = None,
330
+ fp_size: int = None,
331
+ fp_radius: int = None,
332
+ **kwargs):
333
+
334
+ self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
335
+ # super().__init__(**kwargs)
336
+
337
+ if self.use_fp:
338
+ self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size)
339
+
340
+ self.candidates_pth = candidates_pth
341
+ self.mol_label_transform = mol_label_transform
342
+
343
+ # Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
344
+ with open(self.candidates_pth, "r") as file:
345
+ candidates = json.load(file)
346
+
347
+ self.candidates = {}
348
+ for s, cand in candidates.items():
349
+ self.candidates[s] = [c for c in cand if '.' not in c]
350
+
351
+ self.spec_cand = [] #(spec index, cand_smiles, true_label)
352
+ test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
353
+ test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
354
+
355
+ spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
356
+ for spec_id, s in zip(test_ms_id, test_smiles):
357
+ candidates = self.candidates[s]
358
+ # mol_label = self.mol_label_transform(s)
359
+ # labels = [self.mol_label_transform(c) == mol_label for c in candidates]
360
+ labels = [c == s for c in candidates]
361
+ if len(candidates) == 0:
362
+ print(f"Skipping {spec_id}; empty candidate set")
363
+ continue
364
+ if not any(labels):
365
+ print(f"Target smiles not in candidate set")
366
+
367
+
368
+ self.spec_cand.extend([(spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
369
+
370
+ def __getattr__(self, name):
371
+ return self.instance.__getattribute__(name)
372
+
373
+ def __len__(self):
374
+ return len(self.spec_cand)
375
+
376
+ def __getitem__(self, i):
377
+ spec_i = self.spec_cand[i][0]
378
+ cand_smiles = self.spec_cand[i][1]
379
+ label = self.spec_cand[i][2]
380
+
381
+ item = self.instance.__getitem__(spec_i, transform_mol=False)
382
+ item['cand'] = self.mol_transform(cand_smiles)
383
+ item['cand_smiles'] = cand_smiles
384
+ item['label'] = label
385
+
386
+ if self.use_fp:
387
+ item['fp'] = torch.Tensor(self.fpgen.GetFingerprint(Chem.MolFromSmiles(cand_smiles)).ToList())
388
+
389
+ return item
390
+
391
+ class MassSpecDataset_Candidates:
392
+
393
+ def __init__(self,
394
+ use_formulas: bool,
395
+ aug_cands_dir_pth: str,
396
+ aug_cands_size: int,
397
+ **kwargs):
398
+ self.aug_cands_size = aug_cands_size
399
+ self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
400
+
401
+ with open(aug_cands_dir_pth, 'rb') as f:
402
+ aug_cands = pickle.load(f)
403
+
404
+ if self.use_fp:
405
+ self.fpgen = AllChem.GetMorganGenerator(radius=5,fpSize=1024)
406
+
407
+ self.aug_cands = {}
408
+ targets = np.array(list(aug_cands.keys()))
409
+ for smiles, cands in aug_cands.items():
410
+ # sort candidates by tanimoto similarity
411
+ cands.sort(key=lambda x: x[1], reverse=True)
412
+ cands = [c for c in cands if '.' not in c]
413
+ # assert(len(cands) >0)
414
+ if len(cands) <=1: # if no candidates, shuffle from target list
415
+ np.random.shuffle(targets)
416
+ cands = targets
417
+ self.aug_cands[smiles] = itertools.cycle(cands)
418
+
419
+ def __getattr__(self, name):
420
+ return self.instance.__getattribute__(name)
421
+
422
+ def __getitem__(self, i):
423
+ item = self.instance.__getitem__(i,transform_mol=False)
424
+
425
+ aug_cands = [next(self.aug_cands[item['mol']]) for _ in range(self.aug_cands_size)]
426
+ item['aug_cands_fp'] = [self.fpgen.GetFingerprint(Chem.MolFromSmiles(c)).ToList() for c in aug_cands]
427
+ item["aug_cands"] = [self.mol_transform(c) for c in aug_cands]
428
+ item["mol"] = self.mol_transform(item["mol"])
429
+
430
+ return item
mvp/data/transforms.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matchms
4
+ from typing import Optional
5
+ from rdkit.Chem import AllChem as Chem
6
+ from mvp.definitions import CHEM_ELEMS_SMALL
7
+ from massspecgym.data.transforms import MolTransform, SpecTransform, default_matchms_transforms
8
+ from massspecgym.data.transforms import SpecBinner
9
+
10
+ import dgllife.utils as chemutils
11
+ import re
12
+
13
+ class SpecBinnerLog(SpecTransform):
14
+ def __init__(
15
+ self,
16
+ max_mz: float = 1005,
17
+ bin_width: float = 1,
18
+ ) -> None:
19
+ self.max_mz = max_mz
20
+ self.bin_width = bin_width
21
+ if not (max_mz / bin_width).is_integer():
22
+ raise ValueError("`max_mz` must be divisible by `bin_width`.")
23
+
24
+ def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
25
+ return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None)
26
+
27
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
28
+ """
29
+ Bin the spectrum into a fixed number of bins.
30
+ """
31
+ binned_spec = self._bin_mass_spectrum(
32
+ mzs=spec.peaks.mz,
33
+ intensities=spec.peaks.intensities,
34
+ max_mz=self.max_mz,
35
+ bin_width=self.bin_width,
36
+ )
37
+ return torch.from_numpy(binned_spec).to(dtype=torch.float32)
38
+
39
+ def _bin_mass_spectrum(
40
+ self, mzs, intensities, max_mz, bin_width
41
+ ):
42
+
43
+ # Calculate the number of bins
44
+ num_bins = int(np.ceil(max_mz / bin_width))
45
+
46
+ # Calculate the bin indices for each mass
47
+ bin_indices = np.floor(mzs -1 / bin_width).astype(int)
48
+
49
+ # Filter out mzs that exceed max_mz
50
+ valid_indices = bin_indices[mzs <= max_mz]
51
+ valid_intensities = intensities[mzs <= max_mz]
52
+
53
+ # Clip bin indices to ensure they are within the valid range
54
+ valid_indices = np.clip(valid_indices, 0, num_bins - 1)
55
+
56
+ # Initialize an array to store the binned intensities
57
+ binned_intensities = np.zeros(num_bins)
58
+
59
+ # Use np.add.at to sum intensities in the appropriate bins
60
+ np.add.at(binned_intensities, valid_indices, valid_intensities)
61
+
62
+ binned_intensities = binned_intensities/np.max(binned_intensities) * 999
63
+
64
+ binned_intensities = np.log10(binned_intensities + 1) / 3
65
+
66
+ return binned_intensities
67
+
68
+ class SpecMzIntTokenizer(SpecTransform):
69
+ def __init__(self, max_mz, mz_mean_std=None, mask_precursor=None):
70
+ self.max_mz = max_mz
71
+ self.mz_mean_std = mz_mean_std
72
+ def matchms_transforms(self, spec: matchms.Spectrum):
73
+ return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None)
74
+
75
+ def matchms_to_torch(self, spec: matchms.Spectrum):
76
+ mzs = spec.peaks.mz
77
+ intensities = spec.peaks.intensities
78
+ spec = np.zeros((len(mzs), 2))
79
+
80
+ if self.mz_mean_std:
81
+ mz = (mzs-self.mz_mean_std['mz_mean'])/self.mz_mean_std['mz_std']
82
+ else:
83
+ mz = mzs/self.max_mz
84
+
85
+ spec[:, 0] = mz
86
+ spec[:, 1] = intensities
87
+
88
+ return torch.from_numpy(spec.astype(np.float32))
89
+
90
+ class SpecFormulaMzFeaturizer(SpecTransform):
91
+ ''' Uses raw mz and intensities '''
92
+
93
+ def __init__(
94
+ self,
95
+ add_intensities: bool,
96
+ max_mz: float = 1005,
97
+ element_list: list = CHEM_ELEMS_SMALL,
98
+ formula_normalize_vector: Optional[np.array] = None,
99
+ mz_mean_std: dict[str, float] = None,
100
+ mask_precursor: bool = False,
101
+ ) -> None:
102
+ self.max_mz = max_mz
103
+ self.elem_to_pos = {e: i for i, e in enumerate(element_list)}
104
+ if formula_normalize_vector is None:
105
+ formula_normalize_vector = np.ones(len(element_list))
106
+ self.formula_normalize_vector = formula_normalize_vector
107
+ self.CHEM_FORMULA_SIZE = "([A-Z][a-z]*)([0-9]*)"
108
+ self.mz_mean_std = mz_mean_std
109
+ self.add_intensities = add_intensities
110
+ self.mask_precursor = mask_precursor
111
+
112
+ def matchms_transforms(self, spec: matchms.Spectrum):
113
+ return spec
114
+
115
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
116
+ mzs = spec.peaks.mz
117
+ intensities = spec.peaks.intensities
118
+ formulas = spec.metadata['formulas'] # mz to formula dict
119
+
120
+ peak_idx = np.where(mzs <= self.max_mz)[0]
121
+ mzs = mzs[peak_idx]
122
+ intensities = intensities[peak_idx]
123
+ formulas = [formulas.get(mz, "NA") for mz in mzs[peak_idx]]
124
+
125
+ if self.mask_precursor:
126
+ try:
127
+ precursor_i = formulas.index(spec.metadata['precursor_formula'])
128
+ formulas[precursor_i] = 'NA'
129
+ except:
130
+ pass
131
+
132
+ formulas = self._featurize_formula(formulas)
133
+ formulas = formulas/self.formula_normalize_vector
134
+
135
+ if self.mz_mean_std:
136
+ mz = (mzs-self.mz_mean_std['mz_mean'])/self.mz_mean_std['mz_std']
137
+ else:
138
+ mz = mzs/self.max_mz
139
+
140
+ if self.add_intensities:
141
+ spec = np.concatenate((mz.reshape(-1,1), intensities.reshape(-1,1), formulas), axis=1)
142
+ else:
143
+ spec = np.concatenate((mz.reshape(-1,1), formulas), axis=1)
144
+
145
+ return torch.from_numpy(spec)
146
+
147
+ def _featurize_formula(self, formulas):
148
+ formula_vector = np.zeros((len(formulas), len(self.elem_to_pos)))
149
+ for i, f in enumerate(formulas):
150
+ if f == "NA":
151
+ # formula_vector[i] = np.zeros((1, len(self.elem_to_pos)))
152
+ formula_vector[i] = np.ones((1, len(self.elem_to_pos))) * -1
153
+
154
+ else:
155
+ for (e, ct) in re.findall(self.CHEM_FORMULA_SIZE, f):
156
+ ct = 1 if ct == "" else int(ct)
157
+ try:
158
+ formula_vector[i][self.elem_to_pos[e]]+=ct
159
+ except:
160
+ # print(f"Couldn't vectorize {f}, element {e} not supported")
161
+ continue
162
+ return formula_vector
163
+
164
+ class SpecFormulaFeaturizer(SpecTransform):
165
+ ''' Uses processed mz and intensities, excludes mz values, keep peaks with formulas only'''
166
+ def __init__(
167
+ self,
168
+ add_intensities: bool,
169
+ max_mz: float = 1005,
170
+ element_list: list = CHEM_ELEMS_SMALL,
171
+ formula_normalize_vector: Optional[np.array] = None
172
+ ) -> None:
173
+ self.max_mz = max_mz
174
+ self.elem_to_pos = {e: i for i, e in enumerate(element_list)}
175
+ self.add_intensities = add_intensities
176
+ if formula_normalize_vector is None:
177
+ formula_normalize_vector = np.ones(len(element_list))
178
+ self.formula_normalize_vector = formula_normalize_vector
179
+ self.CHEM_FORMULA_SIZE = "([A-Z][a-z]*)([0-9]*)"
180
+
181
+ def matchms_transforms(self, spec: matchms.Spectrum):
182
+ return spec
183
+
184
+ def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
185
+ mzs = spec.peaks.mz
186
+ intensities = spec.peaks.intensities
187
+ formulas = spec.metadata['formulas'] # list of formulas
188
+
189
+ peak_idx = np.where(mzs <= self.max_mz)[0]
190
+ intensities = intensities[peak_idx]
191
+ formulas = formulas[peak_idx]
192
+
193
+ spec = self._featurize_formula(formulas)
194
+ spec = spec/self.formula_normalize_vector
195
+
196
+ if self.add_intensities:
197
+ spec = np.concatenate((spec, intensities.reshape(-1,1)), axis=1)
198
+ spec = spec.astype(np.float32)
199
+
200
+ return torch.from_numpy(spec)
201
+
202
+ def _featurize_formula(self, formulas):
203
+ formula_vector = np.zeros((len(formulas), len(self.elem_to_pos)))
204
+ for i, f in enumerate(formulas):
205
+ try:
206
+ for (e, ct) in re.findall(self.CHEM_FORMULA_SIZE, f):
207
+ ct = 1 if ct == "" else int(ct)
208
+ try:
209
+ formula_vector[i][self.elem_to_pos[e]]+=ct
210
+ except:
211
+ print(f"Couldn't vectorize {f}, element {e} not supported")
212
+ continue
213
+ except:
214
+ print(f"Couldn't vectorize {f}, formula not supported")
215
+ continue
216
+ return formula_vector
217
+
218
+ class MolToGraph(MolTransform):
219
+ def __init__ (self, atom_feature: str = "full", bond_feature: str = "full", element_list: list = CHEM_ELEMS_SMALL):
220
+ self.atom_feature = atom_feature
221
+ self.bond_feature = bond_feature
222
+ self.node_featurizer = self._get_atom_featurizer(element_list=element_list)
223
+ self.edge_featurizer = self._get_bond_featurizer()
224
+
225
+ def from_smiles(self, mol:str):
226
+ mol = Chem.MolFromSmiles(mol)
227
+ g = chemutils.mol_to_bigraph(mol, node_featurizer=self.node_featurizer, edge_featurizer=self.edge_featurizer, add_self_loop = True,
228
+ num_virtual_nodes = 0, canonical_atom_order=False)
229
+
230
+ # atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] # added for visualization
231
+ # g.ndata['atom_id'] = torch.tensor(atom_ids, dtype=torch.long)
232
+
233
+ return g
234
+
235
+ def _get_atom_featurizer(self, element_list) -> dict:
236
+ feature_mode = self.atom_feature
237
+ atom_mass_fun = chemutils.ConcatFeaturizer(
238
+ [chemutils.atom_mass]
239
+ )
240
+ def atom_bond_type_one_hot(atom):
241
+ bs = atom.GetBonds()
242
+ bt = np.array([chemutils.bond_type_one_hot(b) for b in bs])
243
+ return [any(bt[:, i]) for i in range(bt.shape[1])]
244
+
245
+ def atom_type_one_hot(atom):
246
+ return chemutils.atom_type_one_hot(
247
+ atom, allowable_set = element_list, encode_unknown = True
248
+ )
249
+
250
+ if feature_mode == 'light':
251
+ atom_featurizer_funs = chemutils.ConcatFeaturizer([
252
+ chemutils.atom_mass,
253
+ atom_type_one_hot
254
+ ])
255
+ elif feature_mode == 'full':
256
+ atom_featurizer_funs = chemutils.ConcatFeaturizer([
257
+ chemutils.atom_mass,
258
+ atom_type_one_hot,
259
+ atom_bond_type_one_hot,
260
+ chemutils.atom_degree_one_hot,
261
+ chemutils.atom_total_degree_one_hot,
262
+ chemutils.atom_explicit_valence_one_hot,
263
+ chemutils.atom_implicit_valence_one_hot,
264
+ chemutils.atom_hybridization_one_hot,
265
+ chemutils.atom_total_num_H_one_hot,
266
+ chemutils.atom_formal_charge_one_hot,
267
+ chemutils.atom_num_radical_electrons_one_hot,
268
+ chemutils.atom_is_aromatic_one_hot,
269
+ chemutils.atom_is_in_ring_one_hot,
270
+ chemutils.atom_chiral_tag_one_hot
271
+ ])
272
+ elif feature_mode == 'medium':
273
+ atom_featurizer_funs = chemutils.ConcatFeaturizer([
274
+ chemutils.atom_mass,
275
+ atom_type_one_hot,
276
+ atom_bond_type_one_hot,
277
+ chemutils.atom_total_degree_one_hot,
278
+ chemutils.atom_total_num_H_one_hot,
279
+ chemutils.atom_is_aromatic_one_hot,
280
+ chemutils.atom_is_in_ring_one_hot,
281
+ ])
282
+ return chemutils.BaseAtomFeaturizer(
283
+ {"h": atom_featurizer_funs,
284
+ "m": atom_mass_fun}
285
+ )
286
+
287
+ def _get_bond_featurizer(self, self_loop=True) -> dict:
288
+ feature_mode = self.bond_feature
289
+ if feature_mode == 'light':
290
+ return chemutils.BaseBondFeaturizer(
291
+ featurizer_funcs = {'e': chemutils.ConcatFeaturizer([
292
+ chemutils.bond_type_one_hot
293
+ ])}, self_loop = self_loop
294
+ )
295
+ elif feature_mode == 'full':
296
+ return chemutils.CanonicalBondFeaturizer(
297
+ bond_data_field='e', self_loop = self_loop
298
+ )
mvp/data_preprocess.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from mvp.utils.preprocessing import generate_cons_spec_formulas, generate_cons_spec
3
+ import os
4
+ import pickle
5
+ import pandas as pd
6
+ from rdkit.Chem import AllChem
7
+ from rdkit import Chem
8
+ from tqdm import tqdm
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--spec_type", choices=('formSpec', 'binnedSpec'), required=True)
12
+ parser.add_argument("--dataset_pth", required=True, help="path to spectra data")
13
+ parser.add_argument("--candidates_pth", required=True, help="path to candidates data")
14
+ parser.add_argument("--output_dir", required=True, help="path to output directory")
15
+ parser.add_argument("--subformula_dir_pth", default='', help="path to subformula directory if using formSpec")
16
+
17
+
18
+ def check_args():
19
+
20
+ # create output directory
21
+ os.makedirs(args.output_dir, exist_ok=True)
22
+
23
+ # check files
24
+ if args.spec_type == 'formSpec':
25
+ assert(os.path.isdir(args.subformula_dir_pth))
26
+
27
+ assert(os.path.exists(args.dataset_pth))
28
+ assert(os.path.exists(args.candidates_pth))
29
+
30
+ def construct_smiles_to_fp(smiles_list, r=5, fp_size=1024):
31
+ fpgen = AllChem.GetMorganGenerator(radius=r,fpSize=fp_size)
32
+ smiles_to_fp = {}
33
+ failed_ct = 0
34
+
35
+ for s in tqdm(smiles_list, total=len(smiles_list)):
36
+ try:
37
+ mol = Chem.MolFromSmiles(s)
38
+ fp = fpgen.GetFingerprint(mol)
39
+ smiles_to_fp[s] = fp
40
+ except:
41
+ failed_ct+=1
42
+ print(f'Failed to generate fingerprints for {failed_ct} smiles')
43
+
44
+ # save smiles_to_fp
45
+ with open(os.path.join(args.output_dir, f'morganfp_r{r}_{fp_size}.pickle'), 'wb') as f:
46
+ pickle.dump(smiles_to_fp, f)
47
+
48
+ def construct_consensus_spectra():
49
+ if args.spec_type == 'formSpec':
50
+ df = generate_cons_spec_formulas(args.dataset_pth, args.subformula_dir_pth, args.output_dir)
51
+ elif args.spec_type == 'binnedSpec':
52
+ df = generate_cons_spec(args.dataset_pth, args.output_dir)
53
+
54
+ # save consensus spectra df
55
+ with open(os.path.join(args.output_dir, f'consensus_{args.spec_type}.pkl'), 'wb') as f:
56
+ pickle.dump(df, f)
57
+
58
+ def main(data):
59
+
60
+ # generate fingerpints
61
+ print("Processing fingerprints...")
62
+ unique_smiles = data['smiles'].unique().tolist()
63
+ construct_smiles_to_fp(unique_smiles)
64
+
65
+ # generate consensus spectra
66
+ print("Processring consensus spectra...")
67
+ construct_consensus_spectra()
68
+
69
+
70
+ if __name__ == '__main__':
71
+ args = parser.parse_args([] if "__file__" not in globals() else None)
72
+
73
+ check_args()
74
+
75
+ # load data
76
+ data = pd.read_csv(args.dataset_pth, sep='\t')
77
+
78
+ main(data)
mvp/definitions.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Global variables used across the package."""
2
+ import pathlib
3
+
4
+ # Dirs
5
+ ROOT_DIR = pathlib.Path(__file__).parent.absolute()
6
+ REPO_DIR = ROOT_DIR.parent
7
+ DATA_DIR = REPO_DIR / 'data'
8
+ TEST_RESULTS_DIR = REPO_DIR / 'experiments'
9
+ ASSETS_DIR = REPO_DIR / 'assets'
10
+
11
+ # C
12
+ # CHEM_ELEMS_SMALL = ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I']
13
+ CHEM_ELEMS_SMALL = ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
14
+
15
+ MSGYM_FORMULA_VECTOR_NORM = [102.0, 59.0, 25.0, 13.0, 3.0, 6.0, 6.0, 17.0, 4.0, 4.0, 1.0, 1.0, 5.0, 2.0]
16
+ # MSGYM_FORMULA_VECTOR_NORM = [102.0, 59.0, 25.0, 13.0, 3.0, 6.0, 6.0, 17.0, 4.0, 4.0]
17
+ MSGYM_FORMULA_STANDARD = {
18
+ 'formula_norm':[6.53758314e+00, 6.26973237e+00,
19
+ 8.90610447e-01, 4.73889402e-01, 2.31793513e-02, 3.56956333e-02,
20
+ 2.78056172e-02, 3.28356898e-02, 2.19480328e-03, 1.58458297e-03,
21
+ 2.34802165e-05, 1.71127001e-05, 1.71127001e-04, 1.69800435e-05],
22
+ 'formula_norm': [9.68749281e+00, 7.46795232e+00,
23
+ 1.75427539e+00, 9.81685190e-01, 1.52363430e-01, 2.01197446e-01,
24
+ 1.93046421e-01, 2.65309185e-01, 4.82433244e-02, 5.23009413e-02,
25
+ 4.84558202e-03, 4.13671455e-03, 1.51218609e-02, 5.02040474e-03],
26
+ 'formula_max':[102. , 59. ,
27
+ 25. , 13. , 3. , 6. ,
28
+ 6. , 17. , 4. , 4. ,
29
+ 1. , 1. , 5. , 2. ],
30
+ 'formula_min' :[0. , 0. , 0. , 0. ,
31
+ 0. , 0. , 0. , 0. , 0. ,
32
+ 0. , 0. , 0. , 0. , 0. ,
33
+ 0. ]
34
+ }
35
+
36
+ #MSGYM standardization
37
+ MSGYM_STANDARD_MH = {
38
+ 'mz_mean': 195.155185,
39
+ 'mz_std':127.591549
40
+ }
41
+ MSGYM_STANDARD_all = { # got these from Yinkai
42
+ "mz_mean": 80.88304948022557,
43
+ "mz_std" : 197.4588028571758}
mvp/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "/data/yzhouc01//MassSpecGym")
3
+ from massspecgym.models import *
mvp/models/contrastive.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ import numpy as np
7
+ import os
8
+ from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
9
+ from massspecgym.models.base import Stage
10
+ from massspecgym import utils
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ from mvp.utils.loss import contrastive_loss, cand_spec_sim_loss, fp_loss, cons_spec_loss, filip_loss_with_mask
14
+ import mvp.utils.models as model_utils
15
+ from mvp.utils.general import pad_graph_nodes, filip_similarity_batch
16
+
17
+ from mvp.models.encoders import CrossAttention
18
+ import torch.nn.functional as F
19
+
20
+ from torch_geometric.nn import global_mean_pool
21
+
22
+ class ContrastiveModel(RetrievalMassSpecGymModel):
23
+ def __init__(
24
+ self,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.save_hyperparameters()
29
+
30
+ if 'use_fp' not in self.hparams:
31
+ self.hparams.use_fp = False
32
+ if 'use_fp' not in self.hparams:
33
+ self.hparams.use_fp = False
34
+ if 'use_NL_spec' not in self.hparams:
35
+ self.hparams.use_NL_spec = False
36
+
37
+ if 'loss_strategy' not in self.hparams:
38
+ self.hparams.loss_strategy = 'static'
39
+ self.hparams.contr_wt = 1.0
40
+ self.hparams.use_contr = True
41
+
42
+ self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
43
+ self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
44
+
45
+ # setup loss strategy
46
+ if self.hparams.model == 'contrastive':
47
+ self._loss_setup()
48
+ if self.hparams.pred_fp:
49
+ self.fp_loss = fp_loss(self.hparams.fp_loss_type)
50
+ self.fp_pred_model = model_utils.get_fp_pred_model(self.hparams)
51
+ if self.hparams.use_cons_spec:
52
+ self.cons_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
53
+ self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
54
+
55
+ self.spec_view = self.hparams.spectra_view
56
+
57
+ # result storage for testing results
58
+ self.result_dct = defaultdict(lambda: defaultdict(list))
59
+
60
+
61
+ def _loss_setup(self):
62
+ self.loss_wts = {}
63
+ self.loss_updates = {}
64
+
65
+
66
+ for p, loss in zip(['use_contr','pred_fp', 'use_cons_spec', 'aug_cands'], ['contr_wt','fp_wt','cons_spec_wt' ,'aug_cands_wt']):
67
+ if p not in self.hparams:
68
+ self.hparams[p] = False
69
+ if self.hparams[p]:
70
+ if self.hparams.loss_strategy == 'linear':
71
+ start_wt = self.hparams[loss+'_update']['start']
72
+ end_wt = self.hparams[loss+'_update']['end']
73
+ change = (end_wt - start_wt)/self.hparams.max_epochs
74
+ self.loss_updates[loss] = change
75
+ self.loss_wts[loss] = start_wt
76
+ elif self.hparams.loss_strategy == 'manual':
77
+ self.loss_updates[loss] = self.hparams[loss+'_update']
78
+ self.loss_wts[loss] = self.hparams[loss]
79
+ else:
80
+ self.loss_wts[loss] = self.hparams[loss]
81
+
82
+ def forward(self, batch, stage):
83
+ g = batch['cand'] if stage == Stage.TEST else batch['mol']
84
+
85
+ if self.hparams.use_cons_spec and stage != Stage.TEST:
86
+ spec = batch['cons_spec']
87
+ n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
88
+ spec_enc = self.cons_spec_enc_model(spec, n_peaks)
89
+ else:
90
+ spec = batch[self.spec_view]
91
+ n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
92
+ spec_enc = self.spec_enc_model(spec, n_peaks)
93
+
94
+ fp = batch['fp'] if self.hparams.use_fp else None
95
+ mol_enc = self.mol_enc_model(g, fp=fp)
96
+
97
+ return spec_enc, mol_enc
98
+
99
+ def compute_loss(self, batch: dict, spec_enc, mol_enc, output):
100
+ loss = 0
101
+ losses = {}
102
+ contr_loss, cong_loss, noncong_loss = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp)
103
+ contr_loss = self.loss_wts['contr_wt'] *contr_loss
104
+ losses['contr_loss'] = contr_loss.detach().item()
105
+ losses['cong_loss'] = cong_loss.detach().item()
106
+ losses['noncong_loss'] = noncong_loss.detach().item()
107
+
108
+ loss+=contr_loss
109
+ if self.hparams.pred_fp:
110
+ fp_loss_val = self.loss_wts['fp_wt'] *self.fp_loss(output['fp'], batch['fp'])
111
+ loss+= fp_loss_val
112
+ losses['fp_loss'] = fp_loss_val.detach().item()
113
+
114
+ if 'aug_cand_enc' in output:
115
+ aug_cand_loss = self.loss_wts['aug_cand_wt'] * cand_spec_sim_loss(spec_enc, output['aug_cand_enc'])
116
+ loss+= aug_cand_loss
117
+ losses['aug_cand_loss'] = aug_cand_loss.detach().item()
118
+
119
+ if 'ind_spec' in output:
120
+ spec_loss = self.loss_wts['cons_spec_wt'] * self.cons_loss(spec_enc, output['ind_spec'])
121
+ loss+=spec_loss
122
+ losses['cons_spec_loss'] = spec_loss.detach().item()
123
+
124
+ losses['loss'] = loss
125
+
126
+ return losses
127
+
128
+ def step(
129
+ self, batch: dict, stage= Stage.NONE):
130
+
131
+ # Compute spectra and mol encoding
132
+ spec_enc, mol_enc = self.forward(batch, stage)
133
+
134
+ if stage == Stage.TEST:
135
+ return dict(spec_enc=spec_enc, mol_enc=mol_enc)
136
+
137
+ # Aux tasks
138
+ output = {}
139
+ if self.hparams.pred_fp:
140
+ output['fp'] = self.fp_pred_model(mol_enc)
141
+
142
+ if self.hparams.use_cons_spec:
143
+ spec = batch[self.spec_view]
144
+ n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
145
+ output['ind_spec'] = self.spec_enc_model(spec, n_peaks)
146
+
147
+ # Calculate loss
148
+ losses = self.compute_loss(batch, spec_enc, mol_enc, output)
149
+
150
+ return losses
151
+
152
+ def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
153
+ # total loss
154
+ self.log(
155
+ f'{stage.to_pref()}loss',
156
+ outputs['loss'],
157
+ batch_size=len(batch['identifier']),
158
+ sync_dist=True,
159
+ prog_bar=True,
160
+ on_epoch=True,
161
+ # on_step=True
162
+ )
163
+
164
+ # contr loss
165
+ if self.hparams.use_contr:
166
+ self.log(
167
+ f'{stage.to_pref()}contr_loss',
168
+ outputs['contr_loss'],
169
+ batch_size=len(batch['identifier']),
170
+ sync_dist=True,
171
+ prog_bar=False,
172
+ on_epoch=True,
173
+ # on_step=True
174
+ )
175
+
176
+ # noncongruent pairs
177
+ self.log(
178
+ f'{stage.to_pref()}noncong_loss',
179
+ outputs['noncong_loss'],
180
+ batch_size=len(batch['identifier']),
181
+ sync_dist=True,
182
+ prog_bar=False,
183
+ on_epoch=True,
184
+ # on_step=True
185
+ )
186
+
187
+ # congruent pairs
188
+ self.log(
189
+ f'{stage.to_pref()}cong_loss',
190
+ outputs['cong_loss'],
191
+ batch_size=len(batch['identifier']),
192
+ sync_dist=True,
193
+ prog_bar=False,
194
+ on_epoch=True,
195
+ # on_step=True
196
+ )
197
+
198
+
199
+ if self.hparams.pred_fp:
200
+
201
+ self.log(
202
+ f'{stage.to_pref()}_fp_loss',
203
+ outputs['fp_loss'],
204
+ batch_size=len(batch['identifier']),
205
+ sync_dist=True,
206
+ prog_bar=False,
207
+ on_epoch=True,
208
+ )
209
+
210
+ if self.hparams.use_cons_spec:
211
+ self.log(
212
+ f'{stage.to_pref()}cons_loss',
213
+ outputs['cons_spec_loss'],
214
+ batch_size=len(batch['identifier']),
215
+ sync_dist=True,
216
+ prog_bar=False,
217
+ on_epoch=True,
218
+ )
219
+
220
+ def test_step(self, batch):
221
+ # Unpack inputs
222
+ identifiers = batch['identifier']
223
+ cand_smiles = batch['cand_smiles']
224
+ id_to_ct = defaultdict(int)
225
+ for i in identifiers: id_to_ct[i]+=1
226
+ batch_ptr = torch.tensor(list(id_to_ct.values()))
227
+
228
+ outputs = self.step(batch, stage=Stage.TEST)
229
+ spec_enc = outputs['spec_enc']
230
+ mol_enc = outputs['mol_enc']
231
+
232
+ # Calculate scores
233
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
234
+
235
+ scores = nn.functional.cosine_similarity(spec_enc, mol_enc)
236
+ scores = torch.split(scores, list(id_to_ct.values()))
237
+
238
+ cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
239
+ labels = utils.unbatch_list(batch['label'], indexes)
240
+
241
+ return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
242
+
243
+ def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
244
+
245
+ # save scores
246
+ for i, cands, scores, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores'], outputs['labels']):
247
+ self.result_dct[i]['candidates'].extend(cands)
248
+ self.result_dct[i]['scores'].extend(scores.cpu().tolist())
249
+ self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
250
+
251
+ def _compute_rank(self, scores, labels):
252
+ if not any(labels):
253
+ return -1
254
+ scores = np.array(scores)
255
+ target_score = scores[labels][0]
256
+ rank = np.count_nonzero(scores >=target_score)
257
+ return rank
258
+
259
+ def on_test_epoch_end(self) -> None:
260
+
261
+ self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
262
+
263
+ # Compute rank
264
+ self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
265
+ if not self.df_test_path:
266
+ self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
267
+ # self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
268
+ self.df_test.to_pickle(self.df_test_path)
269
+
270
+ def get_checkpoint_monitors(self) -> T.List[dict]:
271
+ monitors = [
272
+ {"monitor": f"{Stage.TRAIN.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor train loss
273
+ ]
274
+ return monitors
275
+
276
+ def _update_loss_weights(self)-> None:
277
+ if self.hparams.loss_strategy == 'linear':
278
+ for loss in self.loss_wts:
279
+ self.loss_wts[loss] += self.loss_updates[loss]
280
+ elif self.hparams.loss_strategy == 'manual':
281
+ for loss in self.loss_wts:
282
+ if self.current_epoch in self.loss_updates[loss]:
283
+ self.loss_wts[loss] = self.loss_updates[loss][self.current_epoch]
284
+
285
+ def on_train_epoch_end(self) -> None:
286
+ self._update_loss_weights()
287
+
288
+ class MultiViewContrastive(ContrastiveModel):
289
+
290
+ def __init__(self,
291
+ **kwargs):
292
+
293
+ super().__init__(**kwargs)
294
+
295
+ # build fingerprint encoder model
296
+ if self.hparams.use_fp:
297
+ self.fp_enc_model = model_utils.get_fp_enc_model(self.hparams)
298
+
299
+ # build NL encoder model
300
+ if self.hparams.use_NL_spec:
301
+ self.NL_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
302
+
303
+ def forward(self, batch, stage):
304
+ g = batch['cand'] if stage == Stage.TEST else batch['mol']
305
+
306
+ spec = batch[self.spec_view]
307
+ n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
308
+
309
+ spec_enc = self.spec_enc_model(spec, n_peaks)
310
+ mol_enc = self.mol_enc_model(g)
311
+ views = {'spec_enc': spec_enc, 'mol_enc': mol_enc}
312
+
313
+ if self.hparams.use_fp:
314
+ fp_enc = self.fp_enc_model(batch['fp'])
315
+ views['fp_enc'] = fp_enc
316
+
317
+ if self.hparams.use_cons_spec:
318
+ spec = batch['cons_spec']
319
+ n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
320
+ spec_enc = self.cons_spec_enc_model(spec, n_peaks)
321
+ views['cons_spec_enc'] = spec_enc
322
+
323
+ if self.hparams.use_NL_spec:
324
+ spec = batch['NL_spec']
325
+ n_peaks = batch['NL_n_peaks'] if 'NL_n_peaks' in batch else None
326
+ spec_enc = self.NL_enc_model(spec, n_peaks)
327
+ views['NL_spec_enc'] = spec_enc
328
+ return views
329
+
330
+ def step(
331
+ self, batch: dict, stage= Stage.NONE):
332
+
333
+ # Compute spectra and mol encoding
334
+ views = self.forward(batch, stage)
335
+
336
+ if stage == Stage.TEST:
337
+ return views
338
+
339
+ # Calculate loss
340
+ losses = self.compute_loss(batch, views)
341
+
342
+ return losses
343
+
344
+ def compute_loss(self, batch: dict, views: dict):
345
+ loss = 0
346
+ losses = {}
347
+ for v1, v2 in self.hparams.contr_views:
348
+ contr_loss, cong_loss, noncong_loss = contrastive_loss(views[v1], views[v2], self.hparams.contr_temp)
349
+ loss+=contr_loss
350
+
351
+ losses[f'{v1[:-4]}-{v2[:-4]}_contr_loss'] = contr_loss.detach().item()
352
+ losses[f'{v1[:-4]}-{v2[:-4]}_cong_loss'] = cong_loss.detach().item()
353
+ losses[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'] = noncong_loss.detach().item()
354
+
355
+ losses['loss'] = loss
356
+
357
+ return losses
358
+
359
+ def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
360
+ # total loss
361
+ self.log(
362
+ f'{stage.to_pref()}loss',
363
+ outputs['loss'],
364
+ batch_size=len(batch['identifier']),
365
+ sync_dist=True,
366
+ prog_bar=True,
367
+ on_epoch=True,
368
+ # on_step=True
369
+ )
370
+
371
+ for v1, v2 in self.hparams.contr_views:
372
+ self.log(
373
+ f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_contr_loss',
374
+ outputs[f'{v1[:-4]}-{v2[:-4]}_contr_loss'],
375
+ batch_size=len(batch['identifier']),
376
+ sync_dist=True,
377
+ on_epoch=True,
378
+ )
379
+ self.log(
380
+ f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_cong_loss',
381
+ outputs[f'{v1[:-4]}-{v2[:-4]}_cong_loss'],
382
+ batch_size=len(batch['identifier']),
383
+ sync_dist=True,
384
+ on_epoch=True,
385
+ )
386
+ self.log(
387
+ f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_noncong_loss',
388
+ outputs[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'],
389
+ batch_size=len(batch['identifier']),
390
+ sync_dist=True,
391
+ on_epoch=True,
392
+ )
393
+
394
+ def test_step(self, batch):
395
+ # Unpack inputs
396
+ identifiers = batch['identifier']
397
+ cand_smiles = batch['cand_smiles']
398
+ id_to_ct = defaultdict(int)
399
+ for i in identifiers: id_to_ct[i]+=1
400
+ batch_ptr = torch.tensor(list(id_to_ct.values()))
401
+
402
+ outputs = self.step(batch, stage=Stage.TEST)
403
+ scores = {}
404
+ for v1, v2 in self.hparams.contr_views:
405
+ # if 'cons_spec_enc' in (v1, v2):
406
+ # continue
407
+ v1_enc = outputs[v1]
408
+ v2_enc = outputs[v2]
409
+
410
+ s = nn.functional.cosine_similarity(v1_enc, v2_enc)
411
+ scores[f'{v1[:-4]}-{v2[:-4]}_scores'] = torch.split(s, list(id_to_ct.values()))
412
+
413
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
414
+
415
+ cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
416
+ labels = utils.unbatch_list(batch['label'], indexes)
417
+
418
+ return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
419
+
420
+ def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
421
+
422
+ # save scores
423
+ for i, cands, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels']):
424
+ self.result_dct[i]['candidates'].extend(cands)
425
+ self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
426
+
427
+ for v1, v2 in self.hparams.contr_views:
428
+ for i, scores in zip(outputs['identifiers'], outputs['scores'][f'{v1[:-4]}-{v2[:-4]}_scores']):
429
+ self.result_dct[i][f'{v1[:-4]}-{v2[:-4]}_scores'].extend(scores.cpu().tolist())
430
+
431
+
432
+ def on_test_epoch_end(self) -> None:
433
+
434
+ self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
435
+
436
+ # Compute rank
437
+ for v1, v2 in self.hparams.contr_views:
438
+ self.df_test[f'{v1[:-4]}-{v2[:-4]}_rank'] = self.df_test.apply(lambda row: self._compute_rank(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['labels']), axis=1)
439
+
440
+ self.df_test.to_pickle(self.df_test_path)
441
+
442
+ class FilipContrastive(ContrastiveModel):
443
+ def __init__(self,
444
+ **kwargs):
445
+
446
+ super().__init__(**kwargs)
447
+
448
+ def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask):
449
+ losses = {}
450
+
451
+ loss = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp)
452
+
453
+ losses['loss'] = loss
454
+
455
+ return losses
456
+
457
+ def step(
458
+ self, batch: dict, stage= Stage.NONE):
459
+
460
+ # Compute spectra and mol encoding
461
+ spec_enc, mol_enc = self.forward(batch, stage)
462
+
463
+ # pad nodes to max_n_nodes in batch (Spectra are already padded)
464
+ mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch['mol_n_nodes'])
465
+ spec_mask = ~torch.all((spec_enc == -5), dim=-1)
466
+
467
+ if stage == Stage.TEST:
468
+ return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask)
469
+
470
+ # Calculate loss
471
+ losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask)
472
+
473
+ return losses
474
+
475
+ def test_step(self, batch):
476
+ # Unpack inputs
477
+ identifiers = batch['identifier']
478
+ cand_smiles = batch['cand_smiles']
479
+ id_to_ct = defaultdict(int)
480
+ for i in identifiers: id_to_ct[i]+=1
481
+ batch_ptr = torch.tensor(list(id_to_ct.values()))
482
+
483
+ outputs = self.step(batch, stage=Stage.TEST)
484
+ spec_enc = outputs['spec_enc']
485
+ mol_enc = outputs['mol_enc']
486
+ spec_mask = outputs['spec_mask']
487
+ mol_mask = outputs['mol_mask']
488
+
489
+ # Calculate scores
490
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
491
+
492
+ scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
493
+ scores = torch.split(scores, list(id_to_ct.values()))
494
+
495
+ cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
496
+ labels = utils.unbatch_list(batch['label'], indexes)
497
+
498
+ return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
499
+
500
+ class MultiViewFineTuning(MultiViewContrastive):
501
+ def __init__(self,
502
+ **kwargs):
503
+ super().__init__(**kwargs)
504
+
505
+ # load preptrained spec, mol, fp encoders
506
+ checkpoint = torch.load(self.hparams.partial_checkpoint)
507
+ state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
508
+ self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
509
+
510
+ state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
511
+ self.mol_enc_model.load_state_dict(state_dict)
512
+
513
+ state_dict = state_dict = {k[len("fp_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("fp_enc_model")}
514
+ self.fp_enc_model.load_state_dict(state_dict)
515
+
516
+ self.encoding_views = ['spec_enc', 'mol_enc', 'fp_enc']
517
+ self.loss_fn = nn.BCELoss()
518
+
519
+ # freeze encoders
520
+ for param in self.mol_enc_model.parameters():
521
+ param.requires_grad = False
522
+ for param in self.spec_enc_model.parameters():
523
+ param.requires_grad = False
524
+ for param in self.fp_enc_model.parameters():
525
+ param.requires_grad = False
526
+ for param in self.cons_spec_enc_model.parameters():
527
+ param.requires_grad = False
528
+
529
+ # n_views = 2
530
+ # if self.hparams.use_fp:
531
+ # n_views+=1
532
+
533
+ # in_dim = self.hparams.final_embedding_dim*n_views
534
+ in_dim = self.hparams.final_embedding_dim *2 + 2
535
+
536
+ self.classifier_model = nn.Sequential(
537
+ nn.Linear(in_dim, 512),
538
+ nn.ReLU(),
539
+ nn.BatchNorm1d(512),
540
+ nn.Dropout(0.3),
541
+ nn.Linear(512, 256),
542
+ nn.ReLU(),
543
+ nn.BatchNorm1d(256),
544
+ nn.Dropout(0.3),
545
+ nn.Linear(256, 1),
546
+ nn.Sigmoid()
547
+ )
548
+ self.noise_std = 0.01
549
+
550
+ def _add_noise(self, x):
551
+ noise = torch.randn_like(x) * self.noise_std
552
+ return x + noise
553
+
554
+ def forward(self, batch, stage):
555
+
556
+ matching_views = super().forward(batch, stage)
557
+ # matching_enc = torch.concat((matching_views['spec_enc'], matching_views['mol_enc'], matching_views['fp_enc']), dim=-1)
558
+ # enc1 = matching_views['spec_enc'] - matching_views['mol_enc']
559
+ # enc2 = matching_views['spec_enc'] - matching_views['fp_enc']
560
+ # matching_enc = torch.concat((enc1, enc2), dim=-1)
561
+ view1 = matching_views['spec_enc']
562
+ view2 = matching_views['mol_enc']
563
+ view3 = matching_views['fp_enc']
564
+
565
+ if stage == Stage.TRAIN:
566
+ view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
567
+
568
+ pairwise_diffs = torch.cat([
569
+ torch.abs(view1 - view2),
570
+ torch.abs(view1 - view3),
571
+ ], dim=-1)
572
+
573
+ pairwise_sims = torch.cat([
574
+ (view1 * view2).sum(dim=-1, keepdim=True),
575
+ (view1 * view3).sum(dim=-1, keepdim=True),
576
+ ], dim=-1)
577
+
578
+ matching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
579
+ matching_scores = self.classifier_model(matching_enc)
580
+
581
+ if stage == Stage.TEST:
582
+ return dict(matching_scores = matching_scores)
583
+
584
+ view1 = view1.repeat_interleave(self.hparams.aug_cands_size, dim=0)
585
+ view2 = self.mol_enc_model(batch['aug_cands'])
586
+ view3= self.fp_enc_model(batch['aug_cands_fp'])
587
+ if stage == Stage.TRAIN:
588
+ view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
589
+
590
+ pairwise_diffs = torch.cat([
591
+ torch.abs(view1 - view2),
592
+ torch.abs(view1 - view3),
593
+ ], dim=-1)
594
+
595
+ pairwise_sims = torch.cat([
596
+ (view1 * view2).sum(dim=-1, keepdim=True),
597
+ (view1 * view3).sum(dim=-1, keepdim=True),
598
+ ], dim=-1)
599
+
600
+ nonmatching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
601
+
602
+ nonmatching_scores = self.classifier_model(nonmatching_enc)
603
+
604
+ return dict(matching_scores=matching_scores, nonmatching_scores=nonmatching_scores)
605
+
606
+ def compute_loss(self, matching_scores, nonmatching_scores):
607
+
608
+ matching_loss = self.loss_fn(matching_scores, torch.ones_like(matching_scores).to(matching_scores.device))
609
+ nonmatching_loss = self.loss_fn(nonmatching_scores, torch.zeros_like(nonmatching_scores).to(nonmatching_scores.device))
610
+
611
+ loss = matching_loss + (1/self.hparams.aug_cands_size)*nonmatching_loss
612
+
613
+ return dict(loss=loss)
614
+
615
+ def step(
616
+ self, batch: dict, stage= Stage.NONE):
617
+
618
+ output = self.forward(batch, stage)
619
+
620
+ if stage == Stage.TEST:
621
+ return output
622
+
623
+ # Calculate loss
624
+ losses = self.compute_loss(output['matching_scores'], output['nonmatching_scores'])
625
+
626
+ return losses
627
+
628
+ def test_step(self, batch):
629
+ # Unpack inputs
630
+ identifiers = batch['identifier']
631
+ cand_smiles = batch['cand_smiles']
632
+ id_to_ct = defaultdict(int)
633
+ for i in identifiers: id_to_ct[i]+=1
634
+ batch_ptr = torch.tensor(list(id_to_ct.values()))
635
+
636
+ outputs = self.step(batch, stage=Stage.TEST)
637
+ scores = outputs['matching_scores']
638
+
639
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
640
+
641
+ cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
642
+ labels = utils.unbatch_list(batch['label'], indexes)
643
+
644
+ return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
645
+
646
+ def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
647
+ # total loss
648
+ self.log(
649
+ f'{stage.to_pref()}loss',
650
+ outputs['loss'],
651
+ batch_size=len(batch['identifier']),
652
+ sync_dist=True,
653
+ prog_bar=True,
654
+ on_epoch=True,
655
+ # on_step=True
656
+ )
657
+
658
+ def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
659
+ ContrastiveModel.on_test_batch_end(self, outputs, batch, batch_idx, stage)
660
+
661
+ def on_test_epoch_end(self):
662
+ self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
663
+ # self.df_test.to_csv(self.hparams.resutl)
664
+ print(self.df_test_path)
665
+ self.df_test.to_pickle(self.df_test_path)
666
+ # ContrastiveModel.on_test_epoch_end(self)
667
+
668
+ def get_checkpoint_monitors(self) -> T.List[dict]:
669
+ monitors = [
670
+ {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
671
+ ]
672
+ return monitors
673
+ def configure_optimizers(self):
674
+ return torch.optim.Adam(
675
+ self.classifier_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
676
+ )
677
+
678
+ class IndSpecEncoder(ContrastiveModel):
679
+ """ Trains a spectra encoder that maps to a pretrained spec encoder"""
680
+ def __init__(
681
+ self,
682
+ **kwargs
683
+ ):
684
+ super().__init__(**kwargs)
685
+
686
+ # initialize ind_spec_encoder and loss
687
+ self.ind_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
688
+ self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
689
+
690
+ # load preptrained spec and mol encoders
691
+ checkpoint = torch.load(self.hparams.partial_checkpoint)
692
+ state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
693
+ self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
694
+
695
+ state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
696
+ self.mol_enc_model.load_state_dict(state_dict)
697
+
698
+ # freeze cons spec and mol encoders
699
+ for param in self.mol_enc_model.parameters():
700
+ param.requires_grad = False
701
+ for param in self.spec_enc_model.parameters():
702
+ param.requires_grad = False
703
+
704
+ def forward(self, batch, stage):
705
+
706
+ spec = batch[self.spec_view]
707
+ n_peaks = batch['n_peaks']
708
+ spec_enc = self.ind_spec_enc_model(spec, n_peaks)
709
+
710
+ return spec_enc
711
+
712
+ def compute_loss(self, spec_enc, cons_spec_enc):
713
+ loss = self.cons_loss(spec_enc, cons_spec_enc)
714
+ return dict(loss=loss)
715
+
716
+ def step(self, batch: dict, stage=Stage.NONE):
717
+ self.spec_enc_model.eval()
718
+ self.mol_enc_model.eval()
719
+
720
+ spec_enc = self.forward(batch, stage)
721
+
722
+ if stage == Stage.TEST:
723
+ mol_enc = self.mol_enc_model(batch['cand'])
724
+ return dict(spec_enc=spec_enc, mol_enc=mol_enc)
725
+
726
+ cons_spec_enc = self.spec_enc_model(batch['cons_spec'], batch['cons_n_peaks'])
727
+
728
+ losses = self.compute_loss(spec_enc, cons_spec_enc)
729
+
730
+ return losses
731
+
732
+
733
+ def configure_optimizers(self):
734
+ return torch.optim.Adam(
735
+ self.ind_spec_enc_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
736
+ )
737
+ def get_checkpoint_monitors(self) -> T.List[dict]:
738
+ monitors = [
739
+ {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
740
+ ]
741
+ return monitors
742
+
743
+ class CrossAttenContrastive(ContrastiveModel):
744
+ def __init__(
745
+ self,
746
+ **kwargs
747
+ ):
748
+ super(CrossAttenContrastive, self).__init__(**kwargs)
749
+ self.specMolCrossAttentionModel = CrossAttention(self.hparams.formula_dims[-1], self.hparams.gnn_channels[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3)
750
+ self.molSpecCrossAttentionModel = CrossAttention(self.hparams.gnn_channels[-1], self.hparams.formula_dims[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3)
751
+
752
+ def forward(self, batch, stage) -> tuple[torch.Tensor, torch.Tensor]:
753
+ # Unpack inputs
754
+ spec = batch[self.spec_view]
755
+ spec_n_forms = batch['n_peaks']
756
+ g = batch['cand'] if stage == Stage.TEST else batch['mol']
757
+ g_n_nodes = batch['mol_n_nodes']
758
+
759
+ # encode peaks and nodes
760
+ spec_enc = self.spec_enc_model(spec)
761
+ mol_enc = self.mol_enc_model(g)
762
+
763
+ # pad mol_enc and spec_enc to have the same length
764
+ max_nodes = max(g_n_nodes)
765
+ max_forms = max(spec_n_forms)
766
+
767
+ if max_forms > max_nodes: ## pad mol_enc
768
+ mol_enc = torch.cat((mol_enc, torch.rand(max_forms, self.hparams.gnn_channels[-1]).to(spec.device)))
769
+ mol_enc = torch.split(mol_enc, g_n_nodes+[max_forms])
770
+ mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)[:-1,:,:]
771
+
772
+ elif max_nodes > max_forms: ## pad spec_enc
773
+ dim_diff = max_nodes - max_forms
774
+ spec_enc = F.pad(spec_enc, (0,0,0,dim_diff, 0,0), value=-5)
775
+ mol_enc = torch.split(mol_enc, g_n_nodes)
776
+ mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)
777
+ else:
778
+ mol_enc = torch.split(mol_enc, g_n_nodes)
779
+ mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)
780
+
781
+ spec_pad = torch.all((spec_enc == -5), -1)
782
+ mol_pad = torch.all((mol_enc == -5), -1)
783
+
784
+ # cross attention
785
+ tmp_spec_enc = spec_enc * 1.0
786
+ spec_enc = self.specMolCrossAttentionModel(spec_enc, mol_enc, mol_enc, mask=mol_pad)
787
+ mol_enc = self.molSpecCrossAttentionModel(mol_enc, tmp_spec_enc, tmp_spec_enc, mask=spec_pad)
788
+
789
+ # pool
790
+ spec_indecies = torch.tensor([i for i, count in enumerate(spec_n_forms) for _ in range(count)]).to(spec_enc.device)
791
+ mol_indecies = torch.tensor([i for i, count in enumerate(g_n_nodes) for _ in range(count)]).to(mol_enc.device)
792
+
793
+ spec_enc = spec_enc[~spec_pad].reshape(-1, spec_enc.shape[-1])
794
+ mol_enc = mol_enc[~mol_pad].reshape(-1, mol_enc.shape[-1])
795
+
796
+ spec_enc = global_mean_pool(spec_enc, spec_indecies)
797
+ mol_enc = global_mean_pool(mol_enc, mol_indecies)
798
+
799
+ return spec_enc, mol_enc
mvp/models/encoders.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class MLP(nn.Module):
5
+ def __init__(self, in_dim, hidden_dims, dropout=0.1, final_activation=None):
6
+ super(MLP, self).__init__()
7
+
8
+ self.dropout = nn.Dropout(dropout)
9
+ self.has_final_activation = False
10
+ layers = [nn.Linear(in_dim, hidden_dims[0])]
11
+ for d1, d2 in zip(hidden_dims[:-1], hidden_dims[1:]):
12
+ layers.append(nn.Linear(d1, d2))
13
+ self.layers = nn.ModuleList(layers)
14
+ if final_activation is not None:
15
+ self.has_final_activation = True
16
+
17
+ self.final_activation = {'relu': F.relu,
18
+ 'sigmoid': F.sigmoid,
19
+ 'softmax': F.softmax,}[final_activation]
20
+
21
+ def forward(self, x):
22
+ for i, layer in enumerate(self.layers):
23
+ x = layer(x)
24
+ if i < len(self.layers) -1:
25
+ x = F.relu(x)
26
+ x = self.dropout(x)
27
+ elif self.has_final_activation:
28
+ x = self.final_activation(x)
29
+ return x
30
+
31
+ class CrossAttention(nn.Module):
32
+ def __init__(self, embed_dim_q, embed_dim_kv, num_heads, dim_out, dropout=0.0):
33
+ """
34
+ Args:
35
+ embed_dim_q (int): Dimension of query embeddings.
36
+ embed_dim_kv (int): Dimension of key/value embeddings.
37
+ num_heads (int): Number of attention heads.
38
+ dropout (float): Dropout probability for attention weights.
39
+ """
40
+ super(CrossAttention, self).__init__()
41
+
42
+ # Ensure the embedding dimensions are divisible by the number of heads
43
+ assert embed_dim_q % num_heads == 0, "embed_dim_q must be divisible by num_heads"
44
+ assert embed_dim_kv % num_heads == 0, "embed_dim_kv must be divisible by num_heads"
45
+
46
+ self.query_proj = nn.Linear(embed_dim_q, embed_dim_q)
47
+ self.key_proj = nn.Linear(embed_dim_kv, embed_dim_q) # Match dimensions with queries
48
+ self.value_proj = nn.Linear(embed_dim_kv, embed_dim_q)
49
+
50
+ self.attention = nn.MultiheadAttention(embed_dim=embed_dim_q, num_heads=num_heads, dropout=dropout, batch_first=True)
51
+ self.out_proj = nn.Linear(embed_dim_q, dim_out)
52
+
53
+ def forward(self, queries, keys, values, mask=None):
54
+ """
55
+ Args:
56
+ queries (Tensor): Shape (batch_size, len_q, embed_dim_q)
57
+ keys (Tensor): Shape (batch_size, len_k, embed_dim_kv)
58
+ values (Tensor): Shape (batch_size, len_v, embed_dim_kv)
59
+ mask (Tensor, optional): Shape (batch_size, len_q, len_k), 1 for valid positions and 0 for masked.
60
+
61
+ Returns:
62
+ Tensor: Shape (batch_size, len_q, embed_dim_q)
63
+ """
64
+ # Project inputs to the required dimensions
65
+ queries = self.query_proj(queries) # (batch_size, len_q, embed_dim_q)
66
+ keys = self.key_proj(keys) # (batch_size, len_k, embed_dim_q)
67
+ values = self.value_proj(values) # (batch_size, len_v, embed_dim_q)
68
+
69
+ # Compute attention
70
+ attn_output, _ = self.attention(queries, keys, values, key_padding_mask=mask)
71
+
72
+ # Apply output projection
73
+ output = self.out_proj(attn_output)
74
+ return output
mvp/models/mol_encoder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import dgl
4
+ from dgllife.model import GCN, GAT
5
+
6
+ class MolEnc(nn.Module):
7
+
8
+ def __init__(self,
9
+ args,
10
+ in_dim,):
11
+ super().__init__()
12
+
13
+ self.return_emb = False
14
+
15
+ if args.model in ('crossAttenContrastive', 'filipContrastive'):
16
+ self.return_emb = True
17
+
18
+ dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
19
+ batchnorm = [True for _ in range(len(args.gnn_channels))]
20
+ gnn_map = {
21
+ "gcn": GCN(in_dim, args.gnn_channels, batchnorm = batchnorm, dropout = dropout),
22
+ "gat": GAT(in_dim, args.gnn_channels, args.attn_heads)
23
+ }
24
+ self.GNN = gnn_map[args.gnn_type]
25
+ self.pool = dgl.nn.pytorch.glob.MaxPooling()
26
+
27
+ if not self.return_emb:
28
+ self.fc1_graph = nn.Linear(args.gnn_channels[len(args.gnn_channels) - 1], args.gnn_hidden_dim * 2)
29
+ self.fc2_graph = nn.Linear(args.gnn_hidden_dim * 2, args.final_embedding_dim)
30
+
31
+ self.dropout = nn.Dropout(args.fc_dropout)
32
+ self.relu = nn.ReLU()
33
+
34
+ def forward(self, g, fp=None) -> torch.Tensor:
35
+ g1 = g
36
+ f1 = g.ndata['h']
37
+
38
+ f = self.GNN(g1, f1)
39
+ if self.return_emb:
40
+ return f
41
+ h = self.pool(g1, f)
42
+ if fp is not None:
43
+ h = torch.concat((h, fp), dim=-1)
44
+ h1 = self.relu(self.fc1_graph(h))
45
+ h1 = self.dropout(h1)
46
+ h1 = self.fc2_graph(h1)
47
+ h1 = self.dropout(h1)
48
+
49
+ return h1
50
+
mvp/models/spec_encoder.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from mvp.models.encoders import MLP
4
+ from torch_geometric.nn import global_mean_pool
5
+
6
+
7
+ class SpecEncMLP_BIN(nn.Module):
8
+ def __init__(self, args, out_dim=None):
9
+ super(SpecEncMLP_BIN, self).__init__()
10
+
11
+ if not out_dim:
12
+ out_dim = args.final_embedding_dim
13
+
14
+ bin_size = int(args.max_mz / args.bin_width)
15
+ self.dropout = nn.Dropout(args.fc_dropout)
16
+ self.mz_fc1 = nn.Linear(bin_size, out_dim * 2)
17
+ self.mz_fc2 = nn.Linear(out_dim* 2, out_dim * 2)
18
+ self.mz_fc3 = nn.Linear(out_dim * 2, out_dim)
19
+ self.relu = nn.ReLU()
20
+
21
+ def forward(self, mzi_b, n_peaks=None):
22
+
23
+ h1 = self.mz_fc1(mzi_b)
24
+ h1 = self.relu(h1)
25
+ h1 = self.dropout(h1)
26
+ h1 = self.mz_fc2(h1)
27
+ h1 = self.relu(h1)
28
+ h1 = self.dropout(h1)
29
+ mz_vec = self.mz_fc3(h1)
30
+ mz_vec = self.dropout(mz_vec)
31
+
32
+ return mz_vec
33
+
34
+ class SpecMzIntTokenTransformer(nn.Module):
35
+ def __init__(self, args):
36
+ super(SpecMzIntTokenTransformer, self).__init__()
37
+ in_dim = 2
38
+ self.tokenEnc = MLP(in_dim, args.hidden_dims, dropout=args.peak_dropout)
39
+
40
+ self.returnEmb = False
41
+ if args.model in ('crossAttenContrastive', 'filipContrastive'):
42
+ self.returnEmb = True
43
+ assert(args.use_cls == False)
44
+
45
+ self.use_cls = args.use_cls
46
+ if self.use_cls:
47
+ self.cls_embed = torch.nn.Embedding(1,args.hidden_dims[-1])
48
+ encoder_layer = nn.TransformerEncoderLayer(d_model=args.hidden_dims[-1], nhead=2, batch_first=True)
49
+ self.tokenTransformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
50
+
51
+ self.specEncoder = nn.Sequential(nn.Linear(args.hidden_dims[-1], args.final_embedding_dim), nn.Dropout(args.fc_dropout))
52
+
53
+ def forward(self, spec, n_peaks=None):
54
+ h = self.tokenEnc(spec)
55
+ pad = (spec == -5)
56
+ pad = torch.all(pad, -1)
57
+
58
+ if self.use_cls:
59
+ cls_embed = self.cls_embed(torch.tensor(0).to(spec.device))
60
+ h = torch.concat((cls_embed.repeat(spec.shape[0], 1).unsqueeze(1), h), dim=1)
61
+ pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1)
62
+ h = self.tokenTransformer(h, src_key_padding_mask=pad)
63
+ h = h[:,0,:]
64
+ else:
65
+
66
+ # mean
67
+ h = self.tokenTransformer(h, src_key_padding_mask=pad)
68
+
69
+ if self.returnEmb:
70
+ # repad h
71
+ h[pad] = -5
72
+ return h
73
+ n_peaks_indices = torch.tensor([i for i, count in enumerate(n_peaks) for _ in range(count)]).to(spec.device)
74
+ h = h[~pad].reshape(-1, h.shape[-1])
75
+ h = global_mean_pool(h, n_peaks_indices)
76
+
77
+ h = self.specEncoder(h)
78
+ return h
79
+
80
+
81
+ class SpecFormulaEncMLP(nn.Module):
82
+ def __init__(self, args, out_dim=None):
83
+ super(SpecFormulaEncMLP, self).__init__()
84
+ in_dim = len(args.element_list)
85
+ if args.add_intensities:
86
+ in_dim+=1
87
+ if args.spectra_view == "SpecFormulaMz": #mz
88
+ in_dim+=1
89
+
90
+ self.formulaEnc = MLP(in_dim, args.formula_dims, dropout=args.formula_dropout)
91
+
92
+ if not out_dim:
93
+ out_dim = args.final_embedding_dim
94
+ self.mz_fc1 = nn.Linear(args.formula_dims[-1], out_dim)
95
+ self.dropout = nn.Dropout(args.fc_dropout)
96
+
97
+ def forward(self, spec, n_peaks):
98
+ h = self.formulaEnc(spec)
99
+ h = torch.sum(h, axis=1)
100
+
101
+ h = self.mz_fc1(h)
102
+ h = self.dropout(h)
103
+ return h
104
+
105
+ class SpecFormulaTransformer(nn.Module):
106
+ def __init__(self, args, out_dim=None):
107
+ super(SpecFormulaTransformer, self).__init__()
108
+ in_dim = len(args.element_list)
109
+ if args.add_intensities: # intensity
110
+ in_dim+=1
111
+ if args.spectra_view == "SpecFormulaMz": #mz
112
+ in_dim+=1
113
+
114
+ self.returnEmb = False
115
+ if args.model in ('crossAttenContrastive', 'filipContrastive'):
116
+ self.returnEmb = True
117
+ assert(args.use_cls == False)
118
+
119
+ self.formulaEnc = MLP(in_dim=in_dim, hidden_dims=args.formula_dims, dropout=args.formula_dropout)
120
+
121
+ self.use_cls = args.use_cls
122
+ if args.use_cls:
123
+ self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1])
124
+ encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=2, batch_first=True)
125
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
126
+
127
+ if not out_dim:
128
+ out_dim = args.final_embedding_dim
129
+ self.fc = nn.Linear(args.formula_dims[-1], out_dim)
130
+
131
+ def forward(self, spec, n_peaks):
132
+ h = self.formulaEnc(spec)
133
+ pad = (spec == -5)
134
+ pad = torch.all(pad, -1)
135
+
136
+ if self.use_cls:
137
+ cls_embed = self.cls_embed(torch.tensor(0).to(spec.device))
138
+ h = torch.concat((cls_embed.repeat(spec.shape[0], 1).unsqueeze(1), h), dim=1)
139
+ pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1)
140
+ h = self.transformer(h, src_key_padding_mask=pad)
141
+ h = h[:,0,:]
142
+ else:
143
+ h = self.transformer(h, src_key_padding_mask=pad)
144
+
145
+ if self.returnEmb:
146
+ # repad h
147
+ h[pad] = -5
148
+ return h
149
+
150
+ h = h[~pad].reshape(-1, h.shape[-1])
151
+ indecies = torch.tensor([i for i, count in enumerate(n_peaks) for _ in range(count)]).to(h.device)
152
+ h = global_mean_pool(h, indecies)
153
+
154
+ h = self.fc(h)
155
+
156
+ return h
157
+
158
+ class SpecFormula_mz_Encoder(nn.Module):
159
+ '''
160
+ Encodes formula and mz_int
161
+ '''
162
+
163
+ def __init__(self, args):
164
+
165
+ super(SpecFormula_mz_Encoder, self).__init__()
166
+
167
+ self.formula_encoder = SpecFormulaTransformer(args, out_dim=args.final_embedding_dim//4)
168
+ self.mz_encoder = SpecEncMLP_BIN(args, out_dim=args.final_embedding_dim//4)
169
+
170
+ self.fc = nn.Sequential(nn.Linear(args.final_embedding_dim //2, args.final_embedding_dim), nn.ReLU(),
171
+ )
172
+
173
+ def forward(self, formulas, binned_mzs):
174
+ h_formula = self.formula_encoder(formulas)
175
+ h_bin = self.mz_encoder(binned_mzs)
176
+
177
+ h_spec = torch.concat((h_formula, h_bin), axis=1)
178
+ h = self.fc(h_spec)
179
+
180
+ return h
181
+
182
+
mvp/params_binnedSpec.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Experiment setup
3
+ job_key: ''
4
+ run_name: 'binnedSpec_experiment'
5
+ run_details: ""
6
+ project_name: ''
7
+ wandb_entity_name: 'mass-spec-ml'
8
+ no_wandb: True
9
+ seed: 0
10
+ debug: False
11
+ checkpoint_pth: ""
12
+
13
+ # Training setup
14
+ max_epochs: 1000
15
+ accelerator: 'gpu'
16
+ devices: [1]
17
+ log_every_n_steps: 250
18
+ val_check_interval: 1.0
19
+
20
+ # Data paths
21
+ candidates_pth: ../data/sample/candidates_mass.json
22
+ dataset_pth: "../data/sample/data.tsv"
23
+ subformula_dir_pth: ""
24
+ split_pth:
25
+ fp_dir_pth: '../data/sample/morganfp_r5_1024.pickle'
26
+ cons_spec_dir_pth: "../data/sample/consensus_binnedSpec.pkl"
27
+ NL_spec_dir_pth: ""
28
+ partial_checkpoint: ""
29
+
30
+ # General hyperparameters
31
+ batch_size: 64
32
+ lr: 5.0e-4
33
+ weight_decay: 0
34
+ contr_temp: 0.05
35
+ early_stopping_patience: 300
36
+ loss_strategy: 'static' # static, linear, manual
37
+ num_workers: 50
38
+
39
+
40
+ ############################## Data transforms ##############################
41
+ # - Spectra
42
+ spectra_view: SpecBinnerLog
43
+ max_mz: 1000
44
+ bin_width: 1
45
+ mask_peak_ratio: 0.00
46
+
47
+ # 2. SpecFormula
48
+ element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
49
+ add_intensities: True
50
+ mask_precursor: False
51
+
52
+ # - Molecule
53
+ molecule_view: "MolGraph"
54
+ atom_feature: 'full'
55
+ bond_feature: 'full'
56
+
57
+
58
+ ############################## Views ##############################
59
+ # contrastive
60
+ use_contr: True
61
+ contr_wt: 1
62
+ contr_wt_update: {}
63
+
64
+ # consensus spectra
65
+ use_cons_spec: False
66
+ cons_spec_wt: 3
67
+ cons_spec_wt_update: {}
68
+ cons_loss_type: 'l2' # cosine, l2
69
+
70
+ # fp prediction/usage
71
+ pred_fp: False
72
+ use_fp: False
73
+ fp_loss_type: 'cosine' #cosine, bce
74
+ fp_wt: 3
75
+ fp_wt_update: {}
76
+ fp_size: 1024
77
+ fp_radius: 5
78
+ fp_dropout: 0.4
79
+
80
+ # candidates
81
+ aug_cands: False
82
+ aug_cands_wt: 0.1
83
+ aug_cands_update: {}
84
+ aug_cands_size: 3
85
+
86
+ # neutral loss
87
+ use_NL: False
88
+
89
+
90
+
91
+ ############################## Task and model ##############################
92
+ task: 'retrieval'
93
+ spec_enc: MLP_BIN
94
+ mol_enc: "GNN"
95
+ model: "MultiviewContrastive"
96
+ contr_views: [['spec_enc', 'mol_enc']]
97
+ log_only_loss_at_stages: []
98
+ df_test_path: ""
99
+
100
+ # - Spectra encoder
101
+ final_embedding_dim: 512
102
+ fc_dropout: 0.4
103
+
104
+ # - Spectra Token encoder
105
+ hidden_dims: [64, 128]
106
+ peak_dropout: 0.2
107
+
108
+ # - Formula-based spec encoders
109
+ formula_dropout: 0.2
110
+ formula_dims: [64, 128, 256]
111
+ cross_attn_heads: 2
112
+ use_cls: True
113
+
114
+ # -- GAT params
115
+ attn_heads: [12,12,12]
116
+
117
+ # - Molecule encoder (GNN)
118
+ gnn_channels: [64,128,256]
119
+ gnn_type: "gcn"
120
+ num_gnn_layers: 3
121
+ gnn_hidden_dim: 512
122
+ gnn_dropout: 0.3
mvp/params_formSpec.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment setup
2
+ job_key: ''
3
+ run_name: 'filip_quick_test'
4
+ run_details: ""
5
+ project_name: ''
6
+ wandb_entity_name: 'mass-spec-ml'
7
+ no_wandb: True
8
+ seed: 0
9
+ debug: False
10
+ checkpoint_pth: #'../pretrained_models/msgym_formSpec.ckpt'
11
+
12
+ # Training setup
13
+ max_epochs: 2000
14
+ accelerator: 'gpu'
15
+ devices: [1]
16
+ log_every_n_steps: 250
17
+ val_check_interval: 1.0
18
+
19
+ # Data paths
20
+ candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
+ dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
+ subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
+ split_pth:
24
+ fp_dir_pth: '../data/MassSpecGym/data/morganfp_r5_1024.pickle'
25
+ cons_spec_dir_pth: "../data/MassSpecGym/data/sample_consensus_formSpec.pkl"
26
+ NL_spec_dir_pth: ""
27
+ partial_checkpoint: ""
28
+
29
+ # General hyperparameters
30
+ batch_size: 64
31
+ lr: 5.0e-05
32
+ weight_decay: 0
33
+ contr_temp: 0.05
34
+ early_stopping_patience: 300
35
+ loss_strategy: 'static'
36
+ num_workers: 50
37
+
38
+
39
+ ############################## Data transforms ##############################
40
+ # - Spectra
41
+ spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
42
+ # 1. Binner
43
+ max_mz: 1000
44
+ bin_width: 1
45
+ mask_peak_ratio: 0.00
46
+
47
+ # 2. SpecFormula
48
+ element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
49
+ add_intensities: True
50
+ mask_precursor: False
51
+
52
+ # - Molecule
53
+ molecule_view: "MolGraph"
54
+ atom_feature: 'full'
55
+ bond_feature: 'full'
56
+
57
+
58
+ ############################## Views ##############################
59
+ # contrastive
60
+ use_contr: False
61
+ contr_wt: 1
62
+ contr_wt_update: {}
63
+
64
+ # consensus spectra
65
+ use_cons_spec: False
66
+ cons_spec_wt: 3
67
+ cons_spec_wt_update: {}
68
+ cons_loss_type: 'l2' # cosine, l2
69
+
70
+ # fp prediction/usage
71
+ pred_fp: False
72
+ use_fp: False
73
+ fp_loss_type: 'cosine' #cosine, bce
74
+ fp_wt: 3
75
+ fp_wt_update: {}
76
+ fp_size: 1024
77
+ fp_radius: 5
78
+ fp_dropout: 0.4
79
+
80
+ # candidates
81
+ aug_cands: False
82
+ aug_cands_wt: 0.1
83
+ aug_cands_update: {}
84
+ aug_cands_size: 3
85
+
86
+ # neutral loss
87
+ use_NL: False
88
+
89
+
90
+ ############################## Task and model ##############################
91
+ task: 'retrieval'
92
+ spec_enc: Transformer_Formula # Transformer_MzInt #Transformer_Formula
93
+ mol_enc: "GNN"
94
+ model: filipContrastive # "MultiviewContrastive"
95
+ contr_views: [['spec_enc', 'mol_enc']] #[['spec_enc', 'mol_enc'], ['spec_enc', 'NL_spec_enc'], ['mol_enc', 'NL_spec_enc']] #[['spec_enc', 'mol_enc'], ['mol_enc', 'cons_spec_enc'], ['cons_spec_enc', 'spec_enc'], ['fp_enc', 'mol_enc'], ['fp_enc', 'spec_enc'], ['fp_enc', 'cons_spec_enc']]
96
+ log_only_loss_at_stages: []
97
+ df_test_path: ""
98
+
99
+ # - Spectra encoder
100
+ final_embedding_dim: 512
101
+ fc_dropout: 0.4
102
+
103
+ # - Spectra Token encoder
104
+ hidden_dims: [64, 128]
105
+ peak_dropout: 0.2
106
+
107
+ # - Formula-based spec encoders
108
+ formula_dropout: 0.2
109
+ formula_dims: [64, 128, 256]
110
+ cross_attn_heads: 2
111
+ use_cls: False
112
+
113
+ # -- GAT params
114
+ attn_heads: [12,12,12]
115
+
116
+ # - Molecule encoder (GNN)
117
+ gnn_channels: [64,128,256]
118
+ gnn_type: "gcn"
119
+ num_gnn_layers: 3
120
+ gnn_hidden_dim: 512
121
+ gnn_dropout: 0.3
mvp/params_jestr.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Experiment setup
3
+ job_key: ''
4
+ run_name: 'combined_d_1024dim_100bs'
5
+ run_details: ""
6
+ project_name: ''
7
+ wandb_entity_name: 'mass-spec-ml'
8
+ no_wandb: True
9
+ seed: 3
10
+ debug: False
11
+ checkpoint_pth:
12
+
13
+ # Training setup
14
+ max_epochs: 2000
15
+ accelerator: 'gpu'
16
+ devices: [1]
17
+ log_every_n_steps: 250
18
+ val_check_interval: 1.0
19
+
20
+ # Data paths
21
+ candidates_pth: "/r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json"
22
+ dataset_pth: '/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv' # '/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv' #"/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
23
+ subformula_dir_pth: ""
24
+ split_pth:
25
+ fp_dir_pth: ''
26
+ cons_spec_dir_pth:
27
+ NL_spec_dir_pth: ""
28
+ partial_checkpoint: ""
29
+
30
+ # General hyperparameters
31
+ batch_size: 100
32
+ lr: 5.0e-4
33
+ weight_decay: 0
34
+ contr_temp: 0.05
35
+ early_stopping_patience: 300
36
+ loss_strategy: 'static' # static, linear, manual
37
+ num_workers: 50
38
+
39
+
40
+ ############################## Data transforms ##############################
41
+ # - Spectra
42
+ spectra_view: SpecBinnerLog
43
+ max_mz: 1000
44
+ bin_width: 1
45
+ mask_peak_ratio: 0.00
46
+
47
+ # 2. SpecFormula
48
+ element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
49
+ add_intensities: True
50
+ mask_precursor: False
51
+
52
+ # - Molecule
53
+ molecule_view: "MolGraph"
54
+ atom_feature: 'full'
55
+ bond_feature: 'full'
56
+
57
+
58
+ ############################## Views ##############################
59
+ # contrastive
60
+ use_contr: True
61
+ contr_wt: 1
62
+ contr_wt_update: {}
63
+
64
+ # consensus spectra
65
+ use_cons_spec: False
66
+ cons_spec_wt: 3
67
+ cons_spec_wt_update: {}
68
+ cons_loss_type: 'l2' # cosine, l2
69
+
70
+ # fp prediction/usage
71
+ pred_fp: False
72
+ use_fp: False
73
+ fp_loss_type: 'cosine' #cosine, bce
74
+ fp_wt: 3
75
+ fp_wt_update: {}
76
+ fp_size: 1024
77
+ fp_radius: 5
78
+ fp_dropout: 0.4
79
+
80
+ # candidates
81
+ aug_cands: False
82
+ aug_cands_wt: 0.1
83
+ aug_cands_update: {}
84
+ aug_cands_size: 3
85
+
86
+ # neutral loss
87
+ use_NL: False
88
+
89
+
90
+
91
+ ############################## Task and model ##############################
92
+ task: 'retrieval'
93
+ spec_enc: MLP_BIN
94
+ mol_enc: "GNN"
95
+ model: "MultiviewContrastive"
96
+ contr_views: [['spec_enc', 'mol_enc']]
97
+ log_only_loss_at_stages: []
98
+ df_test_path: ""
99
+
100
+ # - Spectra encoder
101
+ final_embedding_dim: 1024
102
+ fc_dropout: 0.4
103
+
104
+ # - Spectra Token encoder
105
+ hidden_dims: [64, 128]
106
+ peak_dropout: 0.2
107
+
108
+ # - Formula-based spec encoders
109
+ formula_dropout: 0.2
110
+ formula_dims: [64, 128, 256]
111
+ cross_attn_heads: 2
112
+ use_cls: True
113
+
114
+ # -- GAT params
115
+ attn_heads: [12,12,12]
116
+
117
+ # - Molecule encoder (GNN)
118
+ gnn_channels: [64,128,256]
119
+ gnn_type: "gcn"
120
+ num_gnn_layers: 3
121
+ gnn_hidden_dim: 1024
122
+ gnn_dropout: 0.3
mvp/run.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. preprocess data (subformula labels should be obtained through MIST)
2
+ # python data_preprocess.py --spec_type formSpec --dataset_pth ../data/sample/data.tsv --candidates_pth ../data/sample/candidates_mass.json --subformula_dir_pth ../data/sample/subformulae_default/ --output_dir ../data/sample/
3
+
4
+ # 2. train model on msgym
5
+ # python train.py --param_pth params_formSpec.yaml
6
+
7
+ # 3. test model on msgym
8
+ # python train.py --param_pth params_binnedSpec.yaml
9
+
10
+ # python train.py
11
+ python test.py
12
+ python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
mvp/subformula_assign/assign_subformulae.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ assign_subformulae.py
2
+
3
+ Copied from https://github.com/samgoldman97/mist/blob/main_v2/src/mist/subformulae/assign_subformulae.py
4
+
5
+ Given a set of spectra and candidates from a labels file, assign subformulae and save to JSON files.
6
+
7
+ """
8
+
9
+ from pathlib import Path
10
+ import argparse
11
+ from functools import partial
12
+ import numpy as np
13
+ import pandas as pd
14
+ import json
15
+
16
+ from tqdm import tqdm
17
+ import utils
18
+
19
+
20
+ def get_args():
21
+ """get args"""
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument(
24
+ "--feature-id",
25
+ default="ID",
26
+ help="ID key in mgf input"
27
+ )
28
+ parser.add_argument(
29
+ "--spec-files",
30
+ default="data/paired_spectra/canopus_train/spec_files/",
31
+ help="Spec files; either MGF or directory.",
32
+ )
33
+ parser.add_argument("--output-dir", default=None,
34
+ help="Name of output dir.")
35
+ parser.add_argument(
36
+ "--labels-file",
37
+ default="data/paired_spectra/canopus_train/labels.tsv",
38
+ help="Labels file",
39
+ )
40
+ parser.add_argument(
41
+ "--debug", action="store_true", default=False, help="Debug flag."
42
+ )
43
+ parser.add_argument(
44
+ "--mass-diff-type",
45
+ default="ppm",
46
+ type=str,
47
+ help="Type of mass difference - absolute differece (abs) or relative difference (ppm).",
48
+ )
49
+ parser.add_argument(
50
+ "--mass-diff-thresh",
51
+ action="store",
52
+ default=20,
53
+ type=float,
54
+ help="Threshold of mass difference.",
55
+ )
56
+ parser.add_argument(
57
+ "--inten-thresh",
58
+ action="store",
59
+ default=0.001,
60
+ type=float,
61
+ help="Threshold of MS2 subpeak intensity (normalized to 1).",
62
+ )
63
+ parser.add_argument(
64
+ "--max-formulae",
65
+ action="store",
66
+ default=50,
67
+ type=int,
68
+ help="Max number of peaks to keep",
69
+ )
70
+ parser.add_argument(
71
+ "--num-workers", action="store", default=32, type=int, help="num workers"
72
+ )
73
+ return parser.parse_args()
74
+
75
+
76
+ def process_spec_file(spec_name: str, spec_files: str, max_inten=0.001, max_peaks=60):
77
+ """_summary_
78
+
79
+ Args:
80
+ spec_name (str): _description_
81
+ spec_files (str): _description_
82
+ max_inten (float, optional): _description_. Defaults to 0.001.
83
+ max_peaks (int, optional): _description_. Defaults to 60.
84
+
85
+ Returns:
86
+ _type_: _description_
87
+ """
88
+ spec_file = Path(spec_files) / f"{spec_name}.ms"
89
+
90
+ meta, tuples = utils.parse_spectra(spec_file)
91
+ spec = utils.process_spec_file(meta, tuples)
92
+ return spec_name, spec
93
+
94
+
95
+ def assign_subforms(spec_files, labels_file,
96
+ mass_diff_thresh: int = 20,
97
+ mass_diff_type: str = "ppm",
98
+ inten_thresh: float = 0.001,
99
+ output_dir=None,
100
+ num_workers: int = 32,
101
+ feature_id="ID",
102
+ max_formulae: int = 50,
103
+ debug=False):
104
+ """_summary_
105
+
106
+ Args:
107
+ spec_files (_type_): _description_
108
+ labels_file (_type_): _description_
109
+ mass_diff_thresh (int, optional): _description_. Defaults to 20.
110
+ mass_diff_type (str, optional): _description_. Defaults to "ppm".
111
+ inten_thresh (float, optional): _description_. Defaults to 0.001.
112
+ output_dir (_type_, optional): _description_. Defaults to None.
113
+ num_workers (int, optional): _description_. Defaults to 32.
114
+ feature_id (str, optional): _description_. Defaults to "ID".
115
+ max_formulae (int, optional): _description_. Defaults to 50.
116
+ debug (bool, optional): _description_. Defaults to False.
117
+
118
+ Raises:
119
+ ValueError: _description_
120
+ """
121
+ spec_files = Path(spec_files)
122
+ label_path = Path(labels_file)
123
+
124
+ # Read in labels
125
+ labels_df = pd.read_csv(label_path, sep="\t").astype(str)
126
+ if spec_files.suffix == ".tsv": # YZC msgym-like data
127
+ labels_df.rename(columns={'identifier': 'spec',
128
+ 'adduct': 'ionization'}, inplace=True)
129
+
130
+ if debug:
131
+ labels_df = labels_df[:50]
132
+
133
+ # Define output directory name
134
+ output_dir = Path(output_dir)
135
+ if output_dir is None:
136
+ subform_dir = label_path.parent / "subformulae"
137
+ output_dir_name = f"subform_{max_formulae}"
138
+ output_dir = subform_dir / output_dir_name
139
+
140
+ output_dir.mkdir(exist_ok=True, parents=True)
141
+
142
+ if spec_files.suffix == ".mgf":
143
+ # Input specs
144
+ parsed_specs = utils.parse_spectra_mgf(spec_files)
145
+ input_specs = [utils.process_spec_file(*i) for i in parsed_specs]
146
+ spec_names = [i[0][feature_id] for i in parsed_specs]
147
+ input_specs = list(zip(spec_names, input_specs))
148
+ elif spec_files.is_dir():
149
+ spec_fn_lst = labels_df["spec"].to_list()
150
+ proc_spec_full = partial(
151
+ process_spec_file,
152
+ spec_files=spec_files,
153
+ max_inten=inten_thresh,
154
+ max_peaks=max_formulae,
155
+ )
156
+ # input_specs = [proc_spec_full(i) for i in tqdm(spec_fn_lst)]
157
+ input_specs = utils.chunked_parallel(
158
+ spec_fn_lst, proc_spec_full, chunks=100, max_cpu=max(num_workers, 1)
159
+ )
160
+
161
+ elif spec_files.suffix == '.tsv':
162
+ parsed_specs = utils.parse_spectra_msgym(labels_df)
163
+ input_specs = [utils.process_spec_file(*i) for i in parsed_specs]
164
+ spec_names = [i[0][feature_id] for i in parsed_specs]
165
+ input_specs = list(zip(spec_names, input_specs))
166
+ else:
167
+ raise ValueError(f"Spec files arg {spec_files} is not a dir or mgf")
168
+
169
+
170
+ # input_specs contains a list of tuples (spec, subpeak tuple array)
171
+ input_specs_dict = {tup[0]: tup[1] for tup in input_specs}
172
+ export_dicts, spec_names = [], []
173
+ for _, row in labels_df.iterrows():
174
+ spec = str(row["spec"])
175
+ new_entry = {
176
+ "spec": input_specs_dict[spec],
177
+ "form": row["formula"],
178
+ "mass_diff_type": mass_diff_type,
179
+ "spec_name": spec,
180
+ "mass_diff_thresh": mass_diff_thresh,
181
+ "ion_type": row["ionization"],
182
+ }
183
+ spec_names.append(spec)
184
+ export_dicts.append(new_entry)
185
+
186
+ # Build dicts
187
+ print(f"There are {len(export_dicts)} spec-cand pairs this spec files")
188
+ def export_wrapper(x): return utils.get_output_dict(**x)
189
+ if debug:
190
+ output_dict_lst = [export_wrapper(i) for i in export_dicts[:10]]
191
+ else:
192
+ output_dict_lst = utils.chunked_parallel(
193
+ export_dicts, export_wrapper, chunks=100, max_cpu=max(num_workers, 1)
194
+ )
195
+ assert len(export_dicts) == len(output_dict_lst)
196
+
197
+ # Write all output jsons to files
198
+ for output_dict, spec_name in tqdm(zip(output_dict_lst, spec_names)):
199
+ with open(output_dir / f"{spec_name}.json", "w") as f:
200
+ json.dump(output_dict, f, indent=4)
201
+ f.close()
202
+
203
+ if __name__ == "__main__":
204
+ args = get_args()
205
+ assign_subforms(spec_files=args.spec_files,
206
+ labels_file=args.labels_file,
207
+ mass_diff_thresh=args.mass_diff_thresh,
208
+ mass_diff_type=args.mass_diff_type,
209
+ inten_thresh=args.inten_thresh,
210
+ output_dir=args.output_dir,
211
+ num_workers=args.num_workers,
212
+ feature_id=args.feature_id,
213
+ max_formulae=args.max_formulae,
214
+ debug=args.debug)
mvp/subformula_assign/run.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ SPEC_FILES="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
2
+ OUTPUT_DIR="/data/yzhouc01/spectra_data/subformulae"
3
+ MAX_FORMULAE=60
4
+ LABELS_FILE="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
5
+
6
+ python assign_subformulae.py --spec-files $SPEC_FILES --output-dir $OUTPUT_DIR --max-formulae $MAX_FORMULAE --labels-file $LABELS_FILE
mvp/subformula_assign/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from .parse_utils import *
3
+ from .chem_utils import *
4
+ from .parallel_utils import *
5
+ from .spectra_utils import *
mvp/subformula_assign/utils/chem_utils.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """chem_utils.py"""
2
+
3
+ import re
4
+ import numpy as np
5
+ import pandas as pd
6
+ import json
7
+ from functools import reduce
8
+ from collections import defaultdict
9
+
10
+ import torch
11
+ from rdkit import Chem
12
+ from rdkit.Chem import Atom
13
+ from rdkit.Chem.rdMolDescriptors import CalcMolFormula
14
+ from rdkit.Chem.Descriptors import ExactMolWt
15
+ from rdkit.Chem.MolStandardize import rdMolStandardize
16
+
17
+ P_TBL = Chem.GetPeriodicTable()
18
+
19
+ ROUND_FACTOR = 4
20
+
21
+ ELECTRON_MASS = 0.00054858
22
+ CHEM_FORMULA_SIZE = "([A-Z][a-z]*)([0-9]*)"
23
+
24
+ VALID_ELEMENTS = [
25
+ "C",
26
+ "H",
27
+ "As",
28
+ "B",
29
+ "Br",
30
+ "Cl",
31
+ "Co",
32
+ "F",
33
+ "Fe",
34
+ "I",
35
+ "K",
36
+ "N",
37
+ "Na",
38
+ "O",
39
+ "P",
40
+ "S",
41
+ "Se",
42
+ "Si",
43
+ ]
44
+ VALID_ATOM_NUM = [Atom(i).GetAtomicNum() for i in VALID_ELEMENTS]
45
+
46
+
47
+ CHEM_ELEMENT_NUM = len(VALID_ELEMENTS)
48
+
49
+ ATOM_NUM_TO_ONEHOT = torch.zeros((max(VALID_ATOM_NUM) + 1, CHEM_ELEMENT_NUM))
50
+
51
+ # Convert to onehot
52
+ ATOM_NUM_TO_ONEHOT[VALID_ATOM_NUM, torch.arange(CHEM_ELEMENT_NUM)] = 1
53
+
54
+ VALID_MONO_MASSES = np.array(
55
+ [P_TBL.GetMostCommonIsotopeMass(i) for i in VALID_ELEMENTS]
56
+ )
57
+ CHEM_MASSES = VALID_MONO_MASSES[:, None]
58
+
59
+ ELEMENT_VECTORS = np.eye(len(VALID_ELEMENTS))
60
+ ELEMENT_VECTORS_MASS = np.hstack([ELEMENT_VECTORS, CHEM_MASSES])
61
+ ELEMENT_TO_MASS = dict(zip(VALID_ELEMENTS, CHEM_MASSES.squeeze()))
62
+
63
+ ELEMENT_DIM_MASS = len(ELEMENT_VECTORS_MASS[0])
64
+ ELEMENT_DIM = len(ELEMENT_VECTORS[0])
65
+
66
+ # Reasonable normalization vector for elements
67
+ # Estimated by max counts (+ 1 when zero)
68
+ NORM_VEC = np.array([81, 158, 2, 1, 3, 10, 1, 17, 1, 6, 1, 19, 2, 34, 6, 6, 2, 6])
69
+
70
+ NORM_VEC_MASS = np.array(NORM_VEC.tolist() + [1471])
71
+
72
+ # Assume 64 is the highest repeat of any 1 atom
73
+ MAX_ELEMENT_NUM = 64
74
+
75
+ element_to_ind = dict(zip(VALID_ELEMENTS, np.arange(len(VALID_ELEMENTS))))
76
+ element_to_position = dict(zip(VALID_ELEMENTS, ELEMENT_VECTORS))
77
+ element_to_position_mass = dict(zip(VALID_ELEMENTS, ELEMENT_VECTORS_MASS))
78
+
79
+ ION_LST = [
80
+ "[M+H]+",
81
+ "[M+Na]+",
82
+ "[M+K]+",
83
+ "[M-H2O+H]+",
84
+ "[M+H3N+H]+",
85
+ "[M]+",
86
+ "[M-H4O2+H]+",
87
+ "[M-H]-"
88
+ ]
89
+
90
+ ion_remap = dict(zip(ION_LST, ION_LST))
91
+ ion_remap.update(
92
+ {
93
+ "[M+NH4]+": "[M+H3N+H]+",
94
+ "M+H": "[M+H]+",
95
+ "M+Na": "[M+Na]+",
96
+ "M+H-H2O": "[M-H2O+H]+",
97
+ "M-H2O+H": "[M-H2O+H]+",
98
+ "M+NH4": "[M+H3N+H]+",
99
+ "M-2H2O+H": "[M-H4O2+H]+",
100
+ "[M-2H2O+H]+": "[M-H4O2+H]+",
101
+ "[M-H]-": "[M-H]-",
102
+ }
103
+ )
104
+
105
+ ion_to_idx = dict(zip(ION_LST, np.arange(len(ION_LST))))
106
+
107
+ ion_to_mass = {
108
+ "[M+H]+": ELEMENT_TO_MASS["H"] - ELECTRON_MASS,
109
+ "[M+Na]+": ELEMENT_TO_MASS["Na"] - ELECTRON_MASS,
110
+ "[M+K]+": ELEMENT_TO_MASS["K"] - ELECTRON_MASS,
111
+ "[M-H2O+H]+": -ELEMENT_TO_MASS["O"] - ELEMENT_TO_MASS["H"] - ELECTRON_MASS,
112
+ "[M+H3N+H]+": ELEMENT_TO_MASS["N"] + ELEMENT_TO_MASS["H"] * 4 - ELECTRON_MASS,
113
+ "[M]+": 0 - ELECTRON_MASS,
114
+ "[M-H4O2+H]+": -ELEMENT_TO_MASS["O"] * 2 - ELEMENT_TO_MASS["H"] * 3 - ELECTRON_MASS,
115
+ "[M-H]-": ELEMENT_TO_MASS["H"] + ELECTRON_MASS,
116
+ }
117
+
118
+ ion_to_add_vec = {
119
+ "[M+H]+": element_to_position["H"],
120
+ "[M+Na]+": element_to_position["Na"],
121
+ "[M+K]+": element_to_position["K"],
122
+ "[M-H2O+H]+": -element_to_position["O"] - element_to_position["H"],
123
+ "[M+H3N+H]+": element_to_position["N"] + element_to_position["H"] * 4,
124
+ "[M]+": np.zeros_like(element_to_position["H"]),
125
+ "[M-H4O2+H]+": -element_to_position["O"] * 2 - element_to_position["H"] * 3,
126
+ }
127
+
128
+ instrument_to_type = defaultdict(lambda : "unknown")
129
+ instrument_to_type.update({
130
+ "Thermo Finnigan Velos Orbitrap": "orbitrap",
131
+ "Thermo Finnigan Elite Orbitrap": "orbitrap",
132
+ "Orbitrap Fusion Lumos": "orbitrap",
133
+ "Q-ToF (LCMS)": "qtof",
134
+ "Unknown (LCMS)": "unknown",
135
+ "ion trap": "iontrap",
136
+ "FTICR (LCMS)": "fticr",
137
+ "Bruker Q-ToF (LCMS)": "qtof",
138
+ "Orbitrap (LCMS)": "orbitrap",
139
+ })
140
+
141
+ instruments = sorted(list(set(instrument_to_type.values())))
142
+ max_instr_idx = len(instruments) + 1
143
+ instrument_to_idx = dict(zip(instruments, np.arange(len(instruments))))
144
+
145
+
146
+ # Define rdbe mult
147
+ rdbe_mult = np.zeros_like(ELEMENT_VECTORS[0])
148
+ els = ["C", "N", "P", "H", "Cl", "Br", "I", "F"]
149
+ weights = [2, 1, 1, -1, -1, -1, -1, -1]
150
+ for k, v in zip(els, weights):
151
+ rdbe_mult[element_to_ind[k]] = v
152
+
153
+
154
+ def get_ion_idx(ionization: str) -> int:
155
+ """map ionization to its index in one hot encoding"""
156
+ return ion_to_idx[ionization]
157
+
158
+
159
+ def get_instr_idx(instrument: str) -> int:
160
+ """map instrument to its index in one hot encoding"""
161
+ inst = instrument_to_type.get(instrument, "unknown")
162
+ return instrument_to_idx[inst]
163
+
164
+
165
+ def has_valid_els(chem_formula: str) -> bool:
166
+ """has_valid_els"""
167
+ for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
168
+ if chem_symbol not in VALID_ELEMENTS:
169
+ return False
170
+ return True
171
+
172
+
173
+ def formula_to_dense(chem_formula: str) -> np.ndarray:
174
+ """formula_to_dense.
175
+
176
+ Args:
177
+ chem_formula (str): Input chemical formal
178
+ Return:
179
+ np.ndarray of vector
180
+
181
+ """
182
+ total_onehot = []
183
+ for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
184
+ # Convert num to int
185
+ num = 1 if num == "" else int(num)
186
+ one_hot = element_to_position[chem_symbol].reshape(1, -1)
187
+ one_hot_repeats = np.repeat(one_hot, repeats=num, axis=0)
188
+ total_onehot.append(one_hot_repeats)
189
+
190
+ # Check if null
191
+ if len(total_onehot) == 0:
192
+ dense_vec = np.zeros(len(element_to_position))
193
+ else:
194
+ dense_vec = np.vstack(total_onehot).sum(0)
195
+ return dense_vec
196
+
197
+
198
+ def cross_sum(x, y):
199
+ """cross_sum."""
200
+ return (np.expand_dims(x, 0) + np.expand_dims(y, 1)).reshape(-1, y.shape[-1])
201
+
202
+
203
+ def get_all_subsets_dense(
204
+ dense_formula: str, element_vectors
205
+ ) -> (np.ndarray, np.ndarray):
206
+ """_summary_
207
+
208
+ Args:
209
+ dense_formula (str, element_vectors): _description_
210
+ np (_type_): _description_
211
+
212
+ Returns:
213
+ _type_: _description_
214
+ """
215
+
216
+ non_zero = np.argwhere(dense_formula > 0).flatten()
217
+
218
+ vectorized_formula = []
219
+ for nonzero_ind in non_zero:
220
+ temp = element_vectors[nonzero_ind] * np.arange(
221
+ 0, dense_formula[nonzero_ind] + 1
222
+ ).reshape(-1, 1)
223
+ vectorized_formula.append(temp)
224
+
225
+ zero_vec = np.zeros((1, element_vectors.shape[-1]))
226
+ cross_prod = reduce(cross_sum, vectorized_formula, zero_vec)
227
+
228
+ cross_prod_inds = rdbe_filter(cross_prod)
229
+ cross_prod = cross_prod[cross_prod_inds]
230
+ all_masses = cross_prod.dot(VALID_MONO_MASSES)
231
+ return cross_prod, all_masses
232
+
233
+
234
+ def get_all_subsets(chem_formula: str):
235
+ dense_formula = formula_to_dense(chem_formula)
236
+ return get_all_subsets_dense(dense_formula, element_vectors=ELEMENT_VECTORS)
237
+
238
+
239
+ def rdbe_filter(cross_prod):
240
+ """rdbe_filter.
241
+ Args:
242
+ cross_prod:
243
+ """
244
+ rdbe_total = 1 + 0.5 * cross_prod.dot(rdbe_mult)
245
+ filter_inds = np.argwhere(rdbe_total >= 0).flatten()
246
+ return filter_inds
247
+
248
+
249
+ def formula_to_dense(chem_formula: str) -> np.ndarray:
250
+ """formula_to_dense.
251
+
252
+ Args:
253
+ chem_formula (str): Input chemical formal
254
+ Return:
255
+ np.ndarray of vector
256
+
257
+ """
258
+ total_onehot = []
259
+ for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
260
+ # Convert num to int
261
+ num = 1 if num == "" else int(num)
262
+ one_hot = element_to_position[chem_symbol].reshape(1, -1)
263
+ one_hot_repeats = np.repeat(one_hot, repeats=num, axis=0)
264
+ total_onehot.append(one_hot_repeats)
265
+
266
+ # Check if null
267
+ if len(total_onehot) == 0:
268
+ dense_vec = np.zeros(len(element_to_position))
269
+ else:
270
+ dense_vec = np.vstack(total_onehot).sum(0)
271
+
272
+ return dense_vec
273
+
274
+
275
+ def formula_to_dense_mass(chem_formula: str) -> np.ndarray:
276
+ """formula_to_dense_mass.
277
+
278
+ Return formula including full compound mass
279
+
280
+ Args:
281
+ chem_formula (str): Input chemical formal
282
+ Return:
283
+ np.ndarray of vector
284
+
285
+ """
286
+ total_onehot = []
287
+ for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
288
+ # Convert num to int
289
+ num = 1 if num == "" else int(num)
290
+ one_hot = element_to_position_mass[chem_symbol].reshape(1, -1)
291
+ one_hot_repeats = np.repeat(one_hot, repeats=num, axis=0)
292
+ total_onehot.append(one_hot_repeats)
293
+
294
+ # Check if null
295
+ if len(total_onehot) == 0:
296
+ dense_vec = np.zeros(len(element_to_position_mass["H"]))
297
+ else:
298
+ dense_vec = np.vstack(total_onehot).sum(0)
299
+
300
+ return dense_vec
301
+
302
+
303
+ def formula_to_dense_mass_norm(chem_formula: str) -> np.ndarray:
304
+ """formula_to_dense_mass_norm.
305
+
306
+ Return formula including full compound mass and normalized
307
+
308
+ Args:
309
+ chem_formula (str): Input chemical formal
310
+ Return:
311
+ np.ndarray of vector
312
+
313
+ """
314
+ dense_vec = formula_to_dense_mass(chem_formula)
315
+ dense_vec = dense_vec / NORM_VEC_MASS
316
+
317
+ return dense_vec
318
+
319
+
320
+ def formula_mass(chem_formula: str) -> float:
321
+ """get formula mass"""
322
+ mass = 0
323
+ for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
324
+ # Convert num to int
325
+ num = 1 if num == "" else int(num)
326
+ mass += ELEMENT_TO_MASS[chem_symbol] * num
327
+ return mass
328
+
329
+
330
+ def electron_correct(mass: float) -> float:
331
+ """subtract the rest mass of an electron"""
332
+ return mass - ELECTRON_MASS
333
+
334
+
335
+ def formula_difference(formula_1, formula_2):
336
+ """formula_1 - formula_2"""
337
+ form_1 = {
338
+ chem_symbol: (int(num) if num != "" else 1)
339
+ for chem_symbol, num in re.findall(CHEM_FORMULA_SIZE, formula_1)
340
+ }
341
+ form_2 = {
342
+ chem_symbol: (int(num) if num != "" else 1)
343
+ for chem_symbol, num in re.findall(CHEM_FORMULA_SIZE, formula_2)
344
+ }
345
+
346
+ for k, v in form_2.items():
347
+ form_1[k] = form_1[k] - form_2[k]
348
+ out_formula = "".join([f"{k}{v}" for k, v in form_1.items() if v > 0])
349
+ return out_formula
350
+
351
+
352
+ def get_mol_from_structure_string(structure_string, structure_type):
353
+ if structure_type == "InChI":
354
+ mol = Chem.MolFromInchi(structure_string)
355
+ else:
356
+ mol = Chem.MolFromSmiles(structure_string)
357
+ return mol
358
+
359
+
360
+ def vec_to_formula(form_vec):
361
+ """vec_to_formula."""
362
+ build_str = ""
363
+ for i in np.argwhere(form_vec > 0).flatten():
364
+ el = VALID_ELEMENTS[i]
365
+ ct = int(form_vec[i])
366
+ new_item = f"{el}{ct}" if ct > 1 else f"{el}"
367
+ build_str = build_str + new_item
368
+ return build_str
369
+
370
+
371
+ def standardize_form(i):
372
+ """standardize_form."""
373
+ return vec_to_formula(formula_to_dense(i))
374
+
375
+
376
+ def standardize_adduct(adduct):
377
+ """standardize_adduct."""
378
+ adduct = adduct.replace(" ", "")
379
+ adduct = ion_remap.get(adduct, adduct)
380
+ if adduct not in ION_LST:
381
+ raise ValueError(f"Adduct {adduct} not in ION_LST")
382
+ return adduct
383
+
384
+
385
+ def calc_structure_string_type(structure_string):
386
+ """calc_structure_string_type.
387
+
388
+ Args:
389
+ structure_string:
390
+ """
391
+ structure_type = None
392
+ if pd.isna(structure_string):
393
+ structure_type = "empty"
394
+ elif structure_string.startswith("InChI="):
395
+ structure_type = "InChI"
396
+ elif Chem.MolFromSmiles(structure_string) is not None:
397
+ structure_type = "Smiles"
398
+ return structure_type
399
+
400
+
401
+ def uncharged_formula(mol, mol_type="mol") -> str:
402
+ """Compute uncharged formula"""
403
+ if mol_type == "mol":
404
+ chem_formula = CalcMolFormula(mol)
405
+ elif mol_type == "smiles":
406
+ mol = Chem.MolFromSmiles(mol)
407
+ if mol is None:
408
+ return None
409
+ chem_formula = CalcMolFormula(mol)
410
+ else:
411
+ raise ValueError()
412
+
413
+ return re.findall(r"^([^\+,^\-]*)", chem_formula)[0]
414
+
415
+
416
+ def form_from_smi(smi: str) -> str:
417
+ """form_from_smi.
418
+
419
+ Args:
420
+ smi (str): smi
421
+
422
+ Return:
423
+ str
424
+ """
425
+ mol = Chem.MolFromSmiles(smi)
426
+ if mol is None:
427
+ return ""
428
+ else:
429
+ return CalcMolFormula(mol)
430
+
431
+
432
+ def inchikey_from_smiles(smi: str) -> str:
433
+ """inchikey_from_smiles.
434
+
435
+ Args:
436
+ smi (str): smi
437
+
438
+ Returns:
439
+ str:
440
+ """
441
+ mol = Chem.MolFromSmiles(smi)
442
+ if mol is None:
443
+ return ""
444
+ else:
445
+ return Chem.MolToInchiKey(mol)
446
+
447
+
448
+ def contains_metals(formula: str) -> bool:
449
+ """returns true if formula contains metals"""
450
+ METAL_RE = "(Fe|Co|Zn|Rh|Pt|Li)"
451
+ return len(re.findall(METAL_RE, formula)) > 0
452
+
453
+
454
+ class SmilesStandardizer(object):
455
+ """Standardize smiles"""
456
+
457
+ def __init__(self, *args, **kwargs):
458
+ self.fragment_standardizer = rdMolStandardize.LargestFragmentChooser()
459
+ self.charge_standardizer = rdMolStandardize.Uncharger()
460
+
461
+ def standardize_smiles(self, smi):
462
+ """Standardize smiles string"""
463
+ mol = Chem.MolFromSmiles(smi)
464
+ out_smi = self.standardize_mol(mol)
465
+ return out_smi
466
+
467
+ def standardize_mol(self, mol) -> str:
468
+ """Standardize smiles string"""
469
+ mol = self.fragment_standardizer.choose(mol)
470
+ mol = self.charge_standardizer.uncharge(mol)
471
+
472
+ # Round trip to and from inchi to tautomer correct
473
+ # Also standardize tautomer in the middle
474
+ output_smi = Chem.MolToSmiles(mol, isomericSmiles=False)
475
+ return output_smi
476
+
477
+
478
+ def mass_from_smi(smi: str) -> float:
479
+ """mass_from_smi.
480
+
481
+ Args:
482
+ smi (str): smi
483
+
484
+ Return:
485
+ str
486
+ """
487
+ mol = Chem.MolFromSmiles(smi)
488
+ if mol is None:
489
+ return 0
490
+ else:
491
+ return ExactMolWt(mol)
492
+
493
+
494
+ def min_formal_from_smi(smi: str):
495
+ mol = Chem.MolFromSmiles(smi)
496
+ if mol is None:
497
+ return 0
498
+ else:
499
+ formal = np.array([j.GetFormalCharge() for j in mol.GetAtoms()])
500
+ return formal.min()
501
+
502
+
503
+ def max_formal_from_smi(smi: str):
504
+ mol = Chem.MolFromSmiles(smi)
505
+ if mol is None:
506
+ return 0
507
+ else:
508
+ formal = np.array([j.GetFormalCharge() for j in mol.GetAtoms()])
509
+ return formal.max()
510
+
511
+
512
+ def atoms_from_smi(smi: str) -> int:
513
+ """atoms_from_smi.
514
+
515
+ Args:
516
+ smi (str): smi
517
+
518
+ Return:
519
+ int
520
+ """
521
+ mol = Chem.MolFromSmiles(smi)
522
+ if mol is None:
523
+ return 0
524
+ else:
525
+ return mol.GetNumAtoms()
526
+
527
+
528
+ def has_valid_els(chem_formula: str) -> bool:
529
+ """has_valid_els"""
530
+ for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
531
+ if chem_symbol not in VALID_ELEMENTS:
532
+ return False
533
+ return True
534
+
535
+
536
+ def add_ion(form: str, ion: str):
537
+ """add_ion.
538
+ Args:
539
+ form (str): form
540
+ ion (str): ion
541
+ """
542
+ ion_vec = ion_to_add_vec[ion]
543
+ form_vec = formula_to_dense(form)
544
+ return vec_to_formula(form_vec + ion_vec)
545
+
546
+
547
+ def achiral_smi(smi: str) -> str:
548
+ """achiral_smi.
549
+
550
+ Return:
551
+ isomeric smiles
552
+
553
+ """
554
+ try:
555
+ mol = Chem.MolFromSmiles(smi)
556
+ if mol is not None:
557
+ smi = Chem.MolToSmiles(mol, isomericSmiles=False)
558
+ return smi
559
+ else:
560
+ return ""
561
+ except:
562
+ return ""
563
+
564
+
565
+ def npclassifer_query(inputs):
566
+ """npclassifier_query.
567
+
568
+ Args:
569
+ input: Tuple of name, molecule
570
+ Return:
571
+ Dict of name to molecule
572
+ """
573
+ import requests
574
+
575
+ spec = inputs[0]
576
+ endpoint = "https://npclassifier.ucsd.edu/classify"
577
+ req_data = {"smiles": inputs[1]}
578
+ out = requests.get(f"{endpoint}", data=req_data)
579
+ out.raise_for_status()
580
+ out_json = out.json()
581
+ return {spec: out_json}
582
+
583
+
584
+ def clipped_ppm(mass_diff: np.ndarray, parentmass: np.ndarray) -> np.ndarray:
585
+ """clipped_ppm.
586
+
587
+ Args:
588
+ mass_diff (np.ndarray): mass_diff
589
+ parentmass (np.ndarray): parentmass
590
+
591
+ Returns:
592
+ np.ndarray:
593
+ """
594
+ parentmass_copy = parentmass * 1
595
+ parentmass_copy[parentmass < 200] = 200
596
+ ppm = mass_diff / parentmass_copy * 1e6
597
+ return ppm
598
+
599
+
600
+ def clipped_ppm_single(
601
+ cls_mass_diff: float,
602
+ parentmass: float,
603
+ ):
604
+ """clipped_ppm_single.
605
+
606
+ Args:
607
+ cls_mass_diff (float): cls_mass_diff
608
+ parentmass (float): parentmass
609
+ """
610
+ div_factor = 200 if parentmass < 200 else parentmass
611
+ cls_ppm = cls_mass_diff / div_factor * 1e6
612
+ return cls_ppm
mvp/subformula_assign/utils/parallel_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """parallel_utils.py"""
2
+ import logging
3
+ from multiprocess.context import TimeoutError
4
+ from pathos import multiprocessing as mp
5
+ from tqdm import tqdm
6
+
7
+
8
+ def simple_parallel(
9
+ input_list, function, max_cpu=16, timeout=4000, max_retries=3, use_ray: bool = False
10
+ ):
11
+ """Simple parallelization.
12
+
13
+ Use map async and retries in case we get odd stalling behavior.
14
+
15
+ input_list: Input list to op on
16
+ function: Fn to apply
17
+ max_cpu: Num cpus
18
+ timeout: Length of timeout
19
+ max_retries: Num times to retry this
20
+ use_ray
21
+
22
+ """
23
+ # If ray is required. Set to false.
24
+ if use_ray and False:
25
+ import ray
26
+
27
+ @ray.remote
28
+ def ray_func(x):
29
+ return function(x)
30
+
31
+ return ray.get([ray_func.remote(x) for x in input_list])
32
+
33
+ from multiprocess.context import TimeoutError
34
+ from pathos import multiprocessing as mp
35
+
36
+ cpus = min(mp.cpu_count(), max_cpu)
37
+ pool = mp.Pool(processes=cpus)
38
+ results = pool.map(function, input_list)
39
+ pool.close()
40
+ pool.join()
41
+ return results
42
+
43
+
44
+ def chunked_parallel(
45
+ input_list, function, chunks=100, max_cpu=16, timeout=4000, max_retries=3
46
+ ):
47
+ """chunked_parallel.
48
+
49
+ Args:
50
+ input_list : list of objects to apply function
51
+ function : Callable with 1 input and returning a single value
52
+ chunks: number of hcunks
53
+ max_cpu: Max num cpus
54
+ timeout: Length of timeout
55
+ max_retries: Num times to retry this
56
+ """
57
+
58
+ # Adding it here fixes somessetting disrupted elsewhere
59
+
60
+ def batch_func(list_inputs):
61
+ outputs = []
62
+ for i in list_inputs:
63
+ outputs.append(function(i))
64
+ return outputs
65
+
66
+ list_len = len(input_list)
67
+ num_chunks = min(list_len, chunks)
68
+ step_size = len(input_list) // num_chunks + 1
69
+
70
+ chunked_list = [
71
+ input_list[i : i + step_size] for i in range(0, len(input_list), step_size)
72
+ ]
73
+
74
+ list_outputs = simple_parallel(
75
+ chunked_list,
76
+ batch_func,
77
+ max_cpu=max_cpu,
78
+ timeout=timeout,
79
+ max_retries=max_retries,
80
+ )
81
+ # Unroll
82
+ full_output = [j for i in list_outputs for j in i]
83
+
84
+ return full_output
mvp/subformula_assign/utils/parse_utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ parse_utils.py """
2
+ from pathlib import Path
3
+ from typing import Tuple, List, Optional
4
+ from itertools import groupby
5
+
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+
11
+ def parse_spectra(spectra_file: str) -> Tuple[dict, List[Tuple[str, np.ndarray]]]:
12
+ """parse_spectra.
13
+
14
+ Parses spectra in the SIRIUS format and returns
15
+
16
+ Args:
17
+ spectra_file (str): Name of spectra file to parse
18
+ Return:
19
+ Tuple[dict, List[Tuple[str, np.ndarray]]]: metadata and list of spectra
20
+ tuples containing name and array
21
+ """
22
+ lines = [i.strip() for i in open(spectra_file, "r").readlines()]
23
+
24
+ group_num = 0
25
+ metadata = {}
26
+ spectras = []
27
+ my_iterator = groupby(
28
+ lines, lambda line: line.startswith(">") or line.startswith("#")
29
+ )
30
+
31
+ for index, (start_line, lines) in enumerate(my_iterator):
32
+ group_lines = list(lines)
33
+ subject_lines = list(next(my_iterator)[1])
34
+ # Get spectra
35
+ if group_num > 0:
36
+ spectra_header = group_lines[0].split(">")[1]
37
+ peak_data = [
38
+ [float(x) for x in peak.split()[:2]]
39
+ for peak in subject_lines
40
+ if peak.strip()
41
+ ]
42
+ # Check if spectra is empty
43
+ if len(peak_data):
44
+ peak_data = np.vstack(peak_data)
45
+ # Add new tuple
46
+ spectras.append((spectra_header, peak_data))
47
+ # Get meta data
48
+ else:
49
+ entries = {}
50
+ for i in group_lines:
51
+ if " " not in i:
52
+ continue
53
+ elif i.startswith("#INSTRUMENT TYPE"):
54
+ key = "#INSTRUMENT TYPE"
55
+ val = i.split(key)[1].strip()
56
+ entries[key[1:]] = val
57
+ else:
58
+ start, end = i.split(" ", 1)
59
+ start = start[1:]
60
+ while start in entries:
61
+ start = f"{start}'"
62
+ entries[start] = end
63
+
64
+ metadata.update(entries)
65
+ group_num += 1
66
+
67
+ metadata["_FILE_PATH"] = spectra_file
68
+ metadata["_FILE"] = Path(spectra_file).stem
69
+ return metadata, spectras
70
+
71
+
72
+ def spec_to_ms_str(
73
+ spec: List[Tuple[str, np.ndarray]], essential_keys: dict, comments: dict = {}
74
+ ) -> str:
75
+ """spec_to_ms_str.
76
+
77
+ Turn spec ars and info dicts into str for output file
78
+
79
+
80
+ Args:
81
+ spec (List[Tuple[str, np.ndarray]]): spec
82
+ essential_keys (dict): essential_keys
83
+ comments (dict): comments
84
+
85
+ Returns:
86
+ str:
87
+ """
88
+
89
+ def pair_rows(rows):
90
+ return "\n".join([f"{i} {j}" for i, j in rows])
91
+
92
+ header = "\n".join(f">{k} {v}" for k, v in essential_keys.items())
93
+ comments = "\n".join(f"#{k} {v}" for k, v in essential_keys.items())
94
+ spec_strs = [f">{name}\n{pair_rows(ar)}" for name, ar in spec]
95
+ spec_str = "\n\n".join(spec_strs)
96
+ output = f"{header}\n{comments}\n\n{spec_str}"
97
+ return output
98
+
99
+
100
+ def build_mgf_str(
101
+ meta_spec_list: List[Tuple[dict, List[Tuple[str, np.ndarray]]]],
102
+ merge_charges=True,
103
+ parent_mass_keys=["PEPMASS", "parentmass", "PRECURSOR_MZ"],
104
+ ) -> str:
105
+ """build_mgf_str.
106
+
107
+ Args:
108
+ meta_spec_list (List[Tuple[dict, List[Tuple[str, np.ndarray]]]]): meta_spec_list
109
+
110
+ Returns:
111
+ str:
112
+ """
113
+ entries = []
114
+ for meta, spec in tqdm(meta_spec_list):
115
+ str_rows = ["BEGIN IONS"]
116
+
117
+ # Try to add precusor mass
118
+ for i in parent_mass_keys:
119
+ if i in meta:
120
+ pep_mass = float(meta.get(i, -100))
121
+ str_rows.append(f"PEPMASS={pep_mass}")
122
+ break
123
+
124
+ for k, v in meta.items():
125
+ str_rows.append(f"{k.upper().replace(' ', '_')}={v}")
126
+
127
+ if merge_charges:
128
+ spec_ar = np.vstack([i[1] for i in spec])
129
+ spec_ar = np.vstack([i for i in sorted(spec_ar, key=lambda x: x[0])])
130
+ else:
131
+ raise NotImplementedError()
132
+ str_rows.extend([f"{i} {j}" for i, j in spec_ar])
133
+ str_rows.append("END IONS")
134
+
135
+ str_out = "\n".join(str_rows)
136
+ entries.append(str_out)
137
+
138
+ full_out = "\n\n".join(entries)
139
+ return full_out
140
+
141
+
142
+ def parse_spectra_msp(
143
+ mgf_file: str, max_num: Optional[int] = None
144
+ ) -> List[Tuple[dict, List[Tuple[str, np.ndarray]]]]:
145
+ """parse_spectr_msp.
146
+
147
+ Parses spectra in the MSP file format
148
+
149
+ Args:
150
+ mgf_file (str) : str
151
+ max_num (Optional[int]): If set, only parse this many
152
+ Return:
153
+ List[Tuple[dict, List[Tuple[str, np.ndarray]]]]: metadata and list of spectra
154
+ tuples containing name and array
155
+ """
156
+
157
+ key = lambda x: x.strip().startswith("PEPMASS")
158
+ parsed_spectra = []
159
+ with open(mgf_file, "r", encoding="utf-8") as fp:
160
+ for (is_header, group) in tqdm(groupby(fp, key)):
161
+
162
+ if is_header:
163
+ continue
164
+ meta = dict()
165
+ spectra = []
166
+ # Note: Sometimes we have multiple scans
167
+ # This mgf has them collapsed
168
+ cur_spectra_name = "spec"
169
+ cur_spectra = []
170
+ group = list(group)
171
+ for line in group:
172
+ line = line.strip()
173
+ if not line:
174
+ pass
175
+ elif ":" in line:
176
+ k, v = [i.strip() for i in line.split(":", 1)]
177
+ meta[k] = v
178
+ else:
179
+ mz, intens = line.split()
180
+ cur_spectra.append((float(mz), float(intens)))
181
+
182
+ if len(cur_spectra) > 0:
183
+ cur_spectra = np.vstack(cur_spectra)
184
+ spectra.append((cur_spectra_name, cur_spectra))
185
+ parsed_spectra.append((meta, spectra))
186
+ else:
187
+ pass
188
+ # print("no spectra found for group: ", "".join(group))
189
+
190
+ if max_num is not None and len(parsed_spectra) > max_num:
191
+ # print("Breaking")
192
+ break
193
+ return parsed_spectra
194
+
195
+
196
+ def parse_spectra_mgf(
197
+ mgf_file: str, max_num: Optional[int] = None
198
+ ) -> List[Tuple[dict, List[Tuple[str, np.ndarray]]]]:
199
+ """parse_spectr_mgf.
200
+
201
+ Parses spectra in the MGF file formate, with
202
+
203
+ Args:
204
+ mgf_file (str) : str
205
+ max_num (Optional[int]): If set, only parse this many
206
+ Return:
207
+ List[Tuple[dict, List[Tuple[str, np.ndarray]]]]: metadata and list of spectra
208
+ tuples containing name and array
209
+ """
210
+
211
+ key = lambda x: x.strip() == "BEGIN IONS"
212
+ parsed_spectra = []
213
+ with open(mgf_file, "r") as fp:
214
+
215
+ for (is_header, group) in tqdm(groupby(fp, key)):
216
+
217
+ if is_header:
218
+ continue
219
+
220
+ meta = dict()
221
+ spectra = []
222
+ # Note: Sometimes we have multiple scans
223
+ # This mgf has them collapsed
224
+ cur_spectra_name = "spec"
225
+ cur_spectra = []
226
+ group = list(group)
227
+ for line in group:
228
+ line = line.strip()
229
+ if not line:
230
+ pass
231
+ elif line == "END IONS" or line == "BEGIN IONS":
232
+ pass
233
+ elif "=" in line:
234
+ k, v = [i.strip() for i in line.split("=", 1)]
235
+ meta[k] = v
236
+ else:
237
+ mz, intens = line.split()
238
+ cur_spectra.append((float(mz), float(intens)))
239
+
240
+ if len(cur_spectra) > 0:
241
+ cur_spectra = np.vstack(cur_spectra)
242
+ spectra.append((cur_spectra_name, cur_spectra))
243
+ parsed_spectra.append((meta, spectra))
244
+ else:
245
+ pass
246
+ # print("no spectra found for group: ", "".join(group))
247
+
248
+ if max_num is not None and len(parsed_spectra) > max_num:
249
+ # print("Breaking")
250
+ break
251
+ return parsed_spectra
252
+
253
+
254
+ def parse_tsv_spectra(spectra_file: str) -> List[Tuple[str, np.ndarray]]:
255
+ """parse_tsv_spectra.
256
+
257
+ Parses spectra returned from sirius fragmentation tree
258
+
259
+ Args:
260
+ spectra_file (str): Name of spectra tsv file to parse
261
+ Return:
262
+ List[Tuple[str, np.ndarray]]]: list of spectra
263
+ tuples containing name and array. This is used to maintain
264
+ consistency with the parse_spectra output
265
+ """
266
+ output_spec = []
267
+ with open(spectra_file, "r") as fp:
268
+ for index, line in enumerate(fp):
269
+ if index == 0:
270
+ continue
271
+ line = line.strip().split("\t")
272
+ intensity = float(line[1])
273
+ exact_mass = float(line[3])
274
+ output_spec.append([exact_mass, intensity])
275
+
276
+ output_spec = np.array(output_spec)
277
+ return_obj = [("sirius_spec", output_spec)]
278
+ return return_obj
279
+
280
+ # YZC parse msgym-like formatted data
281
+ def parse_spectra_msgym(df):
282
+
283
+ parsed_spectra = []
284
+ for _, row in df.iterrows():
285
+ mzs = [float(m) for m in row['mzs'].split(',')]
286
+ intensities = [float(i) for i in row['intensities'].split(',')]
287
+ cur_spectra = [(m, i) for m, i in zip(mzs, intensities)]
288
+ cur_spectra = np.vstack(cur_spectra)
289
+ cur_spectra_name = row['spec']
290
+ meta = {'ID': cur_spectra_name,
291
+ 'parentmass': row['parent_mass']}
292
+ parsed_spectra.append((meta, [(cur_spectra_name, cur_spectra)]))
293
+ return parsed_spectra
294
+
295
+
mvp/subformula_assign/utils/spectra_utils.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ spectra_utils.py"""
2
+ import logging
3
+ import numpy as np
4
+ from typing import List
5
+
6
+
7
+ from .chem_utils import (
8
+ vec_to_formula,
9
+ get_all_subsets,
10
+ ion_to_mass,
11
+ ION_LST,
12
+ clipped_ppm,
13
+ )
14
+
15
+
16
+ def bin_spectra(
17
+ spectras: List[np.ndarray], num_bins: int = 2000, upper_limit: int = 1000
18
+ ) -> np.ndarray:
19
+ """bin_spectra.
20
+
21
+ Args:
22
+ spectras (List[np.ndarray]): Input list of spectra tuples
23
+ [(header, spec array)]
24
+ num_bins (int): Number of discrete bins from [0, upper_limit)
25
+ upper_limit (int): Max m/z to consider featurizing
26
+
27
+ Return:
28
+ np.ndarray of shape [channels, num_bins]
29
+ """
30
+ bins = np.linspace(0, upper_limit, num=num_bins)
31
+ binned_spec = np.zeros((len(spectras), len(bins)))
32
+ for spec_index, spec in enumerate(spectras):
33
+
34
+ # Convert to digitized spectra
35
+ digitized_mz = np.digitize(spec[:, 0], bins=bins)
36
+
37
+ # Remove all spectral peaks out of range
38
+ in_range = digitized_mz < len(bins)
39
+ digitized_mz, spec = digitized_mz[in_range], spec[in_range, :]
40
+
41
+ # Add the current peaks to the spectra
42
+ # Use a loop rather than vectorize because certain bins have conflicts
43
+ # based upon resolution
44
+ for bin_index, spec_val in zip(digitized_mz, spec[:, 1]):
45
+ binned_spec[spec_index, bin_index] += spec_val
46
+
47
+ return binned_spec
48
+
49
+
50
+ def merge_norm_spectra(spec_tuples, precision=4) -> np.ndarray:
51
+ """merge_norm_spectra.
52
+
53
+ Take a list of mz, inten tuple arrays and merge them by 4 digit precision
54
+
55
+ Note this uses _max_ merging
56
+
57
+ """
58
+ mz_to_inten_pair = {}
59
+ for i in spec_tuples:
60
+ for tup in i:
61
+ mz, inten = tup
62
+ mz_ind = np.round(mz, precision)
63
+ cur_pair = mz_to_inten_pair.get(mz_ind)
64
+ if cur_pair is None:
65
+ mz_to_inten_pair[mz_ind] = tup
66
+ elif inten > cur_pair[1]:
67
+ mz_to_inten_pair[mz_ind] = (mz_ind, inten)
68
+ else:
69
+ pass
70
+
71
+ merged_spec = np.vstack([v for k, v in mz_to_inten_pair.items()])
72
+ merged_spec[:, 1] = merged_spec[:, 1] / merged_spec[:, 1].max()
73
+ return merged_spec
74
+
75
+
76
+ def norm_spectrum(binned_spec: np.ndarray) -> np.ndarray:
77
+ """norm_spectrum.
78
+
79
+ Normalizes each spectral channel to have norm 1
80
+ This change is made in place
81
+
82
+ Args:
83
+ binned_spec (np.ndarray) : Vector of spectras
84
+
85
+ Return:
86
+ np.ndarray where each channel has max(1)
87
+ """
88
+
89
+ spec_maxes = binned_spec.max(1)
90
+
91
+ non_zero_max = spec_maxes > 0
92
+
93
+ spec_maxes = spec_maxes[non_zero_max]
94
+ binned_spec[non_zero_max] = binned_spec[non_zero_max] / spec_maxes.reshape(-1, 1)
95
+
96
+ return binned_spec
97
+
98
+
99
+ def process_spec_file(meta, tuples, precision=4, max_inten=0.001, max_peaks=60):
100
+ """process_spec_file."""
101
+
102
+ if "parentmass" in meta:
103
+ parentmass = meta.get("parentmass", None)
104
+ elif "PARENTMASS" in meta:
105
+ parentmass = meta.get("PARENTMASS", None)
106
+ elif "PEPMASS" in meta:
107
+ parentmass = meta.get("PEPMASS", None)
108
+ else:
109
+ logging.debug(f"missing parentmass for spec")
110
+ parentmass = 1000000
111
+
112
+ parentmass = float(parentmass)
113
+
114
+ # First norm spectra
115
+ fused_tuples = [x for _, x in tuples if x.size > 0]
116
+
117
+ if len(fused_tuples) == 0:
118
+ return
119
+
120
+ mz_to_inten_pair = {}
121
+ new_tuples = []
122
+ for i in fused_tuples:
123
+ for tup in i:
124
+ mz, inten = tup
125
+ mz_ind = np.round(mz, precision)
126
+ cur_pair = mz_to_inten_pair.get(mz_ind)
127
+ if cur_pair is None:
128
+ mz_to_inten_pair[mz_ind] = tup
129
+ new_tuples.append(tup)
130
+ elif inten > cur_pair[1]:
131
+ cur_pair[1] = inten
132
+ else:
133
+ pass
134
+
135
+ merged_spec = np.vstack(new_tuples)
136
+ merged_spec = merged_spec[merged_spec[:, 0] <= (parentmass + 1)] # could end up removing all peaks?
137
+ try:
138
+ merged_spec[:, 1] = merged_spec[:, 1] / merged_spec[:, 1].max()
139
+ except:
140
+ return
141
+
142
+ # Sqrt intensities here
143
+ merged_spec[:, 1] = np.sqrt(merged_spec[:, 1])
144
+
145
+ merged_spec = max_inten_spec(
146
+ merged_spec, max_num_inten=max_peaks, inten_thresh=max_inten
147
+ )
148
+ return merged_spec
149
+
150
+
151
+ def max_inten_spec(spec, max_num_inten: int = 60, inten_thresh: float = 0):
152
+ """max_inten_spec.
153
+
154
+ Args:
155
+ spec: 2D spectra array
156
+ max_num_inten: Max number of peaks
157
+ inten_thresh: Min intensity to alloow in returned peak
158
+
159
+ Return:
160
+ Spec filtered down
161
+
162
+
163
+ """
164
+ spec_masses, spec_intens = spec[:, 0], spec[:, 1]
165
+
166
+ # Make sure to only take max of each formula
167
+ # Sort by intensity and select top subpeaks
168
+ new_sort_order = np.argsort(spec_intens)[::-1]
169
+ if max_num_inten is not None:
170
+ new_sort_order = new_sort_order[:max_num_inten]
171
+
172
+ spec_masses = spec_masses[new_sort_order]
173
+ spec_intens = spec_intens[new_sort_order]
174
+
175
+ spec_mask = spec_intens > inten_thresh
176
+ spec_masses = spec_masses[spec_mask]
177
+ spec_intens = spec_intens[spec_mask]
178
+ spec = np.vstack([spec_masses, spec_intens]).transpose(1, 0)
179
+ return spec
180
+
181
+
182
+ def max_thresh_spec(spec: np.ndarray, max_peaks=100, inten_thresh=0.003):
183
+ """max_thresh_spec.
184
+
185
+ Args:
186
+ spec (np.ndarray): spec
187
+ max_peaks: Max num peaks to keep
188
+ inten_thresh: Min inten to keep
189
+ """
190
+
191
+ spec_masses, spec_intens = spec[:, 0], spec[:, 1]
192
+
193
+ # Make sure to only take max of each formula
194
+ # Sort by intensity and select top subpeaks
195
+ new_sort_order = np.argsort(spec_intens)[::-1]
196
+ new_sort_order = new_sort_order[:max_peaks]
197
+
198
+ spec_masses = spec_masses[new_sort_order]
199
+ spec_intens = spec_intens[new_sort_order]
200
+
201
+ spec_mask = spec_intens > inten_thresh
202
+ spec_masses = spec_masses[spec_mask]
203
+ spec_intens = spec_intens[spec_mask]
204
+ out_ar = np.vstack([spec_masses, spec_intens]).transpose(1, 0)
205
+ return out_ar
206
+
207
+
208
+ def assign_subforms(form, spec, ion_type, mass_diff_thresh=15):
209
+ """_summary_
210
+
211
+ Args:
212
+ form (_type_): _description_
213
+ spec (_type_): _description_
214
+ ion_type (_type_): _description_
215
+ mass_diff_thresh (int, optional): _description_. Defaults to 15.
216
+
217
+ Returns:
218
+ _type_: _description_
219
+ """
220
+ try:
221
+ cross_prod, masses = get_all_subsets(form)
222
+ spec_masses, spec_intens = spec[:, 0], spec[:, 1]
223
+
224
+ ion_masses = ion_to_mass[ion_type]
225
+ masses_with_ion = masses + ion_masses
226
+ ion_types = np.array([ion_type] * len(masses_with_ion))
227
+
228
+ mass_diffs = np.abs(spec_masses[:, None] - masses_with_ion[None, :])
229
+
230
+ formula_inds = mass_diffs.argmin(-1)
231
+ min_mass_diff = mass_diffs[np.arange(len(mass_diffs)), formula_inds]
232
+ rel_mass_diff = clipped_ppm(min_mass_diff, spec_masses)
233
+
234
+ # Filter by mass diff threshold (ppm)
235
+ valid_mask = rel_mass_diff < mass_diff_thresh
236
+ spec_masses = spec_masses[valid_mask]
237
+ spec_intens = spec_intens[valid_mask]
238
+ min_mass_diff = min_mass_diff[valid_mask]
239
+ rel_mass_diff = rel_mass_diff[valid_mask]
240
+ formula_inds = formula_inds[valid_mask]
241
+
242
+ formulas = np.array([vec_to_formula(j) for j in cross_prod[formula_inds]])
243
+ formula_masses = masses_with_ion[formula_inds]
244
+ ion_types = ion_types[formula_inds]
245
+
246
+ # Build mask for uniqueness on formula and ionization
247
+ # note that ionization are all the same for one subformula assignment
248
+ # hence we only need to consider the uniqueness of the formula
249
+ formula_idx_dict = {}
250
+ uniq_mask = []
251
+ for idx, formula in enumerate(formulas):
252
+ uniq_mask.append(formula not in formula_idx_dict)
253
+ gather_ind = formula_idx_dict.get(formula, None)
254
+ if gather_ind is None:
255
+ continue
256
+ spec_intens[gather_ind] += spec_intens[idx]
257
+ formula_idx_dict[formula] = idx
258
+
259
+ spec_masses = spec_masses[uniq_mask]
260
+ spec_intens = spec_intens[uniq_mask]
261
+ min_mass_diff = min_mass_diff[uniq_mask]
262
+ rel_mass_diff = rel_mass_diff[uniq_mask]
263
+ formula_masses = formula_masses[uniq_mask]
264
+ formulas = formulas[uniq_mask]
265
+ ion_types = ion_types[uniq_mask]
266
+
267
+ # To calculate explained intensity, preserve the original normalized
268
+ # intensity
269
+ if spec_intens.size == 0:
270
+ output_tbl = None
271
+ else:
272
+ output_tbl = {
273
+ "mz": list(spec_masses),
274
+ "ms2_inten": list(spec_intens),
275
+ "mono_mass": list(formula_masses),
276
+ "abs_mass_diff": list(min_mass_diff),
277
+ "mass_diff": list(rel_mass_diff),
278
+ "formula": list(formulas),
279
+ "ions": list(ion_types),
280
+ }
281
+ except:
282
+ output_tbl = None
283
+ print(f"failed to process formula {form}")
284
+ pass
285
+ output_dict = {
286
+ "cand_form": form,
287
+ "cand_ion": ion_type,
288
+ "output_tbl": output_tbl,
289
+ }
290
+ return output_dict
291
+
292
+
293
+ def get_output_dict(
294
+ spec_name: str,
295
+ spec: np.ndarray,
296
+ form: str,
297
+ mass_diff_type: str,
298
+ mass_diff_thresh: float,
299
+ ion_type: str,
300
+ ) -> dict:
301
+ """_summary_
302
+
303
+ This function attemps to take an array of mass intensity values and assign
304
+ formula subsets to subpeaks
305
+
306
+ Args:
307
+ spec_name (str): _description_
308
+ spec (np.ndarray): _description_
309
+ form (str): _description_
310
+ mass_diff_type (str): _description_
311
+ mass_diff_thresh (float): _description_
312
+ ion_type (str): _description_
313
+
314
+ Returns:
315
+ dict: _description_
316
+ """
317
+ assert mass_diff_type == "ppm"
318
+ # This is the case for some erroneous MS2 files for which proc_spec_file return None
319
+ # All the MS2 subpeaks in these erroneous MS2 files has mz larger than parentmass
320
+ output_dict = {"cand_form": form, "cand_ion": ion_type, "output_tbl": None}
321
+ if spec is not None and ion_type in ION_LST:
322
+ output_dict = assign_subforms(
323
+ form, spec, ion_type, mass_diff_thresh=mass_diff_thresh
324
+ )
325
+ return output_dict
mvp/test.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import sys
4
+ sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
5
+ sys.path.insert(0, "/data/yzhouc01/MVP")
6
+
7
+ from rdkit import RDLogger
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import Trainer
10
+ from massspecgym.models.base import Stage
11
+ import os
12
+
13
+ from mvp.data.data_module import TestDataModule
14
+ from mvp.data.datasets import ContrastiveDataset
15
+ from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
16
+ from mvp.utils.models import get_model
17
+
18
+ from mvp.definitions import TEST_RESULTS_DIR
19
+ import yaml
20
+ from functools import partial
21
+ # Suppress RDKit warnings and errors
22
+ lg = RDLogger.logger()
23
+ lg.setLevel(RDLogger.CRITICAL)
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
27
+ parser.add_argument('--checkpoint_pth', type=str, default='')
28
+ parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['train', 'val'])
29
+ parser.add_argument('--df_test_pth', type=str, help='result file name')
30
+ parser.add_argument('--exp_dir', type=str)
31
+ parser.add_argument('--candidates_pth', type=str)
32
+ def main(params):
33
+ # Seed everything
34
+ pl.seed_everything(params['seed'])
35
+
36
+ # Init paths to data files
37
+ if params['debug']:
38
+ params['dataset_pth'] = "../data/sample/data.tsv"
39
+ params['split_pth']=None
40
+ params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')
41
+
42
+ # Load dataset
43
+ spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
44
+ mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
45
+ dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
46
+
47
+ # Init data module
48
+ collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], stage=Stage.TEST)
49
+ data_module = TestDataModule(
50
+ dataset=dataset,
51
+ collate_fn=collate_fn,
52
+ split_pth=params['split_pth'],
53
+ batch_size=params['batch_size'],
54
+ num_workers=params['num_workers']
55
+ )
56
+
57
+ model = get_model(params['model'], params)
58
+ model.df_test_path = params['df_test_path']
59
+
60
+ # Init trainer
61
+ trainer = Trainer(
62
+ accelerator=params['accelerator'],
63
+ devices=params['devices'],
64
+ default_root_dir=params['experiment_dir']
65
+ )
66
+
67
+ # Prepare data module to test
68
+ data_module.prepare_data()
69
+ data_module.setup(stage="test")
70
+
71
+ # Test
72
+ trainer.test(model, datamodule=data_module)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ args = parser.parse_args([] if "__file__" not in globals() else None)
77
+
78
+ # Load
79
+ with open(args.param_pth) as f:
80
+ params = yaml.load(f, Loader=yaml.FullLoader)
81
+
82
+ # Experiment directory
83
+ if args.exp_dir:
84
+ exp_dir = args.exp_dir
85
+ else:
86
+ run_name = params['run_name']
87
+ for exp in os.listdir(TEST_RESULTS_DIR): # find exp dir with matching run_name
88
+ if exp.endswith("_"+run_name):
89
+ exp_dir = str(TEST_RESULTS_DIR / exp)
90
+ break
91
+ if not exp_dir:
92
+ now = datetime.datetime.now().strftime("%Y%m%d")
93
+ exp_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}")
94
+ os.makedirs(exp_dir, exist_ok=True)
95
+ print("EXPERIMENT directory: ",exp_dir)
96
+ params['experiment_dir'] = exp_dir
97
+
98
+ # Checkpoint path
99
+ if args.checkpoint_pth:
100
+ params['checkpoint_pth'] = args.checkpoint_pth
101
+
102
+ if not params['checkpoint_pth']:
103
+ print("No checkpoint provided. Using the checkpoint in the experiment directory")
104
+ for f in os.listdir(exp_dir):
105
+ if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
106
+ checkpoint_path = os.path.join(exp_dir, f)
107
+ params['checkpoint_pth'] = checkpoint_path
108
+ break
109
+ assert(params['checkpoint_pth'] != '')
110
+
111
+ if args.candidates_pth:
112
+ params['candidates_pth'] = args.candidates_pth
113
+ if args.df_test_pth:
114
+ params['df_test_path'] = os.path.join(exp_dir, args.df_test_pth)
115
+ if not params['df_test_path']:
116
+ params['df_test_path'] = os.path.join(exp_dir, f"result_{params['candidates_pth'].split('/')[-1].split('.')[0]}.pkl")
117
+
118
+ main(params)
mvp/train.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+
4
+ import os
5
+ import sys
6
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+ from rdkit import RDLogger
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning import Trainer
11
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
12
+
13
+
14
+ from mvp.data.data_module import ContrastiveDataModule
15
+
16
+ from mvp.definitions import TEST_RESULTS_DIR
17
+ import yaml
18
+ from mvp.data.datasets import ContrastiveDataset
19
+ from functools import partial
20
+
21
+ from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
22
+ from mvp.utils.models import get_model
23
+ # Suppress RDKit warnings and errors
24
+ lg = RDLogger.logger()
25
+ lg.setLevel(RDLogger.CRITICAL)
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
29
+
30
+ def main(params):
31
+ # Seed everything
32
+ pl.seed_everything(params['seed'])
33
+
34
+ # Init paths to data files
35
+ if params['debug']:
36
+ params['dataset_pth'] = "../data/sample/data.tsv"
37
+ params['candidates_pth'] =None
38
+ params['split_pth']=None
39
+
40
+ # Load dataset
41
+ spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
42
+ mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
43
+ dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
44
+
45
+ # Init data module
46
+ collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], mask_peak_ratio=params['mask_peak_ratio'], aug_cands=params['aug_cands'])
47
+ data_module = ContrastiveDataModule(
48
+ dataset=dataset,
49
+ collate_fn=collate_fn,
50
+ split_pth=params['split_pth'],
51
+ batch_size=params['batch_size'],
52
+ num_workers=params['num_workers'],
53
+ )
54
+
55
+ model = get_model(params['model'], params)
56
+
57
+ # Init logger
58
+ if params['no_wandb']:
59
+ logger = None
60
+ else:
61
+ logger = pl.loggers.WandbLogger(
62
+ save_dir=params['experiment_dir'],
63
+ dir=params['experiment_dir'],
64
+ log_dir=params['experiment_dir'],
65
+ name=params['run_name'],
66
+ project=params['project_name'],
67
+ log_model=False,
68
+ config=model.hparams
69
+ )
70
+
71
+ # Init callbacks for checkpointing and early stopping
72
+ callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
73
+ for i, monitor in enumerate(model.get_checkpoint_monitors()):
74
+ monitor_name = monitor['monitor']
75
+ checkpoint = pl.callbacks.ModelCheckpoint(
76
+ monitor=monitor_name,
77
+ save_top_k=1,
78
+ mode=monitor['mode'],
79
+ dirpath=params['experiment_dir'],
80
+ filename=f'{{epoch}}-{{{monitor_name}:.2f}}',
81
+ # filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
82
+ auto_insert_metric_name=True,
83
+ save_last=(i == 0)
84
+ )
85
+ callbacks.append(checkpoint)
86
+ if monitor.get('early_stopping', False):
87
+ early_stopping = EarlyStopping(
88
+ monitor=monitor_name,
89
+ mode=monitor['mode'],
90
+ verbose=True,
91
+ patience=params['early_stopping_patience'],
92
+ )
93
+ callbacks.append(early_stopping)
94
+
95
+ # Init trainer
96
+ trainer = Trainer(
97
+ accelerator=params['accelerator'],
98
+ devices=params['devices'],
99
+ max_epochs=params['max_epochs'],
100
+ logger=logger,
101
+ log_every_n_steps=params['log_every_n_steps'],
102
+ val_check_interval=params['val_check_interval'],
103
+ callbacks=callbacks,
104
+ default_root_dir=params['experiment_dir'],
105
+ )
106
+
107
+ # Prepare data module to validate or test before training
108
+ data_module.prepare_data()
109
+ data_module.setup()
110
+
111
+
112
+ # Validate before training
113
+ trainer.validate(model, datamodule=data_module)
114
+
115
+ # Train
116
+ trainer.fit(model, datamodule=data_module)
117
+
118
+
119
+
120
+ if __name__ == "__main__":
121
+ args = parser.parse_args([] if "__file__" not in globals() else None)
122
+
123
+ # Get current time
124
+ now = datetime.datetime.now()
125
+ now_formatted = now.strftime("%Y%m%d")
126
+
127
+ # Load
128
+ with open(args.param_pth) as f:
129
+ params = yaml.load(f, Loader=yaml.FullLoader)
130
+
131
+ experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
132
+ params['experiment_dir'] = experiment_dir
133
+
134
+ if not params['df_test_path']:
135
+ params['df_test_path'] = os.path.join(experiment_dir, "result.pkl")
136
+
137
+ main(params)
mvp/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
3
+ from massspecgym.utils import *
mvp/utils/data.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+
5
+ from mvp.data.transforms import SpecBinner, SpecBinnerLog, SpecFormulaFeaturizer, SpecFormulaMzFeaturizer, SpecMzIntTokenizer
6
+ from massspecgym.data.transforms import SpecTransform, MolTransform
7
+ from mvp.data.transforms import MolToGraph
8
+ import mvp.data.datasets as jestr_datasets
9
+ import typing as T
10
+ from mvp.definitions import MSGYM_FORMULA_VECTOR_NORM, MSGYM_STANDARD_MH
11
+ import matchms
12
+
13
+ class Subformula_Loader:
14
+ def __init__(self, spectra_view, dir_path) -> None:
15
+
16
+ self.dir_path = dir_path
17
+ if spectra_view == 'SpecFormula':
18
+ self.load = self.load_subformula_data
19
+ elif spectra_view == "SpecFormulaMz":
20
+ self.load = self.load_subformula_dict
21
+ else:
22
+ raise Exception("Spectra view is not supported.")
23
+
24
+ def __call__(self, ids):
25
+ id_to_form_spec = {}
26
+ for id in ids:
27
+ data = self.load(id)
28
+ if data:
29
+ id_to_form_spec[id] = data
30
+
31
+ return id_to_form_spec
32
+
33
+ def load_subformula_data(self, spec_id: str):
34
+ '''MIST subformula format:https://github.com/samgoldman97/mist/blob/main_v2/src/mist/utils/spectra_utils.py
35
+ '''
36
+ try:
37
+ file = os.path.join(self.dir_path, spec_id+".json")
38
+ with open(file) as f:
39
+ data = json.load(f)
40
+ mzs = np.array(data['output_tbl']['mz'])
41
+ formulas = np.array(data['output_tbl']['formula'])
42
+ intensities = np.array(data['output_tbl']['ms2_inten'])
43
+
44
+ # sort by mzs
45
+ ind = mzs.argsort()
46
+ mzs = mzs[ind]
47
+ formulas = formulas[ind]
48
+ intensities = intensities[ind]
49
+ return {'formulas': formulas, 'formula_mzs': mzs, 'formula_intensities': intensities}
50
+ except:
51
+ return None
52
+
53
+ def load_subformula_dict(self, spec_id: str):
54
+ '''MIST subformula format:https://github.com/samgoldman97/mist/blob/main_v2/src/mist/utils/spectra_utils.py
55
+ '''
56
+ try:
57
+ file = os.path.join(self.dir_path, spec_id+".json")
58
+ with open(file) as f:
59
+ data = json.load(f)
60
+ mzs = np.array(data['output_tbl']['mz'])
61
+ formulas = np.array(data['output_tbl']['formula'])
62
+ intensities = np.array(data['output_tbl']['ms2_inten'])
63
+
64
+ mz_to_formulas = {mz:f for mz, f in zip(mzs, formulas)}
65
+ for mz, f in zip(mzs, formulas):
66
+ mz_to_formulas[mz] = f
67
+
68
+ ind = mzs.argsort()
69
+ mzs = mzs[ind]
70
+ formulas = formulas[ind]
71
+ intensities = intensities[ind]
72
+ return {'formulas': mz_to_formulas, 'formula_mzs': mzs, 'formula_intensities': intensities}
73
+ except:
74
+ return None
75
+
76
+ def make_tmp_subformula_spectra(row):
77
+ return {'formulas':[row['formula']], 'formula_mzs':[float(row['precursor_mz'])], 'formula_intensities':[1.0]}
78
+
79
+ def get_spec_featurizer(spectra_view: T.Union[str, list[str]],
80
+ params) -> T.Union[SpecTransform, T.Dict[str, SpecTransform]]:
81
+
82
+ featurizers = {"BinnedSpectra": SpecBinner,
83
+ "SpecBinnerLog": SpecBinnerLog,
84
+ "SpecFormula": SpecFormulaFeaturizer,
85
+ "SpecFormulaMz": SpecFormulaMzFeaturizer,
86
+ 'SpecMzIntTokens': SpecMzIntTokenizer}
87
+
88
+ spectra_featurizer = {}
89
+
90
+ if isinstance(spectra_view, str):
91
+ spectra_view = [spectra_view]
92
+
93
+ for view in spectra_view:
94
+ featurizer_params = {'max_mz': params['max_mz']}
95
+ if view in ["BinnedSpectra", "SpecBinnerLog"]:
96
+ featurizer_params.update({'bin_width': params['bin_width']})
97
+ elif view in ["SpecFormula", "SpecFormulaMz"]:
98
+ featurizer_params.update({'element_list': params['element_list'], 'add_intensities': params['add_intensities'], 'formula_normalize_vector': MSGYM_FORMULA_VECTOR_NORM})
99
+
100
+ if view in ("SpecFormulaMz", 'SpecMzIntTokens'):
101
+ featurizer_params.update({'mz_mean_std': MSGYM_STANDARD_MH, 'mask_precursor': params['mask_precursor']})
102
+ # featurizer_params.update({'mask_precursor': params['mask_precursor']})
103
+
104
+ spectra_featurizer[view] = featurizers[view](**featurizer_params)
105
+
106
+ return spectra_featurizer
107
+
108
+ def get_mol_featurizer(molecule_view: T.Union[str, T.List[str]], params) -> MolTransform:
109
+ featurizes = {'MolGraph':MolToGraph}
110
+ mol_featurizer = {}
111
+
112
+ if isinstance(molecule_view, str):
113
+ molecule_view = [molecule_view]
114
+ for view in molecule_view:
115
+ featurizer_params = {}
116
+ if view in ('MolGraph'):
117
+ featurizer_params.update({'atom_feature': params['atom_feature'], 'bond_feature': params['bond_feature'], 'element_list': params['element_list']})
118
+
119
+ if len(molecule_view) == 1:
120
+ return featurizes[view](**featurizer_params)
121
+
122
+ mol_featurizer[view] = featurizes[view](**featurizer_params)
123
+
124
+ return mol_featurizer
125
+
126
+ def get_test_ms_dataset(spectra_view: T.Union[str, T.List[str]],
127
+ mol_view: T.Union[str, T.List[str]],
128
+ spectra_featurizer: SpecTransform,
129
+ mol_featurizer: MolTransform,
130
+ params):
131
+
132
+ use_formulas = False
133
+
134
+ views = []
135
+ for v in [spectra_view, mol_view]:
136
+ if isinstance(v, str):
137
+ views.append(v)
138
+ else: views.extend(v)
139
+ views = frozenset(views)
140
+
141
+ dataset_params = {'spectra_view': spectra_view, 'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, "candidates_pth": params['candidates_pth']}
142
+ if "SpecFormula" in views or "SpecFormulaMz" in views:
143
+ dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth']})
144
+ use_formulas = True
145
+
146
+ if params['use_cons_spec']:
147
+ dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
148
+ if 'use_NL_spec' in params and params['use_NL_spec']:
149
+ dataset_params.update({'NL_spec_dir_pth': params['NL_spec_dir_pth']})
150
+ if params['pred_fp'] or params['use_fp']:
151
+ dataset_params.update({'fp_dir_pth': '', 'fp_size': params['fp_size'], 'fp_radius': params['fp_radius']})
152
+
153
+ return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, **dataset_params)
154
+
155
+ def get_ms_dataset(spectra_view: str,
156
+ mol_view: str,
157
+ spectra_featurizer: SpecTransform,
158
+ mol_featurizer: MolTransform,
159
+ params):
160
+
161
+
162
+ # set up dataset_parameters
163
+ dataset_params = {'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, 'spectra_view': spectra_view}
164
+ use_formulas = False
165
+ if "SpecFormula" in spectra_view:
166
+ dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth']})
167
+ use_formulas = True
168
+
169
+ if params['pred_fp'] or params['use_fp']:
170
+ dataset_params.update({'fp_dir_pth': params['fp_dir_pth']})
171
+
172
+ if params['aug_cands']:
173
+ dataset_params.update({'aug_cands_dir_pth': params['aug_cands_dir_pth'],
174
+ 'use_formulas':use_formulas,
175
+ "aug_cands_size": params['aug_cands_size']})
176
+
177
+ if params['use_cons_spec']:
178
+ dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
179
+
180
+ if 'use_NL_spec' in params and params['use_NL_spec']:
181
+ dataset_params.update({'NL_spec_dir_pth': params['NL_spec_dir_pth']})
182
+
183
+ # select dataset
184
+ if params['aug_cands']:
185
+ return jestr_datasets.MassSpecDataset_Candidates(**dataset_params)
186
+ elif use_formulas:
187
+ return jestr_datasets.MassSpecDataset_PeakFormulas(**dataset_params)
188
+
189
+ return jestr_datasets.JESTR1_MassSpecDataset(**dataset_params)
190
+
191
+ class PrepMatchMS:
192
+ def __init__(self, spectra_view) -> None:
193
+
194
+ if spectra_view == 'SpecFormula':
195
+ self.prepare = self.specFormula
196
+ elif spectra_view == "SpecFormulaMz":
197
+ self.prepare = self.specFormulaMz
198
+ elif spectra_view in ('SpecBinnerLog', 'BinnedSpectra', 'SpecMzIntTokenizer'):
199
+ self.prepare = self.specMzInt
200
+ else:
201
+ raise Exception("Spectra view is not supported.")
202
+
203
+ def specFormulaMz(self, row):
204
+
205
+ return matchms.Spectrum(
206
+ mz = np.array([float(m) for m in row["mzs"].split(",")]),
207
+ intensities = np.array(
208
+ [float(i) for i in row["intensities"].split(",")]
209
+ ),
210
+ metadata = {'precursor_mz': row['precursor_mz'], 'formulas': row['formulas']}
211
+ )
212
+
213
+ def specFormula(self, row):
214
+
215
+ return matchms.Spectrum(
216
+ mz = np.array(row['formula_mzs']),
217
+ intensities = np.array(row['formula_intensities']),
218
+ metadata = {'precursor_mz': row['precursor_mz'], 'formulas': np.array(row['formulas']), 'precursor_formula': row['precursor_formula']}
219
+ )
220
+
221
+ def specMzInt(self, row):
222
+ return matchms.Spectrum(
223
+ mz = row['mzs'],
224
+ intensities = row['intensities'],
225
+ metadata = {'precursor_mz': row['precursor_mz']}
226
+ )
mvp/utils/debug.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def nan_hook(self,inp, output):
4
+
5
+ nan_mask = torch.isnan(output)
6
+
7
+ if nan_mask.any():
8
+
9
+ print("In", self.__class__.__name__)
10
+
11
+ raise RuntimeError(f"Found NAN in output at indices: ", nan_mask.nonzero())
12
+
13
+ inf_mask = torch.isinf(output)
14
+
15
+ if inf_mask.any():
16
+
17
+ print("In", self.__class__.__name__)
18
+
19
+ raise RuntimeError(f"Found INF in output at indices: ", inf_mask.nonzero())
mvp/utils/eval.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from MassSpecGym.massspecgym.utils import MyopicMCES
2
+ import numpy as np
3
+ import tqdm
4
+ from multiprocessing import Pool
5
+
6
+ import os
7
+ import pandas as pd
8
+
9
+ class Compute_Myopic_MCES:
10
+ mces_compute = MyopicMCES()
11
+
12
+
13
+ def compute_mces(tar_cand):
14
+ target, cand = tar_cand
15
+
16
+ dist = Compute_Myopic_MCES.mces_compute(target, cand)
17
+ return (tar_cand, dist)
18
+
19
+ def compute_mces_parallel(target_cand_list, n_processes=25):
20
+
21
+
22
+ with Pool(processes=n_processes) as pool:
23
+ results = list(tqdm.tqdm(pool.imap(Compute_Myopic_MCES.compute_mces, target_cand_list), total=len(target_cand_list)))
24
+ return results
25
+
26
+ class Compute_Myopic_MCES_timeout:
27
+ mces_compute = MyopicMCES()
28
+
29
+ @staticmethod
30
+ def compute_mces(tar_cand):
31
+ target, cand = tar_cand
32
+ dist = Compute_Myopic_MCES.mces_compute(target, cand)
33
+ return (tar_cand, dist)
34
+
35
+ @staticmethod
36
+ def compute_mces_parallel(target_cand_list, n_processes=35, timeout=60): # timeout in seconds
37
+ results = []
38
+
39
+ with Pool(processes=n_processes) as pool:
40
+ async_results = [
41
+ pool.apply_async(Compute_Myopic_MCES.compute_mces, args=(tar_cand,))
42
+ for tar_cand in target_cand_list
43
+ ]
44
+ for async_res in tqdm.tqdm(async_results, total=len(target_cand_list)):
45
+ try:
46
+ result = async_res.get(timeout=timeout)
47
+ except Exception as e:
48
+ # You can log the error or return a default value
49
+ result = (None, f"Timeout or error")
50
+ results.append(result)
51
+
52
+ return results
53
+
54
+
55
+ def get_result_files(exp_dir, spec_type, views_type):
56
+ files = os.listdir(exp_dir)
57
+ mass_result = ''
58
+ form_result = ''
59
+
60
+ for f in files:
61
+ try:
62
+ _, s, views = f.split('_')
63
+ except:
64
+ continue
65
+
66
+ if s == spec_type and views == views_type:
67
+ print(exp_dir / f)
68
+
69
+ files = os.listdir(exp_dir / f)
70
+ for fr in files:
71
+ if 'mass_result' in fr:
72
+ mass_result = exp_dir / f / fr
73
+ elif 'result' in fr:
74
+ form_result = exp_dir / f/ fr
75
+
76
+ return mass_result, form_result
77
+
78
+ # get target
79
+ def get_target(candidates, labels):
80
+ return np.array(candidates)[labels][0]
81
+
82
+ # get mol rank at 1
83
+ def get_top_cand(candidates, scores):
84
+ return candidates[np.argmax(scores)]
85
+
86
+ # split into hit rates
87
+ def convert_rank_to_hit_rates(row, rank_col ,top_k=[1,5,20]):
88
+ top_k_hits ={}
89
+ rank = row[rank_col]
90
+ for k in top_k:
91
+ if rank <= k:
92
+ top_k_hits[f'{rank_col}-hit_rate@{k}'] = 1
93
+ else:
94
+ top_k_hits[f'{rank_col}-hit_rate@{k}'] = 0
95
+ return pd.Series(top_k_hits)
96
+
97
+ #################### Rank aggregation #######################
98
+ from collections import defaultdict
99
+ import numpy as np
100
+ from scipy.stats import rankdata
101
+
102
+ def borda_count(candidates, score_lists, target):
103
+ scores = defaultdict(int)
104
+ N = len(candidates)
105
+ for score_list in score_lists:
106
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
107
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
108
+ scores[mol] += N - rank + 1
109
+ ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
110
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
111
+
112
+ def average_rank(candidates, score_lists, target):
113
+ rank_sums = defaultdict(list)
114
+ for score_list in score_lists:
115
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
116
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
117
+ rank_sums[mol].append(rank)
118
+ avg_ranks = {mol: np.mean(ranks) for mol, ranks in rank_sums.items()}
119
+ ranked_candidates = [mol for mol, _ in sorted(avg_ranks.items(), key=lambda x: x[1])]
120
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
121
+
122
+ def reciprocal_rank_aggregation(candidates, score_lists, target):
123
+ scores = defaultdict(float)
124
+ for score_list in score_lists:
125
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
126
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
127
+ scores[mol] += 1 / rank
128
+ ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
129
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
130
+
131
+ def weighted_voting(candidates, score_lists, weights, target):
132
+ scores = defaultdict(float)
133
+ for weight, score_list in zip(weights, score_lists):
134
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
135
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
136
+ scores[mol] += weight / rank
137
+ ranked_candidates = [mol for mol, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
138
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
139
+
140
+ def median_rank(candidates, score_lists, target):
141
+ rank_sums = defaultdict(list)
142
+ for score_list in score_lists:
143
+ ranked_list = sorted(zip(candidates, score_list), key=lambda x: x[1], reverse=True)
144
+ for rank, (mol, _) in enumerate(ranked_list, start=1):
145
+ rank_sums[mol].append(rank)
146
+ median_ranks = {mol: np.median(ranks) for mol, ranks in rank_sums.items()}
147
+ ranked_candidates = [mol for mol, _ in sorted(median_ranks.items(), key=lambda x: x[1])]
148
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
149
+
150
+ def score_based_aggregation(candidates, score_lists, target):
151
+ scores = defaultdict(list)
152
+ for score_list in score_lists:
153
+ for mol, score in zip(candidates, score_list):
154
+ scores[mol].append(score)
155
+ avg_scores = {mol: np.mean(vals) for mol, vals in scores.items()}
156
+ ranked_candidates = [mol for mol, _ in sorted(avg_scores.items(), key=lambda x: x[1], reverse=True)]
157
+ return ranked_candidates.index(target) + 1 if target in ranked_candidates else None
mvp/utils/general.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ def pad_graph_nodes(mol_enc, g_n_nodes):
6
+ """
7
+ Args:
8
+ mol_enc: 2D tensor of shape (sum_nodes, D)
9
+ Node embeddings for each molecule.
10
+ g_n_nodes: list[int] Number of nodes per graph (len = B)
11
+
12
+ Returns:
13
+ padded: (B, max_nodes, D) tensor
14
+ mask: (B, max_nodes) bool tensor, True for valid nodes
15
+ """
16
+
17
+ # Already concatenated: shape (sum_nodes, D)
18
+ B = len(g_n_nodes)
19
+ D = mol_enc.shape[1]
20
+ max_nodes = max(g_n_nodes)
21
+ padded = mol_enc.new_zeros((B, max_nodes, D))
22
+ mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
23
+
24
+ idx = 0
25
+ for i, n in enumerate(g_n_nodes):
26
+ padded[i, :n] = mol_enc[idx:idx+n]
27
+ mask[i, :n] = True
28
+ idx += n
29
+ return padded, mask
30
+
31
+
32
+ def filip_similarity_batch(image_tokens, text_tokens, mask_image, mask_text):
33
+ """
34
+ Compute FILIP similarity for batches of image and text token embeddings.
35
+
36
+ Args:
37
+ image_tokens: (B, N_img, D) float tensor
38
+ text_tokens: (B, N_text, D) float tensor
39
+ mask_image: (B, N_img) bool tensor
40
+ mask_text: (B, N_text) bool tensor
41
+
42
+ Returns:
43
+ similarities: (B,) float tensor of similarity scores
44
+ """
45
+ B, N_img, D = image_tokens.shape
46
+ N_text = text_tokens.shape[1]
47
+
48
+ # Normalize tokens
49
+ image_norm = F.normalize(image_tokens, p=2, dim=-1) # (B, N_img, D)
50
+ text_norm = F.normalize(text_tokens, p=2, dim=-1) # (B, N_text, D)
51
+
52
+ # Compute batched cosine similarity matrices
53
+ # Result shape: (B, N_img, N_text)
54
+ sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
55
+
56
+ # Expand masks for broadcasting
57
+ mask_image_exp = mask_image.unsqueeze(2) # (B, N_img, 1)
58
+ mask_text_exp = mask_text.unsqueeze(1) # (B, 1, N_text)
59
+ valid_mask = mask_image_exp & mask_text_exp # (B, N_img, N_text)
60
+
61
+ # Mask invalid positions by setting them to -inf
62
+ sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf'))
63
+
64
+ # Max over text tokens per image token: (B, N_img)
65
+ max_sim_img, _ = sim_matrix_masked.max(dim=2)
66
+
67
+ # Max over image tokens per text token: (B, N_text)
68
+ max_sim_text, _ = sim_matrix_masked.max(dim=1)
69
+
70
+ # Replace -inf (no valid tokens) with zeros to avoid NaNs
71
+ max_sim_img[max_sim_img == float('-inf')] = 0
72
+ max_sim_text[max_sim_text == float('-inf')] = 0
73
+
74
+ # Sum over valid tokens and divide by number of valid tokens (avoid division by zero)
75
+ sum_img = (max_sim_img * mask_image).sum(dim=1)
76
+ count_img = mask_image.sum(dim=1).clamp(min=1).float()
77
+
78
+ sum_text = (max_sim_text * mask_text).sum(dim=1)
79
+ count_text = mask_text.sum(dim=1).clamp(min=1).float()
80
+
81
+ avg_img = sum_img / count_img
82
+ avg_text = sum_text / count_text
83
+
84
+ # Final similarity per batch element
85
+ similarity = (avg_img + avg_text) / 2
86
+
87
+ return similarity
mvp/utils/loss.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
6
+ v1_norm = torch.norm(v1, dim=1, keepdim=True)
7
+ v2_norm = torch.norm(v2, dim=1, keepdim=True)
8
+
9
+ v2T = torch.transpose(v2, 0, 1)
10
+
11
+ inner_prod = torch.matmul(v1, v2T)
12
+
13
+ v2_normT = torch.transpose(v2_norm, 0, 1)
14
+
15
+ norm_mat = torch.matmul(v1_norm, v2_normT)
16
+
17
+ loss_mat = torch.div(inner_prod, norm_mat)
18
+
19
+ loss_mat = loss_mat * (1/tau)
20
+
21
+ loss_mat = torch.exp(loss_mat)
22
+
23
+ numerator = torch.diagonal(loss_mat)
24
+ numerator = torch.unsqueeze(numerator, 0)
25
+
26
+ Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True)
27
+ Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1)
28
+ #Lv1_v2_denom = Lv1_v2_denom - numerator
29
+
30
+ Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True)
31
+ #Lv2_v1_denom = Lv2_v1_denom - numerator
32
+
33
+ Lv1_v2 = torch.div(numerator, Lv1_v2_denom)
34
+
35
+ Lv1_v2 = -1 * torch.log(Lv1_v2)
36
+ Lv1_v2 = torch.mean(Lv1_v2)
37
+
38
+ Lv2_v1 = torch.div(numerator, Lv2_v1_denom)
39
+
40
+ Lv2_v1 = -1 * torch.log(Lv2_v1)
41
+ Lv2_v1 = torch.mean(Lv2_v1)
42
+
43
+ return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
44
+
45
+ def cand_spec_sim_loss(spec_enc, cand_enc):
46
+ cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d
47
+ spec_enc = spec_enc.unsqueeze(0) # 1 x B x d
48
+
49
+ sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2)
50
+ loss = torch.mean(sim)
51
+
52
+ return loss
53
+
54
+ class cons_spec_loss:
55
+ def __init__(self, loss_type) -> None:
56
+ self.loss_compute = {'cosine': self.cos_loss,
57
+ 'l2':torch.nn.MSELoss()}[loss_type]
58
+ def __call__(self,cons_spec, ind_spec):
59
+ return self.loss_compute(cons_spec, ind_spec)
60
+
61
+ def cos_loss(self, cons_spec, ind_spec):
62
+ sim = nn.functional.cosine_similarity(cons_spec, ind_spec)
63
+ loss = 1-torch.mean(sim)
64
+ return loss
65
+
66
+ class fp_loss:
67
+ def __init__(self, loss_type) -> None:
68
+ self.loss_compute = {'cosine': self.fp_loss_cos,
69
+ 'bce': nn.BCELoss()}[loss_type]
70
+
71
+ def __call__(self, predicted_fp, target_fp):
72
+ return self.loss_compute(predicted_fp, target_fp)
73
+
74
+ def fp_loss_cos(self, predicted_fp, target_fp):
75
+ sim = nn.functional.cosine_similarity(predicted_fp, target_fp)
76
+ return 1 - torch.mean(sim)
77
+
78
+
79
+ import torch
80
+ import torch.nn.functional as F
81
+ import torch.distributed as dist
82
+
83
+ # ---------- Utility ----------
84
+ def _safe_divide(num, denom, eps=1e-8):
85
+ return num / (denom + eps)
86
+
87
+
88
+ # ---------- Single-GPU masked FILIP ----------
89
+ def filip_loss_with_mask(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07):
90
+ """
91
+ Single-GPU FILIP loss for modality A (spectra peaks) and modality B (graph nodes),
92
+ accounting for padding masks.
93
+
94
+ Args:
95
+ a_tokens: (B, N_a, D) float tensor (will be normalized to unit vectors)
96
+ b_tokens: (B, N_b, D)
97
+ mask_a: (B, N_a) bool or byte tensor (True=valid)
98
+ mask_b: (B, N_b) bool or byte tensor
99
+ temperature: scalar or 0-dim tensor (learnable ok)
100
+
101
+ Returns:
102
+ scalar loss
103
+ """
104
+ device = a_tokens.device
105
+ B, N_a, D = a_tokens.shape
106
+ N_b = b_tokens.shape[1]
107
+
108
+ # normalize to cos sim
109
+ a = F.normalize(a_tokens, dim=-1)
110
+ b = F.normalize(b_tokens, dim=-1)
111
+
112
+ # Expand to compute all pairwise (batch-wise) similarities:
113
+ # sim shape: (B, B, N_a, N_b) where sim[i,j,k,l] = dot(a[i,k], b[j,l])
114
+ a_exp = a.unsqueeze(1).expand(-1, B, -1, -1) # (B, B, N_a, D)
115
+ b_exp = b.unsqueeze(0).expand(B, -1, -1, -1) # (B, B, N_b, D)
116
+ sim = torch.einsum('bijd,bitd->bijt', a_exp, b_exp) # (B, B, N_a, N_b)
117
+
118
+ # Expand masks to (B, B, N_a) and (B, B, N_b)
119
+ mask_a_exp = mask_a.unsqueeze(1).expand(-1, B, -1) # (B, B, N_a)
120
+ mask_b_exp = mask_b.unsqueeze(0).expand(B, -1, -1) # (B, B, N_b)
121
+
122
+ # ---- A -> B similarity (s_a2b) ----
123
+ # For every a-token we need max over valid b-tokens.
124
+ # Set invalid positions in sim to -inf before max.
125
+ sim_a2b = sim.clone()
126
+ invalid_b = ~mask_b_exp.unsqueeze(2).expand(-1, -1, sim_a2b.size(2), -1) # (B, B, N_a, N_b)
127
+ sim_a2b[invalid_b] = float('-inf')
128
+
129
+ # max over b tokens -> (B, B, N_a)
130
+ max_over_b = sim_a2b.max(dim=3).values
131
+
132
+ # zero-out padded a-tokens then average over valid tokens
133
+ max_over_b = max_over_b * mask_a_exp # padded a tokens get zero
134
+ denom_a = mask_a_exp.sum(dim=2).clamp(min=1).to(sim.dtype) # (B, B)
135
+ s_a2b = max_over_b.sum(dim=2) / denom_a # (B, B)
136
+
137
+ # ---- B -> A similarity (s_b2a) ----
138
+ sim_b2a = sim.clone()
139
+ invalid_a = ~mask_a_exp.unsqueeze(3).expand(-1,-1,-1,sim_b2a.size(3)) # (B, B, N_a, N_b)
140
+ sim_b2a[invalid_a] = float('-inf')
141
+
142
+ max_over_a = sim_b2a.max(dim=2).values # (B, B, N_b)
143
+ max_over_a = max_over_a * mask_b_exp
144
+ denom_b = mask_b_exp.sum(dim=2).clamp(min=1).to(sim.dtype)
145
+ s_b2a = max_over_a.sum(dim=2) / denom_b # (B, B)
146
+
147
+ # logits and loss
148
+ logits_a2b = s_a2b / temperature
149
+ logits_b2a = s_b2a / temperature
150
+
151
+ labels = torch.arange(B, device=device, dtype=torch.long)
152
+ loss_a2b = F.cross_entropy(logits_a2b, labels)
153
+ loss_b2a = F.cross_entropy(logits_b2a, labels)
154
+
155
+ return 0.5 * (loss_a2b + loss_b2a)
156
+
mvp/utils/models.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mvp.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
2
+ from mvp.models.mol_encoder import MolEnc
3
+ from mvp.models.encoders import MLP
4
+ from mvp.models.contrastive import ContrastiveModel, CrossAttenContrastive, IndSpecEncoder, MultiViewContrastive, MultiViewFineTuning, FilipContrastive
5
+
6
+ def get_spec_encoder(spec_enc:str, args):
7
+ return {"MLP_BIN": SpecEncMLP_BIN,
8
+ "MLP_Formula":SpecFormulaEncMLP,
9
+ "Transformer_Formula": SpecFormulaTransformer,
10
+ "Formula_BinnedSpec": SpecFormula_mz_Encoder,
11
+ "Transformer_MzInt": SpecMzIntTokenTransformer}[spec_enc](args)
12
+
13
+ def get_mol_encoder(mol_enc: str, args):
14
+ return {'GNN': MolEnc}[mol_enc](args, in_dim=78)
15
+
16
+ def get_fp_pred_model(args):
17
+ return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout)
18
+
19
+ def get_fp_enc_model(args):
20
+ return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0)
21
+
22
+ def get_model(model:str,
23
+ params):
24
+
25
+ if model == 'contrastive':
26
+ model= ContrastiveModel(**params)
27
+ elif model =='crossAttenContrastive':
28
+ model = CrossAttenContrastive(**params)
29
+ elif model == 'IndSpecEncoder':
30
+ params['pred_fp'] = False
31
+ params['use_cons_spec'] = False
32
+ model = IndSpecEncoder(**params)
33
+ elif model == "MultiviewContrastive":
34
+ model = MultiViewContrastive(**params)
35
+ elif model == "MultiViewFineTuning":
36
+ model = MultiViewFineTuning(**params)
37
+ elif model == "filipContrastive":
38
+ model = FilipContrastive(**params)
39
+ else:
40
+ raise Exception(f"Model {model} not implemented.")
41
+
42
+ # If checkpoint path is provided, load the model from the checkpoint instead
43
+ if params['checkpoint_pth'] is not None and params['checkpoint_pth'] != "":
44
+ model = type(model).load_from_checkpoint(
45
+ params['checkpoint_pth'],
46
+ log_only_loss_at_stages=params['log_only_loss_at_stages'],
47
+ df_test_path=params['df_test_path']
48
+ )
49
+ print("Loaded Model from checkpoint")
50
+
51
+ return model
mvp/utils/preprocessing.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pickle
3
+ import numpy as np
4
+ import mvp.utils.data as data_utils
5
+ import collections
6
+ import os
7
+ import requests
8
+ import tqdm
9
+ from multiprocessing import Pool
10
+ from urllib.parse import quote
11
+ from tqdm import tqdm
12
+
13
+ class NPClassProcess:
14
+ def process_smiles(smiles):
15
+ try:
16
+ encoded_smiles = quote(smiles)
17
+ url = f"https://npclassifier.gnps2.org/classify?smiles={encoded_smiles}"
18
+ r = requests.get(url)
19
+ return (smiles, r.json())
20
+ except:
21
+ return (smiles, None)
22
+
23
+ def NPclass_from_smiles(pth, output_dir, n_processes=20):
24
+
25
+ data = pd.read_csv(pth, sep='\t')
26
+ unique_smiles = data['smiles'].unique().tolist()
27
+
28
+ items = unique_smiles
29
+
30
+ with Pool(processes=n_processes) as pool:
31
+ results = list(tqdm(pool.imap(NPClassProcess.process_smiles, items), total=len(items)))
32
+
33
+ failed_ct = 0
34
+ smiles_to_class = {}
35
+ for s, out in results:
36
+ if out is None:
37
+ smiles_to_class[s] = 'NA'
38
+ failed_ct+=1
39
+ else:
40
+ smiles_to_class[s] = out
41
+ file_pth = os.path.join(output_dir, 'SMILES_TO_CLASS.pkl')
42
+ with open(file_pth, 'wb') as f:
43
+ pickle.dump(smiles_to_class, f)
44
+ print(f'Failed to process {failed_ct} SMILES')
45
+ print(f'result file saved to {file_pth}')
46
+ return file_pth
47
+
48
+
49
+
50
+ def construct_NL_spec(pth, output_dir):
51
+ def _get_spec(row):
52
+ mzs = np.array([float(m) for m in row["mzs"].split(",")], dtype=np.float32)
53
+ intensities = np.array([float(i) for i in row["intensities"].split(",")],dtype=np.float32)
54
+ mzs = float(row['precursor_mz']) - mzs
55
+ valid_idx = np.where(mzs>1.0)
56
+ mzs = mzs[valid_idx]
57
+ intensities = intensities[valid_idx]
58
+
59
+ sorted_idx = np.argsort(mzs)
60
+ mzs = np.concatenate((mzs[sorted_idx], [float(row['precursor_mz'])]))
61
+ intensities = np.concatenate((intensities[sorted_idx], [1.0]))
62
+
63
+ return mzs, intensities
64
+
65
+ spec_data = pd.read_csv(pth, sep='\t')
66
+ spec_data[['mzs', 'intensities']] = spec_data.apply(lambda row: _get_spec(row), axis=1, result_type='expand')
67
+
68
+ file_pth = os.path.join(output_dir, 'NL_spec.pkl')
69
+ with open(file_pth, 'wb') as f:
70
+ pickle.dump(spec_data, f)
71
+ return file_pth
72
+
73
+ def generate_cons_spec(pth, output_dir):
74
+ spec_data = pd.read_csv(pth, sep='\t')
75
+ data_by_smiles = spec_data[['identifier', 'smiles', 'mzs', 'intensities', 'fold']].groupby('smiles').agg({'identifier':list, 'mzs':lambda x: ','.join(x), 'intensities': lambda x: ','.join(x), 'fold':list})
76
+ smiles_to_fold = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['fold'].tolist()))
77
+
78
+ consensus_spectra = {}
79
+ for idx, row in tqdm(data_by_smiles.iterrows(), total=len(data_by_smiles)):
80
+ mzs = np.array([float(m) for m in row["mzs"].split(",")], dtype=np.float32)
81
+ intensities = np.array([float(i) for i in row["intensities"].split(",")],dtype=np.float32)
82
+
83
+ sorted_idx = np.argsort(mzs)
84
+ mzs = mzs[sorted_idx]
85
+ intensities = intensities[sorted_idx]
86
+ smiles = row.name
87
+
88
+ consensus_spectra[smiles] = {'mzs':mzs, 'intensities':intensities,'precursor_mz': 10000.0,
89
+ 'fold': smiles_to_fold[smiles][0]}
90
+
91
+ df = pd.DataFrame.from_dict(consensus_spectra, orient='index')
92
+ df = df.rename_axis('smiles').reset_index()
93
+
94
+ return df
95
+
96
+
97
+ def generate_cons_spec_formulas(pth, subformula_dir, output_dir=''):
98
+ # load tsv file
99
+ spec_data = pd.read_csv(pth, sep='\t')
100
+
101
+ # goup spectra by SMILES
102
+ data_by_smiles = spec_data[['identifier', 'smiles', 'fold', 'precursor_mz', 'formula', 'adduct']].groupby('smiles').agg({'identifier':list, 'fold': list, 'formula': list, 'precursor_mz': "max", 'adduct': list})
103
+ smiles_to_id = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['identifier'].tolist()))
104
+ smiles_to_fold = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['fold'].tolist()))
105
+ smiles_to_precursorMz = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['precursor_mz'].tolist()))
106
+ smiles_to_precursorFormula = dict(zip(data_by_smiles.index.tolist(), data_by_smiles['formula'].tolist()))
107
+ # load subformulas
108
+ subformulaLoader = data_utils.Subformula_Loader(spectra_view='SpecFormula', dir_path=subformula_dir)
109
+ id_to_spec = subformulaLoader(spec_data['identifier'].tolist())
110
+
111
+ # combine spectra
112
+ consensus_spectra = {}
113
+ for smiles, ids in tqdm(smiles_to_id.items(), total=len(data_by_smiles)):
114
+ cons_spec = collections.defaultdict(list)
115
+ for id in ids:
116
+ if id in id_to_spec:
117
+ for k, v in id_to_spec[id].items():
118
+ cons_spec[k].extend(v)
119
+ cons_spec = pd.DataFrame(cons_spec)
120
+
121
+ assert(len(set(smiles_to_fold[smiles]))==1)
122
+
123
+ # keep maxed mz and maxed intensity
124
+ try:
125
+ cons_spec = cons_spec.groupby('formulas').agg({'formula_mzs': "max", 'formula_intensities': "max"})
126
+ cons_spec.reset_index(inplace=True)
127
+ except:
128
+ d = {
129
+ 'formulas': [smiles_to_precursorFormula[smiles][0]],
130
+ 'formula_mzs': [smiles_to_precursorMz[smiles]],
131
+ 'formula_intensities': [1.0]
132
+ }
133
+ cons_spec = pd.DataFrame(d)
134
+
135
+ cons_spec = cons_spec.sort_values(by='formula_mzs').reset_index(drop=True)
136
+ cons_spec = {'formulas': cons_spec['formulas'].tolist(),
137
+ 'formula_mzs': cons_spec['formula_mzs'].tolist(),
138
+ 'formula_intensities': cons_spec['formula_intensities'].tolist(),
139
+ 'precursor_mz': smiles_to_precursorMz[smiles],
140
+ 'fold': smiles_to_fold[smiles][0],
141
+ 'precursor_formula': smiles_to_precursorFormula[smiles][0]}# formula without adduct...
142
+
143
+ consensus_spectra[smiles] = cons_spec
144
+
145
+ # save consensus spectra
146
+ df = pd.DataFrame.from_dict(consensus_spectra, orient='index')
147
+ df = df.rename_axis('smiles').reset_index()
148
+
149
+ return df