FLARE / app_utils /viz_utils.py
yzhouchen001's picture
update
19a4dfc
import torch.nn.functional as F
import torch
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from rdkit import Chem
from rdkit.Chem import rdDepictor
from flare.subformula_assign.utils.spectra_utils import assign_subforms
import matchms
def mol_to_graph_coords(mol):
"""Return atom coordinates and bond list for a molecule."""
rdDepictor.Compute2DCoords(mol)
conf = mol.GetConformer()
coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())}
bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
return coords, bonds
def interactive_attention_visualization(
spectral_embeds,
graph_embeds,
peak_mzs,
peak_intensities,
peak_formulas,
mol
):
"""
Build base Plotly figure + similarity matrix for Streamlit interactivity.
- Streamlit will handle clicks & recoloring using sim_norm
"""
# --- Similarity matrix ---
spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1)
graph_embeds = F.normalize(graph_embeds, p=2, dim=-1)
similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy()
sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8)
num_peaks, num_nodes = similarity.shape
# --- Molecule graph ---
coords, bonds = mol_to_graph_coords(mol)
atom_labels = [a.GetSymbol() for a in mol.GetAtoms()]
atom_x = [coords[i].x for i in range(num_nodes)]
atom_y = [coords[i].y for i in range(num_nodes)]
# --- Spectrum trace ---
spectrum_trace = go.Scatter(
x=peak_mzs,
y=peak_intensities,
mode='markers', # crucial for clickable peaks
name="peak",
marker=dict(
size=12,
color="lightgray",
colorscale="Viridis",
cmin=0,
cmax=1,
colorbar=dict(title="Similarity", len=0.8, y=0.5),
),
hovertext=[f"{f} \n ({m:,.2f}, {i:.2})" for f, m, i in zip(peak_formulas, peak_mzs, peak_intensities)],
hoverinfo='text',
customdata=list(range(num_peaks)), # actual peak indices
)
# --- Graph nodes ---
graph_nodes = go.Scatter(
x=atom_x,
y=atom_y,
mode="markers+text",
name="node",
text=atom_labels,
textposition="middle center",
marker=dict(
size=20,
color="lightgray",
colorscale="Viridis",
cmin=0,
cmax=1,
colorbar=dict(title="Similarity", len=0.8, y=0.5),
),
customdata=list(range((num_nodes+1))),
)
# --- Graph bonds ---
edge_x, edge_y = [], []
for i, j in bonds:
edge_x += [coords[i].x, coords[j].x, None]
edge_y += [coords[i].y, coords[j].y, None]
graph_edges = go.Scatter(
x=edge_x,
y=edge_y,
mode="lines",
line=dict(color="gray", width=2),
hoverinfo="none",
showlegend=False,
)
# --- Subplots ---
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=("Spectrum", "Molecule"),
column_widths=[0.6, 0.4],
)
fig.add_trace(spectrum_trace, row=1, col=1)
fig.add_trace(graph_edges, row=1, col=2)
fig.add_trace(graph_nodes, row=1, col=2)
fig.update_xaxes(title="m/z", row=1, col=1)
fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_xaxes(visible=False, row=1, col=2)
fig.update_yaxes(visible=False, row=1, col=2)
fig.update_layout(showlegend=False)
return fig, sim_norm
# ------------------------
# Model set up
# ------------------------
def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1):
# step 1 - label peaks with formula, setup matchms spectrum
x = assign_subforms(formula, np.array(ms), adduct, mass_diff_thresh=mass_diff_thresh)
if x['output_tbl'] is None:
return None, None
formulas = np.array(x['output_tbl']['formula'])
mzs = x['output_tbl']['mz']
intensities = x['output_tbl']['ms2_inten']
mzs = np.array([float(m) for m in mzs])
intensities = np.array([float(i) for i in intensities])
# add precursor if not already present
if formula not in formulas:
mzs = np.concatenate([mzs, [float(precursor_mz)]])
formulas = np.concatenate([formulas, [formula]])
intensities = np.concatenate([intensities, [float(precursor_intensity)]])
else:
i = np.where(formulas==formula)[0]
intensities[i] = precursor_intensity
sorted_idx = np.argsort(mzs)
mzs = mzs[sorted_idx]
intensities = intensities[sorted_idx]
formulas = formulas[sorted_idx]
spectrum = matchms.Spectrum(
mz = mzs,
intensities = intensities,
metadata = {'precursor_mz': precursor_mz, 'formulas': formulas}
)
# step 2 - featurize spectra
spectrum_encoding = spec_featurizer['SpecFormula'](spectrum)
# step 3 - featuraize molecule
molecule_encoding = mol_featurizer(smiles)
# step 4 - Embed spectra & molecules
model_input = {'mol': molecule_encoding, 'SpecFormula': spectrum_encoding}
model = model.to(torch.device('cpu'))
model.eval()
with torch.no_grad():
spec_embed, mol_embed = model.forward(model_input, stage='test')
# step 5 - visualization
mol = Chem.MolFromSmiles(smiles)
fig, sim_norm = interactive_attention_visualization(spec_embed, mol_embed, mzs, intensities, formulas, mol)
return fig, sim_norm