Anton Bushuiev commited on
Commit
fc34019
·
1 Parent(s): bbf2542

First version, spectral library matching

Browse files
Files changed (1) hide show
  1. app.py +131 -50
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import urllib.request
3
  import os
 
4
  from functools import partial
5
  import matplotlib.pyplot as plt
6
  import matplotlib
@@ -14,15 +15,17 @@ from rdkit.Chem import Draw
14
  from rdkit.Chem.Draw import rdMolDraw2D
15
  import base64
16
  from io import BytesIO
 
 
17
  import dreams.utils.spectra as su
18
- import dreams.utils.io as io
19
  from dreams.utils.spectra import PeakListModifiedCosine
20
  from dreams.utils.data import MSData
21
  from dreams.api import dreams_embeddings
22
  from dreams.definitions import *
23
 
24
 
25
- def smiles_to_html_img(smiles, svg_size=1500):
26
  """
27
  Convert SMILES to HTML image string for display in Gradio dataframe
28
  """
@@ -31,21 +34,38 @@ def smiles_to_html_img(smiles, svg_size=1500):
31
  if mol is None:
32
  return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
33
 
34
- # Create SVG drawing
35
- d2d = rdMolDraw2D.MolDraw2DSVG(svg_size, svg_size)
36
  opts = d2d.drawOptions()
37
  opts.clearBackground = False
 
 
38
  d2d.DrawMolecule(mol)
39
  d2d.FinishDrawing()
40
- svg_str = d2d.GetDrawingText()
41
 
42
- # Convert to base64 for HTML embedding
43
- buffered = BytesIO()
44
- buffered.write(str.encode(svg_str))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  img_str = base64.b64encode(buffered.getvalue())
46
- img_str = f"data:image/svg+xml;base64,{repr(img_str)[2:-1]}"
47
 
48
- return f"<img src='{img_str}' style='width: {svg_size}px; height: {svg_size}px;' title='{smiles}' />"
49
  except Exception as e:
50
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
51
 
@@ -58,19 +78,36 @@ def spectrum_to_html_img(spec1, spec2, img_size=1500):
58
  matplotlib.use('Agg') # Use non-interactive backend
59
 
60
  # Create the plot using the existing function
61
- su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(8, 4))
62
 
63
- # Save the current figure to a buffer
64
  buffered = BytesIO()
65
- plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100)
66
  buffered.seek(0)
67
- img_str = base64.b64encode(buffered.getvalue())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
69
 
70
  # Close the figure to free memory
71
  plt.close()
72
 
73
- return f"<img src='{img_str}' style='width: {img_size}px; height: auto;' title='Spectrum comparison' />"
74
  except Exception as e:
75
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
76
 
@@ -84,6 +121,13 @@ def setup():
84
  if not target_path.exists():
85
  urllib.request.urlretrieve(url, target_path)
86
 
 
 
 
 
 
 
 
87
  # Run simple example as a test and to download weights
88
  example_url = 'https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/data/examples/example_5_spectra.mgf'
89
  example_path = Path('./data/example_5_spectra.mgf')
@@ -94,39 +138,54 @@ def setup():
94
  print("Setup complete")
95
 
96
 
97
- def predict(lib_pth, in_pth):
98
  in_pth = Path(in_pth)
99
  # # in_pth = Path('DreaMS/data/MSV000086206/peak/mzml/S_N1.mzML') # Example dataset
100
 
 
101
  msdata_lib = MSData.load(lib_pth)
102
  embs_lib = msdata_lib[DREAMS_EMBEDDING]
103
  print('Shape of the library embeddings:', embs_lib.shape)
104
 
 
105
  msdata = MSData.load(in_pth)
 
 
106
  embs = dreams_embeddings(msdata)
107
  print('Shape of the query embeddings:', embs.shape)
108
 
 
109
  sims = cosine_similarity(embs, embs_lib)
110
  print('Shape of the similarity matrix:', sims.shape)
111
 
112
- k = 5
113
  topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
114
  topk_cands.shape
115
 
 
 
116
  # Construct a DataFrame with the top-k candidates for each spectrum and their corresponding similarities
 
