yzhouchen001 commited on
Commit
326e019
·
1 Parent(s): 135849c

visuals added

Browse files
app.py CHANGED
@@ -1,6 +1,12 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import io
 
 
 
 
 
 
4
 
5
  st.set_page_config(page_title="Spectra Tool Demo", layout="wide")
6
 
@@ -8,68 +14,70 @@ st.title("Spectra Visualization Tool")
8
 
9
  st.markdown("Provide inputs below or load one of the example datasets.")
10
 
11
- # ------------------------
12
- # Example presets
13
- # ------------------------
14
- examples = {
15
- "Example 1": {
16
- "spectra": """mz,intensity
17
- 100,10
18
- 150,50
19
- 200,80
20
- 250,40
21
- 300,20
22
- """,
23
- "smiles": "CCO", # ethanol
24
- "formula": "C2H6O",
25
- "adduct": "[M+H]+",
26
- },
27
- "Example 2": {
28
- "spectra": """mz,intensity
29
- 120,15
30
- 180,60
31
- 240,30
32
- 300,70
33
- 360,25
34
- """,
35
- "smiles": "C6H6", # benzene
36
- "formula": "C6H6",
37
- "adduct": "[M+Na]+",
38
- },
39
- }
40
 
41
  # ------------------------
42
  # Session state defaults
43
  # ------------------------
44
- if "spectra" not in st.session_state:
45
- st.session_state.spectra = ""
46
- if "smiles" not in st.session_state:
47
- st.session_state.smiles = ""
48
- if "formula" not in st.session_state:
49
- st.session_state.formula = ""
50
- if "adduct" not in st.session_state:
51
- st.session_state.adduct = ""
 
 
 
 
 
 
 
 
52
 
53
  # ------------------------
54
- # Example loader buttons
55
  # ------------------------
