Anton Bushuiev commited on
Commit
c765e79
·
1 Parent(s): e934ea3

Major refactor and code clean-up

Browse files
Files changed (1) hide show
  1. app.py +457 -204
app.py CHANGED
@@ -1,7 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import urllib.request
4
- import torch
5
  from datetime import datetime
6
  from functools import partial
7
  import matplotlib.pyplot as plt
@@ -9,10 +19,8 @@ import matplotlib
9
  import pandas as pd
10
  import numpy as np
11
  from pathlib import Path
12
- from tqdm import tqdm
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  from rdkit import Chem
15
- from rdkit.Chem import Draw
16
  from rdkit.Chem.Draw import rdMolDraw2D
17
  import base64
18
  from io import BytesIO
@@ -20,214 +28,332 @@ from PIL import Image
20
  import io
21
  import dreams.utils.spectra as su
22
  import dreams.utils.io as dio
23
- from dreams.utils.spectra import PeakListModifiedCosine
24
  from dreams.utils.data import MSData
25
  from dreams.api import dreams_embeddings
26
  from dreams.definitions import *
27
 
 
 
 
 
 
 
 
28
 
29
- def smiles_to_html_img(smiles, img_size=200):
 
 
 
 
 
 
 
 
 
 
 
 
30
  """
31
- Convert SMILES to HTML image string for display in Gradio dataframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
  try:
 
34
  mol = Chem.MolFromSmiles(smiles)
35
  if mol is None:
36
  return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
37
 
38
- # Use PNG drawing for better control over cropping
39
  d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
40
  opts = d2d.drawOptions()
41
  opts.clearBackground = False
42
  opts.padding = 0.05 # Minimal padding
43
  opts.bondLineWidth = 2.0 # Make bonds more visible
 
 
44
  d2d.DrawMolecule(mol)
45
  d2d.FinishDrawing()
46
 
47
- # Get PNG data
48
  png_data = d2d.GetDrawingText()
49
-
50
- # Convert PNG data to PIL Image for cropping
51
  img = Image.open(io.BytesIO(png_data))
52
 
53
- # Convert to RGBA if not already
54
- if img.mode != 'RGBA':
55
- img = img.convert('RGBA')
56
-
57
- # Get the bounding box of non-transparent pixels
58
- bbox = img.getbbox()
59
- if bbox:
60
- # Crop the image to remove transparent space
61
- img = img.crop(bbox)
62
-
63
- # Convert back to base64
64
- buffered = io.BytesIO()
65
- img.save(buffered, format='PNG')
66
- img_str = base64.b64encode(buffered.getvalue())
67
- img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
68
 
69
  return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
 
70
  except Exception as e:
71
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
72
 
73
 
74
- def spectrum_to_html_img(spec1, spec2, img_size=1500):
75
  """