117
  df = []
118
  cos_sim = su.PeakListModifiedCosine()
119
- for i, topk in enumerate(tqdm(topk_cands)):
 
 
 
120
  for n, j in enumerate(topk):
121
  smiles = msdata_lib.get_smiles(j)
122
  spec1 = msdata.get_spectra(i)
123
  spec2 = msdata_lib.get_spectra(j)
124
  df.append({
125
  'feature_id': i + 1,
 
 
126
  'topk': n + 1,
127
  'library_j': j,
128
  'library_SMILES': smiles_to_html_img(smiles),
 
129
  'Spectrum': spectrum_to_html_img(spec1, spec2),
 
130
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
131
  'DreaMS_similarity': sims[i, j],
132
  'Modified_cosine_similarity': cos_sim(
@@ -137,28 +196,52 @@ def predict(lib_pth, in_pth):
137
  ),
138
  'i': i,
139
  'j': j,
 
140
  })
141
  df = pd.DataFrame(df)
142
 
143
- # # TODO Add some (random) name to the output file
144
- df_path = io.append_to_stem(in_pth, 'MassSpecGym_hits').with_suffix('.csv')
145
- df.to_csv(df_path, index=False)
146
-
147
- # i = df_top1['i'].iloc[25]
148
- # df_i = df[df['i'] == i]
149
- # for _, row in df_i.iterrows():
150
- # i, j = row['i'], row['j']
151
- # print(f'Library ID: {row["library_ID"]} (top {row["topk"]} hit)')
152
- # print(f'Query precursor m/z: {msdata.get_prec_mzs(i)}, Library precursor m/z: {msdata_lib.get_prec_mzs(j)}')
153
- # print('DreaMS similarity:', row['DreaMS_similarity'])
154
- # print('Modified cosine similarity:', row['Modified_cosine_similarity'])
155
- # su.plot_spectrum(spec=msdata.get_spectra(i), mirror_spec=msdata_lib.get_spectra(j))
156
- # display(Chem.MolFromSmiles(row['library_SMILES']))
157
-
158
  # Sort hits by DreaMS similarity
159
  df_top1 = df[df['topk'] == 1].sort_values('DreaMS_similarity', ascending=False)
160
  df = df.set_index('feature_id').loc[df_top1['feature_id'].values].reset_index()
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return df, str(df_path)
163
 
164
 
@@ -171,23 +254,21 @@ with app:
171
  gr.Image("https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png", label="DreaMS")
172
  gr.Markdown(value="""
173
  DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a transformer-based
174
- neural network designed to interpret tandem mass spectrometry (MS/MS) data. Pre-trained in a
175
- self-supervised way on millions of unannotated spectra from our new GeMS (GNPS Experimental
176
- Mass Spectra) dataset, DreaMS acquires rich molecular representations by predicting masked
177
- spectral peaks and chromatographic retention orders. When fine-tuned for tasks such as spectral
178
- similarity, chemical properties prediction, and fluorine detection, DreaMS achieves state-of-the-art
179
- performance across various mass spectrometry interpretation tasks (<a href="https://www.nature.com/articles/s41587-025-02663-3">Bushuiev et al., Nature Biotechnology, 2025</a>).
180
  """)
181
  with gr.Row(equal_height=True):
182
  in_pth = gr.File(
183
  file_count="single",
184
- label=".mzML file (TODO Extend to other formats)",
185
  )
186
  lib_pth = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5') # MassSpecGym library
187
  examples = gr.Examples(
188
- examples=["./data/S_N1.mzML", "./data/example_5_spectra.mgf"],
 
189
  inputs=[in_pth],
190
- label="Examples (click on a line to pre-fill the inputs)",
191
  # TODO
192
  # cache_examples=True
193
  # outputs=[df, df_file],
@@ -195,20 +276,20 @@ with app:
195
  )
196
 
197
  # Predict GUI
198
- predict_button = gr.Button(value="Run library matching", variant="primary")
199
 
200
  # Output GUI
201
  gr.Markdown("## Predictions")
202
  df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
