CSU-EP / app.py
Tingxie's picture
Update app.py
8372297 verified
import torch
from model_finetune import CSUEP_finetune
from modular_csuep import CsuepConfig
from matchms.exporting import save_as_mgf
import numpy as np
import os
from tqdm import tqdm
from matchms.importing import load_from_mgf,load_from_msp
import matplotlib.pyplot as plt
import matchms.filtering as msfilters
import torch.nn.functional as F
from scipy.sparse import csr_matrix, save_npz,load_npz
import gradio as gr
from typing import Iterable
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
import time
import sqlite3
import hnswlib
import pickle
from rdkit.Chem import Draw
from rdkit import Chem
import base64
import csv
import simsimd
from huggingface_hub import hf_hub_download
from PIL import Image
with open("logo.jpg", "rb") as f:
img_base64 = base64.b64encode(f.read()).decode("utf-8")
with open("example_res.png", "rb") as f:
img_base64_2 = base64.b64encode(f.read()).decode("utf-8")
dataset_repo = "Tingxie/CSU-EP-DB"
token = os.getenv("HF_TOKEN")
db_path = 'csu-ep.db'
metabodb_path = 'metabodb.db'
db_paths = {
"CSU-EP-DB": db_path ,
"MetaboDB": metabodb_path
}
mass_index = {}
formula_index = {}
def preload_mass_index():
for name, filename_in_repo in db_paths.items():
print(f"Loading mass & formula index for {name}...")
downloaded_path = hf_hub_download(repo_id=dataset_repo, filename=filename_in_repo, repo_type="dataset", token=token)
conn = sqlite3.connect(downloaded_path)
cur = conn.cursor()
if "CSU" in name:
table = "CsuepDB"
else:
table = filename_in_repo.split('.')[0]
cur.execute(f"""
SELECT id, MonoisotopicMass, Formula
FROM {table}
ORDER BY id
""")
data = cur.fetchall()
mass_index[name] = {
"ids": np.array([r[0] for r in data]),
"masses": np.array([r[1] for r in data], dtype='float32')
}
formula_index[name] = {
"ids": np.array([r[0] for r in data]),
"formulas": np.array([r[2] for r in data], dtype=object)
}
conn.close()
preload_mass_index()
def load_simsimd_matrix(db_filename):
print(f"Loading {db_filename} for SimSIMD...")
path = hf_hub_download(repo_id=dataset_repo, filename=db_filename, repo_type="dataset", token=token)
temp_conn = sqlite3.connect(path)
cur = temp_conn.cursor()
table_name = db_filename.split('.')[0]
cur.execute(f"SELECT embedding FROM {table_name} ORDER BY id")
rows = cur.fetchall()
matrix = np.array([pickle.loads(r[0]) for r in rows]).astype('float32')
temp_conn.close()
return matrix
metabo_matrix = load_simsimd_matrix("metabodb.db")
model_path ="checkpoints/model.pth"
device = torch.device('cpu')
state_dict = torch.load(model_path,map_location=device)
model = CSUEP_finetune()
model.load_state_dict(state_dict)
model.to(device)
model.eval()
index_path = hf_hub_download(
repo_id="Tingxie/CSU-EP-DB",
filename="references_index.bin",
repo_type="dataset",
token=os.getenv("HF_TOKEN")
)
p = hnswlib.Index(space='l2', dim=768)
p.load_index(index_path)
logo = f"""
<center><img src="data:image/png;base64,{img_base64}"
style="width:350px; margin-bottom:2px"></center>
"""
title = r"""
<div style="font-size:50px; font-weight:bold;"><h1 align="center">CSU-EP: a framework to produce domain-invariant spectral embeddings for accurate compound identification</h1>
"""
class Seafoam(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.emerald,
secondary_hue: colors.Color | str = colors.blue,
neutral_hue: colors.Color | str = colors.blue,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("Quicksand"),
"ui-sans-serif",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
#body_background_fill="repeating-linear-gradient(45deg, *primary_200, *primary_200 10px, *primary_50 10px, *primary_50 20px)",
body_background_fill_dark="repeating-linear-gradient(45deg, *primary_800, *primary_800 10px, *primary_900 10px, *primary_900 20px)",
button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
button_primary_text_color="white",
button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
slider_color="*secondary_300",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_large_padding="17px",
body_text_color="#000000",
)
seafoam = Seafoam()
custom_css = """
<style>
.file-upload-height {
height:320px !important;
display: none;
}
.file-upload-height2 {
height:190px !important;
}
.gallery-height {
height: 350px !important;
}
#custom_plot {
height: 600px !important;
}
</style>
"""
def spectrum_processing(s):
"""This is how one would typically design a desired pre- and post-
processing pipeline."""
s = msfilters.normalize_intensities(s)
s = msfilters.select_by_mz(s, mz_from=0, mz_to=1000)
s = msfilters.select_by_intensity(s, intensity_from=0.001)
s = msfilters.require_minimum_number_of_peaks(s, n_required=2)
return s
def draw_mass_spectrum(peak_data_path):
ms = list(load_from_msp(peak_data_path.name))[0]
ms = spectrum_processing(ms)
Mz = np.array(ms.mz)
Intens = np.array(ms.intensities)
plt.figure(figsize=(8.5,5))
for i in range(len(Mz)):
plt.axvline(x=Mz[i], ymin=0, ymax=Intens[i],c='red')
plt.xlabel("m/z")
plt.ylabel("Intensity")
plt.title("Mass Spectrum")
return plt
def MS2Embedding(spectrum):
word_list = list(np.linspace(0,1001,1001,endpoint=False))
word_list = [str(i) for i in word_list]
word2idx = {'[PAD]':1002,'[MASK]':1003}
for i, w in enumerate(word_list):
word2idx[w] = i + 1
#spectrum=spectrum_processing(spectrum)
spec_mz = spectrum.mz
spec_intens = spectrum.intensities
input_ids = [word2idx[str(float(int(s)))] for s in spec_mz]
input_ids = np.array(input_ids)
attention_mask = np.ones_like(input_ids)
input_ids=torch.from_numpy(input_ids).long()
intensities=torch.from_numpy(spec_intens).float()
attention_mask=torch.from_numpy(attention_mask).long()
input_ids = input_ids.to(device)
intensities = intensities.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = model.text_encoder(
input_ids=input_ids.unsqueeze(0), # Add batch dimension
intensities=intensities.unsqueeze(0),
attention_mask=attention_mask.unsqueeze(0),
return_dict = True,
output_attentions=True
)
output_feats = outputs.last_hidden_state # shape: (B, L, D)
output_aggr_feats = model.pooler(output_feats,attention_mask)
output_aggr_feats = model.proj(output_aggr_feats)
spectrum_embeddings = F.normalize(output_aggr_feats, dim=1)
return spectrum_embeddings.detach().cpu().numpy()
def search_library(spectrum, db_choice="CSU-EP-DB", mass_filter_val=None,formula_filter_val=None):
query_emb = MS2Embedding(spectrum).astype('float32')
results = []
seen_inchikeys = set()
if db_choice == "CSU-EP-DB":
target_path, table_name = db_path, "CsuepDB"
else:
target_path, table_name = metabodb_path, "metabodb"
target_path = hf_hub_download(repo_id=dataset_repo, filename=target_path, repo_type="dataset", use_auth_token=token)
target_conn = sqlite3.connect(target_path, check_same_thread=False)
cur = target_conn.cursor()
candidate_ids = None
if mass_filter_val is not None and mass_filter_val > 0:
tolerance = 0.01
m_data = mass_index[db_choice]
mask = (m_data["masses"] >= mass_filter_val - tolerance) & \
(m_data["masses"] <= mass_filter_val + tolerance)
candidate_ids = m_data["ids"][mask]
if formula_filter_val is not None and formula_filter_val.strip() != "":
f_data = formula_index[db_choice]
f_mask = f_data["formulas"] == formula_filter_val
formula_ids = f_data["ids"][f_mask]
candidate_ids = (
formula_ids if candidate_ids is None
else np.intersect1d(candidate_ids, formula_ids)
)
if candidate_ids is not None:
if len(candidate_ids) == 0:
target_conn.close()
return []
placeholders = ",".join(["?"] * len(candidate_ids))
cur.execute(
f"""
SELECT id, embedding, SMILES, Formula, InChIKey, ShortInChIKey,
MonoisotopicMass, PredictedSpectrum
FROM {table_name}
WHERE id IN ({placeholders})
""",
[int(i) for i in candidate_ids]
)
rows = cur.fetchall()
filtered_ids_db, filtered_embeddings, metadata_map = [], [], {}
for r in rows:
row_id, emb_blob, smiles, formula, inchikey, short_key, mass, spec_blob = r
filtered_ids_db.append(row_id)
filtered_embeddings.append(pickle.loads(emb_blob))
metadata_map[row_id] = (smiles, formula, inchikey, short_key, mass, spec_blob)
filtered_matrix = np.array(filtered_embeddings).astype('float32')
dist_tensor = simsimd.cdist(query_emb, filtered_matrix, metric="cosine", threads=0)
dists = np.array(dist_tensor, copy=False).flatten()
for i in np.argsort(dists)[:100]:
row_id = filtered_ids_db[i]
smiles, formula, inchikey, short_key, mass, spec_blob = metadata_map[row_id]
if inchikey in seen_inchikeys: continue
seen_inchikeys.add(inchikey)
results.append({
"SMILES": smiles,
"Formula": formula,
"InChIKey": inchikey,
"ShortInChIKey": short_key,
"Mass": round(mass, 4),
"Score": round(float(1 - dists[i]), 4),
"PredictedSpectrum": pickle.loads(spec_blob)
})
else:
if db_choice == "CSU-EP-DB":
labels, distances = p.knn_query(query_emb, k=500)
target_labels, target_distances, score_mode = labels[0], distances[0], "hnsw"
else:
matrix = metabo_matrix
dist_tensor = simsimd.cdist(query_emb, matrix, metric="cosine", threads=0)
all_dists = np.array(dist_tensor, copy=False).flatten()
target_labels = np.argpartition(all_dists, 500)[:500]
target_labels = target_labels[np.argsort(all_dists[target_labels])]
target_distances = all_dists[target_labels]
score_mode = "simsimd"
for idx, dist in zip(target_labels, target_distances):
cur.execute(f"SELECT SMILES, Formula, InChIKey, ShortInChIKey, MonoisotopicMass, PredictedSpectrum FROM {table_name} WHERE id=?;", (int(idx),))
row = cur.fetchone()
if not row or row[1] in seen_inchikeys: continue
seen_inchikeys.add(row[1])
cos_sim = 1 - dist/2 if score_mode == "hnsw" else 1 - dist
results.append({
"SMILES": row[0],
"Formula": row[1],
"InChIKey": row[2],
"ShortInChIKey": row[3],
"Mass": round(row[4], 4),
"Score": round(float(cos_sim), 4),
"PredictedSpectrum": pickle.loads(row[5])
})
target_conn.close()
results.sort(key=lambda x: x["Score"], reverse=True)
for n, item in enumerate(results, 1):
item["Rank"] = n
return results[:10]
def draw_molecule_gallery(results):
mol_images = []
for r in results:
mol = Chem.MolFromSmiles(r["SMILES"])
if mol:
img = Draw.MolToImage(mol, size=(250, 250))
caption = (
f"Rank {r['Rank']} | "
f"SMILES {r['SMILES']} | "
f"Formula: {r.get('Formula', 'N/A')} | "
f"Mass: {r['Mass']} | "
f"Score: {r['Score']}"
)
mol_images.append((img, caption))
return mol_images
def plot_comparison_spectrum(exp_spec, pred_spec, rank):
Mz_exp, Intens_exp = np.array(exp_spec.mz), np.array(exp_spec.intensities)
Mz_pred, Intens_pred = np.array(pred_spec["mz"]), np.array(pred_spec["intensity"])
plt.figure(figsize=(8, 6))
for i in range(len(Mz_exp)): plt.vlines(Mz_exp[i], 0, Intens_exp[i], colors="red")
for i in range(len(Mz_pred)): plt.vlines(Mz_pred[i], 0, -Intens_pred[i], colors="blue")
plt.xlabel("m/z"); plt.ylabel("Intensity"); plt.title(f"Comparison: Exp (Red) vs Pred (Blue) - Rank {rank}")
plt.axhline(0, color="black", linewidth=1); plt.tight_layout()
return plt
def visualize_from_batch(selected_idx, cache):
if cache is None: raise gr.Error("No batch data found.")
res, spec = cache[int(selected_idx)]
if res is None: raise gr.Error("Spectrum was skipped.")
return draw_molecule_gallery(res), plot_comparison_spectrum(spec, res[0]["PredictedSpectrum"], 1), res, spec
def run_pipeline(msp_file, db_choice, use_mass, mass_val,use_formula, formula_val):
spectrum = list(load_from_msp(msp_file.name))[0]
spectrum = spectrum_processing(spectrum)
m_val = mass_val if use_mass else None
f_val = formula_val if use_formula else None
results = search_library(spectrum, db_choice, mass_filter_val=m_val,formula_filter_val=f_val)
gallery_items = draw_molecule_gallery(results)
pred_plot = plot_comparison_spectrum(spectrum, results[0]["PredictedSpectrum"], 1)
return gallery_items, pred_plot, results, spectrum
def run_pipeline_from_spectrum(spectrum, db_choice, use_mass, mass_val,use_formula,formula_val):
spectrum = spectrum_processing(spectrum)
m_val = mass_val if use_mass else None
f_val = formula_val if use_formula else None
results = search_library(spectrum, db_choice, mass_filter_val=m_val,formula_filter_val=f_val)
gallery_items = draw_molecule_gallery(results)
pred_plot = plot_comparison_spectrum(spectrum, results[0]["PredictedSpectrum"], 1)
return gallery_items, pred_plot, results, spectrum
def update_plot(rank, results, spectrum):
if results is None or spectrum is None:
raise gr.Error("Please run retrieval first!")
if rank < 1 or rank > len(results):
raise gr.Error("Invalid rank number!")
selected_pred = results[int(rank)-1]["PredictedSpectrum"]
plt = plot_comparison_spectrum(spectrum, selected_pred, rank)
return plt
def batch_process(file_path, db_choice):
if file_path is None:
raise gr.Error("Please upload a .msp file.")
spectra = list(load_from_msp(file_path.name))
if len(spectra) > 100:
raise gr.Error("Batch processing limited to a maximum of 100 spectra.")
output_csv = "batch_retrieval_results.csv"
csv_data = [["Spectrum_Index", "Rank", "SMILES"]]
for idx, spec in enumerate(tqdm(spectra, desc="Batch Processing")):
processed_spec = spectrum_processing(spec)
if processed_spec is None:
continue
results = search_library(processed_spec, db_choice)[:]
for r in results:
csv_data.append([idx, r["Rank"], r["SMILES"]])
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerows(csv_data)
return output_csv
def batch_process_and_prepare_viz(file_path, db_choice, mass_file=None,formula_file=None):
if file_path is None: raise gr.Error("Please upload a .msp file.")
spectra = list(load_from_msp(file_path.name))
mass_list = [None] * len(spectra)
if mass_file is not None:
try:
with open(mass_file.name, 'r') as f:
mass_list = [float(line.strip()) for line in f if line.strip()]
if len(mass_list) != len(spectra):
raise gr.Error("Mismatch between .msp and .txt counts.")
except: raise gr.Error("Invalid mass .txt file.")
formula_list = [None] * len(spectra)
if formula_file is not None:
try:
with open(formula_file.name, 'r') as f:
formula_list = [line.strip() if line.strip() else None for line in f]
if len(formula_list) != len(spectra):
raise gr.Error("Mismatch between formula list and spectra count.")
except:
raise gr.Error("Invalid formula list file.")
csv_path = "batch_retrieval_results.csv"
all_data, csv_rows = [], [["Spectrum_Index", "Rank", "SMILES", "Mass_Filter", "Formula_Filter"]]
for idx, (s, m, fml) in enumerate(tqdm(zip(spectra, mass_list, formula_list), total=len(spectra), desc="Batch")):
processed = spectrum_processing(s)
if processed is None:
all_data.append((None, None)); continue
res = search_library(
processed,
db_choice,
mass_filter_val=m,
formula_filter_val=fml
)[:]
all_data.append((res, processed))
for r in res: csv_rows.append([idx, r["Rank"], r["SMILES"], m if m else "N/A",fml if fml is not None else "N/A"])
with open(csv_path, 'w', newline='', encoding='utf-8') as f: csv.writer(f).writerows(csv_rows)
return csv_path, gr.update(visible=True), gr.update(choices=[str(i) for i in range(len(spectra))], value="0"), all_data
def visualize_from_batch(selected_idx, cache):
if cache is None: raise gr.Error("No data found.")
res, spec = cache[int(selected_idx)]
return draw_molecule_gallery(res), plot_comparison_spectrum(spec, res[0]["PredictedSpectrum"], 1), res, spec
batch_instruction_des = r"""
❗️❗️❗️[<b>Important</b>] Batch Processing:<br>
1️⃣ Select <b>Batch Processing</b> in the retrieval tab.<br>
2️⃣ Choose the target <b>Database</b> (CSU-EP-DB, or MetaboDB).<br>
3️⃣ Upload a <b>.msp</b> file containing multiple spectra (max 100).<br>
4️⃣ (Optional) Enable <b>Monoisotopic Mass Filtering</b> and upload a <b>.txt</b> file (one exact mass per line).<br>
5️⃣ Click <b>Run Batch Identification</b> to process all spectra at once.<br>
6️⃣ Download the generated <b>CSV file</b> containing the top-10 candidates for each spectrum."""
retrieval_des1 = r"""
❗️❗️❗️[<b>Important</b>] How to use:<br>
1️⃣ Choose the target <b>Database</b> for matching.<br>
2️⃣ Upload an EI-MS spectrum in MSP format or manual input.<br>
3️⃣ (Optional) Enable <b>Monoisotopic Mass Filtering</b> and input the exact mass (e.g., 125.0477) to narrow down candidates.<br>
4️⃣ Click the <b>CSU-EP Retrieval</b> button to start matching.<br>
5️⃣ View the top-10 candidate retrieval results by clicking the candidate structure in the gallery; the corresponding predicted spectrum is displayed and compared with the experimental spectrum:"""
mass_filter_des = r"""
❗️❗️❗️[<b>Important</b>] Monoisotopic Mass Filtering:<br>
1️⃣ This function filters candidates within a tolerance of <b>±0.01 Da</b> of the input mass.<br>
2️⃣ <b>Single Mode</b>: Manually enter the mass in the input box.<br>
3️⃣ <b>Batch Mode</b>: Upload a .txt file. The number of mass values must match the number of spectra in the .msp file."""
description = r"""
<b>Official 🤗 interactive demo for the paper "CSU-EP: Contrastive Learning between Experimental and Predicted Electron Ionization Spectra for Efficient In-silico Library Matching"
<a title="Github" href="https://github.com/tingxiecsu/CSU-EP" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/Github-Repo-blue">
</a>
"""
contact = r"""
📧 **Contact**
<br>
If you have any questions, please feel free to reach me out at <b>212307003@csu.edu.cn</b>.
"""
des2 = r"""
❗️❗️❗️[<b>Important</b>] Usage tips: Users can perform EI-MS identification using the CSU-EP method by either uploading EI-MS spectra in .msp format or entering the spectral data manually.
"""
logopicture2 = f"""
<center><img src="data:image/png;base64,{img_base64_2}"
style="width:600px; margin-top:10px; margin-bottom:5px;"></center>
"""
example_spectrum = """
<pre style="line-height: 1.5; font-size: 17px;">
12.0 0.02702702702702703
13.0 0.044044044044044044
14.0 0.07807807807807808
15.0 1.0
19.0 0.02702702702702703
20.0 0.014014014014014014
26.0 0.022022022022022022
27.0 0.15815815815815815
28.0 0.7957957957957958
29.0 0.056056056056056056
33.0 0.043043043043043044
46.0 0.08408408408408409
47.0 0.036036036036036036
52.0 0.004004004004004004
66.0 0.055055055055055056
67.0 0.3833833833833834
68.0 0.005005005005005005
</pre>
"""
example_spectrum2 = """12.0 0.006006006006006006
13.0 0.01001001001001001
14.0 0.007007007007007007
15.0 1.0
16.0 0.005005005005005005
26.0 0.018018018018018018
27.0 0.0890890890890891
28.0 0.5685685685685685
29.0 0.19019019019019018
30.0 0.056056056056056056
31.0 0.004004004004004004
35.0 0.001001001001001001
36.0 0.02002002002002002
38.0 0.024024024024024024
39.0 0.022022022022022022
40.0 0.0960960960960961
41.0 0.06906906906906907
42.0 0.3763763763763764
43.0 0.2802802802802803
44.0 0.4874874874874875
45.0 0.013013013013013013
52.0 0.01001001001001001
53.0 0.05005005005005005
54.0 0.04904904904904905
55.0 0.1841841841841842
56.0 0.1941941941941942
57.0 0.021021021021021023
58.0 0.008008008008008008
66.0 0.007007007007007007
67.0 0.12512512512512514
68.0 0.005005005005005005
69.0 0.0980980980980981
70.0 0.03303303303303303
71.0 0.35235235235235235
72.0 0.008008008008008008
73.0 0.002002002002002002
79.0 0.016016016016016016
80.0 0.03303303303303303
81.0 0.04104104104104104
82.0 0.14814814814814814
83.0 0.03803803803803804
84.0 0.003003003003003003
85.0 0.04104104104104104
93.0 0.002002002002002002
95.0 0.001001001001001001
96.0 0.07307307307307308
97.0 0.006006006006006006
105.0 0.003003003003003003
106.0 0.04004004004004004
107.0 0.004004004004004004
108.0 0.04104104104104104
109.0 0.13213213213213212
110.0 0.003003003003003003
120.0 0.008008008008008008
122.0 0.01001001001001001
123.0 0.004004004004004004
125.0 0.38238238238238237
126.0 0.022022022022022022
133.0 0.031031031031031032
134.0 0.08608608608608609
135.0 0.03403403403403404
136.0 0.08508508508508508
137.0 0.002002002002002002
147.0 0.008008008008008008
149.0 0.007007007007007007
161.0 0.001001001001001001
162.0 0.03403403403403404
163.0 0.004004004004004004
164.0 0.03303303303303303
165.0 0.004004004004004004
174.0 0.002002002002002002
175.0 0.0890890890890891
176.0 0.008008008008008008
177.0 0.08808808808808809
178.0 0.008008008008008008
189.0 0.14614614614614616
190.0 0.013013013013013013
191.0 0.14214214214214213
192.0 0.012012012012012012
203.0 0.007007007007007007
204.0 0.18618618618618618
205.0 0.01901901901901902
206.0 0.18118118118118118
207.0 0.012012012012012012"""
with gr.Blocks(theme=seafoam, css=custom_css) as demo:
gr.HTML(logo)
gr.Markdown(title)
gr.Markdown(f"<div style='font-size:20px;'>{description}</div>")
#gr.Markdown('<div style="font-size:50px; font-weight:bold;">🔍 CSU-EP web server </div>')
with gr.Tabs():
with gr.TabItem("📄 Instructions"):
gr.Markdown(r"""
### Welcome to our online tool for unknown EI-MS annotation! 😊
<div style='font-size:20px;'> instructions are available for the functionality when users click on a tab below 👇.</div>
""")
with gr.Accordion("📄 Input Instruction"):
gr.Markdown(f"<div style='font-size:20px;'>{des2}</div>")
gr.Markdown("<div style='font-size:20px;'>EI-MS spectra can be uploaded in MSP format or manually as follow:</div>")
with gr.Accordion("EI-MS spectrum"):
gr.HTML(example_spectrum)
gr.Markdown("<div style='font-size:20px;'>We have provided an <b>example</b> of an EI-MS that can be directly used as input in the <b>Retrieval</b> tab.</div>")
with gr.Accordion("📄 Retrieval Instruction"):
gr.Markdown(f"<div style='font-size:20px;'>{retrieval_des1}</div>")
gr.HTML(logopicture2)
with gr.Accordion("📄 Batch Processing Instruction"):
gr.Markdown(f"<div style='font-size:20px;'>{batch_instruction_des}</div>")
with gr.Accordion("📄 Monoisotopic Mass Filtering Instruction"):
gr.Markdown(f"<div style='font-size:20px;'>{mass_filter_des}</div>")
gr.Markdown(f"<div style='font-size:20px;'>{contact}</div>")
with gr.TabItem("🔍 Retrieval"):
db_selector = gr.Radio(choices=["CSU-EP-DB", "MetaboDB"], value="CSU-EP-DB", label="Database")
input_method = gr.Radio(choices=["Upload", "Manual", "Batch Processing"], value="Upload", label="Input Method")
upload_row = gr.Row(visible=True)
with upload_row:
peak_data = gr.File(label="Upload MSP file")
gr.Examples(examples=[["test.msp"]], inputs=[peak_data])
manual_row = gr.Row(visible=False)
with manual_row:
manual_input = gr.Textbox(lines=8, label="m/z intensity list")
gr.Examples(examples=[[example_spectrum2]], inputs=[manual_input], label="Manual Input Example")
batch_row = gr.Row(visible=False)
with batch_row:
batch_file = gr.File(label="Upload Batch MSP (Max 100)")
gr.Examples(examples=[["test_batch.msp"]], inputs=[batch_file])
with gr.Group():
enable_mass_filter = gr.Checkbox(label="Enable Monoisotopic Mass Filtering", value=False)
single_mass_box = gr.Number(label="Exact Mass (e.g. 125.0477)", value=None, visible=False)
enable_formula_filter = gr.Checkbox(label="Enable Molecular Formula Filtering", value=False)
single_formula_box = gr.Textbox(
label="Molecular Formula (e.g. C8H10O2)",
placeholder="C8H10O2",
visible=False
)
with gr.Group(visible=False) as batch_mass_ui_group:
batch_mass_file = gr.File(label="Mass List (.txt)")
gr.Examples(examples=[["test_batch_mass.txt"]], inputs=[batch_mass_file])
with gr.Group(visible=False) as batch_formula_ui_group:
batch_formula_file = gr.File(label="Formula List (.txt)")
gr.Examples(examples=[["test_batch_formula.txt"]], inputs=[batch_formula_file])
lib_button = gr.Button("CSU-EP Retrieval", variant="primary")
with gr.Column(visible=False) as batch_execution_area:
batch_btn = gr.Button("Run Batch Identification", variant="primary")
batch_output = gr.File(label="Download CSV Results")
with gr.Row(visible=False) as batch_viz_ctrl:
batch_select = gr.Dropdown(label="Select Index", choices=[])
viz_btn = gr.Button("Visualize")
lib_gallery = gr.Gallery(columns=4, label='Top Candidates', elem_classes="gallery-height")
comparison_plot = gr.Plot(label="Comparison", elem_id="custom_plot")
results_state = gr.State(); spectrum_state = gr.State(); batch_data_cache = gr.State()
def update_mass_ui(method, mass_enabled, formula_enabled):
is_batch = (method == "Batch Processing")
return {
single_mass_box: gr.update(visible=(mass_enabled and not is_batch)),
batch_mass_ui_group: gr.update(visible=(mass_enabled and is_batch)),
single_formula_box: gr.update(visible=(formula_enabled and not is_batch)),
batch_formula_ui_group: gr.update(visible=(formula_enabled and is_batch))
}
input_method.change(
update_mass_ui,
[input_method, enable_mass_filter, enable_formula_filter],
[single_mass_box, batch_mass_ui_group, single_formula_box, batch_formula_ui_group]
)
enable_mass_filter.change(
update_mass_ui,
[input_method, enable_mass_filter, enable_formula_filter],
[single_mass_box, batch_mass_ui_group, single_formula_box, batch_formula_ui_group]
)
enable_formula_filter.change(
update_mass_ui,
[input_method, enable_mass_filter, enable_formula_filter],
[single_mass_box, batch_mass_ui_group, single_formula_box, batch_formula_ui_group]
)
def toggle_input(method):
is_batch = (method == "Batch Processing")
return {
upload_row: gr.update(visible=(method == "Upload")),
manual_row: gr.update(visible=(method == "Manual")),
batch_row: gr.update(visible=is_batch),
lib_button: gr.update(visible=not is_batch),
batch_execution_area: gr.update(visible=is_batch),
batch_viz_ctrl: gr.update(visible=False)
}
input_method.change(toggle_input, input_method, [upload_row, manual_row, batch_row, lib_button, batch_execution_area, batch_viz_ctrl])
def handle_input(upload_file, manual_text, db_choice, use_mass, mass_val,use_formula, formula_val, current_method):
from matchms import Spectrum
if current_method == "Upload":
spec = list(load_from_msp(upload_file.name))[0]
else:
lines = [ln.split() for ln in manual_text.strip().splitlines() if len(ln.split())==2]
spec = Spectrum(mz=np.array([float(l[0]) for l in lines]), intensities=np.array([float(l[1]) for l in lines]))
spec = spectrum_processing(spec)
m_val = mass_val if use_mass else None
f_val = formula_val if use_formula else None
results = search_library(
spec,
db_choice,
mass_filter_val=m_val,
formula_filter_val=f_val
)
return draw_molecule_gallery(results), plot_comparison_spectrum(spec, results[0]["PredictedSpectrum"], 1), results, spec
lib_button.click(
handle_input,
[
peak_data,
manual_input,
db_selector,
enable_mass_filter,
single_mass_box,
enable_formula_filter,
single_formula_box,
input_method
],
[lib_gallery, comparison_plot, results_state, spectrum_state]
)
batch_btn.click(batch_process_and_prepare_viz, [batch_file, db_selector, batch_mass_file, batch_formula_file], [batch_output, batch_viz_ctrl, batch_select, batch_data_cache])
viz_btn.click(visualize_from_batch, [batch_select, batch_data_cache], [lib_gallery, comparison_plot, results_state, spectrum_state])
def on_gallery_select(results, spectrum, evt: gr.SelectData):
if results is None or spectrum is None:
return None
selected_index = evt.index
selected_candidate = results[selected_index]
pred_spec = selected_candidate["PredictedSpectrum"]
return plot_comparison_spectrum(spectrum, pred_spec, selected_index + 1)
lib_gallery.select(
fn=on_gallery_select,
inputs=[results_state, spectrum_state],
outputs=comparison_plot
)
demo.launch()