76
- Convert spectrum plot to HTML image string for display in Gradio dataframe
 
 
 
 
 
 
 
 
77
  """
78
  try:
79
- matplotlib.use('Agg') # Use non-interactive backend
 
80
 
81
- # Create the plot using the existing function
82
  su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(2, 1))
83
 
84
- # Save the current figure to a buffer with transparent background
85
  buffered = BytesIO()
86
  plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100, transparent=True)
87
  buffered.seek(0)
88
 
89
- # Convert to PIL Image for cropping
90
  img = Image.open(buffered)
 
 
91
 
92
- # Convert to RGBA if not already
93
- if img.mode != 'RGBA':
94
- img = img.convert('RGBA')
95
-
96
- # Get the bounding box of non-transparent pixels
97
- bbox = img.getbbox()
98
- if bbox:
99
- # Crop the image to remove transparent space
100
- img = img.crop(bbox)
101
-
102
- # Convert back to base64
103
- buffered_cropped = BytesIO()
104
- img.save(buffered_cropped, format='PNG')
105
- img_str = base64.b64encode(buffered_cropped.getvalue())
106
- img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
107
-
108
- # Close the figure to free memory
109
  plt.close()
110
 
111
  return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='Spectrum comparison' />"
 
112
  except Exception as e:
113
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
114
 
115
 
116
- def setup():
117
- # Download spectral library
118
- data_path = Path('./DreaMS/data')
119
- data_path.mkdir(parents=True, exist_ok=True)
120
- url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
121
- target_path = data_path / 'MassSpecGym_DreaMS.hdf5'
 
 
 
 
 
 
 
122
  if not target_path.exists():
 
 
123
  urllib.request.urlretrieve(url, target_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Download example file
126
- # example_url = 'https://huggingface.co/datasets/titodamiani/PiperNET/resolve/main/lcms/rawfiles/202312_147_P55-Leaf-r2_1uL.mzML'
127
- # example_path = Path('./data/202312_147_P55-Leaf-r2_1uL.mzML')
128
- example_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/example_piper_2k_spectra.mgf'
129
- example_path = Path('./data/example_piper_2k_spectra.mgf')
130
- example_path.parent.mkdir(parents=True, exist_ok=True)
131
- if not example_path.exists():
132
- urllib.request.urlretrieve(example_url, example_path)
133
-
134
- # Run simple example as a test and to download weights
135
- example_url = 'https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/data/examples/example_5_spectra.mgf'
136
- example_path = Path('./data/example_5_spectra.mgf')
137
- example_path.parent.mkdir(parents=True, exist_ok=True)
138
- if not example_path.exists():
139
- urllib.request.urlretrieve(example_url, example_path)
140
- embs = dreams_embeddings(example_path)
141
- print("Setup complete")
142
 
 
 
 
143
 
144
  @spaces.GPU
145
  def _predict_gpu(in_pth, progress):
 
 
 
 
 
 
 
 
 
 
146
  progress(0.1, desc="Loading spectra data...")
147
  msdata = MSData.load(in_pth)
 
148
  progress(0.2, desc="Computing DreaMS embeddings...")
149
  embs = dreams_embeddings(msdata)
150
- print('Shape of the query embeddings:', embs.shape)
 
151
  return embs
152
 
153
 
154
- def _predict_core(lib_pth, in_pth, progress):
155
- """Core prediction function without error handling"""
156
- in_pth = Path(in_pth)
157
- # # in_pth = Path('DreaMS/data/MSV000086206/peak/mzml/S_N1.mzML') # Example dataset
158
 
159
- progress(0, desc="Loading library data...")
160
- msdata_lib = MSData.load(lib_pth)
161
- embs_lib = msdata_lib[DREAMS_EMBEDDING]
162
- print('Shape of the library embeddings:', embs_lib.shape)
 
 
 
 
 
163
 
164
- embs = _predict_gpu(in_pth, progress)
165
-
166
- progress(0.4, desc="Computing similarity matrix...")
167
- sims = cosine_similarity(embs, embs_lib)
168
- print('Shape of the similarity matrix:', sims.shape)
169
-
170
- k = 1
171
- topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
172
- topk_cands.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- # TODO This is loaded for the 2nd time here, otpimize
175
- msdata = MSData.load(in_pth)
176
- print(msdata.columns())
177
 
178
- # Construct a DataFrame with the top-k candidates for each spectrum and their corresponding similarities
179
- progress(0.5, desc="Constructing results table...")
180
- df = []
181
- cos_sim = su.PeakListModifiedCosine()
182
- total_spectra = len(topk_cands)
183
 
184
- for i, topk in enumerate(topk_cands):
185
- progress(0.5 + 0.4 * (i / total_spectra), desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
186
- for n, j in enumerate(topk):
187
- smiles = msdata_lib.get_smiles(j)
188
- spec1 = msdata.get_spectra(i)
189
- spec2 = msdata_lib.get_spectra(j)
190
- df.append({
191
- 'feature_id': i + 1,
192
- 'precursor_mz': msdata.get_prec_mzs(i),
193
- # 'RT': msdata.get_values('RTINSECONDS', i),
194
- 'topk': n + 1,
195
- 'library_j': j,
196
- 'library_SMILES': smiles_to_html_img(smiles),
197
- 'library_SMILES_raw': smiles,
198
- 'Spectrum': spectrum_to_html_img(spec1, spec2),
199
- 'Spectrum_raw': su.unpad_peak_list(spec1),
200
- 'library_ID': msdata_lib.get_values('IDENTIFIER', j),
201
- 'DreaMS_similarity': sims[i, j],
202
- 'Modified_cosine_similarity': cos_sim(
203
- spec1=spec1,
204
- prec_mz1=msdata.get_prec_mzs(i),
205
- spec2=spec2,
206
- prec_mz2=msdata_lib.get_prec_mzs(j),
207
- ),
208
- 'i': i,
209
- 'j': j,
210
- 'DreaMS_embedding': embs[i],
211
- })
212
- df = pd.DataFrame(df)
213
-
214
  # Sort hits by DreaMS similarity
215
  df_top1 = df[df['topk'] == 1].sort_values('DreaMS_similarity', ascending=False)
216
  df = df.set_index('feature_id').loc[df_top1['feature_id'].values].reset_index()
217
-
218
- progress(0.9, desc="Post-processing results...")
219
  # Remove unnecessary columns and round similarity scores
220
  df = df.drop(columns=['i', 'j', 'library_j'])
221
  df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
222
  df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4)
223
  df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
224
- # df['RT'] = df['RT'].round(1)
225
- df = df.rename(columns={
 
226
  'topk': 'Top k',
227
  'library_ID': 'Library ID',
228
  "feature_id": "Feature ID",
229
  "precursor_mz": "Precursor m/z",
230
- # "RT": "RT",
231
  "library_SMILES": "Molecule",
232
  "library_SMILES_raw": "SMILES",
233
  "Spectrum": "Spectrum",
@@ -235,97 +361,224 @@ def _predict_core(lib_pth, in_pth, progress):
235
  "DreaMS_similarity": "DreaMS similarity",
236
  "Modified_cosine_similarity": "Modified cos similarity",
237
  "DreaMS_embedding": "DreaMS embedding",
238
- })
239
-
240
- progress(0.95, desc="Saving results to CSV...")
241
- # Save full df to .csv
242
- df_path = dio.append_to_stem(in_pth, f"MassSpecGym_hits_{datetime.now().strftime('%Y%m%d_%H%M%S')}").with_suffix('.csv')
 
243
  df_to_save = df.drop(columns=['Molecule', 'Spectrum', 'Top k'])
244
  df_to_save.to_csv(df_path, index=False)
245
-
246
- progress(0.98, desc="Filtering and sorting results...")
247
- # Postprocess to only show most relevant hits
248
  df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"])
249
  df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False)
250
  df = df.drop(columns=['Top k'])
251
- df = df[df["DreaMS similarity"] >= 0.75]
252
- # Add row numbers as first column
 
253
  df.insert(0, 'Row', range(1, len(df) + 1))
254
 
255
- progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
256
-
257
  return df, str(df_path)
258
 
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
261
- """Wrapper function with error handling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  try:
 
 
 
 
 
 
 
 
263
  return _predict_core(lib_pth, in_pth, progress)
 
 
 
 
