Anton Bushuiev commited on
Commit ·
c765e79
1
Parent(s): e934ea3
Major refactor and code clean-up
Browse files
app.py
CHANGED
|
@@ -1,7 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import spaces
|
| 3 |
import urllib.request
|
| 4 |
-
import torch
|
| 5 |
from datetime import datetime
|
| 6 |
from functools import partial
|
| 7 |
import matplotlib.pyplot as plt
|
|
@@ -9,10 +19,8 @@ import matplotlib
|
|
| 9 |
import pandas as pd
|
| 10 |
import numpy as np
|
| 11 |
from pathlib import Path
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 14 |
from rdkit import Chem
|
| 15 |
-
from rdkit.Chem import Draw
|
| 16 |
from rdkit.Chem.Draw import rdMolDraw2D
|
| 17 |
import base64
|
| 18 |
from io import BytesIO
|
|
@@ -20,214 +28,332 @@ from PIL import Image
|
|
| 20 |
import io
|
| 21 |
import dreams.utils.spectra as su
|
| 22 |
import dreams.utils.io as dio
|
| 23 |
-
from dreams.utils.spectra import PeakListModifiedCosine
|
| 24 |
from dreams.utils.data import MSData
|
| 25 |
from dreams.api import dreams_embeddings
|
| 26 |
from dreams.definitions import *
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"""
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
try:
|
|
|
|
| 34 |
mol = Chem.MolFromSmiles(smiles)
|
| 35 |
if mol is None:
|
| 36 |
return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
|
| 37 |
|
| 38 |
-
#
|
| 39 |
d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
|
| 40 |
opts = d2d.drawOptions()
|
| 41 |
opts.clearBackground = False
|
| 42 |
opts.padding = 0.05 # Minimal padding
|
| 43 |
opts.bondLineWidth = 2.0 # Make bonds more visible
|
|
|
|
|
|
|
| 44 |
d2d.DrawMolecule(mol)
|
| 45 |
d2d.FinishDrawing()
|
| 46 |
|
| 47 |
-
# Get PNG data
|
| 48 |
png_data = d2d.GetDrawingText()
|
| 49 |
-
|
| 50 |
-
# Convert PNG data to PIL Image for cropping
|
| 51 |
img = Image.open(io.BytesIO(png_data))
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# Get the bounding box of non-transparent pixels
|
| 58 |
-
bbox = img.getbbox()
|
| 59 |
-
if bbox:
|
| 60 |
-
# Crop the image to remove transparent space
|
| 61 |
-
img = img.crop(bbox)
|
| 62 |
-
|
| 63 |
-
# Convert back to base64
|
| 64 |
-
buffered = io.BytesIO()
|
| 65 |
-
img.save(buffered, format='PNG')
|
| 66 |
-
img_str = base64.b64encode(buffered.getvalue())
|
| 67 |
-
img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
|
| 68 |
|
| 69 |
return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
|
|
|
|
| 70 |
except Exception as e:
|
| 71 |
return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
|
| 72 |
|
| 73 |
|
| 74 |
-
def spectrum_to_html_img(spec1, spec2, img_size=
|
| 75 |
"""
|
| 76 |
-
Convert spectrum plot to HTML image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
"""
|
| 78 |
try:
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
-
# Create the plot using
|
| 82 |
su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(2, 1))
|
| 83 |
|
| 84 |
-
# Save
|
| 85 |
buffered = BytesIO()
|
| 86 |
plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100, transparent=True)
|
| 87 |
buffered.seek(0)
|
| 88 |
|
| 89 |
-
# Convert to PIL Image
|
| 90 |
img = Image.open(buffered)
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
if img.mode != 'RGBA':
|
| 94 |
-
img = img.convert('RGBA')
|
| 95 |
-
|
| 96 |
-
# Get the bounding box of non-transparent pixels
|
| 97 |
-
bbox = img.getbbox()
|
| 98 |
-
if bbox:
|
| 99 |
-
# Crop the image to remove transparent space
|
| 100 |
-
img = img.crop(bbox)
|
| 101 |
-
|
| 102 |
-
# Convert back to base64
|
| 103 |
-
buffered_cropped = BytesIO()
|
| 104 |
-
img.save(buffered_cropped, format='PNG')
|
| 105 |
-
img_str = base64.b64encode(buffered_cropped.getvalue())
|
| 106 |
-
img_str = f"data:image/png;base64,{repr(img_str)[2:-1]}"
|
| 107 |
-
|
| 108 |
-
# Close the figure to free memory
|
| 109 |
plt.close()
|
| 110 |
|
| 111 |
return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='Spectrum comparison' />"
|
|
|
|
| 112 |
except Exception as e:
|
| 113 |
return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
|
| 114 |
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
if not target_path.exists():
|
|
|
|
|
|
|
| 123 |
urllib.request.urlretrieve(url, target_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
# Download example file
|
| 126 |
-
# example_url = 'https://huggingface.co/datasets/titodamiani/PiperNET/resolve/main/lcms/rawfiles/202312_147_P55-Leaf-r2_1uL.mzML'
|
| 127 |
-
# example_path = Path('./data/202312_147_P55-Leaf-r2_1uL.mzML')
|
| 128 |
-
example_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/example_piper_2k_spectra.mgf'
|
| 129 |
-
example_path = Path('./data/example_piper_2k_spectra.mgf')
|
| 130 |
-
example_path.parent.mkdir(parents=True, exist_ok=True)
|
| 131 |
-
if not example_path.exists():
|
| 132 |
-
urllib.request.urlretrieve(example_url, example_path)
|
| 133 |
-
|
| 134 |
-
# Run simple example as a test and to download weights
|
| 135 |
-
example_url = 'https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/data/examples/example_5_spectra.mgf'
|
| 136 |
-
example_path = Path('./data/example_5_spectra.mgf')
|
| 137 |
-
example_path.parent.mkdir(parents=True, exist_ok=True)
|
| 138 |
-
if not example_path.exists():
|
| 139 |
-
urllib.request.urlretrieve(example_url, example_path)
|
| 140 |
-
embs = dreams_embeddings(example_path)
|
| 141 |
-
print("Setup complete")
|
| 142 |
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
@spaces.GPU
|
| 145 |
def _predict_gpu(in_pth, progress):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
progress(0.1, desc="Loading spectra data...")
|
| 147 |
msdata = MSData.load(in_pth)
|
|
|
|
| 148 |
progress(0.2, desc="Computing DreaMS embeddings...")
|
| 149 |
embs = dreams_embeddings(msdata)
|
| 150 |
-
print('Shape of the query embeddings:
|
|
|
|
| 151 |
return embs
|
| 152 |
|
| 153 |
|
| 154 |
-
def
|
| 155 |
-
"""
|
| 156 |
-
|
| 157 |
-
# # in_pth = Path('DreaMS/data/MSV000086206/peak/mzml/S_N1.mzML') # Example dataset
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
# TODO This is loaded for the 2nd time here, otpimize
|
| 175 |
-
msdata = MSData.load(in_pth)
|
| 176 |
-
print(msdata.columns())
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
cos_sim = su.PeakListModifiedCosine()
|
| 182 |
-
total_spectra = len(topk_cands)
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
'feature_id': i + 1,
|
| 192 |
-
'precursor_mz': msdata.get_prec_mzs(i),
|
| 193 |
-
# 'RT': msdata.get_values('RTINSECONDS', i),
|
| 194 |
-
'topk': n + 1,
|
| 195 |
-
'library_j': j,
|
| 196 |
-
'library_SMILES': smiles_to_html_img(smiles),
|
| 197 |
-
'library_SMILES_raw': smiles,
|
| 198 |
-
'Spectrum': spectrum_to_html_img(spec1, spec2),
|
| 199 |
-
'Spectrum_raw': su.unpad_peak_list(spec1),
|
| 200 |
-
'library_ID': msdata_lib.get_values('IDENTIFIER', j),
|
| 201 |
-
'DreaMS_similarity': sims[i, j],
|
| 202 |
-
'Modified_cosine_similarity': cos_sim(
|
| 203 |
-
spec1=spec1,
|
| 204 |
-
prec_mz1=msdata.get_prec_mzs(i),
|
| 205 |
-
spec2=spec2,
|
| 206 |
-
prec_mz2=msdata_lib.get_prec_mzs(j),
|
| 207 |
-
),
|
| 208 |
-
'i': i,
|
| 209 |
-
'j': j,
|
| 210 |
-
'DreaMS_embedding': embs[i],
|
| 211 |
-
})
|
| 212 |
-
df = pd.DataFrame(df)
|
| 213 |
-
|
| 214 |
# Sort hits by DreaMS similarity
|
| 215 |
df_top1 = df[df['topk'] == 1].sort_values('DreaMS_similarity', ascending=False)
|
| 216 |
df = df.set_index('feature_id').loc[df_top1['feature_id'].values].reset_index()
|
| 217 |
-
|
| 218 |
-
progress(0.9, desc="Post-processing results...")
|
| 219 |
# Remove unnecessary columns and round similarity scores
|
| 220 |
df = df.drop(columns=['i', 'j', 'library_j'])
|
| 221 |
df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
|
| 222 |
df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4)
|
| 223 |
df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
|
| 224 |
-
|
| 225 |
-
|
|
|
|
| 226 |
'topk': 'Top k',
|
| 227 |
'library_ID': 'Library ID',
|
| 228 |
"feature_id": "Feature ID",
|
| 229 |
"precursor_mz": "Precursor m/z",
|
| 230 |
-
# "RT": "RT",
|
| 231 |
"library_SMILES": "Molecule",
|
| 232 |
"library_SMILES_raw": "SMILES",
|
| 233 |
"Spectrum": "Spectrum",
|
|
@@ -235,97 +361,224 @@ def _predict_core(lib_pth, in_pth, progress):
|
|
| 235 |
"DreaMS_similarity": "DreaMS similarity",
|
| 236 |
"Modified_cosine_similarity": "Modified cos similarity",
|
| 237 |
"DreaMS_embedding": "DreaMS embedding",
|
| 238 |
-
}
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
# Save full
|
| 242 |
-
|
|
|
|
| 243 |
df_to_save = df.drop(columns=['Molecule', 'Spectrum', 'Top k'])
|
| 244 |
df_to_save.to_csv(df_path, index=False)
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
# Postprocess to only show most relevant hits
|
| 248 |
df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"])
|
| 249 |
df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False)
|
| 250 |
df = df.drop(columns=['Top k'])
|
| 251 |
-
df = df[df["DreaMS similarity"] >=
|
| 252 |
-
|
|
|
|
| 253 |
df.insert(0, 'Row', range(1, len(df) + 1))
|
| 254 |
|
| 255 |
-
progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
|
| 256 |
-
|
| 257 |
return df, str(df_path)
|
| 258 |
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
|
| 261 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
return _predict_core(lib_pth, in_pth, progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
except Exception as e:
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
# Set up
|
| 269 |
-
setup()
|
| 270 |
|
| 271 |
-
#
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
const url = new URL(window.location);
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
}
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
# gr.Markdown(value="""# DreaMS""")
|
| 287 |
-
gr.Image("https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png", label="DreaMS")
|
| 288 |
-
gr.Markdown(value="""
|
| 289 |
-
DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a transformer-based
|
| 290 |
-
neural network designed to interpret tandem mass spectrometry (MS/MS) data (<a href="https://www.nature.com/articles/s41587-025-02663-3">Bushuiev et al., Nature Biotechnology, 2025</a>).
|
| 291 |
-
This website provides an easy access to perform library matching with DreaMS. Please upload
|
| 292 |
-
your MS/MS file and click on the "Run DreaMS" button. Predictions may currently take up to 10 minutes for files with several thousands of spectra.
|
| 293 |
-
""")
|
| 294 |
-
with gr.Row(equal_height=True):
|
| 295 |
-
in_pth = gr.File(
|
| 296 |
-
file_count="single",
|
| 297 |
-
label="Input MS/MS file (.mgf or .mzML)",
|
| 298 |
-
)
|
| 299 |
-
lib_pth = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5') # MassSpecGym library
|
| 300 |
-
examples = gr.Examples(
|
| 301 |
-
examples=["./data/example_5_spectra.mgf", "./data/example_piper_2k_spectra.mgf"],
|
| 302 |
-
inputs=[in_pth],
|
| 303 |
-
label="Examples (click on a file to load as input)",
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
# Predict GUI
|
| 307 |
-
predict_button = gr.Button(value="Run DreaMS", variant="primary")
|
| 308 |
-
|
| 309 |
-
# Output GUI
|
| 310 |
-
gr.Markdown("## Predictions")
|
| 311 |
-
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
|
| 312 |
-
df = gr.Dataframe(
|
| 313 |
-
headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum", "Library ID", "DreaMS similarity", "Modified cosine similarity"],
|
| 314 |
-
datatype=["number", "number", "number", "html", "html", "str", "number", "number"],
|
| 315 |
-
col_count=(8, "fixed"),
|
| 316 |
-
# wrap=True,
|
| 317 |
-
column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"],
|
| 318 |
-
max_height=1000,
|
| 319 |
-
show_fullscreen_button=True,
|
| 320 |
-
show_row_numbers=False,
|
| 321 |
-
show_search='filter',
|
| 322 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
-
# Main logic
|
| 325 |
-
inputs = [in_pth]
|
| 326 |
-
outputs = [df, df_file]
|
| 327 |
-
predict = partial(predict, lib_pth)
|
| 328 |
-
predict_button.click(predict, inputs=inputs, outputs=outputs, show_progress="first")
|
| 329 |
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DreaMS Gradio Web Application
|
| 3 |
+
|
| 4 |
+
This module provides a web interface for the DreaMS (Deep Representations Empowering
|
| 5 |
+
the Annotation of Mass Spectra) tool using Gradio. It allows users to upload MS/MS
|
| 6 |
+
files and perform library matching with DreaMS embeddings.
|
| 7 |
+
|
| 8 |
+
Author: DreaMS Team
|
| 9 |
+
License: MIT
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
import gradio as gr
|
| 13 |
import spaces
|
| 14 |
import urllib.request
|
|
|
|
| 15 |
from datetime import datetime
|
| 16 |
from functools import partial
|
| 17 |
import matplotlib.pyplot as plt
|
|
|
|
| 19 |
import pandas as pd
|
| 20 |
import numpy as np
|
| 21 |
from pathlib import Path
|
|
|
|
| 22 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 23 |
from rdkit import Chem
|
|
|
|
| 24 |
from rdkit.Chem.Draw import rdMolDraw2D
|
| 25 |
import base64
|
| 26 |
from io import BytesIO
|
|
|
|
| 28 |
import io
|
| 29 |
import dreams.utils.spectra as su
|
| 30 |
import dreams.utils.io as dio
|
|
|
|
| 31 |
from dreams.utils.data import MSData
|
| 32 |
from dreams.api import dreams_embeddings
|
| 33 |
from dreams.definitions import *
|
| 34 |
|
| 35 |
+
# =============================================================================
|
| 36 |
+
# CONSTANTS AND CONFIGURATION
|
| 37 |
+
# =============================================================================
|
| 38 |
+
|
| 39 |
+
# Default image sizes for different components
|
| 40 |
+
SMILES_IMG_SIZE = 200
|
| 41 |
+
SPECTRUM_IMG_SIZE = 1500
|
| 42 |
|
| 43 |
+
# Library and data paths
|
| 44 |
+
LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5')
|
| 45 |
+
DATA_PATH = Path('./DreaMS/data')
|
| 46 |
+
EXAMPLE_PATH = Path('./data')
|
| 47 |
+
|
| 48 |
+
# Similarity threshold for filtering results
|
| 49 |
+
SIMILARITY_THRESHOLD = 0.75
|
| 50 |
+
|
| 51 |
+
# =============================================================================
|
| 52 |
+
# UTILITY FUNCTIONS FOR IMAGE CONVERSION
|
| 53 |
+
# =============================================================================
|
| 54 |
+
|
| 55 |
+
def _validate_input_file(file_path):
|
| 56 |
"""
|
| 57 |
+
Validate that the input file exists and has a supported format
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
file_path: Path to the input file
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
bool: True if file is valid, False otherwise
|
| 64 |
+
"""
|
| 65 |
+
if not file_path or not Path(file_path).exists():
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
supported_extensions = ['.mgf', '.mzML', '.mzml']
|
| 69 |
+
file_ext = Path(file_path).suffix.lower()
|
| 70 |
+
|
| 71 |
+
return file_ext in supported_extensions
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _convert_pil_to_base64(img, format='PNG'):
|
| 75 |
+
"""
|
| 76 |
+
Convert a PIL Image to base64 encoded string
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
img: PIL Image object
|
| 80 |
+
format: Image format (default: 'PNG')
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
str: Base64 encoded image string
|
| 84 |
+
"""
|
| 85 |
+
buffered = io.BytesIO()
|
| 86 |
+
img.save(buffered, format=format)
|
| 87 |
+
img_str = base64.b64encode(buffered.getvalue())
|
| 88 |
+
return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _crop_transparent_edges(img):
|
| 92 |
+
"""
|
| 93 |
+
Crop transparent edges from a PIL Image
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
img: PIL Image object (should be RGBA)
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
PIL Image: Cropped image
|
| 100 |
+
"""
|
| 101 |
+
# Convert to RGBA if not already
|
| 102 |
+
if img.mode != 'RGBA':
|
| 103 |
+
img = img.convert('RGBA')
|
| 104 |
+
|
| 105 |
+
# Get the bounding box of non-transparent pixels
|
| 106 |
+
bbox = img.getbbox()
|
| 107 |
+
if bbox:
|
| 108 |
+
# Crop the image to remove transparent space
|
| 109 |
+
img = img.crop(bbox)
|
| 110 |
+
|
| 111 |
+
return img
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
|
| 115 |
+
"""
|
| 116 |
+
Convert SMILES string to HTML image for display in Gradio dataframe
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
smiles: SMILES string representation of molecule
|
| 120 |
+
img_size: Size of the output image (default: SMILES_IMG_SIZE)
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
str: HTML img tag with base64 encoded image
|
| 124 |
"""
|
| 125 |
try:
|
| 126 |
+
# Parse SMILES to RDKit molecule
|
| 127 |
mol = Chem.MolFromSmiles(smiles)
|
| 128 |
if mol is None:
|
| 129 |
return f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
|
| 130 |
|
| 131 |
+
# Create PNG drawing with Cairo backend for better control
|
| 132 |
d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
|
| 133 |
opts = d2d.drawOptions()
|
| 134 |
opts.clearBackground = False
|
| 135 |
opts.padding = 0.05 # Minimal padding
|
| 136 |
opts.bondLineWidth = 2.0 # Make bonds more visible
|
| 137 |
+
|
| 138 |
+
# Draw the molecule
|
| 139 |
d2d.DrawMolecule(mol)
|
| 140 |
d2d.FinishDrawing()
|
| 141 |
|
| 142 |
+
# Get PNG data and convert to PIL Image
|
| 143 |
png_data = d2d.GetDrawingText()
|
|
|
|
|
|
|
| 144 |
img = Image.open(io.BytesIO(png_data))
|
| 145 |
|
| 146 |
+
# Crop transparent edges and convert to base64
|
| 147 |
+
img = _crop_transparent_edges(img)
|
| 148 |
+
img_str = _convert_pil_to_base64(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
|
| 151 |
+
|
| 152 |
except Exception as e:
|
| 153 |
return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
|
| 154 |
|
| 155 |
|
| 156 |
+
def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE):
|
| 157 |
"""
|
| 158 |
+
Convert spectrum plot to HTML image for display in Gradio dataframe
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
spec1: First spectrum data
|
| 162 |
+
spec2: Second spectrum data (for mirror plot)
|
| 163 |
+
img_size: Size of the output image (default: SPECTRUM_IMG_SIZE)
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
str: HTML img tag with base64 encoded spectrum plot
|
| 167 |
"""
|
| 168 |
try:
|
| 169 |
+
# Use non-interactive matplotlib backend
|
| 170 |
+
matplotlib.use('Agg')
|
| 171 |
|
| 172 |
+
# Create the spectrum plot using DreaMS utility function
|
| 173 |
su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(2, 1))
|
| 174 |
|
| 175 |
+
# Save figure to buffer with transparent background
|
| 176 |
buffered = BytesIO()
|
| 177 |
plt.savefig(buffered, format='png', bbox_inches='tight', dpi=100, transparent=True)
|
| 178 |
buffered.seek(0)
|
| 179 |
|
| 180 |
+
# Convert to PIL Image, crop edges, and convert to base64
|
| 181 |
img = Image.open(buffered)
|
| 182 |
+
img = _crop_transparent_edges(img)
|
| 183 |
+
img_str = _convert_pil_to_base64(img)
|
| 184 |
|
| 185 |
+
# Clean up matplotlib figure to free memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
plt.close()
|
| 187 |
|
| 188 |
return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='Spectrum comparison' />"
|
| 189 |
+
|
| 190 |
except Exception as e:
|
| 191 |
return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
|
| 192 |
|
| 193 |
|
| 194 |
+
# =============================================================================
|
| 195 |
+
# DATA DOWNLOAD AND SETUP FUNCTIONS
|
| 196 |
+
# =============================================================================
|
| 197 |
+
|
| 198 |
+
def _download_file(url, target_path, description):
|
| 199 |
+
"""
|
| 200 |
+
Download a file from URL if it doesn't exist
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
url: Source URL
|
| 204 |
+
target_path: Target file path
|
| 205 |
+
description: Description for logging
|
| 206 |
+
"""
|
| 207 |
if not target_path.exists():
|
| 208 |
+
print(f"Downloading {description}...")
|
| 209 |
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
| 210 |
urllib.request.urlretrieve(url, target_path)
|
| 211 |
+
print(f"Downloaded {description} to {target_path}")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def setup():
|
| 215 |
+
"""
|
| 216 |
+
Initialize the application by downloading required data files
|
| 217 |
+
|
| 218 |
+
Downloads:
|
| 219 |
+
- MassSpecGym spectral library
|
| 220 |
+
- Example MS/MS files for testing
|
| 221 |
+
|
| 222 |
+
Raises:
|
| 223 |
+
Exception: If critical setup steps fail
|
| 224 |
+
"""
|
| 225 |
+
print("=" * 60)
|
| 226 |
+
print("Setting up DreaMS application...")
|
| 227 |
+
print("=" * 60)
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
# Download spectral library
|
| 231 |
+
library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
|
| 232 |
+
_download_file(library_url, LIBRARY_PATH, "MassSpecGym spectral library")
|
| 233 |
+
|
| 234 |
+
# Download example files
|
| 235 |
+
example_urls = [
|
| 236 |
+
('https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/example_piper_2k_spectra.mgf',
|
| 237 |
+
EXAMPLE_PATH / 'example_piper_2k_spectra.mgf',
|
| 238 |
+
"PiperNET example spectra"),
|
| 239 |
+
('https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/data/examples/example_5_spectra.mgf',
|
| 240 |
+
EXAMPLE_PATH / 'example_5_spectra.mgf',
|
| 241 |
+
"DreaMS example spectra")
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
for url, path, desc in example_urls:
|
| 245 |
+
_download_file(url, path, desc)
|
| 246 |
+
|
| 247 |
+
# Test DreaMS embeddings to ensure everything works
|
| 248 |
+
print("\nTesting DreaMS embeddings...")
|
| 249 |
+
test_path = EXAMPLE_PATH / 'example_5_spectra.mgf'
|
| 250 |
+
embs = dreams_embeddings(test_path)
|
| 251 |
+
print(f"✓ Setup complete - DreaMS embeddings test successful (shape: {embs.shape})")
|
| 252 |
+
print("=" * 60)
|
| 253 |
+
|
| 254 |
+
except Exception as e:
|
| 255 |
+
print(f"✗ Setup failed: {e}")
|
| 256 |
+
print("The application may not work properly. Please check your internet connection and try again.")
|
| 257 |
+
raise
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
# =============================================================================
|
| 261 |
+
# CORE PREDICTION FUNCTIONS
|
| 262 |
+
# =============================================================================
|
| 263 |
|
| 264 |
@spaces.GPU
|
| 265 |
def _predict_gpu(in_pth, progress):
|
| 266 |
+
"""
|
| 267 |
+
GPU-accelerated prediction of DreaMS embeddings
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
in_pth: Input file path
|
| 271 |
+
progress: Gradio progress tracker
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
numpy.ndarray: DreaMS embeddings
|
| 275 |
+
"""
|
| 276 |
progress(0.1, desc="Loading spectra data...")
|
| 277 |
msdata = MSData.load(in_pth)
|
| 278 |
+
|
| 279 |
progress(0.2, desc="Computing DreaMS embeddings...")
|
| 280 |
embs = dreams_embeddings(msdata)
|
| 281 |
+
print(f'Shape of the query embeddings: {embs.shape}')
|
| 282 |
+
|
| 283 |
return embs
|
| 284 |
|
| 285 |
|
| 286 |
+
def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
|
| 287 |
+
"""
|
| 288 |
+
Create a single result row for the DataFrame
|
|
|
|
| 289 |
|
| 290 |
+
Args:
|
| 291 |
+
i: Query spectrum index
|
| 292 |
+
j: Library spectrum index
|
| 293 |
+
n: Top-k rank
|
| 294 |
+
msdata: Query MS data
|
| 295 |
+
msdata_lib: Library MS data
|
| 296 |
+
sims: Similarity matrix
|
| 297 |
+
cos_sim: Cosine similarity calculator
|
| 298 |
+
embs: Query embeddings
|
| 299 |
|
| 300 |
+
Returns:
|
| 301 |
+
dict: Result row data
|
| 302 |
+
"""
|
| 303 |
+
smiles = msdata_lib.get_smiles(j)
|
| 304 |
+
spec1 = msdata.get_spectra(i)
|
| 305 |
+
spec2 = msdata_lib.get_spectra(j)
|
| 306 |
+
|
| 307 |
+
return {
|
| 308 |
+
'feature_id': i + 1,
|
| 309 |
+
'precursor_mz': msdata.get_prec_mzs(i),
|
| 310 |
+
'topk': n + 1,
|
| 311 |
+
'library_j': j,
|
| 312 |
+
'library_SMILES': smiles_to_html_img(smiles),
|
| 313 |
+
'library_SMILES_raw': smiles,
|
| 314 |
+
'Spectrum': spectrum_to_html_img(spec1, spec2),
|
| 315 |
+
'Spectrum_raw': su.unpad_peak_list(spec1),
|
| 316 |
+
'library_ID': msdata_lib.get_values('IDENTIFIER', j),
|
| 317 |
+
'DreaMS_similarity': sims[i, j],
|
| 318 |
+
'Modified_cosine_similarity': cos_sim(
|
| 319 |
+
spec1=spec1,
|
| 320 |
+
prec_mz1=msdata.get_prec_mzs(i),
|
| 321 |
+
spec2=spec2,
|
| 322 |
+
prec_mz2=msdata_lib.get_prec_mzs(j),
|
| 323 |
+
),
|
| 324 |
+
'i': i,
|
| 325 |
+
'j': j,
|
| 326 |
+
'DreaMS_embedding': embs[i],
|
| 327 |
+
}
|
| 328 |
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
def _process_results_dataframe(df, in_pth):
|
| 331 |
+
"""
|
| 332 |
+
Process and clean the results DataFrame
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
Args:
|
| 335 |
+
df: Raw results DataFrame
|
| 336 |
+
in_pth: Input file path for CSV export
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
tuple: (processed_df, csv_path)
|
| 340 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
# Sort hits by DreaMS similarity
|
| 342 |
df_top1 = df[df['topk'] == 1].sort_values('DreaMS_similarity', ascending=False)
|
| 343 |
df = df.set_index('feature_id').loc[df_top1['feature_id'].values].reset_index()
|
| 344 |
+
|
|
|
|
| 345 |
# Remove unnecessary columns and round similarity scores
|
| 346 |
df = df.drop(columns=['i', 'j', 'library_j'])
|
| 347 |
df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
|
| 348 |
df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4)
|
| 349 |
df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
|
| 350 |
+
|
| 351 |
+
# Rename columns for display
|
| 352 |
+
column_mapping = {
|
| 353 |
'topk': 'Top k',
|
| 354 |
'library_ID': 'Library ID',
|
| 355 |
"feature_id": "Feature ID",
|
| 356 |
"precursor_mz": "Precursor m/z",
|
|
|
|
| 357 |
"library_SMILES": "Molecule",
|
| 358 |
"library_SMILES_raw": "SMILES",
|
| 359 |
"Spectrum": "Spectrum",
|
|
|
|
| 361 |
"DreaMS_similarity": "DreaMS similarity",
|
| 362 |
"Modified_cosine_similarity": "Modified cos similarity",
|
| 363 |
"DreaMS_embedding": "DreaMS embedding",
|
| 364 |
+
}
|
| 365 |
+
df = df.rename(columns=column_mapping)
|
| 366 |
+
|
| 367 |
+
# Save full results to CSV
|
| 368 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 369 |
+
df_path = dio.append_to_stem(in_pth, f"MassSpecGym_hits_{timestamp}").with_suffix('.csv')
|
| 370 |
df_to_save = df.drop(columns=['Molecule', 'Spectrum', 'Top k'])
|
| 371 |
df_to_save.to_csv(df_path, index=False)
|
| 372 |
+
|
| 373 |
+
# Filter and prepare final display DataFrame
|
|
|
|
| 374 |
df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"])
|
| 375 |
df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False)
|
| 376 |
df = df.drop(columns=['Top k'])
|
| 377 |
+
df = df[df["DreaMS similarity"] >= SIMILARITY_THRESHOLD]
|
| 378 |
+
|
| 379 |
+
# Add row numbers
|
| 380 |
df.insert(0, 'Row', range(1, len(df) + 1))
|
| 381 |
|
|
|
|
|
|
|
| 382 |
return df, str(df_path)
|
| 383 |
|
| 384 |
|
| 385 |
+
def _predict_core(lib_pth, in_pth, progress):
|
| 386 |
+
"""
|
| 387 |
+
Core prediction function that orchestrates the entire prediction pipeline
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
lib_pth: Library file path
|
| 391 |
+
in_pth: Input file path
|
| 392 |
+
progress: Gradio progress tracker
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
tuple: (results_dataframe, csv_file_path)
|
| 396 |
+
"""
|
| 397 |
+
in_pth = Path(in_pth)
|
| 398 |
+
|
| 399 |
+
# Load library data
|
| 400 |
+
progress(0, desc="Loading library data...")
|
| 401 |
+
msdata_lib = MSData.load(lib_pth)
|
| 402 |
+
embs_lib = msdata_lib[DREAMS_EMBEDDING]
|
| 403 |
+
print(f'Shape of the library embeddings: {embs_lib.shape}')
|
| 404 |
+
|
| 405 |
+
# Get query embeddings
|
| 406 |
+
embs = _predict_gpu(in_pth, progress)
|
| 407 |
+
|
| 408 |
+
# Compute similarity matrix
|
| 409 |
+
progress(0.4, desc="Computing similarity matrix...")
|
| 410 |
+
sims = cosine_similarity(embs, embs_lib)
|
| 411 |
+
print(f'Shape of the similarity matrix: {sims.shape}')
|
| 412 |
+
|
| 413 |
+
# Get top-k candidates
|
| 414 |
+
k = 1
|
| 415 |
+
topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
|
| 416 |
+
|
| 417 |
+
# Load query data for processing
|
| 418 |
+
msdata = MSData.load(in_pth)
|
| 419 |
+
print(f'Available columns: {msdata.columns()}')
|
| 420 |
+
|
| 421 |
+
# Construct results DataFrame
|
| 422 |
+
progress(0.5, desc="Constructing results table...")
|
| 423 |
+
df = []
|
| 424 |
+
cos_sim = su.PeakListModifiedCosine()
|
| 425 |
+
total_spectra = len(topk_cands)
|
| 426 |
+
|
| 427 |
+
for i, topk in enumerate(topk_cands):
|
| 428 |
+
progress(0.5 + 0.4 * (i / total_spectra),
|
| 429 |
+
desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
|
| 430 |
+
|
| 431 |
+
for n, j in enumerate(topk):
|
| 432 |
+
row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs)
|
| 433 |
+
df.append(row_data)
|
| 434 |
+
|
| 435 |
+
df = pd.DataFrame(df)
|
| 436 |
+
|
| 437 |
+
# Process and clean results
|
| 438 |
+
progress(0.9, desc="Post-processing results...")
|
| 439 |
+
df, csv_path = _process_results_dataframe(df, in_pth)
|
| 440 |
+
|
| 441 |
+
progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
|
| 442 |
+
|
| 443 |
+
return df, csv_path
|
| 444 |
+
|
| 445 |
+
|
| 446 |
def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
|
| 447 |
+
"""
|
| 448 |
+
Main prediction function with error handling
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
lib_pth: Library file path
|
| 452 |
+
in_pth: Input file path
|
| 453 |
+
progress: Gradio progress tracker
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
tuple: (results_dataframe, csv_file_path)
|
| 457 |
+
|
| 458 |
+
Raises:
|
| 459 |
+
gr.Error: If prediction fails or input is invalid
|
| 460 |
+
"""
|
| 461 |
try:
|
| 462 |
+
# Validate input file
|
| 463 |
+
if not _validate_input_file(in_pth):
|
| 464 |
+
raise gr.Error("Invalid input file. Please provide a valid .mgf or .mzML file.")
|
| 465 |
+
|
| 466 |
+
# Check if library exists
|
| 467 |
+
if not Path(lib_pth).exists():
|
| 468 |
+
raise gr.Error("Spectral library not found. Please ensure the library file exists.")
|
| 469 |
+
|
| 470 |
return _predict_core(lib_pth, in_pth, progress)
|
| 471 |
+
|
| 472 |
+
except gr.Error:
|
| 473 |
+
# Re-raise Gradio errors as-is
|
| 474 |
+
raise
|
| 475 |
except Exception as e:
|
| 476 |
+
error_msg = str(e)
|
| 477 |
+
if "CUDA" in error_msg or "cuda" in error_msg:
|
| 478 |
+
error_msg = f"GPU/CUDA error: {error_msg}. The app is falling back to CPU mode."
|
| 479 |
+
elif "RuntimeError" in error_msg:
|
| 480 |
+
error_msg = f"Runtime error: {error_msg}. This may be due to memory or device issues."
|
| 481 |
+
else:
|
| 482 |
+
error_msg = f"Error: {error_msg}"
|
| 483 |
+
|
| 484 |
+
print(f"Prediction failed: {error_msg}")
|
| 485 |
+
raise gr.Error(error_msg)
|
| 486 |
|
|
|
|
|
|
|
| 487 |
|
| 488 |
+
# =============================================================================
|
| 489 |
+
# GRADIO INTERFACE SETUP
|
| 490 |
+
# =============================================================================
|
|
|
|
| 491 |
|
| 492 |
+
def _create_gradio_interface():
|
| 493 |
+
"""
|
| 494 |
+
Create and configure the Gradio interface
|
| 495 |
+
|
| 496 |
+
Returns:
|
| 497 |
+
gr.Blocks: Configured Gradio app
|
| 498 |
+
"""
|
| 499 |
+
# JavaScript for theme management
|
| 500 |
+
js_func = """
|
| 501 |
+
function refresh() {
|
| 502 |
+
const url = new URL(window.location);
|
| 503 |
+
if (url.searchParams.get('__theme') !== 'light') {
|
| 504 |
+
url.searchParams.set('__theme', 'light');
|
| 505 |
+
window.location.href = url.href;
|
| 506 |
+
}
|
| 507 |
}
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
# Create app with custom theme
|
| 511 |
+
app = gr.Blocks(
|
| 512 |
+
theme=gr.themes.Default(primary_hue="yellow", secondary_hue="pink"),
|
| 513 |
+
js=js_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
)
|
| 515 |
+
|
| 516 |
+
with app:
|
| 517 |
+
# Header and description
|
| 518 |
+
gr.Image("https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png",
|
| 519 |
+
label="DreaMS")
|
| 520 |
+
|
| 521 |
+
gr.Markdown(value="""
|
| 522 |
+
DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a transformer-based
|
| 523 |
+
neural network designed to interpret tandem mass spectrometry (MS/MS) data (<a href="https://www.nature.com/articles/s41587-025-02663-3">Bushuiev et al., Nature Biotechnology, 2025</a>).
|
| 524 |
+
This website provides an easy access to perform library matching with DreaMS. Please upload
|
| 525 |
+
your MS/MS file and click on the "Run DreaMS" button. Predictions may currently take up to 10 minutes for files with several thousands of spectra.
|
| 526 |
+
""")
|
| 527 |
+
|
| 528 |
+
# Input section
|
| 529 |
+
with gr.Row(equal_height=True):
|
| 530 |
+
in_pth = gr.File(
|
| 531 |
+
file_count="single",
|
| 532 |
+
label="Input MS/MS file (.mgf or .mzML)",
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# Example files
|
| 536 |
+
examples = gr.Examples(
|
| 537 |
+
examples=["./data/example_5_spectra.mgf", "./data/example_piper_2k_spectra.mgf"],
|
| 538 |
+
inputs=[in_pth],
|
| 539 |
+
label="Examples (click on a file to load as input)",
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Prediction button
|
| 543 |
+
predict_button = gr.Button(value="Run DreaMS", variant="primary")
|
| 544 |
+
|
| 545 |
+
# Output section
|
| 546 |
+
gr.Markdown("## Predictions")
|
| 547 |
+
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
|
| 548 |
+
|
| 549 |
+
# Results table
|
| 550 |
+
df = gr.Dataframe(
|
| 551 |
+
headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
|
| 552 |
+
"Library ID", "DreaMS similarity", "Modified cosine similarity"],
|
| 553 |
+
datatype=["number", "number", "number", "html", "html", "str", "number", "number"],
|
| 554 |
+
col_count=(8, "fixed"),
|
| 555 |
+
column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"],
|
| 556 |
+
max_height=1000,
|
| 557 |
+
show_fullscreen_button=True,
|
| 558 |
+
show_row_numbers=False,
|
| 559 |
+
show_search='filter',
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Connect prediction logic
|
| 563 |
+
inputs = [in_pth]
|
| 564 |
+
outputs = [df, df_file]
|
| 565 |
+
predict_func = partial(predict, LIBRARY_PATH)
|
| 566 |
+
predict_button.click(predict_func, inputs=inputs, outputs=outputs, show_progress="first")
|
| 567 |
+
|
| 568 |
+
return app
|
| 569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
+
# =============================================================================
|
| 572 |
+
# MAIN EXECUTION
|
| 573 |
+
# =============================================================================
|
| 574 |
|
| 575 |
+
if __name__ == "__main__":
|
| 576 |
+
# Initialize the application
|
| 577 |
+
setup()
|
| 578 |
+
|
| 579 |
+
# Create and launch the Gradio interface
|
| 580 |
+
app = _create_gradio_interface()
|
| 581 |
+
app.launch(allowed_paths=['./assets'])
|
| 582 |
+
else:
|
| 583 |
+
# When imported as a module, just run setup
|
| 584 |
+
setup()
|