DreaMS / app.py
roman-bushuiev's picture
Fix filtering of spectra
2a8b408
"""
DreaMS Gradio Web Application
This module provides a web interface for the DreaMS (Deep Representations Empowering
the Annotation of Mass Spectra) tool using Gradio. It allows users to upload MS/MS
files and perform library matching with DreaMS embeddings.
Author: DreaMS Team
License: MIT
"""
import base64
import io
import shutil
import threading
import urllib.request
from datetime import datetime
from io import BytesIO
from pathlib import Path
from textwrap import wrap
from typing import Any, Optional, Sequence, Tuple, Union
import gradio as gr
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import multiprocessing
import numpy as np
import os
from concurrent.futures import ProcessPoolExecutor
import pandas as pd
import spaces
from PIL import Image
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from tqdm import tqdm
import dreams.utils.io as dio
import dreams.utils.spectra as su
from dreams.api import DreaMSSearch, dreams_embeddings
from dreams.definitions import CHARGE, PRECURSOR_MZ, SPECTRUM, DREAMS_EMBEDDING, IONMODE
from dreams.utils.data import MSData
from dreams.utils.dformats import assign_dformat
# =============================================================================
# CONSTANTS AND CONFIGURATION
# =============================================================================
# Optimized image sizes for better performance
SMILES_IMG_SIZE = 200
SPECTRUM_IMG_SIZE = 800 # Reduced from 1500 for faster generation
SPECTRUM_FIGSIZE = (1.6, 0.8)
SPECTRUM_DPI = 70
# Supported input formats
SUPPORTED_INPUT_EXTENSIONS = {'.mgf', '.mzml', '.mzxml', '.hdf5'}
# Library and data paths
LIBRARY_PATH = Path("DreaMS/data/MassSpecGym_DreaMS.hdf5")
DATA_PATH = Path("./DreaMS/data")
EXAMPLE_PATH = Path("./data")
EXAMPLE_FILES: Tuple[Tuple[str, Path, str], ...] = (
(
"https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/Piper55-Leaf-r2_1uL_damiani2023.mzML",
EXAMPLE_PATH / "Piper55-Leaf-r2_1uL_damiani2023.mzML",
"PiperNET example spectra",
),
(
"https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/example_5_drugs_zhao2025.mgf",
EXAMPLE_PATH / "example_5_drugs_zhao2025.mgf",
"Drug analogs example spectra",
),
)
DATAFRAME_COLUMNS: Tuple[dict[str, str], ...] = (
{"name": "Row", "header": "Row", "datatype": "number", "width": "50px"},
{"name": "Scan number", "header": "Scan\nnumber", "datatype": "number", "width": "85px"},
{"name": "Precursor m/z", "header": "Precursor\nm/z", "datatype": "number", "width": "130px"},
# {"name": "Adduct", "header": "Adduct", "datatype": "str", "width": "150px"},
{"name": "RT", "header": "RT", "datatype": "number", "width": "60px"},
{"name": "Molecule", "header": "Molecule", "datatype": "html", "width": "150px"},
{"name": "Name", "header": "Name", "datatype": "str", "width": "120px"},
{"name": "Spectrum", "header": "Spectrum", "datatype": "html", "width": "150px"},
{"name": "Ref. precursor m/z", "header": "Ref. precursor\nm/z", "datatype": "html", "width": "130px"},
# {"name": "Ref. adduct", "header": "Ref. adduct", "datatype": "str", "width": "150px"},
{"name": "Ref. RT", "header": "Ref.\nRT", "datatype": "number", "width": "60px"},
{"name": "Ref. molecule", "header": "Ref. molecule", "datatype": "html", "width": "135px"},
{"name": "Ref. name", "header": "Ref. name", "datatype": "str", "width": "150px"},
{"name": "Ref. scan number", "header": "Ref. scan\nnumber", "datatype": "number", "width": "85px"},
{"name": "Ref. ID", "header": "Ref.\nID", "datatype": "str", "width": "130px"},
{"name": "DreaMS similarity", "header": "DreaMS\nsimilarity", "datatype": "number", "width": "110px"},
{"name": "Modified cos. sim.", "header": "Modified\ncos. sim.", "datatype": "number", "width": "140px"},
)
DATAFRAME_CSS = """
#results-dataframe {
overflow-x: auto;
}
#results-dataframe table th,
#results-dataframe table th * {
white-space: pre-line !important;
overflow-wrap: anywhere !important;
word-break: break-word !important;
text-overflow: clip;
}
"""
def _extract_dataframe_config(
existing_columns: Optional[Sequence[str]] = None,
) -> Tuple[list[str], list[str], list[str]]:
"""Build dataframe configuration filtered to the supplied column names."""
cols = DATAFRAME_COLUMNS
if existing_columns is not None:
cols = [col for col in DATAFRAME_COLUMNS if col["name"] in existing_columns]
headers = [col.get("header", col["name"]) for col in cols]
datatypes = [col["datatype"] for col in cols]
widths = [col["width"] for col in cols]
return headers, datatypes, widths
def _build_empty_results_dataframe() -> pd.DataFrame:
"""Return an empty dataframe that matches the display schema."""
return pd.DataFrame({col["name"]: pd.Series(dtype="object") for col in DATAFRAME_COLUMNS})
# Styling for analog hits indicator rendered alongside reference precursor m/z
_ANALOG_TAG_STYLE = (
"display:inline-block;padding:2px 8px;border-radius:999px;"
"background-color:#f25d64;color:#fff;font-size:12px;font-weight:600;"
"line-height:1;"
)
_REF_MZ_CONTAINER_STYLE = "display:inline-flex;align-items:center;gap:6px;"
_REF_MZ_VALUE_STYLE = "font-variant-numeric:tabular-nums;font-weight:500;"
# Cache for SMILES images to avoid regeneration
_smiles_cache = {}
_spectrum_lock = threading.Lock()
_spectrum_fig = plt.figure(figsize=SPECTRUM_FIGSIZE, dpi=SPECTRUM_DPI)
_spectrum_canvas = FigureCanvasAgg(_spectrum_fig)
_spectrum_ax = _spectrum_fig.add_subplot(111)
def clear_smiles_cache() -> None:
"""Clear the SMILES image cache to free memory"""
global _smiles_cache
_smiles_cache.clear()
print("SMILES image cache cleared")
# =============================================================================
# UTILITY FUNCTIONS FOR IMAGE CONVERSION
# =============================================================================
def _validate_input_file(file_path: Union[str, Path]) -> bool:
"""Return True when the user-supplied input file path is valid."""
if not file_path or not Path(file_path).exists():
return False
file_ext = Path(file_path).suffix.lower()
return file_ext in SUPPORTED_INPUT_EXTENSIONS
def _convert_pil_to_base64(img: Image.Image, format: str = 'PNG') -> str:
"""Convert a PIL Image to a base64-encoded string."""
buffered = io.BytesIO()
img.save(buffered, format=format, optimize=True) # Added optimize=True
img_str = base64.b64encode(buffered.getvalue())
return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}"
def _crop_transparent_edges(img: Image.Image) -> Image.Image:
"""Crop transparent edges from a PIL Image."""
# Convert to RGBA if not already
if img.mode != 'RGBA':
img = img.convert('RGBA')
# Get the bounding box of non-transparent pixels
bbox = img.getbbox()
if bbox:
# Crop the image to remove transparent space
img = img.crop(bbox)
return img
def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE):
"""
Convert SMILES string to HTML image for display in Gradio dataframe.
Uses caching to avoid regenerating the same molecule images.
Ensures the molecule is shown on a white background (not transparent), but does not change (expand) the molecule's bounding box.
Args:
smiles: SMILES string representation of molecule
img_size: Size of the output image (default: SMILES_IMG_SIZE)
Returns:
str: HTML img tag with base64 encoded image
"""
# Check cache first
cache_key = f"{smiles}_{img_size}"
if cache_key in _smiles_cache:
return _smiles_cache[cache_key]
try:
# Parse SMILES to RDKit molecule
mol = Chem.MolFromSmiles(smiles)
if mol is None:
result = f"<div style='text-align: center; color: red;'>Invalid SMILES</div>"
_smiles_cache[cache_key] = result
return result
# Create PNG drawing with Cairo backend for better control
d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size)
opts = d2d.drawOptions()
opts.clearBackground = False
opts.padding = 0.05 # Minimal padding
opts.bondLineWidth = 1.5 # Reduced from 2.0 for smaller images
# Draw the molecule
d2d.DrawMolecule(mol)
d2d.FinishDrawing()
# Get PNG data and convert to PIL Image (will have transparency)
png_data = d2d.GetDrawingText()
img = Image.open(io.BytesIO(png_data))
# Crop transparent edges FIRST (keeps molecule framing identical to original)
img = _crop_transparent_edges(img)
# Create white background of same size as cropped molecule image
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
bg = Image.new("RGBA", img.size, (255, 255, 255, 255))
bg.paste(img, mask=img.split()[-1]) # Use alpha channel as mask
img = bg.convert("RGB")
else:
img = img.convert("RGB")
img_str = _convert_pil_to_base64(img)
result = f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />"
# Cache the result
_smiles_cache[cache_key] = result
return result
except Exception as e:
result = f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
_smiles_cache[cache_key] = result
return result
def _format_ref_precursor_mz_value(value: Any, analog_hit: bool) -> str:
"""Return HTML snippet for ref precursor m/z with analog-hit tag always on a new line, and non-centered text for consistency with other columns."""
try:
numeric_value = float(value)
formatted_value = f"{numeric_value:.4f}".rstrip('0').rstrip('.')
if not formatted_value:
formatted_value = "0"
except (TypeError, ValueError):
formatted_value = str(value)
value_html = f"<div style='{_REF_MZ_VALUE_STYLE}; width: 100%; text-align: left;'>{formatted_value}</div>"
if analog_hit:
tag_html = (
f"<div style='width: 100%; text-align: left; margin-top: 8px;'>"
f"<span style='{_ANALOG_TAG_STYLE}'>Analog hit</span>"
f"</div>"
)
content = value_html + tag_html
else:
content = value_html
return (
f"<div style='{_REF_MZ_CONTAINER_STYLE}; display: block; width: 100%; text-align: left;'>"
f"{content}"
f"</div>"
)
class _SpectrumRenderFallback(Exception):
"""Internal sentinel used to fallback to slow rendering path."""
def _render_spectrum_image_fast(spec1: Any, spec2: Any) -> Image.Image:
"""Render a spectrum to a PIL image using a shared matplotlib figure."""
with _spectrum_lock:
_spectrum_ax.clear()
try:
su.plot_spectrum(
spec=spec1,
mirror_spec=spec2,
ax=_spectrum_ax,
figsize=SPECTRUM_FIGSIZE,
)
except TypeError as exc:
# Older versions of su.plot_spectrum may not accept an axis argument.
raise _SpectrumRenderFallback from exc
_spectrum_canvas.draw()
width, height = _spectrum_canvas.get_width_height()
# Copy the RGBA pixel buffer to avoid referencing mutable matplotlib memory
rgba = np.frombuffer(_spectrum_canvas.buffer_rgba(), dtype=np.uint8).copy()
rgba = rgba.reshape((height, width, 4))
return Image.fromarray(rgba, mode='RGBA')
def _render_spectrum_png_fallback(spec1: Any, spec2: Any) -> BytesIO:
"""Fallback rendering path that mirrors the legacy behaviour."""
buffer = BytesIO()
su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=SPECTRUM_FIGSIZE)
fig = plt.gcf()
canvas = fig.canvas
if not isinstance(canvas, FigureCanvasAgg):
canvas = FigureCanvasAgg(fig)
fig.set_canvas(canvas)
fig.savefig(
buffer,
format='png',
bbox_inches='tight',
dpi=SPECTRUM_DPI,
transparent=False,
)
plt.close(fig)
buffer.seek(0)
return buffer
def _spectrum_to_html_img_single(
spec1: Any,
spec2: Any,
img_size: int = SPECTRUM_IMG_SIZE,
) -> str:
"""Render a single spectrum pair to HTML."""
img_buffer: Optional[BytesIO] = None
pil_img: Optional[Image.Image] = None
try:
try:
pil_img = _render_spectrum_image_fast(spec1, spec2)
except _SpectrumRenderFallback:
img_buffer = _render_spectrum_png_fallback(spec1, spec2)
with Image.open(img_buffer) as fallback_img:
fallback_img.load()
pil_img = fallback_img.convert('RGBA')
processed_img = _crop_transparent_edges(pil_img)
try:
img_str = _convert_pil_to_base64(processed_img)
finally:
if processed_img is not pil_img:
processed_img.close()
return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='Spectrum comparison' />"
except Exception as e:
return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>"
finally:
if img_buffer is not None:
img_buffer.close()
if pil_img is not None:
pil_img.close()
def spectrum_to_html_img(
spec1: Any,
spec2: Any,
img_size: int = SPECTRUM_IMG_SIZE,
) -> str:
"""Convert a spectrum (and optional mirror spectrum) to an embeddable HTML image."""
return _spectrum_to_html_img_single(spec1, spec2, img_size)
def _spectrum_to_html_img_worker(args: Tuple[Any, Any]) -> str:
"""Worker entry-point for multiprocessing rendering."""
spec1, spec2 = args
return _spectrum_to_html_img_single(spec1, spec2)
def _render_spectra_parallel(pairs: Sequence[Tuple[Any, Any]]) -> list[str]:
"""Render spectra in parallel using a process pool, falling back to sequential rendering on failure."""
total = len(pairs)
if total == 0:
return []
cpu_count = os.cpu_count() or 1
max_workers = max(1, min(cpu_count - 1, 2))
print(f"Using {max_workers} workers for parallel spectrum rendering")
ctx = multiprocessing.get_context("spawn")
env_flag = "DREAMS_SKIP_SETUP_ON_IMPORT"
previous_flag = os.environ.get(env_flag)
os.environ[env_flag] = "1"
try:
with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as executor:
iterator = executor.map(_spectrum_to_html_img_worker, pairs, chunksize=5)
return [
result
for result in tqdm(iterator, total=total, desc="Painting spectra", leave=False)
]
except Exception as exc:
print(f"Parallel spectrum rendering failed ({exc}); falling back to sequential mode.")
return [
_spectrum_to_html_img_single(spec1, spec2)
for spec1, spec2 in tqdm(pairs, total=total, desc="Painting spectra (fallback)", leave=False)
]
finally:
if previous_flag is None:
os.environ.pop(env_flag, None)
else:
os.environ[env_flag] = previous_flag
# =============================================================================
# DATA DOWNLOAD AND SETUP FUNCTIONS
# =============================================================================
def _download_file(url: str, target_path: Path, description: str) -> None:
"""Download a file from URL if it does not already exist."""
if not target_path.exists():
print(f"Downloading {description}...")
target_path.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(url, target_path)
print(f"Downloaded {description} to {target_path}")
def setup() -> None:
"""Initialize the application by downloading required data files."""
print("=" * 60)
print("Setting up DreaMS application...")
print("=" * 60)
# Clear any existing cache
clear_smiles_cache()
try:
# Download spectral library
library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5'
_download_file(library_url, LIBRARY_PATH, "MassSpecGym spectral library")
# Download example files
for url, path, desc in EXAMPLE_FILES:
_download_file(url, path, desc)
# Test DreaMS embeddings to ensure everything works
print("\nTesting DreaMS embeddings...")
test_path = EXAMPLE_PATH / "example_5_drugs_zhao2025.mgf"
embs = dreams_embeddings(test_path)
print(f"✓ Setup complete - DreaMS embeddings test successful (shape: {embs.shape})")
print("=" * 60)
except Exception as e:
print(f"✗ Setup failed: {e}")
print("The application may not work properly. Please check your internet connection and try again.")
raise
# =============================================================================
# CORE PREDICTION FUNCTIONS
# =============================================================================
@spaces.GPU
def _predict_gpu(
msdata: MSData,
lib_msdata: MSData,
similarity_threshold: float,
progress: gr.Progress,
) -> pd.DataFrame:
"""Execute the search step on GPU (if available) and return raw matches."""
progress(0.3, desc="Initializing DreaMS search engine...")
searcher = DreaMSSearch(ref_spectra=lib_msdata)
progress(0.6, desc="Preparing input spectra for search...")
df = searcher.query(query_spectra=msdata, k=1, dreams_sim_thld=similarity_threshold, out_embs=True)
return df
def _rename_columns_for_display(df: pd.DataFrame) -> pd.DataFrame:
"""Apply human-friendly column names for presentation."""
columns = df.columns.tolist()
columns = [
c.replace('ref_', 'Ref._')
.replace('smiles', 'SMILES')
.replace('precursor_mz', 'precursor_m/z')
.replace('IDENTIFIER', 'ID')
# .replace('scan_number', 'feature_ID')
.replace('SMILES', 'molecule')
.replace('_', ' ')
for c in columns
]
def capitalize_first(s):
return s[0].upper() + s[1:] if s else s
columns = [capitalize_first(c) for c in columns]
df.columns = columns
return df
def _reformat_columns_for_display(df: pd.DataFrame) -> pd.DataFrame:
"""Format numeric columns for readability in the results table."""
for col in df.columns:
if col.endswith('mz'):
df[col] = df[col].astype(float).round(4)
elif col.endswith('rt'):
df[col] = df[col].astype(float).round(2)
elif col.endswith('similarity'):
df[col] = df[col].astype(float).round(4)
elif col.endswith('RT'):
df[col] = (df[col] / 60).round(2) # Seconds to minutes
elif col.endswith('Modified_cos._sim.'):
df[col] = df[col].astype(float).round(4)
return df
def _filter_input_data(
pth: Path,
only_single_charge: bool = True,
only_high_quality_spectra: bool = True,
) -> Path:
"""Return a filtered copy of the input MSData file according to quality filters."""
msdata = MSData.load(pth, in_mem=True)
print(f"Original number of rows in {pth.name}: {len(msdata)}")
idx = []
for i in tqdm(range(len(msdata)), desc=f"Filtering dataset {pth.name}"):
if only_single_charge:
# if IONMODE in msdata.columns():
# if msdata.get_values(IONMODE, idx=i) == '-':
# continue
if CHARGE in msdata.columns():
charge = msdata.get_values(CHARGE, idx=i)
if charge > 1 or charge < -1: # -1 if often used for unknown charge?
continue
if only_high_quality_spectra:
if assign_dformat(su.unpad_peak_list(msdata.get_values(SPECTRUM, i)), msdata.get_values(PRECURSOR_MZ, i)) != 'A':
continue
idx.append(i)
pth_filtered = pth.with_suffix('.filtered.hdf5')
msdata_filtered = msdata.form_subset(idx=idx, out_pth=pth_filtered)
print(f"Filtered number of rows in {pth_filtered.name}: {len(msdata_filtered)}")
if len(msdata_filtered) == 0:
raise ValueError(f"No spectra passed the quality filters. Please disable 'Only predict on high-quality input"
"spectra' or check your input file.")
return pth_filtered
def _predict_core(
lib_pth: Union[str, Path],
in_pth: Union[str, Path],
similarity_threshold: float,
calculate_modified_cosine: bool,
progress: gr.Progress,
only_high_quality_input: bool,
) -> Tuple[pd.DataFrame, Optional[str]]:
"""Coordinate the full library search pipeline for DreaMS predictions."""
in_pth = Path(in_pth)
lib_pth = Path(lib_pth)
# Clear cache at start to prevent memory buildup
clear_smiles_cache()
# Create temporary copies of library and input files to allow multiple processes
progress(0, desc="Creating temporary file copies...")
temp_lib_path = dio.append_to_stem(lib_pth, f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
temp_in_path = dio.append_to_stem(in_pth, f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
shutil.copy2(lib_pth, temp_lib_path)
shutil.copy2(in_pth, temp_in_path)
try:
temp_in_path = _filter_input_data(temp_in_path, only_single_charge=True, only_high_quality_spectra=only_high_quality_input)
# temp_lib_path = _filter_input_data(temp_lib_path)
df = _predict_gpu(temp_in_path, temp_lib_path, similarity_threshold, progress)
if df is None or (hasattr(df, "empty") and df.empty):
progress(1.0, desc="No matches found.")
return _build_empty_results_dataframe(), None
# Add modified cosine similarity only if enabled
if calculate_modified_cosine:
cos_sims = []
modified_cosine_sim = su.PeakListModifiedCosine()
for i in tqdm(range(len(df)), desc="Calculating modified cosine similarity"):
cos_sims.append(modified_cosine_sim(
spec1=df[SPECTRUM].iloc[i],
prec_mz1=df[PRECURSOR_MZ].iloc[i],
spec2=df[f'ref_{SPECTRUM}'].iloc[i],
prec_mz2=df[f'ref_{PRECURSOR_MZ}'].iloc[i],
))
df['Modified_cos._sim.'] = cos_sims
# Add row number for display
if 'Row' not in df.columns:
df.insert(0, 'Row', list(range(1, len(df) + 1)))
df['analog_hit'] = (df[PRECURSOR_MZ] - df[f'ref_{PRECURSOR_MZ}']).round(2).abs() >= 0.01
# Store results to CSV
progress(0.7, desc="Saving results to TSV...")
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
df_path = dio.append_to_stem(in_pth, f"{lib_pth.stem}_hits_{timestamp}").with_suffix('.tsv')
# Convert spectrum to lists before saving to TSV
df_to_save = df.copy()
for col in df_to_save.columns:
if col.endswith(SPECTRUM):
df_to_save[col] = df_to_save[col].apply(lambda x: su.unpad_peak_list(x).tolist())
df_to_save.to_csv(df_path, index=False, sep='\t')
for col in df_to_save.columns:
if col.endswith(IONMODE):
if '-' in df_to_save[col].tolist():
# Note: As of Gradio 3.x/4.x, gr.Warning does not natively support duration control through its API.
gr.Warning(
"Negative mode spectra found. Please note that the current version of DreaMS was "
"trained on positive mode spectra only. This may lead to unexpected results.",
duration=30
)
break
# Subsequent code is performed after saving to TSV, for display dataframe only
progress(0.85, desc="Painting molecules...")
for col in df.columns:
if col.endswith('smiles'):
rendered_smiles = []
for idx, smiles in tqdm(enumerate(df[col], start=1), desc="Painting molecules", total=len(df[col])):
rendered_smiles.append(smiles_to_html_img(smiles))
if idx % 100 == 0:
clear_smiles_cache()
df[col] = rendered_smiles
progress(0.9, desc="Painting spectra...")
spectrum_pairs = list(zip(df[SPECTRUM], df[f'ref_{SPECTRUM}']))
if len(spectrum_pairs) > 1000:
df[SPECTRUM] = _render_spectra_parallel(spectrum_pairs)
else:
df[SPECTRUM] = [
spectrum_to_html_img(query, ref)
for query, ref in tqdm(spectrum_pairs, desc="Painting spectra", total=len(spectrum_pairs))
]
print('Columns:')
print(df.columns)
df = _reformat_columns_for_display(df)
analog_column = df['analog_hit'] if 'analog_hit' in df.columns else None
analog_flags = (
analog_column.fillna(False).astype(bool).tolist()
if analog_column is not None
else [False] * len(df)
)
ref_precursor_values = df[f'ref_{PRECURSOR_MZ}'].tolist()
if len(analog_flags) < len(ref_precursor_values):
analog_flags.extend([False] * (len(ref_precursor_values) - len(analog_flags)))
elif len(analog_flags) > len(ref_precursor_values):
analog_flags = analog_flags[:len(ref_precursor_values)]
df[f'ref_{PRECURSOR_MZ}'] = [
_format_ref_precursor_mz_value(value, analog_flags[idx])
for idx, value in enumerate(ref_precursor_values)
]
if 'analog_hit' in df.columns:
df = df.drop(columns=['analog_hit'])
df = _rename_columns_for_display(df)
print('Renamed columns:')
print(df.columns)
df = df[[c['name'] for c in DATAFRAME_COLUMNS if c['name'] in df.columns]]
progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
return df, str(df_path)
finally:
# Clean up temporary files
if temp_lib_path.exists():
temp_lib_path.unlink()
if temp_in_path.exists():
temp_in_path.unlink()
def predict(
lib_pth: Union[str, Path],
in_pth: Union[str, Path],
similarity_threshold: float = 0.75,
calculate_modified_cosine: bool = False,
only_high_quality_input: bool = True,
progress: gr.Progress = gr.Progress(track_tqdm=True),
) -> Tuple[Any, Any]:
"""Main prediction entry point with user-facing error handling."""
try:
# Validate input file
if not _validate_input_file(in_pth):
raise gr.Error("Invalid input file. Please provide a valid .mgf, .mzML, .mzXML, or .hdf5 file.")
# Check if library exists
if not Path(lib_pth).exists():
raise gr.Error("Spectral library not found. Please ensure the library file exists.")
df_raw, csv_path = _predict_core(
lib_pth,
in_pth,
similarity_threshold,
calculate_modified_cosine,
progress,
only_high_quality_input,
)
headers, datatype, column_widths = _extract_dataframe_config(df_raw.columns)
df = gr.update(
value=df_raw,
headers=headers,
datatype=datatype,
column_widths=column_widths,
col_count=(len(headers), "fixed"),
)
if isinstance(df_raw, pd.DataFrame) and df_raw.empty:
gr.Info("No matches were found. Consider lowering the DreaMS similarity threshold for finding analog matches or checking your input file.")
if csv_path:
file_update = gr.update(value=csv_path, visible=True, interactive=False)
else:
file_update = gr.update(value=None, visible=False, interactive=False)
return df, file_update
except gr.Error:
# Re-raise Gradio errors as-is
raise
except Exception as e:
error_msg = str(e)
if "CUDA" in error_msg or "cuda" in error_msg:
error_msg = f"GPU/CUDA error: {error_msg}. The app is falling back to CPU mode."
elif "RuntimeError" in error_msg:
error_msg = f"Runtime error: {error_msg}. This may be due to memory or device issues."
else:
error_msg = f"Error: {error_msg}"
print(f"Prediction failed: {error_msg}")
raise gr.Error(error_msg)
# =============================================================================
# GRADIO INTERFACE SETUP
# =============================================================================
def _create_gradio_interface() -> gr.Blocks:
"""Create and configure the Gradio Blocks interface."""
js_func = """
() => {
const url = new URL(window.location.href);
if (url.searchParams.get("__theme") !== "light") {
url.searchParams.set("__theme", "light");
window.location.replace(url.toString());
}
}
"""
# Create app with custom theme
app = gr.Blocks(
theme=gr.themes.Default(primary_hue="yellow", secondary_hue="pink"),
js=js_func,
css=DATAFRAME_CSS,
)
with app:
# Header and description
gr.Image(
"https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png",
label="DreaMS"
)
gr.Markdown(
value=(
"DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a "
"transformer-based neural network designed to interpret tandem mass spectrometry (MS/MS) "
"data (<a href=\"https://www.nature.com/articles/s41587-025-02663-3\">Bushuiev et al., Nature Biotechnology, 2025</a>). "
"This website provides an easy access to perform spectral library searches to identify small molecules or their analogue candidates by querying "
"<a href=\"https://huggingface.co/datasets/roman-bushuiev/MassSpecGym\">MassSpecGym</a> spectral library (combination of GNPS, MoNA, and Pluskal lab data) "
"or custom MS/MS datasets. "
"Please upload your file with MS/MS data and click on the \"Run DreaMS\" button. In case of any issues, questions, or feedback, "
"please don't hesitate to open an issue on the <a href=\"https://github.com/pluskal-lab/DreaMS/issues\">DreaMS GitHub</a> page."
)
)
# Input section
with gr.Row(equal_height=True):
in_pth = gr.File(
file_count="single",
label="Input MS/MS file (.mgf, .mzML, .mzXML, .hdf5)",
)
# Example files
examples = gr.Examples(
examples=[
"./data/example_5_drugs_zhao2025.mgf",
"./data/Piper55-Leaf-r2_1uL_damiani2023.mzML"
],
inputs=[in_pth],
label="Examples (click on a file to load as input)",
)
# Settings section
with gr.Accordion("⚙️ Settings", open=False):
lib_pth = gr.File(
file_count="single",
label="Reference MS/MS file or spectral library (.mgf, .mzML, .mzXML, .hdf5)",
value=str(LIBRARY_PATH),
interactive=True,
visible=True,
)
similarity_threshold = gr.Slider(
minimum=0.5,
maximum=1.0,
value=0.8,
step=0.01,
label="Similarity threshold",
info=(
"Only display library matches with DreaMS similarity above this threshold "
"(rendering less results also makes calculation faster)"
),
)
calculate_modified_cosine = gr.Checkbox(
label="Calculate modified cosine similarity",
value=False,
info=(
"Enable to also calculate traditional modified cosine similarity scores between "
"the input spectra and library hits (a bit slower)"
),
)
only_high_quality_input = gr.Checkbox(
label="Only predict on high-quality input spectra",
value=True,
info=(
"Enable to exclude low-quality input spectra before prediction. MS/MS spectrum is considered "
"low-quality if it does not satisfy quality criteria \"A\" as defined in the DreaMS paper "
"(<a href='https://www.nature.com/articles/s41587-025-02663-3/figures/2'>Fig. 2b</a>)."
),
)
# Prediction button
predict_button = gr.Button(value="Run DreaMS", variant="primary")
# Results table
gr.Markdown("## Predictions")
df_file = gr.File(label="Download predictions as .tsv", interactive=False, visible=True)
headers, datatype, column_widths = _extract_dataframe_config()
df = gr.Dataframe(
headers=headers,
datatype=datatype,
col_count=(len(headers), "fixed"),
column_widths=column_widths,
max_height=1000,
show_row_numbers=False,
show_search='filter',
wrap=True,
interactive=False,
pinned_columns=1,
elem_id="results-dataframe"
)
# Connect prediction logic
inputs = [lib_pth, in_pth, similarity_threshold, calculate_modified_cosine, only_high_quality_input]
outputs = [df, df_file]
# Function to update dataframe headers based on setting
def update_headers(show_cosine):
if show_cosine:
return gr.update(headers=headers + ["Modified\ncosine similarity"],
col_count=(len(headers) + 1, "fixed"),
column_widths=column_widths + ["40px"])
else:
return gr.update(headers=headers,
col_count=(len(headers), "fixed"),
column_widths=column_widths)
# Update headers when setting changes
calculate_modified_cosine.change(
fn=update_headers,
inputs=[calculate_modified_cosine],
outputs=[df]
)
predict_button.click(predict, inputs=inputs, outputs=outputs, show_progress="first")
return app
# =============================================================================
# MAIN EXECUTION
# =============================================================================
if __name__ == "__main__":
# Initialize the application
setup()
# Create and launch the Gradio interface
app = _create_gradio_interface()
app.launch(allowed_paths=['./assets'])
else:
# When imported as a module, run setup unless explicitly skipped
if os.environ.get("DREAMS_SKIP_SETUP_ON_IMPORT") != "1":
setup()