203
  df = gr.Dataframe(
204
- headers=["feature_id", "topk", "library_j", "library_SMILES", "Spectrum", "library_ID", "DreaMS_similarity", "Modified_cosine_similarity", "i", "j"],
205
- datatype=["number", "number", "number", "html", "html", "str", "number", "number", "number", "number"],
206
- col_count=(10, "fixed"),
207
- wrap=True,
208
- column_widths=["80px", "60px", "80px", "400px", "800px", "120px", "120px", "150px", "60px", "60px"],
209
  max_height=1000,
210
  show_fullscreen_button=True,
211
- show_row_numbers=True,
212
  show_search='filter',
213
  # pinned_columns= # TODO
214
  )
@@ -217,7 +298,7 @@ with app:
217
  inputs = [in_pth]
218
  outputs = [df, df_file]
219
  predict = partial(predict, lib_pth)
220
- predict_button.click(predict, inputs=inputs, outputs=outputs)
221
 
222
 
223
  app.launch(allowed_paths=['./assets'])
 
1
  import gradio as gr
2
  import urllib.request
3
  import os
4
+ from datetime import datetime
5
  from functools import partial
6
  import matplotlib.pyplot as plt
7
  import matplotlib
 
15
  from rdkit.Chem.Draw import rdMolDraw2D
16
  import base64
17
  from io import BytesIO
18
+ from PIL import Image
19
+ import io
20
  import dreams.utils.spectra as su
21
+ import dreams.utils.io as dio
22
  from dreams.utils.spectra import PeakListModifiedCosine
23
  from dreams.utils.data import MSData
24
  from dreams.api import dreams_embeddings
25
  from dreams.definitions import *
26
 
27
 
28
+ def smiles_to_html_img(smiles, img_size=200):
29
  """
30
  Convert SMILES to HTML image string for display in Gradio dataframe
31
  """
 
34
  if mol is None:
35
  return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
36
 
37
+ # Use PNG drawing for better control over cropping
38
+ d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
39
  opts = d2d.drawOptions()
40
  opts.clearBackground = False
41
+ opts.padding = 0.05 # Minimal padding
42
+ opts.bondLineWidth = 2.0 # Make bonds more visible
43
  d2d.DrawMolecule(mol)
44
  d2d.FinishDrawing()
 
45
 
46
+ # Get PNG data
47
+ png_data = d2d.GetDrawingText()
48
+
49
+ # Convert PNG data to PIL Image for cropping
50
+ img = Image.open(io.BytesIO(png_data))
51
+
52
+ # Convert to RGBA if not already
53
+ if img.mode != 'RGBA':
54
+ img = img.convert('RGBA')
55
+
56
+ # Get the bounding box of non-transparent pixels
57
+ bbox = img.getbbox()
58
+ if bbox:
59
+ # Crop the image to remove transparent space
60
+ img = img.crop(bbox)
61
+
62
+ # Convert back to base64
63
+ buffered = io.BytesIO()
64
+ img.save(buffered, format='PNG')
65
  img_str = base64.b64encode(buffered.getvalue())
66
+ img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
67
 
68
+ return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
69
  except Exception as e:
70
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
71
 
 
78
  matplotlib.use('Agg') # Use non-interactive backend
79
 
80
  # Create the plot using the existing function
81
+ su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(2, 1))
82
 
83
+ # Save the current figure to a buffer with transparent background
84
  buffered = BytesIO()
85
+ plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100, transparent=True)
86
  buffered.seek(0)
87
+
88
+ # Convert to PIL Image for cropping
89
+ img = Image.open(buffered)
90
+
91
+ # Convert to RGBA if not already
92
+ if img.mode != 'RGBA':
93
+ img = img.convert('RGBA')
94
+
95
+ # Get the bounding box of non-transparent pixels
96
+ bbox = img.getbbox()
97
+ if bbox:
98
+ # Crop the image to remove transparent space
99
+ img = img.crop(bbox)
100
+
101
+ # Convert back to base64
102
+ buffered_cropped = BytesIO()
103
+ img.save(buffered_cropped, format='PNG')
104
+ img_str = base64.b64encode(buffered_cropped.getvalue())
105
  img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
