import pandas as pd import json import typing as T import numpy as np import torch import massspecgym.utils as utils from pathlib import Path from torch.utils.data.dataset import Dataset from torch.utils.data.dataloader import default_collate import dgl from collections import defaultdict from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey from massspecgym.data.datasets import MassSpecDataset import flare.utils.data as data_utils from torch.nn.utils.rnn import pad_sequence from massspecgym.models.base import Stage import pickle import math import itertools from rdkit.Chem import AllChem from rdkit import Chem from magma.run_magma import run_magma import matchms class JESTR1_MassSpecDataset(MassSpecDataset): def __init__( self, spectra_view: str, fp_dir_pth: str = None, cons_spec_dir_pth: str = None, NL_spec_dir_pth: str = None, **kwargs ): super().__init__(**kwargs) self.use_fp = False self.use_cons_spec = False self.use_NL_spec = False self.spectra_view = spectra_view # load fingerprints self._load_fp(fp_dir_pth) # load consensus self._load_cons_spec(cons_spec_dir_pth) # load NL specs self._load_NL_spec(NL_spec_dir_pth) def _load_fp(self, fp_dir_pth): if fp_dir_pth is not None: self.use_fp = True if fp_dir_pth: with open(fp_dir_pth, 'rb') as f: self.smiles_to_fp = pickle.load(f) else: self.smiles_to_fp = {} def _load_cons_spec(self, cons_spec_dir_pth): if cons_spec_dir_pth is not None: self.use_cons_spec = True with open(cons_spec_dir_pth, 'rb') as f: cons_specs = pickle.load(f) # Convert spectra to matchms spectra matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view) spectra = cons_specs.apply(matchMS_preparer.prepare,axis=1) self.cons_specs = dict(zip(cons_specs['smiles'].tolist(), spectra)) def _load_NL_spec(self, NL_spec_dir_pth): if NL_spec_dir_pth is not None: self.use_NL_spec = True with open(NL_spec_dir_pth, 'rb') as f: NL_specs = pickle.load(f) # Convert spectra to matchms spectra matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view) self.NL_specs = NL_specs.apply(matchMS_preparer.prepare,axis=1) def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True): spec = self.spectra[i] metadata = self.metadata.iloc[i] mol = metadata["smiles"] if 'smiles' in metadata else metadata["identifier"] # Apply all transformations to the spectrum item = {} if transform_spec and self.spec_transform: if isinstance(self.spec_transform, dict): for key, transform in self.spec_transform.items(): item[key] = transform(spec) if transform is not None else spec else: item["spec"] = self.spec_transform(spec) if self.return_mol_freq: item["mol_freq"] = metadata["mol_freq"] if self.return_identifier: item["identifier"] = metadata["identifier"] if self.use_fp and self.smiles_to_fp: item['fp'] = torch.Tensor(self.smiles_to_fp[mol].ToList()) if self.use_cons_spec: item['cons_spec'] = self.spec_transform[self.spectra_view](self.cons_specs[mol]) if self.use_NL_spec: item['NL_spec'] = self.spec_transform[self.spectra_view](self.NL_specs[i]) # Apply all transformations to the molecule if transform_mol and self.mol_transform: if isinstance(self.mol_transform, dict): for key, transform in self.mol_transform.items(): item[key] = transform(mol) if transform is not None else mol else: item["mol"] = self.mol_transform(mol) else: item["mol"] = mol return item class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset): def __init__( self, spectra_view: str, spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]], mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]], pth: T.Optional[Path], subformula_dir_pth: str, fp_dir_pth: str = None, NL_spec_dir_pth: str = None, cons_spec_dir_pth: str = None, return_mol_freq: bool = False, return_identifier: bool = True, dtype: T.Type = torch.float32, formula_source = 'default', stage: Stage = Stage.TRAIN ): """ Args: """ self.pth = pth self.spec_transform = spec_transform self.mol_transform = mol_transform self.return_mol_freq = return_mol_freq self.pred_fp = False self.use_fp = False self.use_cons_spec = False self.use_NL_spec = False self.spectra_view = spectra_view self.formula_source = formula_source self.subformula_dir_pth = subformula_dir_pth if isinstance(self.pth, str): self.pth = Path(self.pth) self.spectra_view = spectra_view print("Data path: ", self.pth) self.metadata = pd.read_csv(self.pth, sep="\t") # load subformulas id_to_spec = self._load_id_to_spec(stage) # load fingerprints self._load_fp(fp_dir_pth) # load consensus spectra self._load_cons_spec(cons_spec_dir_pth) # load NL specs self._load_NL_spec(NL_spec_dir_pth) self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)] formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'}) self.metadata = self.metadata.merge(formula_df, on='identifier') # create matchms spectra matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view) self.spectra = self.metadata.apply(matchMS_preparer.prepare,axis=1) if self.return_mol_freq: if "inchikey" not in self.metadata.columns: self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key) self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count") self.return_identifier = return_identifier self.dtype = dtype def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True): item = super().__getitem__(i, transform_spec, transform_mol = False) mol = item['mol'] #smiles # transform mol if transform_mol: if isinstance(self.mol_transform, dict): for key, transform in self.mol_transform.items(): item[key] = transform(mol) if transform is not None else mol else: item["mol"] = self.mol_transform(mol) return item def _load_id_to_spec(self, stage): # if stage == Stage.TRAIN: # self.metadata = self.metadata[self.metadata['fold'] != Stage.TEST.value] # else: # self.metadata = self.metadata[self.metadata['fold'] == Stage.TEST.value] all_spec_ids = self.metadata['identifier'].tolist() self.subformulaLoader = data_utils.Subformula_Loader(spectra_view=self.spectra_view, dir_path=self.subformula_dir_pth, formula_source=self.formula_source) form_list = self.metadata['formula'].tolist() prec_mz_list = self.metadata['precursor_mz'].tolist() id_to_spec = self.subformulaLoader(all_spec_ids, form_list, prec_mz_list) # create subformula spectra if no subformula is available tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec] tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)] tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1) id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist()))) return id_to_spec class ContrastiveDataset(Dataset): def __init__( self, spec_mol_data, ): super().__init__() indices = spec_mol_data.indices self.spec_mol_data = spec_mol_data self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby('smiles').indices self.smiles_to_spec_couter = defaultdict(int) self.smiles_list = list(self.smiles_to_specmol_ids.keys()) def __len__(self) -> int: return len(self.smiles_list) def __getitem__(self, i:int) -> dict: mol = self.smiles_list[i] # select spectrum (iterate through list of spectra) specmol_ids = self.smiles_to_specmol_ids[mol] counter = self.smiles_to_spec_couter[mol] specmol_id = specmol_ids[counter % len(specmol_ids)] item = self.spec_mol_data.__getitem__(specmol_id) self.smiles_to_spec_couter[mol] = counter+1 # item['smiles'] = mol # item['spec_id'] = specmol_id return item @staticmethod def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, batch_mol: bool = True) -> dict: mol_key = 'cand' if stage == Stage.TEST else 'mol' non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec'] require_pad = False if 'Formula' in spectra_view or 'Tokens' in spectra_view: require_pad = True padding_value=-5 if spec_enc in ('Transformer_Formula', 'Formula_BinnedSpec', 'Transformer_MzInt') else 0 non_standard_collate.append(spectra_view) else: non_standard_collate.remove('cons_spec') non_standard_collate.remove('NL_spec') collated_batch = {} # standard collate for k in batch[0].keys(): if k not in non_standard_collate: try: collated_batch[k] = default_collate([item[k] for item in batch]) except: print(f"Error in collating key {k}") raise # batch graphs if batch_mol: batch_mol = [] batch_mol_nodes= [] for item in batch: batch_mol.append(item[mol_key]) batch_mol_nodes.append(item[mol_key].num_nodes()) collated_batch[mol_key] = dgl.batch(batch_mol) collated_batch['mol_n_nodes'] = batch_mol_nodes # pad peaks/formulas if require_pad: peaks = [] n_peaks = [] for item in batch: peaks.append(item[spectra_view]) n_peaks.append(len(item[spectra_view])) collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value) collated_batch['n_peaks'] = n_peaks if 'cons_spec' in batch[0]: peaks = [] n_peaks = [] for item in batch: peaks.append(item['cons_spec']) n_peaks.append(len(item['cons_spec'])) collated_batch['cons_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value) collated_batch['cons_n_peaks'] = n_peaks if 'NL_spec' in batch[0]: peaks = [] n_peaks = [] for item in batch: peaks.append(item['NL_spec']) n_peaks.append(len(item['NL_spec'])) collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value) collated_batch['NL_n_peaks'] = n_peaks return collated_batch class ExpandedRetrievalDataset: '''Used for testing only Assumes 'fold' column defines the split''' def __init__(self, use_formulas: bool = True, mol_label_transform: MolTransform = MolToInChIKey(), candidates_pth: T.Optional[T.Union[Path, str]] = None, fp_size: int = None, fp_radius: int = None, use_magma = False, **kwargs): self.use_magma = use_magma self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False, stage = Stage.TEST) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False) if self.use_fp: self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size) self.candidates_pth = candidates_pth self.mol_label_transform = mol_label_transform # Read candidates_pth from json to dict: SMILES -> respective candidate SMILES with open(self.candidates_pth, "r") as file: candidates = json.load(file) self.candidates = {} for s, cand in candidates.items(): clean_cands = [] for c in cand: try: if '.' not in c: clean_cands.append(c) except: print(f"Error in processing candidate {c} for smiles {s}") pass self.candidates[s] = clean_cands self.spec_cand = [] #(spec index, cand_smiles, true_label) # use for external dataset where target smiles is not known # self.candidates should be a dict of identifier to candidates if 'smiles' not in self.metadata.columns: if not isinstance(self.metadata.iloc[0]['identifier'], str): self.metadata['smiles'] = self.metadata['identifier'].apply(str) else: self.metadata['smiles'] = self.metadata['identifier'] # keep datapoints where there are candidates self.metadata = self.metadata[self.metadata['smiles'].isin(self.candidates.keys())] test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist() test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist() self.spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index)) for spec_id, s in zip(test_ms_id, test_smiles): candidates = self.candidates[s] # mol_label = self.mol_label_transform(s) # labels = [self.mol_label_transform(c) == mol_label for c in candidates] labels = [c == s for c in candidates] if len(candidates) == 0: print(f"Skipping {spec_id}; empty candidate set") continue if not any(labels): # print(f"Target smiles not in candidate set") pass self.spec_cand.extend([(self.spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)]) def __getattr__(self, name): return self.instance.__getattribute__(name) def __len__(self): return len(self.spec_cand) def __getitem__(self, i): spec_i = self.spec_cand[i][0] cand_smiles = self.spec_cand[i][1] label = self.spec_cand[i][2] if self.use_magma: item = self.instance.__getitem__(spec_i, transform_mol=False, transform_spec=False) mzs = np.array([float(x) for x in self.metadata.iloc[spec_i]['mzs'].split(',')]) intensities = np.array([float(x) for x in self.metadata.iloc[spec_i]['intensities'].split(',')]) adduct = self.metadata.iloc[spec_i]['adduct'] precursor_mz = self.metadata.iloc[spec_i]['precursor_mz'] formula = self.metadata.iloc[spec_i]['formula'] spec_data = run_magma(i, mzs, intensities, cand_smiles, adduct) spec = self.subformulaLoader.load_magma_data(spec_data, formula, precursor_mz) spec = matchms.Spectrum( mz = np.array(spec['formula_mzs']), intensities = np.array(spec['formula_intensities']), metadata = {'precursor_mz': precursor_mz, 'formulas': np.array(spec['formulas'])}) if isinstance(self.spec_transform, dict): for key, transform in self.spec_transform.items(): item[key] = transform(spec) if transform is not None else spec else: item["spec"] = self.spec_transform(spec) else: item = self.instance.__getitem__(spec_i, transform_mol=False) item['cand'] = self.mol_transform(cand_smiles) item['cand_smiles'] = cand_smiles item['label'] = label if self.use_fp: item['fp'] = torch.Tensor(self.fpgen.GetFingerprint(Chem.MolFromSmiles(cand_smiles)).ToList()) return item class MassSpecDataset_Candidates: def __init__(self, use_formulas: bool, aug_cands_dir_pth: str, aug_cands_size: int, **kwargs): self.aug_cands_size = aug_cands_size self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False) with open(aug_cands_dir_pth, 'rb') as f: aug_cands = pickle.load(f) if self.use_fp: self.fpgen = AllChem.GetMorganGenerator(radius=5,fpSize=1024) self.aug_cands = {} targets = np.array(list(aug_cands.keys())) for smiles, cands in aug_cands.items(): # sort candidates by tanimoto similarity cands.sort(key=lambda x: x[1], reverse=True) cands = [c for c in cands if '.' not in c] # assert(len(cands) >0) if len(cands) <=1: # if no candidates, shuffle from target list np.random.shuffle(targets) cands = targets self.aug_cands[smiles] = itertools.cycle(cands) def __getattr__(self, name): return self.instance.__getattribute__(name) def __getitem__(self, i): item = self.instance.__getitem__(i,transform_mol=False) aug_cands = [next(self.aug_cands[item['mol']]) for _ in range(self.aug_cands_size)] item['aug_cands_fp'] = [self.fpgen.GetFingerprint(Chem.MolFromSmiles(c)).ToList() for c in aug_cands] item["aug_cands"] = [self.mol_transform(c) for c in aug_cands] item["mol"] = self.mol_transform(item["mol"]) return item