yzhouchen001 commited on
Commit
7b7a7b6
·
1 Parent(s): 0b51da1

magma runner

Browse files
magma/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MAGMa
2
+
3
+ This code directory is adapted from https://github.com/samgoldman97/mist/tree/nmi_paper_v1/src/mist/magma
4
+
5
+
6
+ MAGMa is an algorithm which takes as input a molecule and provides as output a list of fragment molecules of the parent.
7
+
8
+ In this project, MAGMa is used to label the fragment peaks of spectra datasets
9
+ with chemical formulae and corresponding smiles, to be used as an extra
10
+ training signal for models. The fragmentation code utilized is heavily inspired
11
+ by the [original source code](https://github.com/NLeSC/MAGMa).
12
+
13
+ `run_magma.py` can be run directly and requires the following arguments:
14
+
15
+ - **--spectra-dir**: The directory path containing the SIRIUS program outputs.
16
+ To subset spectra, we use only peaks that have been preserved by SIRIUS as
17
+ an initial cleaning step. The program can be adapted to use other spectra
18
+ input sources.
19
+ - **--output-dir**: The chosen output directory path to save the magma output files
20
+ - **--lowest-penalty-filter**: If flag set, when selecting candidate chemical formulae and smiles to label spectra peaks, only candidates with the lowest penalty score (as assigned by the Magma fragmentation engine) will be selected
21
+ - **--spec-labels**: TSV file containing all the smiles for the spectra being used.
magma/fragmentation.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """fragmentation.py
2
+
3
+ Code snippets taken from the MAGMa github project
4
+
5
+ https://github.com/NLeSC/MAGMa
6
+
7
+ """
8
+
9
+
10
+ import numpy
11
+ from rdkit import Chem
12
+
13
+ typew = {
14
+ Chem.rdchem.BondType.names["AROMATIC"]: 3.0,
15
+ Chem.rdchem.BondType.names["DOUBLE"]: 2.0,
16
+ Chem.rdchem.BondType.names["TRIPLE"]: 3.0,
17
+ Chem.rdchem.BondType.names["SINGLE"]: 1.0,
18
+ }
19
+ heterow = {False: 2, True: 1}
20
+ missingfragmentpenalty = 10
21
+
22
+
23
+ mims = {
24
+ "H": 1.0078250321,
25
+ "C": 12.0000000,
26
+ "N": 14.0030740052,
27
+ "O": 15.9949146221,
28
+ "F": 18.99840320,
29
+ "Na": 22.9897692809,
30
+ "P": 30.97376151,
31
+ "S": 31.97207069,
32
+ "Cl": 34.96885271,
33
+ "K": 38.96370668,
34
+ "Br": 78.9183376,
35
+ "I": 126.904468,
36
+ "Si": 28.0855,
37
+ "B": 10.811,
38
+ "Se": 78.97,
39
+ "Fe": 55.845,
40
+ "Co": 58.933,
41
+ "As": 74.9216
42
+ }
43
+
44
+ # Mass of hydrogen atom
45
+ Hmass = mims["H"]
46
+ elmass = 0.0005486
47
+
48
+ ionmasses = {
49
+ 1: {
50
+ "+H": mims["H"],
51
+ "+NH4": mims["N"] + 4 * mims["H"],
52
+ "+Na": mims["Na"],
53
+ "+K": mims["K"],
54
+ },
55
+ -1: {"-H": -mims["H"], "+Cl": mims["Cl"]},
56
+ }
57
+
58
+
59
+ class FragmentEngine(object):
60
+ def __init__(
61
+ self,
62
+ smiles,
63
+ max_broken_bonds,
64
+ max_water_losses,
65
+ ionisation_mode,
66
+ skip_fragmentation,
67
+ molcharge,
68
+ ):
69
+ try:
70
+ # self.mol = Chem.MolFromMolBlock(str(mol))
71
+ # self.mol = Chem.MolFromSmiles(smiles)
72
+ self.mol = Chem.MolFromSmiles(smiles)
73
+ self.accept = True
74
+ self.natoms = self.mol.GetNumAtoms()
75
+ except:
76
+ self.accept = False
77
+ return
78
+ self.max_broken_bonds = max_broken_bonds
79
+ self.max_water_losses = max_water_losses
80
+ self.ionisation_mode = ionisation_mode
81
+ self.skip_fragmentation = skip_fragmentation
82
+ self.molcharge = molcharge
83
+ self.atom_masses = []
84
+ self.atomHs = []
85
+ self.neutral_loss_atoms = []
86
+ self.bonded_atoms = [] # [[list of atom numbers]]
87
+ self.bonds = set([])
88
+ self.bondscore = {}
89
+ self.new_fragment = 0
90
+ self.template_fragment = 0
91
+ self.fragment_masses = ((max_broken_bonds + max_water_losses) * 2 + 1) * [0]
92
+ self.fragment_info = [[0, 0, 0]]
93
+ self.avg_score = None
94
+
95
+ for x in range(self.natoms):
96
+ self.bonded_atoms.append([])
97
+ atom = self.mol.GetAtomWithIdx(x)
98
+ self.atomHs.append(atom.GetNumImplicitHs() + atom.GetNumExplicitHs())
99
+ self.atom_masses.append(mims[atom.GetSymbol()] + Hmass * (self.atomHs[x]))
100
+ if (
101
+ atom.GetSymbol() == "O"
102
+ and self.atomHs[x] == 1
103
+ and len(atom.GetBonds()) == 1
104
+ ):
105
+ self.neutral_loss_atoms.append(x)
106
+ if (
107
+ atom.GetSymbol() == "N"
108
+ and self.atomHs[x] == 2
109
+ and len(atom.GetBonds()) == 1
110
+ ):
111
+ self.neutral_loss_atoms.append(x)
112
+ for bond in self.mol.GetBonds():
113
+ a1, a2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
114
+ self.bonded_atoms[a1].append(a2)
115
+ self.bonded_atoms[a2].append(a1)
116
+ bondbits = 1 << a1 | 1 << a2
117
+ bondscore = (
118
+ typew[bond.GetBondType()]
119
+ * heterow[
120
+ bond.GetBeginAtom().GetSymbol() != "C"
121
+ or bond.GetEndAtom().GetSymbol() != "C"
122
+ ]
123
+ )
124
+ self.bonds.add(bondbits)
125
+ self.bondscore[bondbits] = bondscore
126
+
127
+ def extend(self, atom):
128
+ for a in self.bonded_atoms[atom]:
129
+ atombit = 1 << a
130
+ if atombit & self.template_fragment and not atombit & self.new_fragment:
131
+ self.new_fragment = self.new_fragment | atombit
132
+ self.extend(a)
133
+
134
+ def generate_fragments(self):
135
+ frag = (1 << self.natoms) - 1
136
+ all_fragments = set([frag])
137
+ total_fragments = set([frag])
138
+ current_fragments = set([frag])
139
+ new_fragments = set([frag])
140
+ self.add_fragment(frag, self.calc_fragment_mass(frag), 0, 0)
141
+
142
+ if self.skip_fragmentation:
143
+ self.convert_fragments_table()
144
+ return len(self.fragment_info)
145
+
146
+ # generate fragments for max_broken_bond steps
147
+ for step in range(self.max_broken_bonds):
148
+ # loop over all fragments to be fragmented
149
+ for fragment in current_fragments:
150
+ # loop over all atoms
151
+ for atom in range(self.natoms):
152
+ # in the fragment
153
+ if (1 << atom) & fragment:
154
+ # remove the atom
155
+ self.template_fragment = fragment ^ (1 << atom)
156
+ list_ext_atoms = set([])
157
+ extended_fragments = set([])
158
+ # find all its neighbor atoms
159
+ for a in self.bonded_atoms[atom]:
160
+ # present in the fragment
161
+ if (1 << a) & self.template_fragment:
162
+ list_ext_atoms.add(a)
163
+ # in case of one bonded atom, the new fragment is the remainder of the old fragment
164
+ if len(list_ext_atoms) == 1:
165
+ extended_fragments.add(self.template_fragment)
166
+ else:
167
+ # otherwise extend each neighbor atom to a complete fragment
168
+ for a in list_ext_atoms:
169
+ # except when deleted atom is in a ring and a previous extended
170
+ # fragment already contains this neighbor atom, then
171
+ # calculate fragment only once
172
+ for frag in extended_fragments:
173
+ if (1 << a) & frag:
174
+ break
175
+ else:
176
+ # extend atom to complete fragment
177
+ self.new_fragment = 1 << a
178
+ self.extend(a)
179
+ extended_fragments.add(self.new_fragment)
180
+ for frag in extended_fragments:
181
+ # add extended fragments, if not yet present, to the collection
182
+ if frag not in all_fragments:
183
+ all_fragments.add(frag)
184
+ bondbreaks, score = self.score_fragment(frag)
185
+ if bondbreaks <= self.max_broken_bonds and score < (
186
+ missingfragmentpenalty + 5
187
+ ):
188
+ new_fragments.add(frag)
189
+ total_fragments.add(frag)
190
+ self.add_fragment(
191
+ frag,
192
+ self.calc_fragment_mass(frag),
193
+ score,
194
+ bondbreaks,
195
+ )
196
+ current_fragments = new_fragments
197
+ new_fragments = set([])
198
+ # number of OH losses
199
+ for step in range(self.max_water_losses):
200
+ # loop of all fragments
201
+ for fi in self.fragment_info:
202
+ # on which to apply neutral loss rules
203
+ if fi[2] == self.max_broken_bonds + step:
204
+ fragment = fi[0]
205
+ # loop over all atoms in the fragment
206
+ for atom in self.neutral_loss_atoms:
207
+ if (1 << atom) & fragment:
208
+ frag = fragment ^ (1 << atom)
209
+ # add extended fragments, if not yet present, to the collection
210
+ if frag not in total_fragments:
211
+ total_fragments.add(frag)
212
+ bondbreaks, score = self.score_fragment(frag)
213
+ if score < (missingfragmentpenalty + 5):
214
+ self.add_fragment(
215
+ frag,
216
+ self.calc_fragment_mass(frag),
217
+ score,
218
+ bondbreaks,
219
+ )
220
+ self.convert_fragments_table()
221
+ return len(self.fragment_info)
222
+
223
+ def score_fragment(self, fragment):
224
+ score = 0
225
+ bondbreaks = 0
226
+ for bond in self.bonds:
227
+ if 0 < (fragment & bond) < bond:
228
+ score += self.bondscore[bond]
229
+ bondbreaks += 1
230
+ if score == 0:
231
+ print("score=0: ", fragment, bondbreaks)
232
+ return bondbreaks, score
233
+
234
+ def score_fragment_rel2parent(self, fragment, parent):
235
+ score = 0
236
+ for bond in self.bonds:
237
+ if 0 < (fragment & bond) < (bond & parent):
238
+ score += self.bondscore[bond]
239
+ return score
240
+
241
+ def calc_fragment_mass(self, fragment):
242
+ fragment_mass = 0.0
243
+ for atom in range(self.natoms):
244
+ if fragment & (1 << atom):
245
+ fragment_mass += self.atom_masses[atom]
246
+ return fragment_mass
247
+
248
+ def add_fragment(self, fragment, fragmentmass, score, bondbreaks):
249
+ mass_range = (
250
+ (self.max_broken_bonds + self.max_water_losses - bondbreaks) * [0]
251
+ + list(
252
+ numpy.arange(
253
+ -bondbreaks + self.ionisation_mode * (1 - self.molcharge),
254
+ bondbreaks + self.ionisation_mode * (1 - self.molcharge) + 1,
255
+ )
256
+ * Hmass
257
+ + fragmentmass
258
+ )
259
+ + (self.max_broken_bonds + self.max_water_losses - bondbreaks) * [0]
260
+ )
261
+ if bondbreaks == 0:
262
+ # make sure that fragmentmass is included
263
+ mass_range[
264
+ self.max_broken_bonds + self.max_water_losses - self.ionisation_mode
265
+ ] = fragmentmass
266
+ self.fragment_masses += mass_range
267
+ self.fragment_info.append([fragment, score, bondbreaks])
268
+
269
+ def convert_fragments_table(self):
270
+ self.fragment_masses_np = numpy.array(self.fragment_masses).reshape(
271
+ len(self.fragment_info),
272
+ (self.max_broken_bonds + self.max_water_losses) * 2 + 1,
273
+ )
274
+
275
+ def calc_avg_score(self):
276
+ self.avg_score = numpy.average(self.scores)
277
+
278
+ def get_avg_score(self):
279
+ return self.avg_score
280
+
281
+ def find_fragments(self, mass, parent, precision, mz_precision_abs):
282
+ result = numpy.where(
283
+ numpy.where(
284
+ self.fragment_masses_np
285
+ < max(mass * precision, mass + mz_precision_abs),
286
+ self.fragment_masses_np,
287
+ 0,
288
+ )
289
+ > min(mass / precision, mass - mz_precision_abs)
290
+ )
291
+ fragment_set = []
292
+ for i in range(len(result[0])):
293
+ fid = result[0][i]
294
+ fragment_set.append(
295
+ self.fragment_info[fid]
296
+ + [
297
+ self.fragment_masses_np[fid][
298
+ self.max_broken_bonds
299
+ + self.max_water_losses
300
+ - self.ionisation_mode * (1 - self.molcharge)
301
+ ]
302
+ ]
303
+ + [
304
+ self.ionisation_mode * (1 - self.molcharge)
305
+ + result[1][i]
306
+ - self.max_broken_bonds
307
+ - self.max_water_losses
308
+ ]
309
+ )
310
+ return fragment_set
311
+
312
+ def get_fragment_info(self, fragment, deltaH):
313
+ atomlist = []
314
+ elements = {
315
+ "C": 0,
316
+ "H": 0,
317
+ "N": 0,
318
+ "O": 0,
319
+ "F": 0,
320
+ "P": 0,
321
+ "S": 0,
322
+ "Cl": 0,
323
+ "Br": 0,
324
+ "I": 0,
325
+ "Si": 0,
326
+ "B": 0,
327
+ "Se": 0,
328
+ "Fe": 0,
329
+ "Co": 0,
330
+ "As": 0
331
+ }
332
+ for atom in range(self.natoms):
333
+ if (1 << atom) & fragment:
334
+ atomlist.append(atom)
335
+ elements[self.mol.GetAtomWithIdx(atom).GetSymbol()] += 1
336
+ elements["H"] += self.atomHs[atom]
337
+ formula = ""
338
+ for el in (
339
+ "C",
340
+ "H",
341
+ "N",
342
+ "O",
343
+ "F",
344
+ "P",
345
+ "S",
346
+ "Cl",
347
+ "Br",
348
+ "I",
349
+ "Si",
350
+ "B",
351
+ "Se",
352
+ "Fe",
353
+ "Co",
354
+ ):
355
+ nel = elements[el]
356
+ if nel > 0:
357
+ formula += el
358
+ if nel > 1:
359
+ formula += str(nel)
360
+ atomstring = ",".join(str(a) for a in atomlist)
361
+ return atomstring, atomlist, formula, fragment2smiles(self.mol, atomlist)
362
+
363
+ def get_natoms(self):
364
+ return self.natoms
365
+
366
+ def accepted(self):
367
+ return self.accept
368
+
369
+
370
+ def fragment2smiles(mol, atomlist):
371
+ emol = Chem.EditableMol(mol)
372
+ for atom in reversed(range(mol.GetNumAtoms())):
373
+ if atom not in atomlist:
374
+ emol.RemoveAtom(atom)
375
+ frag = emol.GetMol()
376
+ return Chem.MolToSmiles(frag)
magma/magma_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ magma_utils.py
2
+
3
+ Additional utility file to assist with fingerprinting.
4
+
5
+ """
6
+
7
+ import os
8
+ from ast import literal_eval
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from rdkit import Chem
14
+ from rdkit.Chem import AllChem, DataStructs
15
+ import re
16
+
17
+ def extract_adduct_ion(adduct, default='+H'):
18
+ pattern = re.compile(r"\[M([+-][^\]]+)\]")
19
+ match = pattern.search(adduct)
20
+ if match:
21
+ return match.group(1)
22
+ return default
23
+
24
+
25
+ def get_magma_fingerprint(smile):
26
+ """ get_magma_fingerprint. """
27
+ mol = Chem.MolFromSmiles(smile, sanitize=False)
28
+ Chem.SanitizeMol(
29
+ mol,
30
+ sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL
31
+ ^ Chem.SanitizeFlags.SANITIZE_KEKULIZE,
32
+ )
33
+ curr_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
34
+
35
+ fingerprint = np.zeros((0,), dtype=np.uint8)
36
+ DataStructs.ConvertToNumpyArray(curr_fp, fingerprint)
37
+ return fingerprint
38
+
39
+
40
+ def get_magma_fingerprint_bits(smile):
41
+ """ get magma fingerprint bits """
42
+ fingerprint = get_magma_fingerprint(smile)
43
+ hot_indices = list(np.where(np.array(list(fingerprint)) == 1)[0])
44
+ return hot_indices
45
+
46
+
47
+ def read_magma_file(magma_frag_file):
48
+ """Read in magma file"""
49
+ if (
50
+ magma_frag_file is not None
51
+ and os.path.exists(magma_frag_file)
52
+ and os.path.getsize(magma_frag_file) > 0
53
+ ):
54
+
55
+ # correct for inconsistency by me in file parsing (sad)
56
+ sep = "\t"
57
+ spectra_df = pd.read_csv(magma_frag_file, index_col=0, sep=sep)
58
+ if (
59
+ "smiles" not in spectra_df.columns
60
+ or "chemical_formula" not in spectra_df.columns
61
+ ):
62
+ pass
63
+ else:
64
+ spectra_df = _convert_str_to_list(spectra_df, "smiles")
65
+ spectra_df = _convert_str_to_list(spectra_df, "chemical_formula")
66
+ if "mass_to_charge" not in spectra_df.columns:
67
+ spectra_df["mass_to_charge"] = spectra_df["mz"]
68
+
69
+ return spectra_df
70
+ spectra_df = pd.DataFrame(
71
+ columns=[
72
+ "mass_to_charge",
73
+ "intensity",
74
+ "chemical_formula",
75
+ "smiles",
76
+ "molecule_peak",
77
+ ]
78
+ )
79
+ return spectra_df
80
+
81
+
82
+ def _convert_str_to_list(df, column):
83
+ """_convert_str_to_list"""
84
+ df.loc[:, column] = df.loc[:, column].apply(
85
+ lambda x: literal_eval(x) if x != "NAN" and not pd.isna(x) else []
86
+ )
87
+ return df
magma/run_magma.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ run_magma.py
2
+
3
+ Accept input processed spectra and make subformula peak assignments
4
+ accordingly.
5
+
6
+ """
7
+ import logging
8
+ from pathlib import Path
9
+ import numpy as np
10
+ import pandas as pd
11
+ import argparse
12
+ import sys
13
+ from multiprocessing import Pool
14
+ from tqdm import tqdm
15
+ from collections import defaultdict
16
+ import json
17
+
18
+ # add parent path
19
+ import os
20
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
21
+
22
+ # Custom import
23
+ from magma.fragmentation import FragmentEngine, ionmasses
24
+ from magma import magma_utils
25
+ from magma.fragmentation import ionmasses
26
+
27
+ # Define basic logger
28
+ logging.basicConfig(
29
+ level=logging.INFO,
30
+ format="%(asctime)s %(levelname)s: %(message)s",
31
+ handlers=[
32
+ logging.StreamHandler(sys.stdout),
33
+ ],
34
+ )
35
+
36
+
37
+ FRAGMENT_ENGINE_PARAMS = {
38
+ "max_broken_bonds": 3,
39
+ "max_water_losses": 1,
40
+ "ionisation_mode": 1,
41
+ "skip_fragmentation": 0,
42
+ "molcharge": 0,
43
+ }
44
+
45
+ PEAK_ASSIGNMENT_PARAMS = {
46
+ 'lowest_penalty_filter': True,
47
+ 'tolerance': 1
48
+ }
49
+
50
+ def get_args():
51
+ """get args"""
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument(
54
+ '--data_pth',
55
+ required=True
56
+ )
57
+ parser.add_argument(
58
+ "--output_dir",
59
+ required=True,
60
+ help="Output directory to save MAGMA files",
61
+ )
62
+ parser.add_argument(
63
+ "--workers", default=30, action="store", type=int, help="Num workers"
64
+ )
65
+ return parser.parse_args()
66
+
67
+
68
+
69
+ def get_matching_fragment(
70
+ fragment_df, mass_comparison_vector, lowest_penalty_filter: bool
71
+ ):
72
+ """get_matching_fragment.
73
+
74
+ Compare frag
75
+
76
+ Args:
77
+ fragment_df
78
+ mass_comparison_vec
79
+ lowest_penalty_filter
80
+ """
81
+ # Step 1 - Determine and filter for fragments whose mass range cover the peak mass
82
+ matched_fragments_df = fragment_df[mass_comparison_vector]
83
+
84
+ # If no candidate fragments exist, exit function
85
+ if matched_fragments_df.shape[0] == 0:
86
+ return []
87
+
88
+ # Step 2 - If multiple candidate substructures, filter for those with the lowest penalty scores
89
+ if lowest_penalty_filter:
90
+ if matched_fragments_df.shape[0] > 1:
91
+ min_score = matched_fragments_df["score"].min()
92
+ matched_fragments_df = matched_fragments_df[
93
+ matched_fragments_df["score"] == min_score
94
+ ]
95
+
96
+ # Step 3 - Save all remaining candidate fragments
97
+ matched_fragment_idxs = list(matched_fragments_df.index)
98
+
99
+ return matched_fragment_idxs
100
+
101
+
102
+ def get_fragment_mass_range(fragment_engine, fragment_df, tolerance):
103
+ """get_fragment_mass_range.
104
+
105
+ Define min and max masses in the range that are available based upon
106
+ hydrogen diffs.
107
+
108
+ Args:
109
+ fragment_engine: Fragment engine
110
+ fragment_df: fragment_df
111
+ tolerance: Tolerance
112
+
113
+ """
114
+ fragment_masses_np = fragment_engine.fragment_masses_np
115
+
116
+ # Build a list of the min and max mass of each fragment
117
+ fragment_mass_min_max = []
118
+
119
+ for fragment_idx in range(fragment_masses_np.shape[0]):
120
+ fragment_masses = fragment_masses_np[fragment_idx, :]
121
+
122
+ if np.sum(fragment_masses) == 0:
123
+ min_frag_mass = 0
124
+ max_frag_mass = 0
125
+
126
+ else:
127
+ min_frag_mass = (
128
+ fragment_masses[np.nonzero(fragment_masses)[0][0]] - tolerance
129
+ )
130
+ max_frag_mass = max(fragment_masses) + tolerance
131
+
132
+ fragment_mass_min_max.append((min_frag_mass, max_frag_mass))
133
+
134
+ fragment_mass_min_max = np.array(fragment_mass_min_max)
135
+ fragment_df["min_mass"] = fragment_mass_min_max[:, 0]
136
+ fragment_df["max_mass"] = fragment_mass_min_max[:, 1]
137
+
138
+ return fragment_df
139
+
140
+ def run_magma_wrapper(args):
141
+ if os.path.exists(args[-1]): # skip over ones that have been processed
142
+ return
143
+ return run_magma(*args)
144
+
145
+ def run_magma(identifier, mzs, intensities, smiles, adduct, save_filename=''):
146
+ '''YZC
147
+ Run fragmentation, assignment, and save results
148
+
149
+ '''
150
+ # Step 1 - Load fragmentation engine and generate fragments
151
+ (
152
+ max_broken_bonds,
153
+ max_water_losses,
154
+ ionisation_mode,
155
+ skip_fragmentation,
156
+ molcharge,
157
+ ) = FRAGMENT_ENGINE_PARAMS.values()
158
+ try:
159
+ engine = FragmentEngine(
160
+ smiles=smiles,
161
+ max_broken_bonds=max_broken_bonds,
162
+ max_water_losses=max_water_losses,
163
+ ionisation_mode=ionisation_mode,
164
+ skip_fragmentation=skip_fragmentation,
165
+ molcharge=molcharge,
166
+ )
167
+ engine.generate_fragments()
168
+ except Exception as e:
169
+ logging.info(f"Error for spec {identifier}")
170
+ print(e)
171
+ return None
172
+
173
+ # Step 2 - Assign fragments to peaks
174
+ assignment_dict = peak_fragment_assignment(
175
+ engine,
176
+ mzs,
177
+ intensities,
178
+ adduct,
179
+ )
180
+
181
+ # Step 3 - Save assignments
182
+ if save_filename:
183
+ with open(save_filename, 'w') as f:
184
+ json.dump(assignment_dict, f)
185
+ else:
186
+ return assignment_dict
187
+
188
+ def peak_fragment_assignment(fragment_engine, mzs, intensities, adduct):
189
+ ''' returns a df with columns
190
+
191
+ Args:
192
+ fragment_engine: FragmentEngine
193
+ mzs: np array of mz values
194
+ adduct: str eg. [M+H]+ [M+Na]+
195
+
196
+ Returns:
197
+ assignment_df
198
+ '''
199
+
200
+ fragments_info = fragment_engine.fragment_info
201
+
202
+ fragment_df = pd.DataFrame(
203
+ fragment_engine.fragment_info, columns=["id", "score", "bond_breaks"]
204
+ )
205
+ fragment_df = get_fragment_mass_range(fragment_engine, fragment_df, tolerance=PEAK_ASSIGNMENT_PARAMS['tolerance'])
206
+
207
+ # Need to build comparison values here
208
+ min_fragment_mass = fragment_df["min_mass"].values
209
+ max_fragment_mass = fragment_df["max_mass"].values
210
+ adduct = magma_utils.extract_adduct_ion(adduct)
211
+ charge = 1 if adduct.startswith('+') else -1
212
+ exact_masses = mzs + ionmasses[charge][adduct]
213
+
214
+ mass_comparison_matrix = np.logical_and(
215
+ exact_masses[None, :] >= min_fragment_mass[:, None],
216
+ exact_masses[None, :] <= max_fragment_mass[:, None],
217
+ )
218
+
219
+ # Iterate over each peak to find a match
220
+ assignments = defaultdict(list) # {mz, intensity, subformulas, candidates}
221
+ for k, (m, i) in enumerate(zip(mzs, intensities)):
222
+ mass_comparison_vector = mass_comparison_matrix[:, k]
223
+
224
+
225
+ matched_fragment_idxs = get_matching_fragment(
226
+ fragment_df,
227
+ mass_comparison_vector,
228
+ lowest_penalty_filter=PEAK_ASSIGNMENT_PARAMS['lowest_penalty_filter'],
229
+ )
230
+
231
+ # Save selected fragments info
232
+ subformulas = set([])
233
+ substructures = set([])
234
+ for idx in matched_fragment_idxs:
235
+ fragment_info = fragment_engine.get_fragment_info(fragments_info[idx][0], 0)
236
+
237
+ subformulas.add(fragment_info[2])
238
+ substructures.add(fragment_info[3])
239
+
240
+ subformulas = list(subformulas)
241
+ substructures = list(substructures)
242
+
243
+ assignments['mz'].append(m)
244
+ assignments['intensities'].append(i)
245
+ assignments['subformulas'].append(subformulas)
246
+ assignments['substructures'].append(substructures)
247
+ return assignments
248
+
249
+
250
+ if __name__ == "__main__":
251
+ import time
252
+
253
+ start_time = time.time()
254
+ args = get_args()
255
+ kwargs = args.__dict__
256
+
257
+ os.makedirs(args.output_dir, exist_ok=True)
258
+
259
+ df = pd.read_csv(args.data_pth, sep='\t')
260
+ df['save_filename'] = df['identifier'].apply(lambda x: os.path.join(args.output_dir, x + '.json'))
261
+
262
+ df['mzs'] = df['mzs'].apply(lambda x: np.array([float(m) for m in x.split(',')]))
263
+ df['intensities'] = df['intensities'].apply(lambda x: np.array([float(i) for i in x.split(',')]))
264
+
265
+ df = df[['identifier', 'mzs', 'intensities', 'smiles', 'adduct', 'save_filename']]
266
+
267
+
268
+ tasks = list(df.itertuples(index=False, name=None))
269
+
270
+ with Pool(processes=args.workers) as pool:
271
+ results = list(tqdm(pool.imap_unordered(run_magma_wrapper, tasks), total=len(tasks)))
272
+ # pool.starmap(run_magma, tasks)
273
+
274
+ end_time = time.time()
275
+ print(f"Program finished in: {end_time - start_time} seconds")