264
  except Exception as e:
265
- raise gr.Error(e)
266
-
 
 
 
 
 
 
 
 
267
 
268
- # Set up
269
- setup()
270
 
271
- # Start the Gradio app
272
- js_func = """
273
- function refresh() {
274
- const url = new URL(window.location);
275
 
276
- if (url.searchParams.get('__theme') !== 'light') {
277
- url.searchParams.set('__theme', 'light');
278
- window.location.href = url.href;
 
 
 
 
 
 
 
 
 
 
 
 
279
  }
280
- }
281
- """
282
- app = gr.Blocks(theme=gr.themes.Default(primary_hue="yellow", secondary_hue="pink"), js=js_func)
283
- with app:
284
-
285
- # Input GUI
286
- # gr.Markdown(value="""# DreaMS""")
287
- gr.Image("https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png", label="DreaMS")
288
- gr.Markdown(value="""
289
- DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a transformer-based
290
- neural network designed to interpret tandem mass spectrometry (MS/MS) data (<a href="https://www.nature.com/articles/s41587-025-02663-3">Bushuiev et al., Nature Biotechnology, 2025</a>).
291
- This website provides an easy access to perform library matching with DreaMS. Please upload
292
- your MS/MS file and click on the "Run DreaMS" button. Predictions may currently take up to 10 minutes for files with several thousands of spectra.
293
- """)
294
- with gr.Row(equal_height=True):
295
- in_pth = gr.File(
296
- file_count="single",
297
- label="Input MS/MS file (.mgf or .mzML)",
298
- )
299
- lib_pth = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5') # MassSpecGym library
300
- examples = gr.Examples(
301
- examples=["./data/example_5_spectra.mgf", "./data/example_piper_2k_spectra.mgf"],
302
- inputs=[in_pth],
303
- label="Examples (click on a file to load as input)",
304
- )
305
-
306
- # Predict GUI
307
- predict_button = gr.Button(value="Run DreaMS", variant="primary")
308
-
309
- # Output GUI
310
- gr.Markdown("## Predictions")
311
- df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
312
- df = gr.Dataframe(
313
- headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum", "Library ID", "DreaMS similarity", "Modified cosine similarity"],
314
- datatype=["number", "number", "number", "html", "html", "str", "number", "number"],
315
- col_count=(8, "fixed"),
316
- # wrap=True,
317
- column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"],
318
- max_height=1000,
319
- show_fullscreen_button=True,
320
- show_row_numbers=False,
321
- show_search='filter',
322
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
- # Main logic
325
- inputs = [in_pth]
326
- outputs = [df, df_file]
327
- predict = partial(predict, lib_pth)
328
- predict_button.click(predict, inputs=inputs, outputs=outputs, show_progress="first")
329
 
 
 
 
330
 
331
- app.launch(allowed_paths=['./assets'])
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DreaMS Gradio Web Application
3
+
4
+ This module provides a web interface for the DreaMS (Deep Representations Empowering
5
+ the Annotation of Mass Spectra) tool using Gradio. It allows users to upload MS/MS
6
+ files and perform library matching with DreaMS embeddings.
7
+
8
+ Author: DreaMS Team
9
+ License: MIT
10
+ """
11
+
12
  import gradio as gr
13
  import spaces
14
  import urllib.request
 
15
  from datetime import datetime
16
  from functools import partial
17
  import matplotlib.pyplot as plt
 
19
  import pandas as pd
20
  import numpy as np
21
  from pathlib import Path
 
22
  from sklearn.metrics.pairwise import cosine_similarity
23
  from rdkit import Chem
 
24
  from rdkit.Chem.Draw import rdMolDraw2D
25
  import base64
26
  from io import BytesIO
 
28
  import io
29
  import dreams.utils.spectra as su
30
  import dreams.utils.io as dio
 
31
  from dreams.utils.data import MSData
32
  from dreams.api import dreams_embeddings
33
  from dreams.definitions import *
34
 
35
+ # =============================================================================
36
+ # CONSTANTS AND CONFIGURATION
37
+ # =============================================================================
38
+
39
+ # Default image sizes for different components
40
+ SMILES_IMG_SIZE = 200
41
+ SPECTRUM_IMG_SIZE = 1500
42
 
43
+ # Library and data paths
44
+ LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5')
45
+ DATA_PATH = Path('./DreaMS/data')
46
+ EXAMPLE_PATH = Path('./data')
47
+
48
+ # Similarity threshold for filtering results
49
+ SIMILARITY_THRESHOLD = 0.75
50
+
51
+ # =============================================================================
52
+ # UTILITY FUNCTIONS FOR IMAGE CONVERSION
53
+ # =============================================================================
54
+
55
+ def _validate_input_file(file_path):
56
  """
57
+ Validate that the input file exists and has a supported format
58
+
59
+ Args:
60
+ file_path: Path to the input file
61
+
62
+ Returns:
63
+ bool: True if file is valid, False otherwise
64
+ """
65
+ if not file_path or not Path(file_path).exists():
66
+ return False
67
+
68
+ supported_extensions = ['.mgf', '.mzML', '.mzml']
69
+ file_ext = Path(file_path).suffix.lower()
70
+
71
+ return file_ext in supported_extensions
72
+
73
+
74
+ def _convert_pil_to_base64(img, format='PNG'):
75
+ """
76
+ Convert a PIL Image to base64 encoded string
77
+
78
+ Args:
79
+ img: PIL Image object
80
+ format: Image format (default: 'PNG')
81
+
82
+ Returns:
83
+ str: Base64 encoded image string
84
+ """
85
+ buffered = io.BytesIO()
86
+ img.save(buffered, format=format)
87
+ img_str = base64.b64encode(buffered.getvalue())
88
+ return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}"
89
+
90
+
91
+ def _crop_transparent_edges(img):
92
+ """
93
+ Crop transparent edges from a PIL Image
94
+
95
+ Args:
96
+ img: PIL Image object (should be RGBA)
97
+
98
+ Returns:
99
+ PIL Image: Cropped image
100
+ """
101
+ # Convert to RGBA if not already
102
+ if img.mode != 'RGBA':
103
+ img = img.convert('RGBA')
104
+
105
+ # Get the bounding box of non-transparent pixels
106
+ bbox = img.getbbox()
107
+ if bbox:
108
+ # Crop the image to remove transparent space
109
+ img = img.crop(bbox)
110
+
111
+ return img
112
+
113
+
114
+ def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
115
+ """
116
+ Convert SMILES string to HTML image for display in Gradio dataframe
117
+
118
+ Args:
119
+ smiles: SMILES string representation of molecule
120
+ img_size: Size of the output image (default: SMILES_IMG_SIZE)
121
+
122
+ Returns:
123
+ str: HTML img tag with base64 encoded image
124
  """
125
  try:
126
+ # Parse SMILES to RDKit molecule
127
  mol = Chem.MolFromSmiles(smiles)
128
  if mol is None:
129
  return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
130
 
131
+ # Create PNG drawing with Cairo backend for better control
132
  d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
133
  opts = d2d.drawOptions()
134
  opts.clearBackground = False
135
  opts.padding = 0.05 # Minimal padding
136
  opts.bondLineWidth = 2.0 # Make bonds more visible
137
+
138
+ # Draw the molecule
139
  d2d.DrawMolecule(mol)
140
  d2d.FinishDrawing()
141
 
142
+ # Get PNG data and convert to PIL Image
143
  png_data = d2d.GetDrawingText()
 
 
144
  img = Image.open(io.BytesIO(png_data))
145
 
146
+ # Crop transparent edges and convert to base64
147
+ img = _crop_transparent_edges(img)
148
+ img_str = _convert_pil_to_base64(img)
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
151
+
152
  except Exception as e:
153
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
154
 
155
 
156
+ def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
157
  """
158
+ Convert spectrum plot to HTML image for display in Gradio dataframe
159
+
160
+ Args:
161
+ spec1: First spectrum data
162
+ spec2: Second spectrum data (for mirror plot)
163
+ img_size: Size of the output image (default: SPECTRUM_IMG_SIZE)
164
+
165
+ Returns:
166
+ str: HTML img tag with base64 encoded spectrum plot
167
  """
168
  try:
169
+ # Use non-interactive matplotlib backend
170
+ matplotlib.use('Agg')
171
 
172
+ # Create the spectrum plot using DreaMS utility function
173
  su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(2, 1))
174
 
175
+ # Save figure to buffer with transparent background
176
  buffered = BytesIO()
177
  plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100, transparent=True)
178
  buffered.seek(0)
179
 
180
+ # Convert to PIL Image, crop edges, and convert to base64
181
  img = Image.open(buffered)
182
+ img = _crop_transparent_edges(img)
183
+ img_str = _convert_pil_to_base64(img)
184
 
185
+ # Clean up matplotlib figure to free memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  plt.close()
187
 
188
  return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='Spectrum comparison' />"
189
+
190
  except Exception as e:
191
  return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
192
 
193
 
194
+ # =============================================================================
195
+ # DATA DOWNLOAD AND SETUP FUNCTIONS
196
+ # =============================================================================
197
+
198
+ def _download_file(url, target_path, description):
199
+ """
200
+ Download a file from URL if it doesn't exist
201
+
202
+ Args:
203
+ url: Source URL
204
+ target_path: Target file path
205
+ description: Description for logging
206
+ """
207
  if not target_path.exists():
208
+ print(f"Downloading {description}...")
209
+ target_path.parent.mkdir(parents=True, exist_ok=True)
210
  urllib.request.urlretrieve(url, target_path)
211
+ print(f"Downloaded {description} to {target_path}")
212
+
213
+
214
+ def setup():
215
+ """
216
+ Initialize the application by downloading required data files
217
+
218
+ Downloads:
219
+ - MassSpecGym spectral library
220
+ - Example MS/MS files for testing
221
+
222
+ Raises:
223
+ Exception: If critical setup steps fail
224
+ """
225
+ print("=" * 60)
226
+ print("Setting up DreaMS application...")
227
+ print("=" * 60)
228
+
229
+ try:
230
+ # Download spectral library
231
+ library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
232
+ _download_file(library_url, LIBRARY_PATH, "MassSpecGym spectral library")
233
+
234
+ # Download example files
235
+ example_urls = [
236
+ ('https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/example_piper_2k_spectra.mgf',
237
+ EXAMPLE_PATH / 'example_piper_2k_spectra.mgf',
238
+ "PiperNET example spectra"),
239
+ ('https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/data/examples/example_5_spectra.mgf',
240
+ EXAMPLE_PATH / 'example_5_spectra.mgf',
241
+ "DreaMS example spectra")
242
+ ]
243
+
244
+ for url, path, desc in example_urls:
245
+ _download_file(url, path, desc)
246
+
247
+ # Test DreaMS embeddings to ensure everything works
248
+ print("\nTesting DreaMS embeddings...")
249
+ test_path = EXAMPLE_PATH / 'example_5_spectra.mgf'
250
+ embs = dreams_embeddings(test_path)
251
+ print(f"✓ Setup complete - DreaMS embeddings test successful (shape: {embs.shape})")
252
+ print("=" * 60)
253
+
254
+ except Exception as e:
255
+ print(f"✗ Setup failed: {e}")
256
+ print("The application may not work properly. Please check your internet connection and try again.")
257
+ raise
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ # =============================================================================
261
+ # CORE PREDICTION FUNCTIONS
262
+ # =============================================================================
263
 
264
  @spaces.GPU
265
  def _predict_gpu(in_pth, progress):
266
+ """
267
+ GPU-accelerated prediction of DreaMS embeddings
268
+
269
+ Args:
270
+ in_pth: Input file path
271
+ progress: Gradio progress tracker
272
+
273
+ Returns:
274
+ numpy.ndarray: DreaMS embeddings
275
+ """
276
  progress(0.1, desc="Loading spectra data...")
277
  msdata = MSData.load(in_pth)
278
+
279
  progress(0.2, desc="Computing DreaMS embeddings...")
280
  embs = dreams_embeddings(msdata)
281
+ print(f'Shape of the query embeddings: {embs.shape}')
282
+
283
  return embs
284
 
285
 
286
+ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
287
+ """
288
+ Create a single result row for the DataFrame
 
289
 
290
+ Args:
291
+ i: Query spectrum index
292
+ j: Library spectrum index
293
+ n: Top-k rank
294
+ msdata: Query MS data
295
+ msdata_lib: Library MS data
296
+ sims: Similarity matrix
297
+ cos_sim: Cosine similarity calculator
298
+ embs: Query embeddings
299
 
300
+ Returns:
301
+ dict: Result row data
302
+ """
303
+ smiles = msdata_lib.get_smiles(j)
304
+ spec1 = msdata.get_spectra(i)
305
+ spec2 = msdata_lib.get_spectra(j)
306
+
307
+ return {
308
+ 'feature_id': i + 1,
309
+ 'precursor_mz': msdata.get_prec_mzs(i),
310
+ 'topk': n + 1,
311
+ 'library_j': j,
312
+ 'library_SMILES': smiles_to_html_img(smiles),
313
+ 'library_SMILES_raw': smiles,
314
+ 'Spectrum': spectrum_to_html_img(spec1, spec2),
315
+ 'Spectrum_raw': su.unpad_peak_list(spec1),
316
+ 'library_ID': msdata_lib.get_values('IDENTIFIER', j),
317
+ 'DreaMS_similarity': sims[i, j],
318
+ 'Modified_cosine_similarity': cos_sim(
319
+ spec1=spec1,
320
+ prec_mz1=msdata.get_prec_mzs(i),
321
+ spec2=spec2,
322
+ prec_mz2=msdata_lib.get_prec_mzs(j),
323
+ ),
324
+ 'i': i,
325
+ 'j': j,
326
+ 'DreaMS_embedding': embs[i],
327
+ }
328
 
 
 
 
329
 
330
+ def _process_results_dataframe(df, in_pth):
331
+ """
332
+ Process and clean the results DataFrame
 
 
333
 
334
+ Args:
335
+ df: Raw results DataFrame
336
+ in_pth: Input file path for CSV export
337
+
338
+ Returns:
339
+ tuple: (processed_df, csv_path)
340
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  # Sort hits by DreaMS similarity
342
  df_top1 = df[df['topk'] == 1].sort_values('DreaMS_similarity', ascending=False)
343
  df = df.set_index('feature_id').loc[df_top1['feature_id'].values].reset_index()
344
+
 
345
  # Remove unnecessary columns and round similarity scores
346
  df = df.drop(columns=['i', 'j', 'library_j'])
347
  df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
348
  df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4)
