Spaces:
Sleeping
Sleeping
Commit
·
326e019
1
Parent(s):
135849c
visuals added
Browse files- app.py +153 -66
- app_utils/__init__.py +0 -0
- app_utils/examples.py +6 -0
- app_utils/model_utils.py +31 -0
- app_utils/viz_utils.py +186 -0
app.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
st.set_page_config(page_title="Spectra Tool Demo", layout="wide")
|
| 6 |
|
|
@@ -8,68 +14,70 @@ st.title("Spectra Visualization Tool")
|
|
| 8 |
|
| 9 |
st.markdown("Provide inputs below or load one of the example datasets.")
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
"Example 1": {
|
| 16 |
-
"spectra": """mz,intensity
|
| 17 |
-
100,10
|
| 18 |
-
150,50
|
| 19 |
-
200,80
|
| 20 |
-
250,40
|
| 21 |
-
300,20
|
| 22 |
-
""",
|
| 23 |
-
"smiles": "CCO", # ethanol
|
| 24 |
-
"formula": "C2H6O",
|
| 25 |
-
"adduct": "[M+H]+",
|
| 26 |
-
},
|
| 27 |
-
"Example 2": {
|
| 28 |
-
"spectra": """mz,intensity
|
| 29 |
-
120,15
|
| 30 |
-
180,60
|
| 31 |
-
240,30
|
| 32 |
-
300,70
|
| 33 |
-
360,25
|
| 34 |
-
""",
|
| 35 |
-
"smiles": "C6H6", # benzene
|
| 36 |
-
"formula": "C6H6",
|
| 37 |
-
"adduct": "[M+Na]+",
|
| 38 |
-
},
|
| 39 |
-
}
|
| 40 |
|
| 41 |
# ------------------------
|
| 42 |
# Session state defaults
|
| 43 |
# ------------------------
|
| 44 |
-
if "
|
| 45 |
-
st.session_state.
|
| 46 |
-
if "
|
| 47 |
-
st.session_state.
|
| 48 |
-
if "
|
| 49 |
-
st.session_state.
|
| 50 |
-
|
| 51 |
-
st.session_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# ------------------------
|
| 54 |
-
# Example loader
|
| 55 |
# ------------------------
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# ------------------------
|
| 65 |
# Inputs
|
| 66 |
# ------------------------
|
| 67 |
-
st.subheader("Spectra
|
| 68 |
-
|
| 69 |
-
"
|
| 70 |
-
value=st.session_state.
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
st.subheader("SMILES")
|
|
@@ -81,22 +89,101 @@ formula_input = st.text_input("Enter molecular formula:", value=st.session_state
|
|
| 81 |
st.subheader("Adduct")
|
| 82 |
adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct)
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
if st.button("Run"):
|
| 88 |
-
st.subheader("Results")
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
st.info("🔬 Interactive visualization will be displayed here.")
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import io
|
| 4 |
+
from app_utils.model_utils import load_model_components
|
| 5 |
+
from app_utils.viz_utils import run
|
| 6 |
+
from app_utils.examples import EXAMPLES
|
| 7 |
+
import numpy as np
|
| 8 |
+
from streamlit_plotly_events import plotly_events
|
| 9 |
+
|
| 10 |
|
| 11 |
st.set_page_config(page_title="Spectra Tool Demo", layout="wide")
|
| 12 |
|
|
|
|
| 14 |
|
| 15 |
st.markdown("Provide inputs below or load one of the example datasets.")
|
| 16 |
|
| 17 |
+
FIELDS = ['mzs', 'intensities', 'smiles', 'formula', 'adduct', 'precursor_mz']
|
| 18 |
+
def reset_fields():
|
| 19 |
+
for field in FIELDS:
|
| 20 |
+
st.session_state[field] = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# ------------------------
|
| 23 |
# Session state defaults
|
| 24 |
# ------------------------
|
| 25 |
+
if "run_clicked" not in st.session_state:
|
| 26 |
+
st.session_state.run_clicked = False
|
| 27 |
+
if "selected_spectrum_idx" not in st.session_state:
|
| 28 |
+
st.session_state.selected_spectrum_idx = None
|
| 29 |
+
if "selected_node_idx" not in st.session_state:
|
| 30 |
+
st.session_state.selected_node_idx = None
|
| 31 |
+
for f in FIELDS:
|
| 32 |
+
if f not in st.session_state:
|
| 33 |
+
st.session_state[f] = ""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if "model" not in st.session_state:
|
| 37 |
+
spec_featurizer, mol_featurizer, model = load_model_components()
|
| 38 |
+
st.session_state.spec_featurizer = spec_featurizer
|
| 39 |
+
st.session_state.mol_featurizer = mol_featurizer
|
| 40 |
+
st.session_state.model = model
|
| 41 |
|
| 42 |
# ------------------------
|
| 43 |
+
# Example loader dropdown
|
| 44 |
# ------------------------
|
| 45 |
+
example_names = list(EXAMPLES.keys())
|
| 46 |
+
|
| 47 |
+
# Dropdown menu for selecting example
|
| 48 |
+
selected_example = st.selectbox("Choose an example:", ["-- Select --"] + example_names)
|
| 49 |
+
|
| 50 |
+
# Load button
|
| 51 |
+
if st.button("Load Example") and selected_example != "-- Select --":
|
| 52 |
+
|
| 53 |
+
reset_fields()
|
| 54 |
+
ex_data = EXAMPLES[selected_example]
|
| 55 |
+
st.session_state.mzs = ex_data["mzs"]
|
| 56 |
+
st.session_state.intensities = ex_data['intensities']
|
| 57 |
+
st.session_state.smiles = ex_data["smiles"]
|
| 58 |
+
st.session_state.formula = ex_data["formula"]
|
| 59 |
+
st.session_state.adduct = ex_data["adduct"]
|
| 60 |
+
st.session_state.precursor_mz = ex_data["precursor_mz"]
|
| 61 |
+
|
| 62 |
+
# reset graph
|
| 63 |
+
st.session_state.run_clicked = False
|
| 64 |
+
st.session_state.selected_spectrum_idx = None
|
| 65 |
+
st.session_state.selected_node_idx = None
|
| 66 |
|
| 67 |
# ------------------------
|
| 68 |
# Inputs
|
| 69 |
# ------------------------
|
| 70 |
+
st.subheader("Spectra")
|
| 71 |
+
mz_input = st.text_input(
|
| 72 |
+
"m/z values (comma-separated):",
|
| 73 |
+
value=st.session_state.mzs,
|
| 74 |
+
placeholder="100,150,200,250,300"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
intensity_input = st.text_input(
|
| 78 |
+
"Intensities (comma-separated):",
|
| 79 |
+
value=st.session_state.intensities,
|
| 80 |
+
placeholder="10,50,80,40,20"
|
| 81 |
)
|
| 82 |
|
| 83 |
st.subheader("SMILES")
|
|
|
|
| 89 |
st.subheader("Adduct")
|
| 90 |
adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct)
|
| 91 |
|
| 92 |
+
st.subheader("Precursor mz")
|
| 93 |
+
precursor_input = st.text_input("Enter precursor mz:", value=st.session_state.precursor_mz)
|
| 94 |
+
|
| 95 |
+
# --- Run button toggles flag ---
|
| 96 |
if st.button("Run"):
|
|
|
|
| 97 |
|
| 98 |
+
for f in FIELDS:
|
| 99 |
+
if not st.session_state[f]:
|
| 100 |
+
st.error(f"Field {f} is empty.")
|
| 101 |
+
reset_fields()
|
| 102 |
+
st.stop()
|
| 103 |
+
|
| 104 |
+
st.session_state.mzs = mz_input
|
| 105 |
+
st.session_state.intensities = intensity_input
|
| 106 |
+
st.session_state.smiles = smiles_input
|
| 107 |
+
st.session_state.formula = formula_input
|
| 108 |
+
st.session_state.adduct = adduct_input
|
| 109 |
+
st.session_state.precursor_mz = precursor_input
|
| 110 |
+
|
| 111 |
+
mz_input = [float(x) for x in st.session_state.mzs.split(",") if x.strip()]
|
| 112 |
+
intensity_input = [float(x) for x in st.session_state.intensities.split(",") if x.strip()]
|
| 113 |
+
|
| 114 |
+
if len(mz_input) != len(intensity_input):
|
| 115 |
+
st.error("Number of m/z values must match the number of intensty values")
|
| 116 |
+
reset_fields()
|
| 117 |
+
st.stop()
|
| 118 |
+
|
| 119 |
+
ms = np.array(list(zip(mz_input, intensity_input)))
|
| 120 |
+
|
| 121 |
+
st.session_state.fig, st.session_state.sim_norm = run(
|
| 122 |
+
ms,
|
| 123 |
+
st.session_state.smiles,
|
| 124 |
+
st.session_state.formula,
|
| 125 |
+
st.session_state.precursor_mz,
|
| 126 |
+
st.session_state.adduct,
|
| 127 |
+
st.session_state.spec_featurizer,
|
| 128 |
+
st.session_state.mol_featurizer,
|
| 129 |
+
st.session_state.model,
|
| 130 |
+
mass_diff_thresh=20,
|
| 131 |
+
precursor_intensity=1.1
|
| 132 |
+
)
|
| 133 |
+
st.session_state.selected_spectrum_idx = None
|
| 134 |
+
st.session_state.selected_node_idx = None
|
| 135 |
+
st.session_state.run_clicked = True
|
| 136 |
+
|
| 137 |
+
# --- Main results ---
|
| 138 |
+
if st.session_state.run_clicked:
|
| 139 |
+
st.subheader("Peak-to-Node Similarity")
|
| 140 |
+
st.text("Double click on a peak or node to visualize similarity scores")
|
| 141 |
+
|
| 142 |
+
fig = st.session_state.fig # get the figure
|
| 143 |
+
# Apply any coloring before rendering
|
| 144 |
+
if st.session_state.selected_spectrum_idx is not None:
|
| 145 |
+
idx = st.session_state.selected_spectrum_idx
|
| 146 |
+
scores = st.session_state.sim_norm[idx, :]
|
| 147 |
+
st.session_state.fig.data[2].marker.color = scores
|
| 148 |
+
st.session_state.fig.data[0].marker.color = [
|
| 149 |
+
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
|
| 150 |
+
]
|
| 151 |
+
elif st.session_state.selected_node_idx is not None:
|
| 152 |
+
idx = st.session_state.selected_node_idx
|
| 153 |
+
scores = st.session_state.sim_norm[:, idx]
|
| 154 |
+
st.session_state.fig.data[0].marker.color = scores
|
| 155 |
+
st.session_state.fig.data[2].marker.color = [
|
| 156 |
+
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
# Render figure once with plotly_events
|
| 160 |
+
selected = plotly_events(
|
| 161 |
+
st.session_state.fig,
|
| 162 |
+
click_event=True,
|
| 163 |
+
hover_event=False,
|
| 164 |
+
key="events"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Handle click and update figure immediately
|
| 168 |
+
if selected:
|
| 169 |
+
point = selected[0]
|
| 170 |
+
curve, idx = point["curveNumber"], point["pointIndex"]
|
| 171 |
+
|
| 172 |
+
if curve == 0: # Spectrum clicked
|
| 173 |
+
st.session_state.selected_spectrum_idx = idx
|
| 174 |
+
st.session_state.selected_node_idx = None
|
| 175 |
+
scores = st.session_state.sim_norm[idx, :]
|
| 176 |
+
st.session_state.fig.data[2].marker.color = scores
|
| 177 |
+
st.session_state.fig.data[0].marker.color = [
|
| 178 |
+
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
|
| 179 |
+
]
|
| 180 |
|
| 181 |
+
elif curve == 2: # Node clicked
|
| 182 |
+
st.session_state.selected_node_idx = idx
|
| 183 |
+
st.session_state.selected_spectrum_idx = None
|
| 184 |
+
scores = st.session_state.sim_norm[:, idx]
|
| 185 |
+
st.session_state.fig.data[0].marker.color = scores
|
| 186 |
+
st.session_state.fig.data[2].marker.color = [
|
| 187 |
+
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
|
| 188 |
+
]
|
| 189 |
|
|
|
app_utils/__init__.py
ADDED
|
File without changes
|
app_utils/examples.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
data = pd.read_csv("/data/yzhouc01/FILIP-MS/data/sample/data.tsv", sep='\t')
|
| 4 |
+
data = data[['identifier', 'mzs', 'intensities', 'smiles', 'formula', 'precursor_mz', 'adduct']]
|
| 5 |
+
|
| 6 |
+
EXAMPLES = data.set_index('identifier').to_dict('index')
|
app_utils/model_utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
|
| 3 |
+
sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
|
| 4 |
+
|
| 5 |
+
from rdkit import RDLogger
|
| 6 |
+
from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
|
| 7 |
+
from mvp.utils.models import get_model
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
# Suppress RDKit warnings and errors
|
| 12 |
+
lg = RDLogger.logger()
|
| 13 |
+
lg.setLevel(RDLogger.CRITICAL)
|
| 14 |
+
|
| 15 |
+
# Load model and data
|
| 16 |
+
|
| 17 |
+
def load_model_components():
|
| 18 |
+
param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'
|
| 19 |
+
with open(param_pth) as f:
|
| 20 |
+
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 21 |
+
|
| 22 |
+
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
|
| 23 |
+
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
|
| 24 |
+
|
| 25 |
+
# load model
|
| 26 |
+
|
| 27 |
+
checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt"
|
| 28 |
+
params['checkpoint_pth'] = checkpoint_pth
|
| 29 |
+
model = get_model(params['model'], params)
|
| 30 |
+
|
| 31 |
+
return spec_featurizer, mol_featurizer, model
|
app_utils/viz_utils.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
from plotly.subplots import make_subplots
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
from rdkit.Chem import rdDepictor
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
def mol_to_graph_coords(mol):
|
| 12 |
+
"""Return atom coordinates and bond list for a molecule."""
|
| 13 |
+
rdDepictor.Compute2DCoords(mol)
|
| 14 |
+
conf = mol.GetConformer()
|
| 15 |
+
coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())}
|
| 16 |
+
bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
|
| 17 |
+
return coords, bonds
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import plotly.graph_objects as go
|
| 22 |
+
from plotly.subplots import make_subplots
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def interactive_attention_visualization(
|
| 26 |
+
spectral_embeds,
|
| 27 |
+
graph_embeds,
|
| 28 |
+
peak_mzs,
|
| 29 |
+
peak_intensities,
|
| 30 |
+
peak_formulas,
|
| 31 |
+
mol
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Build base Plotly figure + similarity matrix for Streamlit interactivity.
|
| 35 |
+
- Streamlit will handle clicks & recoloring using sim_norm
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# --- Similarity matrix ---
|
| 39 |
+
spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1)
|
| 40 |
+
graph_embeds = F.normalize(graph_embeds, p=2, dim=-1)
|
| 41 |
+
|
| 42 |
+
similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy()
|
| 43 |
+
sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8)
|
| 44 |
+
|
| 45 |
+
num_peaks, num_nodes = similarity.shape
|
| 46 |
+
|
| 47 |
+
# --- Molecule graph ---
|
| 48 |
+
coords, bonds = mol_to_graph_coords(mol)
|
| 49 |
+
atom_labels = [a.GetSymbol() for a in mol.GetAtoms()]
|
| 50 |
+
atom_x = [coords[i].x for i in range(num_nodes)]
|
| 51 |
+
atom_y = [coords[i].y for i in range(num_nodes)]
|
| 52 |
+
|
| 53 |
+
# --- Spectrum trace ---
|
| 54 |
+
spectrum_trace = go.Scatter(
|
| 55 |
+
x=peak_mzs,
|
| 56 |
+
y=peak_intensities,
|
| 57 |
+
mode='markers', # crucial for clickable peaks
|
| 58 |
+
name="peak",
|
| 59 |
+
marker=dict(
|
| 60 |
+
size=12,
|
| 61 |
+
color="lightgray",
|
| 62 |
+
colorscale="Viridis",
|
| 63 |
+
cmin=0,
|
| 64 |
+
cmax=1,
|
| 65 |
+
colorbar=dict(title="Similarity", len=0.8, y=0.5),
|
| 66 |
+
),
|
| 67 |
+
hovertext=[f"{f} \n ({m:,.2f}, {i:.2})" for f, m, i in zip(peak_formulas, peak_mzs, peak_intensities)],
|
| 68 |
+
hoverinfo='text',
|
| 69 |
+
customdata=list(range(num_peaks)), # actual peak indices
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# --- Graph nodes ---
|
| 73 |
+
graph_nodes = go.Scatter(
|
| 74 |
+
x=atom_x,
|
| 75 |
+
y=atom_y,
|
| 76 |
+
mode="markers+text",
|
| 77 |
+
name="node",
|
| 78 |
+
text=atom_labels,
|
| 79 |
+
textposition="middle center",
|
| 80 |
+
marker=dict(
|
| 81 |
+
size=20,
|
| 82 |
+
color="lightgray",
|
| 83 |
+
colorscale="Viridis",
|
| 84 |
+
cmin=0,
|
| 85 |
+
cmax=1,
|
| 86 |
+
colorbar=dict(title="Similarity", len=0.8, y=0.5),
|
| 87 |
+
),
|
| 88 |
+
customdata=list(range((num_nodes+1))),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# --- Graph bonds ---
|
| 92 |
+
edge_x, edge_y = [], []
|
| 93 |
+
for i, j in bonds:
|
| 94 |
+
edge_x += [coords[i].x, coords[j].x, None]
|
| 95 |
+
edge_y += [coords[i].y, coords[j].y, None]
|
| 96 |
+
graph_edges = go.Scatter(
|
| 97 |
+
x=edge_x,
|
| 98 |
+
y=edge_y,
|
| 99 |
+
mode="lines",
|
| 100 |
+
line=dict(color="gray", width=2),
|
| 101 |
+
hoverinfo="none",
|
| 102 |
+
showlegend=False,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# --- Subplots ---
|
| 106 |
+
fig = make_subplots(
|
| 107 |
+
rows=1,
|
| 108 |
+
cols=2,
|
| 109 |
+
subplot_titles=("Spectrum", "Molecule"),
|
| 110 |
+
column_widths=[0.6, 0.4],
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
fig.add_trace(spectrum_trace, row=1, col=1)
|
| 114 |
+
fig.add_trace(graph_edges, row=1, col=2)
|
| 115 |
+
fig.add_trace(graph_nodes, row=1, col=2)
|
| 116 |
+
|
| 117 |
+
fig.update_xaxes(title="m/z", row=1, col=1)
|
| 118 |
+
fig.update_yaxes(title="Intensity", row=1, col=1)
|
| 119 |
+
fig.update_xaxes(visible=False, row=1, col=2)
|
| 120 |
+
fig.update_yaxes(visible=False, row=1, col=2)
|
| 121 |
+
|
| 122 |
+
fig.update_layout(showlegend=False)
|
| 123 |
+
|
| 124 |
+
return fig, sim_norm
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ------------------------
|
| 128 |
+
# Model set up
|
| 129 |
+
# ------------------------
|
| 130 |
+
|
| 131 |
+
from mvp.subformula_assign.utils.spectra_utils import assign_subforms
|
| 132 |
+
import matchms
|
| 133 |
+
|
| 134 |
+
def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1):
|
| 135 |
+
|
| 136 |
+
# step 1 - label peaks with formula, setup matchms spectrum
|
| 137 |
+
x = assign_subforms(formula, np.array(ms), adduct, mass_diff_thresh=mass_diff_thresh)
|
| 138 |
+
if x['output_tbl'] is None:
|
| 139 |
+
return None, None
|
| 140 |
+
|
| 141 |
+
formulas = np.array(x['output_tbl']['formula'])
|
| 142 |
+
mzs = x['output_tbl']['mz']
|
| 143 |
+
intensities = x['output_tbl']['ms2_inten']
|
| 144 |
+
mzs = np.array([float(m) for m in mzs])
|
| 145 |
+
intensities = np.array([float(i) for i in intensities])
|
| 146 |
+
|
| 147 |
+
# add precursor if not already present
|
| 148 |
+
if formula not in formulas:
|
| 149 |
+
mzs = np.concatenate([mzs, [float(precursor_mz)]])
|
| 150 |
+
formulas = np.concatenate([formulas, [formula]])
|
| 151 |
+
intensities = np.concatenate([intensities, [float(precursor_intensity)]])
|
| 152 |
+
else:
|
| 153 |
+
i = np.where(formulas==formula)[0]
|
| 154 |
+
intensities[i] = precursor_intensity
|
| 155 |
+
|
| 156 |
+
sorted_idx = np.argsort(mzs)
|
| 157 |
+
mzs = mzs[sorted_idx]
|
| 158 |
+
intensities = intensities[sorted_idx]
|
| 159 |
+
formulas = formulas[sorted_idx]
|
| 160 |
+
|
| 161 |
+
spectrum = matchms.Spectrum(
|
| 162 |
+
mz = mzs,
|
| 163 |
+
intensities = intensities,
|
| 164 |
+
metadata = {'precursor_mz': precursor_mz, 'formulas': formulas}
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# step 2 - featurize spectra
|
| 168 |
+
spectrum_encoding = spec_featurizer['SpecFormula'](spectrum)
|
| 169 |
+
|
| 170 |
+
# step 3 - featuraize molecule
|
| 171 |
+
molecule_encoding = mol_featurizer(smiles)
|
| 172 |
+
|
| 173 |
+
# step 4 - Embed spectra & molecules
|
| 174 |
+
model_input = {'mol': molecule_encoding, 'SpecFormula': spectrum_encoding}
|
| 175 |
+
|
| 176 |
+
model = model.to(torch.device('cpu'))
|
| 177 |
+
model.eval()
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
spec_embed, mol_embed = model.forward(model_input, stage='test')
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# step 5 - visualization
|
| 183 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 184 |
+
fig, sim_norm = interactive_attention_visualization(spec_embed, mol_embed, mzs, intensities, formulas, mol)
|
| 185 |
+
|
| 186 |
+
return fig, sim_norm
|