Spaces:
Sleeping
Sleeping
| import torch.nn.functional as F | |
| import torch | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from rdkit import Chem | |
| from rdkit.Chem import rdDepictor | |
| from flare.subformula_assign.utils.spectra_utils import assign_subforms | |
| import matchms | |
| def mol_to_graph_coords(mol): | |
| """Return atom coordinates and bond list for a molecule.""" | |
| rdDepictor.Compute2DCoords(mol) | |
| conf = mol.GetConformer() | |
| coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())} | |
| bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()] | |
| return coords, bonds | |
| def interactive_attention_visualization( | |
| spectral_embeds, | |
| graph_embeds, | |
| peak_mzs, | |
| peak_intensities, | |
| peak_formulas, | |
| mol | |
| ): | |
| """ | |
| Build base Plotly figure + similarity matrix for Streamlit interactivity. | |
| - Streamlit will handle clicks & recoloring using sim_norm | |
| """ | |
| # --- Similarity matrix --- | |
| spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1) | |
| graph_embeds = F.normalize(graph_embeds, p=2, dim=-1) | |
| similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy() | |
| sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8) | |
| num_peaks, num_nodes = similarity.shape | |
| # --- Molecule graph --- | |
| coords, bonds = mol_to_graph_coords(mol) | |
| atom_labels = [a.GetSymbol() for a in mol.GetAtoms()] | |
| atom_x = [coords[i].x for i in range(num_nodes)] | |
| atom_y = [coords[i].y for i in range(num_nodes)] | |
| # --- Spectrum trace --- | |
| spectrum_trace = go.Scatter( | |
| x=peak_mzs, | |
| y=peak_intensities, | |
| mode='markers', # crucial for clickable peaks | |
| name="peak", | |
| marker=dict( | |
| size=12, | |
| color="lightgray", | |
| colorscale="Viridis", | |
| cmin=0, | |
| cmax=1, | |
| colorbar=dict(title="Similarity", len=0.8, y=0.5), | |
| ), | |
| hovertext=[f"{f} \n ({m:,.2f}, {i:.2})" for f, m, i in zip(peak_formulas, peak_mzs, peak_intensities)], | |
| hoverinfo='text', | |
| customdata=list(range(num_peaks)), # actual peak indices | |
| ) | |
| # --- Graph nodes --- | |
| graph_nodes = go.Scatter( | |
| x=atom_x, | |
| y=atom_y, | |
| mode="markers+text", | |
| name="node", | |
| text=atom_labels, | |
| textposition="middle center", | |
| marker=dict( | |
| size=20, | |
| color="lightgray", | |
| colorscale="Viridis", | |
| cmin=0, | |
| cmax=1, | |
| colorbar=dict(title="Similarity", len=0.8, y=0.5), | |
| ), | |
| customdata=list(range((num_nodes+1))), | |
| ) | |
| # --- Graph bonds --- | |
| edge_x, edge_y = [], [] | |
| for i, j in bonds: | |
| edge_x += [coords[i].x, coords[j].x, None] | |
| edge_y += [coords[i].y, coords[j].y, None] | |
| graph_edges = go.Scatter( | |
| x=edge_x, | |
| y=edge_y, | |
| mode="lines", | |
| line=dict(color="gray", width=2), | |
| hoverinfo="none", | |
| showlegend=False, | |
| ) | |
| # --- Subplots --- | |
| fig = make_subplots( | |
| rows=1, | |
| cols=2, | |
| subplot_titles=("Spectrum", "Molecule"), | |
| column_widths=[0.6, 0.4], | |
| ) | |
| fig.add_trace(spectrum_trace, row=1, col=1) | |
| fig.add_trace(graph_edges, row=1, col=2) | |
| fig.add_trace(graph_nodes, row=1, col=2) | |
| fig.update_xaxes(title="m/z", row=1, col=1) | |
| fig.update_yaxes(title="Intensity", row=1, col=1) | |
| fig.update_xaxes(visible=False, row=1, col=2) | |
| fig.update_yaxes(visible=False, row=1, col=2) | |
| fig.update_layout(showlegend=False) | |
| return fig, sim_norm | |
| # ------------------------ | |
| # Model set up | |
| # ------------------------ | |
| def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1): | |
| # step 1 - label peaks with formula, setup matchms spectrum | |
| x = assign_subforms(formula, np.array(ms), adduct, mass_diff_thresh=mass_diff_thresh) | |
| if x['output_tbl'] is None: | |
| return None, None | |
| formulas = np.array(x['output_tbl']['formula']) | |
| mzs = x['output_tbl']['mz'] | |
| intensities = x['output_tbl']['ms2_inten'] | |
| mzs = np.array([float(m) for m in mzs]) | |
| intensities = np.array([float(i) for i in intensities]) | |
| # add precursor if not already present | |
| if formula not in formulas: | |
| mzs = np.concatenate([mzs, [float(precursor_mz)]]) | |
| formulas = np.concatenate([formulas, [formula]]) | |
| intensities = np.concatenate([intensities, [float(precursor_intensity)]]) | |
| else: | |
| i = np.where(formulas==formula)[0] | |
| intensities[i] = precursor_intensity | |
| sorted_idx = np.argsort(mzs) | |
| mzs = mzs[sorted_idx] | |
| intensities = intensities[sorted_idx] | |
| formulas = formulas[sorted_idx] | |
| spectrum = matchms.Spectrum( | |
| mz = mzs, | |
| intensities = intensities, | |
| metadata = {'precursor_mz': precursor_mz, 'formulas': formulas} | |
| ) | |
| # step 2 - featurize spectra | |
| spectrum_encoding = spec_featurizer['SpecFormula'](spectrum) | |
| # step 3 - featuraize molecule | |
| molecule_encoding = mol_featurizer(smiles) | |
| # step 4 - Embed spectra & molecules | |
| model_input = {'mol': molecule_encoding, 'SpecFormula': spectrum_encoding} | |
| model = model.to(torch.device('cpu')) | |
| model.eval() | |
| with torch.no_grad(): | |
| spec_embed, mol_embed = model.forward(model_input, stage='test') | |
| # step 5 - visualization | |
| mol = Chem.MolFromSmiles(smiles) | |
| fig, sim_norm = interactive_attention_visualization(spec_embed, mol_embed, mzs, intensities, formulas, mol) | |
| return fig, sim_norm |