349
  df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
350
+
351
+ # Rename columns for display
352
+ column_mapping = {
353
  'topk': 'Top k',
354
  'library_ID': 'Library ID',
355
  "feature_id": "Feature ID",
356
  "precursor_mz": "Precursor m/z",
 
357
  "library_SMILES": "Molecule",
358
  "library_SMILES_raw": "SMILES",
359
  "Spectrum": "Spectrum",
 
361
  "DreaMS_similarity": "DreaMS similarity",
362
  "Modified_cosine_similarity": "Modified cos similarity",
363
  "DreaMS_embedding": "DreaMS embedding",
364
+ }
365
+ df = df.rename(columns=column_mapping)
366
+
367
+ # Save full results to CSV
368
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
369
+ df_path = dio.append_to_stem(in_pth, f"MassSpecGym_hits_{timestamp}").with_suffix('.csv')
370
  df_to_save = df.drop(columns=['Molecule', 'Spectrum', 'Top k'])
371
  df_to_save.to_csv(df_path, index=False)
372
+
373
+ # Filter and prepare final display DataFrame
 
374
  df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"])
375
  df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False)
376
  df = df.drop(columns=['Top k'])
377
+ df = df[df["DreaMS similarity"] >= SIMILARITY_THRESHOLD]
378
+
379
+ # Add row numbers
380
  df.insert(0, 'Row', range(1, len(df) + 1))