106
 
107
  # Close the figure to free memory
108
  plt.close()
109
 
110
+ return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='Spectrum comparison' />"
111
  except Exception as e:
112
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
113
 
 
121
  if not target_path.exists():
122
  urllib.request.urlretrieve(url, target_path)
123
 
124
+ # Download example file
125
+ example_url = 'https://huggingface.co/datasets/titodamiani/PiperNET/resolve/main/lcms/rawfiles/202312_147_P55-Leaf-r2_1uL.mzML'
126
+ example_path = Path('./data/202312_147_P55-Leaf-r2_1uL.mzML')
127
+ example_path.parent.mkdir(parents=True, exist_ok=True)
128
+ if not example_path.exists():
129
+ urllib.request.urlretrieve(example_url, example_path)
130
+
131
  # Run simple example as a test and to download weights
132
  example_url = 'https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/data/examples/example_5_spectra.mgf'
133
  example_path = Path('./data/example_5_spectra.mgf')
 
138
  print("Setup complete")
139
 
140
 
141
+ def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
142
  in_pth = Path(in_pth)
143
  # # in_pth = Path('DreaMS/data/MSV000086206/peak/mzml/S_N1.mzML') # Example dataset
144
 
145
+ progress(0, desc="Loading library data...")
146
  msdata_lib = MSData.load(lib_pth)
147
  embs_lib = msdata_lib[DREAMS_EMBEDDING]
148
  print('Shape of the library embeddings:', embs_lib.shape)
149
 
150
+ progress(0.1, desc="Loading spectra data...")
151
  msdata = MSData.load(in_pth)
152
+
153
+ progress(0.2, desc="Computing spectra embeddings with DreaMS...")
154
  embs = dreams_embeddings(msdata)
155
  print('Shape of the query embeddings:', embs.shape)
156
 
157
+ progress(0.4, desc="Computing similarity matrix...")
158
  sims = cosine_similarity(embs, embs_lib)
159
  print('Shape of the similarity matrix:', sims.shape)
160
 
161
+ k = 1
162
  topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
163
  topk_cands.shape
164
 
165
+ print(msdata.columns())
166
+
167
  # Construct a DataFrame with the top-k candidates for each spectrum and their corresponding similarities
168
+ progress(0.5, desc="Constructing results table...")
169
  df = []
170
  cos_sim = su.PeakListModifiedCosine()
