FLARE / flare /data /datasets.py
yzhouchen001's picture
update
19a4dfc
import pandas as pd
import json
import typing as T
import numpy as np
import torch
import massspecgym.utils as utils
from pathlib import Path
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import default_collate
import dgl
from collections import defaultdict
from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
from massspecgym.data.datasets import MassSpecDataset
import flare.utils.data as data_utils
from torch.nn.utils.rnn import pad_sequence
from massspecgym.models.base import Stage
import pickle
import math
import itertools
from rdkit.Chem import AllChem
from rdkit import Chem
from magma.run_magma import run_magma
import matchms
class JESTR1_MassSpecDataset(MassSpecDataset):
def __init__(
self,
spectra_view: str,
fp_dir_pth: str = None,
cons_spec_dir_pth: str = None,
NL_spec_dir_pth: str = None,
**kwargs
):
super().__init__(**kwargs)
self.use_fp = False
self.use_cons_spec = False
self.use_NL_spec = False
self.spectra_view = spectra_view
# load fingerprints
self._load_fp(fp_dir_pth)
# load consensus
self._load_cons_spec(cons_spec_dir_pth)
# load NL specs
self._load_NL_spec(NL_spec_dir_pth)
def _load_fp(self, fp_dir_pth):
if fp_dir_pth is not None:
self.use_fp = True
if fp_dir_pth:
with open(fp_dir_pth, 'rb') as f:
self.smiles_to_fp = pickle.load(f)
else:
self.smiles_to_fp = {}
def _load_cons_spec(self, cons_spec_dir_pth):
if cons_spec_dir_pth is not None:
self.use_cons_spec = True
with open(cons_spec_dir_pth, 'rb') as f:
cons_specs = pickle.load(f)
# Convert spectra to matchms spectra
matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
spectra = cons_specs.apply(matchMS_preparer.prepare,axis=1)
self.cons_specs = dict(zip(cons_specs['smiles'].tolist(), spectra))
def _load_NL_spec(self, NL_spec_dir_pth):
if NL_spec_dir_pth is not None:
self.use_NL_spec = True
with open(NL_spec_dir_pth, 'rb') as f:
NL_specs = pickle.load(f)
# Convert spectra to matchms spectra
matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
self.NL_specs = NL_specs.apply(matchMS_preparer.prepare,axis=1)
def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
spec = self.spectra[i]
metadata = self.metadata.iloc[i]
mol = metadata["smiles"] if 'smiles' in metadata else metadata["identifier"]
# Apply all transformations to the spectrum
item = {}
if transform_spec and self.spec_transform:
if isinstance(self.spec_transform, dict):
for key, transform in self.spec_transform.items():
item[key] = transform(spec) if transform is not None else spec
else:
item["spec"] = self.spec_transform(spec)
if self.return_mol_freq:
item["mol_freq"] = metadata["mol_freq"]
if self.return_identifier:
item["identifier"] = metadata["identifier"]
if self.use_fp and self.smiles_to_fp:
item['fp'] = torch.Tensor(self.smiles_to_fp[mol].ToList())
if self.use_cons_spec:
item['cons_spec'] = self.spec_transform[self.spectra_view](self.cons_specs[mol])
if self.use_NL_spec:
item['NL_spec'] = self.spec_transform[self.spectra_view](self.NL_specs[i])
# Apply all transformations to the molecule
if transform_mol and self.mol_transform:
if isinstance(self.mol_transform, dict):
for key, transform in self.mol_transform.items():
item[key] = transform(mol) if transform is not None else mol
else:
item["mol"] = self.mol_transform(mol)
else:
item["mol"] = mol
return item
class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
def __init__(
self,
spectra_view: str,
spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]],
mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
pth: T.Optional[Path],
subformula_dir_pth: str,
fp_dir_pth: str = None,
NL_spec_dir_pth: str = None,
cons_spec_dir_pth: str = None,
return_mol_freq: bool = False,
return_identifier: bool = True,
dtype: T.Type = torch.float32,
formula_source = 'default',
stage: Stage = Stage.TRAIN
):
"""
Args:
"""
self.pth = pth
self.spec_transform = spec_transform
self.mol_transform = mol_transform
self.return_mol_freq = return_mol_freq
self.pred_fp = False
self.use_fp = False
self.use_cons_spec = False
self.use_NL_spec = False
self.spectra_view = spectra_view
self.formula_source = formula_source
self.subformula_dir_pth = subformula_dir_pth
if isinstance(self.pth, str):
self.pth = Path(self.pth)
self.spectra_view = spectra_view
print("Data path: ", self.pth)
self.metadata = pd.read_csv(self.pth, sep="\t")
# load subformulas
id_to_spec = self._load_id_to_spec(stage)
# load fingerprints
self._load_fp(fp_dir_pth)
# load consensus spectra
self._load_cons_spec(cons_spec_dir_pth)
# load NL specs
self._load_NL_spec(NL_spec_dir_pth)
self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)]
formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
self.metadata = self.metadata.merge(formula_df, on='identifier')
# create matchms spectra
matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
self.spectra = self.metadata.apply(matchMS_preparer.prepare,axis=1)
if self.return_mol_freq:
if "inchikey" not in self.metadata.columns:
self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
self.return_identifier = return_identifier
self.dtype = dtype
def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
item = super().__getitem__(i, transform_spec, transform_mol = False)
mol = item['mol'] #smiles
# transform mol
if transform_mol:
if isinstance(self.mol_transform, dict):
for key, transform in self.mol_transform.items():
item[key] = transform(mol) if transform is not None else mol
else:
item["mol"] = self.mol_transform(mol)
return item
def _load_id_to_spec(self, stage):
# if stage == Stage.TRAIN:
# self.metadata = self.metadata[self.metadata['fold'] != Stage.TEST.value]
# else:
# self.metadata = self.metadata[self.metadata['fold'] == Stage.TEST.value]
all_spec_ids = self.metadata['identifier'].tolist()
self.subformulaLoader = data_utils.Subformula_Loader(spectra_view=self.spectra_view, dir_path=self.subformula_dir_pth, formula_source=self.formula_source)
form_list = self.metadata['formula'].tolist()
prec_mz_list = self.metadata['precursor_mz'].tolist()
id_to_spec = self.subformulaLoader(all_spec_ids, form_list, prec_mz_list)
# create subformula spectra if no subformula is available
tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)]
tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist())))
return id_to_spec
class ContrastiveDataset(Dataset):
def __init__(
self,
spec_mol_data,
):
super().__init__()
indices = spec_mol_data.indices
self.spec_mol_data = spec_mol_data
self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby('smiles').indices
self.smiles_to_spec_couter = defaultdict(int)
self.smiles_list = list(self.smiles_to_specmol_ids.keys())
def __len__(self) -> int:
return len(self.smiles_list)
def __getitem__(self, i:int) -> dict:
mol = self.smiles_list[i]
# select spectrum (iterate through list of spectra)
specmol_ids = self.smiles_to_specmol_ids[mol]
counter = self.smiles_to_spec_couter[mol]
specmol_id = specmol_ids[counter % len(specmol_ids)]
item = self.spec_mol_data.__getitem__(specmol_id)
self.smiles_to_spec_couter[mol] = counter+1
# item['smiles'] = mol
# item['spec_id'] = specmol_id
return item
@staticmethod
def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, batch_mol: bool = True) -> dict:
mol_key = 'cand' if stage == Stage.TEST else 'mol'
non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
require_pad = False
if 'Formula' in spectra_view or 'Tokens' in spectra_view:
require_pad = True
padding_value=-5 if spec_enc in ('Transformer_Formula', 'Formula_BinnedSpec', 'Transformer_MzInt') else 0
non_standard_collate.append(spectra_view)
else:
non_standard_collate.remove('cons_spec')
non_standard_collate.remove('NL_spec')
collated_batch = {}
# standard collate
for k in batch[0].keys():
if k not in non_standard_collate:
try:
collated_batch[k] = default_collate([item[k] for item in batch])
except:
print(f"Error in collating key {k}")
raise
# batch graphs
if batch_mol:
batch_mol = []
batch_mol_nodes= []
for item in batch:
batch_mol.append(item[mol_key])
batch_mol_nodes.append(item[mol_key].num_nodes())
collated_batch[mol_key] = dgl.batch(batch_mol)
collated_batch['mol_n_nodes'] = batch_mol_nodes
# pad peaks/formulas
if require_pad:
peaks = []
n_peaks = []
for item in batch:
peaks.append(item[spectra_view])
n_peaks.append(len(item[spectra_view]))
collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
collated_batch['n_peaks'] = n_peaks
if 'cons_spec' in batch[0]:
peaks = []
n_peaks = []
for item in batch:
peaks.append(item['cons_spec'])
n_peaks.append(len(item['cons_spec']))
collated_batch['cons_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
collated_batch['cons_n_peaks'] = n_peaks
if 'NL_spec' in batch[0]:
peaks = []
n_peaks = []
for item in batch:
peaks.append(item['NL_spec'])
n_peaks.append(len(item['NL_spec']))
collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
collated_batch['NL_n_peaks'] = n_peaks
return collated_batch
class ExpandedRetrievalDataset:
'''Used for testing only
Assumes 'fold' column defines the split'''
def __init__(self,
use_formulas: bool = True,
mol_label_transform: MolTransform = MolToInChIKey(),
candidates_pth: T.Optional[T.Union[Path, str]] = None,
fp_size: int = None,
fp_radius: int = None,
use_magma = False,
**kwargs):
self.use_magma = use_magma
self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False, stage = Stage.TEST) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
if self.use_fp:
self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size)
self.candidates_pth = candidates_pth
self.mol_label_transform = mol_label_transform
# Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
with open(self.candidates_pth, "r") as file:
candidates = json.load(file)
self.candidates = {}
for s, cand in candidates.items():
clean_cands = []
for c in cand:
try:
if '.' not in c:
clean_cands.append(c)
except:
print(f"Error in processing candidate {c} for smiles {s}")
pass
self.candidates[s] = clean_cands
self.spec_cand = [] #(spec index, cand_smiles, true_label)
# use for external dataset where target smiles is not known
# self.candidates should be a dict of identifier to candidates
if 'smiles' not in self.metadata.columns:
if not isinstance(self.metadata.iloc[0]['identifier'], str):
self.metadata['smiles'] = self.metadata['identifier'].apply(str)
else:
self.metadata['smiles'] = self.metadata['identifier']
# keep datapoints where there are candidates
self.metadata = self.metadata[self.metadata['smiles'].isin(self.candidates.keys())]
test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
self.spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
for spec_id, s in zip(test_ms_id, test_smiles):
candidates = self.candidates[s]
# mol_label = self.mol_label_transform(s)
# labels = [self.mol_label_transform(c) == mol_label for c in candidates]
labels = [c == s for c in candidates]
if len(candidates) == 0:
print(f"Skipping {spec_id}; empty candidate set")
continue
if not any(labels):
# print(f"Target smiles not in candidate set")
pass
self.spec_cand.extend([(self.spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
def __getattr__(self, name):
return self.instance.__getattribute__(name)
def __len__(self):
return len(self.spec_cand)
def __getitem__(self, i):
spec_i = self.spec_cand[i][0]
cand_smiles = self.spec_cand[i][1]
label = self.spec_cand[i][2]
if self.use_magma:
item = self.instance.__getitem__(spec_i, transform_mol=False, transform_spec=False)
mzs = np.array([float(x) for x in self.metadata.iloc[spec_i]['mzs'].split(',')])
intensities = np.array([float(x) for x in self.metadata.iloc[spec_i]['intensities'].split(',')])
adduct = self.metadata.iloc[spec_i]['adduct']
precursor_mz = self.metadata.iloc[spec_i]['precursor_mz']
formula = self.metadata.iloc[spec_i]['formula']
spec_data = run_magma(i, mzs, intensities, cand_smiles, adduct)
spec = self.subformulaLoader.load_magma_data(spec_data, formula, precursor_mz)
spec = matchms.Spectrum(
mz = np.array(spec['formula_mzs']),
intensities = np.array(spec['formula_intensities']),
metadata = {'precursor_mz': precursor_mz, 'formulas': np.array(spec['formulas'])})
if isinstance(self.spec_transform, dict):
for key, transform in self.spec_transform.items():
item[key] = transform(spec) if transform is not None else spec
else:
item["spec"] = self.spec_transform(spec)
else:
item = self.instance.__getitem__(spec_i, transform_mol=False)
item['cand'] = self.mol_transform(cand_smiles)
item['cand_smiles'] = cand_smiles
item['label'] = label
if self.use_fp:
item['fp'] = torch.Tensor(self.fpgen.GetFingerprint(Chem.MolFromSmiles(cand_smiles)).ToList())
return item
class MassSpecDataset_Candidates:
def __init__(self,
use_formulas: bool,
aug_cands_dir_pth: str,
aug_cands_size: int,
**kwargs):
self.aug_cands_size = aug_cands_size
self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
with open(aug_cands_dir_pth, 'rb') as f:
aug_cands = pickle.load(f)
if self.use_fp:
self.fpgen = AllChem.GetMorganGenerator(radius=5,fpSize=1024)
self.aug_cands = {}
targets = np.array(list(aug_cands.keys()))
for smiles, cands in aug_cands.items():
# sort candidates by tanimoto similarity
cands.sort(key=lambda x: x[1], reverse=True)
cands = [c for c in cands if '.' not in c]
# assert(len(cands) >0)
if len(cands) <=1: # if no candidates, shuffle from target list
np.random.shuffle(targets)
cands = targets
self.aug_cands[smiles] = itertools.cycle(cands)
def __getattr__(self, name):
return self.instance.__getattribute__(name)
def __getitem__(self, i):
item = self.instance.__getitem__(i,transform_mol=False)
aug_cands = [next(self.aug_cands[item['mol']]) for _ in range(self.aug_cands_size)]
item['aug_cands_fp'] = [self.fpgen.GetFingerprint(Chem.MolFromSmiles(c)).ToList() for c in aug_cands]
item["aug_cands"] = [self.mol_transform(c) for c in aug_cands]
item["mol"] = self.mol_transform(item["mol"])
return item