381
 
 
 
382
  return df, str(df_path)
383
 
384
 
385
+ def _predict_core(lib_pth, in_pth, progress):
386
+ """
387
+ Core prediction function that orchestrates the entire prediction pipeline
388
+
389
+ Args:
390
+ lib_pth: Library file path
391
+ in_pth: Input file path
392
+ progress: Gradio progress tracker
393
+
394
+ Returns:
395
+ tuple: (results_dataframe, csv_file_path)
396
+ """
397
+ in_pth = Path(in_pth)
398
+
399
+ # Load library data
400
+ progress(0, desc="Loading library data...")
401
+ msdata_lib = MSData.load(lib_pth)
402
+ embs_lib = msdata_lib[DREAMS_EMBEDDING]
403
+ print(f'Shape of the library embeddings: {embs_lib.shape}')
404
+
405
+ # Get query embeddings
406
+ embs = _predict_gpu(in_pth, progress)
407
+
408
+ # Compute similarity matrix
409
+ progress(0.4, desc="Computing similarity matrix...")
410
+ sims = cosine_similarity(embs, embs_lib)
411
+ print(f'Shape of the similarity matrix: {sims.shape}')
412
+
413
+ # Get top-k candidates
414
+ k = 1
415
+ topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
416
+
417
+ # Load query data for processing
418
+ msdata = MSData.load(in_pth)
419
+ print(f'Available columns: {msdata.columns()}')
420
+
421
+ # Construct results DataFrame
422
+ progress(0.5, desc="Constructing results table...")
423
+ df = []
424
+ cos_sim = su.PeakListModifiedCosine()
425
+ total_spectra = len(topk_cands)
426
+
427
+ for i, topk in enumerate(topk_cands):
428
+ progress(0.5 + 0.4 * (i / total_spectra),
429
+ desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
430
+
431
+ for n, j in enumerate(topk):
432
+ row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs)
433
+ df.append(row_data)
434
+
435
+ df = pd.DataFrame(df)
436
+
437
+ # Process and clean results
438
+ progress(0.9, desc="Post-processing results...")
439
+ df, csv_path = _process_results_dataframe(df, in_pth)
440
+
441
+ progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
442
+
443
+ return df, csv_path
444
+
445
+
446
  def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