56
- cols = st.columns(len(examples))
57
- for i, (ex_name, ex_data) in enumerate(examples.items()):
58
- if cols[i].button(f"Load {ex_name}"):
59
- st.session_state.spectra = ex_data["spectra"]
60
- st.session_state.smiles = ex_data["smiles"]
61
- st.session_state.formula = ex_data["formula"]
62
- st.session_state.adduct = ex_data["adduct"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # ------------------------
65
  # Inputs
66
  # ------------------------
67
- st.subheader("Spectra (two-column CSV: mz, intensity)")
68
- spectra_text = st.text_area(
69
- "Paste spectra data here:",
70
- value=st.session_state.spectra,
71
- height=150,
72
- placeholder="mz,intensity\n100,10\n150,50\n..."
 
 
 
 
 
73
  )
74
 
75
  st.subheader("SMILES")
@@ -81,22 +89,101 @@ formula_input = st.text_input("Enter molecular formula:", value=st.session_state
81
  st.subheader("Adduct")
82
  adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct)
83
 
84
- # ------------------------
85
- # Run button
86
- # ------------------------
 
87
  if st.button("Run"):
88
- st.subheader("Results")
89
 
90
- # Try parsing spectra
91
- try:
92
- spectra_df = pd.read_csv(io.StringIO(spectra_text))
93
- st.write("Spectra Preview:")
94
- st.dataframe(spectra_df.head())
95
- except Exception as e:
96
- st.error(f"Could not parse spectra: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- st.write("**SMILES:**", smiles_input)
99
- st.write("**Formula:**", formula_input)
100
- st.write("**Adduct:**", adduct_input)
 
 
 
 
 
101
 
102
- st.info("🔬 Interactive visualization will be displayed here.")
 
1
  import streamlit as st
2
  import pandas as pd
3
  import io
4
+ from app_utils.model_utils import load_model_components
5
+ from app_utils.viz_utils import run
6
+ from app_utils.examples import EXAMPLES
7
+ import numpy as np
8
+ from streamlit_plotly_events import plotly_events
9
+
10
 
11
  st.set_page_config(page_title="Spectra Tool Demo", layout="wide")
12
 
 
14
 
15
  st.markdown("Provide inputs below or load one of the example datasets.")
16
 
17
+ FIELDS = ['mzs', 'intensities', 'smiles', 'formula', 'adduct', 'precursor_mz']
18
+ def reset_fields():
19
+ for field in FIELDS:
20
+ st.session_state[field] = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ------------------------
23
  # Session state defaults
24
  # ------------------------
25
+ if "run_clicked" not in st.session_state:
26
+ st.session_state.run_clicked = False
27
+ if "selected_spectrum_idx" not in st.session_state:
28
+ st.session_state.selected_spectrum_idx = None
29
+ if "selected_node_idx" not in st.session_state:
30
+ st.session_state.selected_node_idx = None
31
+ for f in FIELDS:
32
+ if f not in st.session_state:
33
+ st.session_state[f] = ""
34
+
35
+
36
+ if "model" not in st.session_state:
37
+ spec_featurizer, mol_featurizer, model = load_model_components()
38
+ st.session_state.spec_featurizer = spec_featurizer
39
+ st.session_state.mol_featurizer = mol_featurizer
40
+ st.session_state.model = model
41
 
42
  # ------------------------
43
+ # Example loader dropdown
44
  # ------------------------
45
+ example_names = list(EXAMPLES.keys())
46
+
47
+ # Dropdown menu for selecting example
48
+ selected_example = st.selectbox("Choose an example:", ["-- Select --"] + example_names)
49
+
50
+ # Load button
51
+ if st.button("Load Example") and selected_example != "-- Select --":
52
+
53
+ reset_fields()
54
+ ex_data = EXAMPLES[selected_example]
55
+ st.session_state.mzs = ex_data["mzs"]
56
+ st.session_state.intensities = ex_data['intensities']
57
+ st.session_state.smiles = ex_data["smiles"]
58
+ st.session_state.formula = ex_data["formula"]
59
+ st.session_state.adduct = ex_data["adduct"]
60
+ st.session_state.precursor_mz = ex_data["precursor_mz"]
61
+
62
+ # reset graph
63
+ st.session_state.run_clicked = False
64
+ st.session_state.selected_spectrum_idx = None
65
+ st.session_state.selected_node_idx = None
66
 
67
  # ------------------------
68
  # Inputs
69
  # ------------------------
70
+ st.subheader("Spectra")
71
+ mz_input = st.text_input(
72
+ "m/z values (comma-separated):",
73
+ value=st.session_state.mzs,
74
+ placeholder="100,150,200,250,300"
75
+ )
76
+
77
+ intensity_input = st.text_input(
78
+ "Intensities (comma-separated):",
79
+ value=st.session_state.intensities,
80
+ placeholder="10,50,80,40,20"
81
  )
82
 
83
  st.subheader("SMILES")
 
89
  st.subheader("Adduct")
90
  adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct)
91
 
92
+ st.subheader("Precursor mz")
93
+ precursor_input = st.text_input("Enter precursor mz:", value=st.session_state.precursor_mz)
94
+
95
+ # --- Run button toggles flag ---
96
  if st.button("Run"):
 
97
 
98
+ for f in FIELDS:
99
+ if not st.session_state[f]:
100
+ st.error(f"Field {f} is empty.")
101
+ reset_fields()
102
+ st.stop()
103
+
104
+ st.session_state.mzs = mz_input
105
+ st.session_state.intensities = intensity_input
106
+ st.session_state.smiles = smiles_input
107
+ st.session_state.formula = formula_input
108
+ st.session_state.adduct = adduct_input
109
+ st.session_state.precursor_mz = precursor_input
110
+
111
+ mz_input = [float(x) for x in st.session_state.mzs.split(",") if x.strip()]
112
+ intensity_input = [float(x) for x in st.session_state.intensities.split(",") if x.strip()]
113
+
114
+ if len(mz_input) != len(intensity_input):
115
+ st.error("Number of m/z values must match the number of intensty values")
116
+ reset_fields()
117
+ st.stop()
118
+
119
+ ms = np.array(list(zip(mz_input, intensity_input)))
120
+
121
+ st.session_state.fig, st.session_state.sim_norm = run(
122
+ ms,
123
+ st.session_state.smiles,
124
+ st.session_state.formula,
125
+ st.session_state.precursor_mz,
126
+ st.session_state.adduct,
127
+ st.session_state.spec_featurizer,
128
+ st.session_state.mol_featurizer,
129
+ st.session_state.model,
130
+ mass_diff_thresh=20,
131
+ precursor_intensity=1.1
132
+ )
133
+ st.session_state.selected_spectrum_idx = None
134
+ st.session_state.selected_node_idx = None
135
+ st.session_state.run_clicked = True
136
+
137
+ # --- Main results ---
138
+ if st.session_state.run_clicked:
139
+ st.subheader("Peak-to-Node Similarity")
140
+ st.text("Double click on a peak or node to visualize similarity scores")
141
+
142
+ fig = st.session_state.fig # get the figure
143
+ # Apply any coloring before rendering
144
+ if st.session_state.selected_spectrum_idx is not None:
145
+ idx = st.session_state.selected_spectrum_idx
146
+ scores = st.session_state.sim_norm[idx, :]
147
+ st.session_state.fig.data[2].marker.color = scores
148
+ st.session_state.fig.data[0].marker.color = [
149
+ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
150
+ ]
151
+ elif st.session_state.selected_node_idx is not None:
152
+ idx = st.session_state.selected_node_idx
153
+ scores = st.session_state.sim_norm[:, idx]
154
+ st.session_state.fig.data[0].marker.color = scores
155
+ st.session_state.fig.data[2].marker.color = [
156
+ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
157
+ ]
158
+
159
+ # Render figure once with plotly_events
160
+ selected = plotly_events(
161
+ st.session_state.fig,
162
+ click_event=True,
163
+ hover_event=False,
164
+ key="events"
165
+ )
166
+
167
+ # Handle click and update figure immediately
168
+ if selected:
169
+ point = selected[0]
170
+ curve, idx = point["curveNumber"], point["pointIndex"]
171
+
172
+ if curve == 0: # Spectrum clicked
173
+ st.session_state.selected_spectrum_idx = idx
174
+ st.session_state.selected_node_idx = None
175
+ scores = st.session_state.sim_norm[idx, :]
176
+ st.session_state.fig.data[2].marker.color = scores
177
+ st.session_state.fig.data[0].marker.color = [
178
+ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
179
+ ]
180
 
181
+ elif curve == 2: # Node clicked
182
+ st.session_state.selected_node_idx = idx
183
+ st.session_state.selected_spectrum_idx = None
184
+ scores = st.session_state.sim_norm[:, idx]
185
+ st.session_state.fig.data[0].marker.color = scores
186
+ st.session_state.fig.data[2].marker.color = [
187
+ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
188
+ ]
189
 
 
app_utils/__init__.py ADDED
File without changes
app_utils/examples.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ data = pd.read_csv("/data/yzhouc01/FILIP-MS/data/sample/data.tsv", sep='\t')
4
+ data = data[['identifier', 'mzs', 'intensities', 'smiles', 'formula', 'precursor_mz', 'adduct']]
5
+
6
+ EXAMPLES = data.set_index('identifier').to_dict('index')
app_utils/model_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
3
+ sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
4
+
5
+ from rdkit import RDLogger
6
+ from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
7
+ from mvp.utils.models import get_model
8
+
9
+ import yaml
10
+
11
+ # Suppress RDKit warnings and errors
12
+ lg = RDLogger.logger()
13
+ lg.setLevel(RDLogger.CRITICAL)
14
+
15
+ # Load model and data
16
+
17
+ def load_model_components():
18
+ param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'
19
+ with open(param_pth) as f:
20
+ params = yaml.load(f, Loader=yaml.FullLoader)
21
+
22
+ spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
23
+ mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
24
+
25
+ # load model
26
+
27
+ checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt"
28
+ params['checkpoint_pth'] = checkpoint_pth
29
+ model = get_model(params['model'], params)
30
+
31
+ return spec_featurizer, mol_featurizer, model
app_utils/viz_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+ from rdkit import Chem
8
+ from rdkit.Chem import rdDepictor
9
+ import pandas as pd
10
+
11
+ def mol_to_graph_coords(mol):
12
+ """Return atom coordinates and bond list for a molecule."""
13
+ rdDepictor.Compute2DCoords(mol)
14
+ conf = mol.GetConformer()
15
+ coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())}
16
+ bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
17
+ return coords, bonds
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import plotly.graph_objects as go
22
+ from plotly.subplots import make_subplots
23
+
24
+
25
+ def interactive_attention_visualization(
26
+ spectral_embeds,
27
+ graph_embeds,
28
+ peak_mzs,
29
+ peak_intensities,
30
+ peak_formulas,
31
+ mol
32
+ ):
33
+ """
34
+ Build base Plotly figure + similarity matrix for Streamlit interactivity.
35
+ - Streamlit will handle clicks & recoloring using sim_norm
36
+ """
37
+
38
+ # --- Similarity matrix ---
39
+ spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1)
40
+ graph_embeds = F.normalize(graph_embeds, p=2, dim=-1)
41
+
42
+ similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy()
43
+ sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8)
44
+
45
+ num_peaks, num_nodes = similarity.shape
46
+
47
+ # --- Molecule graph ---
48
+ coords, bonds = mol_to_graph_coords(mol)
49
+ atom_labels = [a.GetSymbol() for a in mol.GetAtoms()]
50
+ atom_x = [coords[i].x for i in range(num_nodes)]
51
+ atom_y = [coords[i].y for i in range(num_nodes)]
52
+
53
+ # --- Spectrum trace ---
54
+ spectrum_trace = go.Scatter(
55
+ x=peak_mzs,
56
+ y=peak_intensities,
57
+ mode='markers', # crucial for clickable peaks
58
+ name="peak",
59
+ marker=dict(
60
+ size=12,
61
+ color="lightgray",
62
+ colorscale="Viridis",
63
+ cmin=0,
64
+ cmax=1,
65
+ colorbar=dict(title="Similarity", len=0.8, y=0.5),
66
+ ),
67
+ hovertext=[f"{f} \n ({m:,.2f}, {i:.2})" for f, m, i in zip(peak_formulas, peak_mzs, peak_intensities)],
68
+ hoverinfo='text',
69
+ customdata=list(range(num_peaks)), # actual peak indices
70
+ )
71
+
72
+ # --- Graph nodes ---
73
+ graph_nodes = go.Scatter(
74
+ x=atom_x,
75
+ y=atom_y,
76
+ mode="markers+text",
77
+ name="node",
78
+ text=atom_labels,
79
+ textposition="middle center",
80
+ marker=dict(
81
+ size=20,
82
+ color="lightgray",
83
+ colorscale="Viridis",
84
+ cmin=0,
85
+ cmax=1,
86
+ colorbar=dict(title="Similarity", len=0.8, y=0.5),
87
+ ),
88
+ customdata=list(range((num_nodes+1))),
89
+ )
90
+
91
+ # --- Graph bonds ---
92
+ edge_x, edge_y = [], []
93
+ for i, j in bonds:
94
+ edge_x += [coords[i].x, coords[j].x, None]
95
+ edge_y += [coords[i].y, coords[j].y, None]
96
+ graph_edges = go.Scatter(
97
+ x=edge_x,
98
+ y=edge_y,
99
+ mode="lines",
100
+ line=dict(color="gray", width=2),
101
+ hoverinfo="none",
102
+ showlegend=False,
103
+ )
104
+
105
+ # --- Subplots ---
106
+ fig = make_subplots(
107
+ rows=1,
108
+ cols=2,
109
+ subplot_titles=("Spectrum", "Molecule"),
110
+ column_widths=[0.6, 0.4],
111
+ )
112
+
113
+ fig.add_trace(spectrum_trace, row=1, col=1)
114
+ fig.add_trace(graph_edges, row=1, col=2)
115
+ fig.add_trace(graph_nodes, row=1, col=2)
116
+
117
+ fig.update_xaxes(title="m/z", row=1, col=1)
118
+ fig.update_yaxes(title="Intensity", row=1, col=1)
119
+ fig.update_xaxes(visible=False, row=1, col=2)
120
+ fig.update_yaxes(visible=False, row=1, col=2)
121
+
122
+ fig.update_layout(showlegend=False)
123
+
124
+ return fig, sim_norm
125
+
126
+
127
+ # ------------------------
128
+ # Model set up
129
+ # ------------------------
130
+
131
+ from mvp.subformula_assign.utils.spectra_utils import assign_subforms
132
+ import matchms
133
+
134
+ def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1):
135
+
136
+ # step 1 - label peaks with formula, setup matchms spectrum
137
+ x = assign_subforms(formula, np.array(ms), adduct, mass_diff_thresh=mass_diff_thresh)
138
+ if x['output_tbl'] is None:
139
+ return None, None
140
+
141
+ formulas = np.array(x['output_tbl']['formula'])
142
+ mzs = x['output_tbl']['mz']
143
+ intensities = x['output_tbl']['ms2_inten']
144
+ mzs = np.array([float(m) for m in mzs])
145
+ intensities = np.array([float(i) for i in intensities])
146
+
147
+ # add precursor if not already present
148
+ if formula not in formulas:
149
+ mzs = np.concatenate([mzs, [float(precursor_mz)]])
150
+ formulas = np.concatenate([formulas, [formula]])
151
+ intensities = np.concatenate([intensities, [float(precursor_intensity)]])
152
+ else:
153
+ i = np.where(formulas==formula)[0]
154
+ intensities[i] = precursor_intensity
155
+
156
+ sorted_idx = np.argsort(mzs)
157
+ mzs = mzs[sorted_idx]
158
+ intensities = intensities[sorted_idx]
159
+ formulas = formulas[sorted_idx]
160
+
161
+ spectrum = matchms.Spectrum(
162
+ mz = mzs,
163
+ intensities = intensities,
164
+ metadata = {'precursor_mz': precursor_mz, 'formulas': formulas}
165
+ )
166
+
167
+ # step 2 - featurize spectra
168
+ spectrum_encoding = spec_featurizer['SpecFormula'](spectrum)
169
+
170
+ # step 3 - featuraize molecule
171
+ molecule_encoding = mol_featurizer(smiles)
172
+
173
+ # step 4 - Embed spectra & molecules
174
+ model_input = {'mol': molecule_encoding, 'SpecFormula': spectrum_encoding}
175
+
176
+ model = model.to(torch.device('cpu'))
177
+ model.eval()
178
+ with torch.no_grad():
179
+ spec_embed, mol_embed = model.forward(model_input, stage='test')
180
+
181
+
182
+ # step 5 - visualization
183
+ mol = Chem.MolFromSmiles(smiles)
184
+ fig, sim_norm = interactive_attention_visualization(spec_embed, mol_embed, mzs, intensities, formulas, mol)
185
+
186
+ return fig, sim_norm