{ "cells": [ { "cell_type": "markdown", "id": "88a064af", "metadata": {}, "source": [ "Given a spectrum and candidates, can we use the peak-to-node scores to extract key substructures?\n", "\n", "- for each spectrum, look at top K candidates\n", "- compute peak-to-node score matrix\n", "- run DFS prioritizing node with high score to any peak, with predefined threshold T\n", "- how often are they substructures?\n", "- Find two examples where substructures extracted make up the target structure\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "c60529bb", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pickle\n", "\n", "import torch\n", "import numpy as np\n", "import plotly.graph_objects as go\n", "from plotly.subplots import make_subplots\n", "from rdkit import Chem\n", "from rdkit.Chem import rdDepictor\n", "from rdkit.Chem.Draw import rdMolDraw2D\n", "import pickle\n", "import copy\n", "from rdkit.Chem import Draw" ] }, { "cell_type": "markdown", "id": "515da6c7", "metadata": {}, "source": [ "# Load data and model" ] }, { "cell_type": "markdown", "id": "7207164b", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 2, "id": "de653fea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data path: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv\n", "Processing formula spectra\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 231104/231104 [00:16<00:00, 14307.79it/s]\n", "/data/yzhouc01/FILIP-MS/flare/data/datasets.py:221: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loaded Model from checkpoint\n" ] } ], "source": [ "# model and dataset\n", "import sys\n", "sys.path.insert(0, \"/data/yzhouc01/MassSpecGym\")\n", "sys.path.insert(0, \"/data/yzhouc01/FILIP-MS\")\n", "\n", "from rdkit import RDLogger\n", "import pytorch_lightning as pl\n", "from pytorch_lightning import Trainer\n", "from massspecgym.models.base import Stage\n", "import os\n", "\n", "from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset\n", "from flare.utils.models import get_model\n", "\n", "from flare.definitions import TEST_RESULTS_DIR\n", "import yaml\n", "from functools import partial\n", "# Suppress RDKit warnings and errors\n", "lg = RDLogger.logger()\n", "lg.setLevel(RDLogger.CRITICAL)\n", "\n", "# Load model and data\n", "\n", "param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'\n", "with open(param_pth) as f:\n", " params = yaml.load(f, Loader=yaml.FullLoader)\n", "\n", "spec_featurizer = get_spec_featurizer(params['spectra_view'], params)\n", "mol_featurizer = get_mol_featurizer(params['molecule_view'], params)\n", "dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)\n", "\n", "\n", "# load model\n", "import torch \n", "checkpoint_pth = \"/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt\"\n", "params['checkpoint_pth'] = checkpoint_pth\n", "model = get_model(params['model'], params)" ] }, { "cell_type": "code", "execution_count": 3, "id": "267e2d12", "metadata": {}, "outputs": [], "source": [ "# annotation result\n", "with open(\"/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/result_MassSpecGym_retrieval_candidates_formula.pkl\", 'rb') as f:\n", " result = pickle.load(f)" ] }, { "cell_type": "code", "execution_count": null, "id": "e1b82ba3", "metadata": {}, "outputs": [], "source": [ "def get_target(candidates, labels):\n", " return np.array(candidates)[labels][0]\n", "\n", "def sorted_candidates(candidates, scores):\n", " return np.array(candidates)[np.argsort(scores)[::-1]].tolist()\n", "\n", "def get_n_heavy_atoms(smiles):\n", " mol = Chem.MolFromSmiles(smiles)\n", " return mol.GetNumHeavyAtoms()\n", "\n", "def get_sorted_scores(scores):\n", " return np.array(scores)[np.argsort(scores)[::-1]].tolist()\n", "\n", "result['target'] = result.apply(lambda x: get_target(x['candidates'], x['labels']), axis=1)\n", "\n", "result['sorted_candidates'] = result.apply(lambda x: sorted_candidates(x['candidates'], x['scores']), axis=1)\n", "\n", "result['n_heavy_atoms'] = result['target'].apply(get_n_heavy_atoms)\n", "\n", "result['sorted_scores'] = result['scores'].apply(get_sorted_scores)" ] }, { "cell_type": "code", "execution_count": 90, "id": "40e1b7e7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | identifier | \n", "candidates | \n", "scores | \n", "labels | \n", "rank | \n", "target | \n", "sorted_candidates | \n", "n_heavy_atoms | \n", "sorted_scores | \n", "
|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "MassSpecGymID0000201 | \n", "[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... | \n", "[0.17369578778743744, 0.12611594796180725, 0.2... | \n", "[True, False, False, False, False, False, Fals... | \n", "17 | \n", "CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... | \n", "[COCCCN1C(=O)COc2ccc(N(C(=O)[C@H]3CN(C(=O)OC(C... | \n", "57 | \n", "[0.2598780691623688, 0.2579679787158966, 0.249... | \n", "
| 1 | \n", "MassSpecGymID0000202 | \n", "[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... | \n", "[0.05142267048358917, 0.07289629429578781, 0.1... | \n", "[True, False, False, False, False, False, Fals... | \n", "24 | \n", "CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... | \n", "[COC(=O)/C(C)=C\\CC1(O)C(=O)C2CC(C(C)C)C13Oc1c(... | \n", "57 | \n", "[0.2371954619884491, 0.21642719209194183, 0.20... | \n", "
| 2 | \n", "MassSpecGymID0000203 | \n", "[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... | \n", "[0.09354929625988007, 0.0947718694806099, 0.10... | \n", "[True, False, False, False, False, False, Fals... | \n", "23 | \n", "CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... | \n", "[C=CCOC12Oc3ccc(OC(=O)NCC)cc3C3C(CCCCO)C(CCCCO... | \n", "57 | \n", "[0.2382681667804718, 0.22565233707427979, 0.21... | \n", "