447
+ """
448
+ Main prediction function with error handling
449
+
450
+ Args:
451
+ lib_pth: Library file path
452
+ in_pth: Input file path
453
+ progress: Gradio progress tracker
454
+
455
+ Returns:
456
+ tuple: (results_dataframe, csv_file_path)
457
+
458
+ Raises:
459
+ gr.Error: If prediction fails or input is invalid
460
+ """
461
  try:
462
+ # Validate input file
463
+ if not _validate_input_file(in_pth):
464
+ raise gr.Error("Invalid input file. Please provide a valid .mgf or .mzML file.")
465
+
466
+ # Check if library exists
467
+ if not Path(lib_pth).exists():
468
+ raise gr.Error("Spectral library not found. Please ensure the library file exists.")
469
+
470
  return _predict_core(lib_pth, in_pth, progress)
471
+
472
+ except gr.Error:
473
+ # Re-raise Gradio errors as-is
474
+ raise
475
  except Exception as e:
476
+ error_msg = str(e)
477
+ if "CUDA" in error_msg or "cuda" in error_msg:
478
+ error_msg = f"GPU/CUDA error: {error_msg}. The app is falling back to CPU mode."
479
+ elif "RuntimeError" in error_msg:
480
+ error_msg = f"Runtime error: {error_msg}. This may be due to memory or device issues."
481
+ else:
482
+ error_msg = f"Error: {error_msg}"
483
+
484
+ print(f"Prediction failed: {error_msg}")
485
+ raise gr.Error(error_msg)
486
 
 
 