171
+ total_spectra = len(topk_cands)
172
+
173
+ for i, topk in enumerate(topk_cands):
174
+ progress(0.5 + 0.4 * (i / total_spectra), desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
175
  for n, j in enumerate(topk):
176
  smiles = msdata_lib.get_smiles(j)
177
  spec1 = msdata.get_spectra(i)
178
  spec2 = msdata_lib.get_spectra(j)
179
  df.append({
180
  'feature_id': i + 1,
181
+ 'precursor_mz': msdata.get_prec_mzs(i),
182
+ # 'RT': msdata.get_values('RTINSECONDS', i),
183
  'topk': n + 1,
184
  'library_j': j,
185
  'library_SMILES': smiles_to_html_img(smiles),
186
+ 'library_SMILES_raw': smiles,
187
  'Spectrum': spectrum_to_html_img(spec1, spec2),
188
+ 'Spectrum_raw': spec1,
189
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
190
  'DreaMS_similarity': sims[i, j],
191
  'Modified_cosine_similarity': cos_sim(
 
196
  ),
197
  'i': i,
198
  'j': j,
199
+ 'DreaMS_embedding': ' '.join(embs[i].astype(str)),
200
  })
201
  df = pd.DataFrame(df)
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  # Sort hits by DreaMS similarity
204
  df_top1 = df[df['topk'] == 1].sort_values('DreaMS_similarity', ascending=False)
205
  df = df.set_index('feature_id').loc[df_top1['feature_id'].values].reset_index()
206
 
207
+ progress(0.9, desc="Post-processing results...")
208
+ # Remove unnecessary columns and round similarity scores
209
+ df = df.drop(columns=['i', 'j', 'library_j'])
210
+ df['DreaMS_similarity'] = df['DreaMS_similarity'].round(4)
211
+ df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].round(4)
212
+ df['precursor_mz'] = df['precursor_mz'].round(4)
213
+ # df['RT'] = df['RT'].round(1)
214
+ df = df.rename(columns={
215
+ 'topk': 'Top k',
216
+ 'library_ID': 'Library ID',
217
+ "feature_id": "Feature ID",
218
+ "precursor_mz": "Precursor m/z",
219
+ # "RT": "RT",
220
+ "library_SMILES": "Molecule",
221
+ "library_SMILES_raw": "SMILES",
222
+ "Spectrum": "Spectrum",
223
+ "Spectrum_raw": "Input Spectrum",
224
+ "DreaMS_similarity": "DreaMS similarity",
225
+ "Modified_cosine_similarity": "Modified cos similarity",
226
+ "DreaMS_embedding": "DreaMS embedding",
227
+ })
228
+
229
+ progress(0.95, desc="Saving results to CSV...")
230
+ # Save full df to .csv
231
+ df_path = dio.append_to_stem(in_pth, f"MassSpecGym_hits_{datetime.now().strftime('%Y%m%d_%H%M%S')}").with_suffix('.csv')
232
+ df.to_csv(df_path, index=False)
233
+
234
+ progress(0.98, desc="Filtering and sorting results...")
235
+ # Postprocess to only show most relevant hits
236
+ df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"])
237
+ df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False)
238
+ df = df.drop(columns=['Top k'])
239
+ df = df[df["DreaMS similarity"] >= 0.75]
240
+ # Add row numbers as first column
241
+ df.insert(0, 'Row', range(len(df)))
242
+
243
+ progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
244
+
245
  return df, str(df_path)
246
 
247
 
 
254
  gr.Image("https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png", label="DreaMS")
255
  gr.Markdown(value="""
256
  DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a transformer-based
257
+ neural network designed to interpret tandem mass spectrometry (MS/MS) data (<a href="https://www.nature.com/articles/s41587-025-02663-3">Bushuiev et al., Nature Biotechnology, 2025</a>).
258
+ This website provides an easy access to perform library matching with DreaMS. Please upload
259
+ your MS/MS file and click on the "Run DreaMS" button. Predictions may currently take up to 10 minutes for files with several thousands of spectra.
 
 
 
260
  """)
261
  with gr.Row(equal_height=True):
262
  in_pth = gr.File(
263
  file_count="single",
264
+ label="Input MS/MS file (.mgf or .mzML)",
265
  )
266
  lib_pth = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5') # MassSpecGym library
267
  examples = gr.Examples(
268
+ examples=["./data/example_5_spectra.mgf", "./data/202312_147_P55-Leaf-r2_1uL.mzML"],
269
+ # examples=["./data/S_N1.mzML", "./data/example_5_spectra.mgf"],
270
  inputs=[in_pth],
271
+ label="Examples (click on a file to load as input)",
272
  # TODO
273
  # cache_examples=True
274
  # outputs=[df, df_file],
 
276
  )
277
 
278
  # Predict GUI
279
+ predict_button = gr.Button(value="Run DreaMS", variant="primary")
280
 
281
  # Output GUI
282
  gr.Markdown("## Predictions")
283
  df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
284
  df = gr.Dataframe(
285
+ headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum", "Library ID", "DreaMS similarity", "Modified cosine similarity"],
286
+ datatype=["number", "number", "number", "html", "html", "str", "number", "number"],
287
+ col_count=(8, "fixed"),
288
+ # wrap=True,
289
+ column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"],
290
  max_height=1000,
291
  show_fullscreen_button=True,
292
+ show_row_numbers=False,
293
  show_search='filter',
294
  # pinned_columns= # TODO
295
  )
 
298
  inputs = [in_pth]
299
  outputs = [df, df_file]
300
  predict = partial(predict, lib_pth)
301
+ predict_button.click(predict, inputs=inputs, outputs=outputs, show_progress="first")
302
 
303
 
304
  app.launch(allowed_paths=['./assets'])