import torch.nn.functional as F import torch import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots from rdkit import Chem from rdkit.Chem import rdDepictor from flare.subformula_assign.utils.spectra_utils import assign_subforms import matchms def mol_to_graph_coords(mol): """Return atom coordinates and bond list for a molecule.""" rdDepictor.Compute2DCoords(mol) conf = mol.GetConformer() coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())} bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()] return coords, bonds def interactive_attention_visualization( spectral_embeds, graph_embeds, peak_mzs, peak_intensities, peak_formulas, mol ): """ Build base Plotly figure + similarity matrix for Streamlit interactivity. - Streamlit will handle clicks & recoloring using sim_norm """ # --- Similarity matrix --- spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1) graph_embeds = F.normalize(graph_embeds, p=2, dim=-1) similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy() sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8) num_peaks, num_nodes = similarity.shape # --- Molecule graph --- coords, bonds = mol_to_graph_coords(mol) atom_labels = [a.GetSymbol() for a in mol.GetAtoms()] atom_x = [coords[i].x for i in range(num_nodes)] atom_y = [coords[i].y for i in range(num_nodes)] # --- Spectrum trace --- spectrum_trace = go.Scatter( x=peak_mzs, y=peak_intensities, mode='markers', # crucial for clickable peaks name="peak", marker=dict( size=12, color="lightgray", colorscale="Viridis", cmin=0, cmax=1, colorbar=dict(title="Similarity", len=0.8, y=0.5), ), hovertext=[f"{f} \n ({m:,.2f}, {i:.2})" for f, m, i in zip(peak_formulas, peak_mzs, peak_intensities)], hoverinfo='text', customdata=list(range(num_peaks)), # actual peak indices ) # --- Graph nodes --- graph_nodes = go.Scatter( x=atom_x, y=atom_y, mode="markers+text", name="node", text=atom_labels, textposition="middle center", marker=dict( size=20, color="lightgray", colorscale="Viridis", cmin=0, cmax=1, colorbar=dict(title="Similarity", len=0.8, y=0.5), ), customdata=list(range((num_nodes+1))), ) # --- Graph bonds --- edge_x, edge_y = [], [] for i, j in bonds: edge_x += [coords[i].x, coords[j].x, None] edge_y += [coords[i].y, coords[j].y, None] graph_edges = go.Scatter( x=edge_x, y=edge_y, mode="lines", line=dict(color="gray", width=2), hoverinfo="none", showlegend=False, ) # --- Subplots --- fig = make_subplots( rows=1, cols=2, subplot_titles=("Spectrum", "Molecule"), column_widths=[0.6, 0.4], ) fig.add_trace(spectrum_trace, row=1, col=1) fig.add_trace(graph_edges, row=1, col=2) fig.add_trace(graph_nodes, row=1, col=2) fig.update_xaxes(title="m/z", row=1, col=1) fig.update_yaxes(title="Intensity", row=1, col=1) fig.update_xaxes(visible=False, row=1, col=2) fig.update_yaxes(visible=False, row=1, col=2) fig.update_layout(showlegend=False) return fig, sim_norm # ------------------------ # Model set up # ------------------------ def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1): # step 1 - label peaks with formula, setup matchms spectrum x = assign_subforms(formula, np.array(ms), adduct, mass_diff_thresh=mass_diff_thresh) if x['output_tbl'] is None: return None, None formulas = np.array(x['output_tbl']['formula']) mzs = x['output_tbl']['mz'] intensities = x['output_tbl']['ms2_inten'] mzs = np.array([float(m) for m in mzs]) intensities = np.array([float(i) for i in intensities]) # add precursor if not already present if formula not in formulas: mzs = np.concatenate([mzs, [float(precursor_mz)]]) formulas = np.concatenate([formulas, [formula]]) intensities = np.concatenate([intensities, [float(precursor_intensity)]]) else: i = np.where(formulas==formula)[0] intensities[i] = precursor_intensity sorted_idx = np.argsort(mzs) mzs = mzs[sorted_idx] intensities = intensities[sorted_idx] formulas = formulas[sorted_idx] spectrum = matchms.Spectrum( mz = mzs, intensities = intensities, metadata = {'precursor_mz': precursor_mz, 'formulas': formulas} ) # step 2 - featurize spectra spectrum_encoding = spec_featurizer['SpecFormula'](spectrum) # step 3 - featuraize molecule molecule_encoding = mol_featurizer(smiles) # step 4 - Embed spectra & molecules model_input = {'mol': molecule_encoding, 'SpecFormula': spectrum_encoding} model = model.to(torch.device('cpu')) model.eval() with torch.no_grad(): spec_embed, mol_embed = model.forward(model_input, stage='test') # step 5 - visualization mol = Chem.MolFromSmiles(smiles) fig, sim_norm = interactive_attention_visualization(spec_embed, mol_embed, mzs, intensities, formulas, mol) return fig, sim_norm