487
 
488
+ # =============================================================================
489
+ # GRADIO INTERFACE SETUP
490
+ # =============================================================================
 
491
 
492
+ def _create_gradio_interface():
493
+ """
494
+ Create and configure the Gradio interface
495
+
496
+ Returns:
497
+ gr.Blocks: Configured Gradio app
498
+ """
499
+ # JavaScript for theme management
500
+ js_func = """
501
+ function refresh() {
502
+ const url = new URL(window.location);
503
+ if (url.searchParams.get('__theme') !== 'light') {
504
+ url.searchParams.set('__theme', 'light');
505
+ window.location.href = url.href;
506
+ }
507
  }
508
+ """
509
+
510
+ # Create app with custom theme
511
+ app = gr.Blocks(
512
+ theme=gr.themes.Default(primary_hue="yellow", secondary_hue="pink"),
513
+ js=js_func
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  )
515
+
516
+ with app:
517
+ # Header and description
518
+ gr.Image("https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png",
519
+ label="DreaMS")
520
+
521
+ gr.Markdown(value="""
522
+ DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a transformer-based
523
+ neural network designed to interpret tandem mass spectrometry (MS/MS) data (<a href="https://www.nature.com/articles/s41587-025-02663-3">Bushuiev et al., Nature Biotechnology, 2025</a>).
524
+ This website provides an easy access to perform library matching with DreaMS. Please upload
525
+ your MS/MS file and click on the "Run DreaMS" button. Predictions may currently take up to 10 minutes for files with several thousands of spectra.
526
+ """)
527
+
528
+ # Input section
529
+ with gr.Row(equal_height=True):
530
+ in_pth = gr.File(
531
+ file_count="single",
532
+ label="Input MS/MS file (.mgf or .mzML)",
533
+ )
534
+
535
+ # Example files
536
+ examples = gr.Examples(
537
+ examples=["./data/example_5_spectra.mgf", "./data/example_piper_2k_spectra.mgf"],
538
+ inputs=[in_pth],
539
+ label="Examples (click on a file to load as input)",
540
+ )
541
+
542
+ # Prediction button
543
+ predict_button = gr.Button(value="Run DreaMS", variant="primary")
544
+
545
+ # Output section
546
+ gr.Markdown("## Predictions")
547
+ df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
548
+
549
+ # Results table
550
+ df = gr.Dataframe(
551
+ headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
552
+ "Library ID", "DreaMS similarity", "Modified cosine similarity"],
553
+ datatype=["number", "number", "number", "html", "html", "str", "number", "number"],
554
+ col_count=(8, "fixed"),
555
+ column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"],
556
+ max_height=1000,
557
+ show_fullscreen_button=True,
558
+ show_row_numbers=False,
559
+ show_search='filter',
560
+ )
561
+
562
+ # Connect prediction logic
563
+ inputs = [in_pth]
564
+ outputs = [df, df_file]
565
+ predict_func = partial(predict, LIBRARY_PATH)
566
+ predict_button.click(predict_func, inputs=inputs, outputs=outputs, show_progress="first")
567
+
568
+ return app
569
 
 
 
 
 
 
570
 
571
+ # =============================================================================
572
+ # MAIN EXECUTION
573
+ # =============================================================================
574
 
575
+ if __name__ == "__main__":
576
+ # Initialize the application
577
+ setup()
578
+
579
+ # Create and launch the Gradio interface
580
+ app = _create_gradio_interface()
581
+ app.launch(allowed_paths=['./assets'])
582
+ else:
583
+ # When imported as a module, just run setup
584
+ setup()