MVP / mvp /data /datasets.py
yzhouchen001's picture
model code
d9df210
import pandas as pd
import json
import typing as T
import numpy as np
import torch
import massspecgym.utils as utils
from pathlib import Path
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import default_collate
import dgl
from collections import defaultdict
from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
from massspecgym.data.datasets import MassSpecDataset
import mvp.utils.data as data_utils
from torch.nn.utils.rnn import pad_sequence
from massspecgym.models.base import Stage
import pickle
import math
import itertools
from rdkit.Chem import AllChem
from rdkit import Chem
class JESTR1_MassSpecDataset(MassSpecDataset):
def __init__(
self,
spectra_view: str,
fp_dir_pth: str = None,
cons_spec_dir_pth: str = None,
NL_spec_dir_pth: str = None,
**kwargs
):
super().__init__(**kwargs)
self.use_fp = False
self.use_cons_spec = False
self.use_NL_spec = False
self.spectra_view = spectra_view
# load fingerprints
self._load_fp(fp_dir_pth)
# load consensus
self._load_cons_spec(cons_spec_dir_pth)
# load NL specs
self._load_NL_spec(NL_spec_dir_pth)
def _load_fp(self, fp_dir_pth):
if fp_dir_pth is not None:
self.use_fp = True
if fp_dir_pth:
with open(fp_dir_pth, 'rb') as f:
self.smiles_to_fp = pickle.load(f)
else:
self.smiles_to_fp = {}
def _load_cons_spec(self, cons_spec_dir_pth):
if cons_spec_dir_pth is not None:
self.use_cons_spec = True
with open(cons_spec_dir_pth, 'rb') as f:
cons_specs = pickle.load(f)
# Convert spectra to matchms spectra
matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
spectra = cons_specs.apply(matchMS_preparer.prepare,axis=1)
self.cons_specs = dict(zip(cons_specs['smiles'].tolist(), spectra))
def _load_NL_spec(self, NL_spec_dir_pth):
if NL_spec_dir_pth is not None:
self.use_NL_spec = True
with open(NL_spec_dir_pth, 'rb') as f:
NL_specs = pickle.load(f)
# Convert spectra to matchms spectra
matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
self.NL_specs = NL_specs.apply(matchMS_preparer.prepare,axis=1)
def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
spec = self.spectra[i]
metadata = self.metadata.iloc[i]
mol = metadata["smiles"]
# Apply all transformations to the spectrum
item = {}
if transform_spec and self.spec_transform:
if isinstance(self.spec_transform, dict):
for key, transform in self.spec_transform.items():
item[key] = transform(spec) if transform is not None else spec
else:
item["spec"] = self.spec_transform(spec)
else:
item["spec"] = spec
if self.return_mol_freq:
item["mol_freq"] = metadata["mol_freq"]
if self.return_identifier:
item["identifier"] = metadata["identifier"]
if self.use_fp and self.smiles_to_fp:
item['fp'] = torch.Tensor(self.smiles_to_fp[mol].ToList())
if self.use_cons_spec:
item['cons_spec'] = self.spec_transform[self.spectra_view](self.cons_specs[mol])
if self.use_NL_spec:
item['NL_spec'] = self.spec_transform[self.spectra_view](self.NL_specs[i])
# Apply all transformations to the molecule
if transform_mol and self.mol_transform:
if isinstance(self.mol_transform, dict):
for key, transform in self.mol_transform.items():
item[key] = transform(mol) if transform is not None else mol
else:
item["mol"] = self.mol_transform(mol)
else:
item["mol"] = mol
return item
class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
def __init__(
self,
spectra_view: str,
spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]],
mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
pth: T.Optional[Path],
subformula_dir_pth: str,
fp_dir_pth: str = None,
NL_spec_dir_pth: str = None,
cons_spec_dir_pth: str = None,
return_mol_freq: bool = False,
return_identifier: bool = True,
dtype: T.Type = torch.float32
):
"""
Args:
"""
self.pth = pth
self.spec_transform = spec_transform
self.mol_transform = mol_transform
self.return_mol_freq = return_mol_freq
self.pred_fp = False
self.use_fp = False
self.use_cons_spec = False
self.use_NL_spec = False
self.spectra_view = spectra_view
if isinstance(self.pth, str):
self.pth = Path(self.pth)
self.spectra_view = spectra_view
print("Data path: ", self.pth)
self.metadata = pd.read_csv(self.pth, sep="\t")
# Used for training on consensus spectra
# with open(self.pth, 'rb') as f:
# self.metadata = pickle.load(f)
# self.metadata['identifier'] = self.metadata['smiles'].tolist()
# load subformulas
all_spec_ids = self.metadata['identifier'].tolist()
subformulaLoader = data_utils.Subformula_Loader(spectra_view=spectra_view, dir_path=subformula_dir_pth)
id_to_spec = subformulaLoader(all_spec_ids)
# create subformula spectra if no subformula is available
tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)]
tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist())))
# load fingerprints
self._load_fp(fp_dir_pth)
# load consensus spectra
self._load_cons_spec(cons_spec_dir_pth)
# load NL specs
self._load_NL_spec(NL_spec_dir_pth)
self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)]
formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
self.metadata = self.metadata.merge(formula_df, on='identifier')
# create matchms spectra
matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
self.spectra = self.metadata.apply(matchMS_preparer.prepare,axis=1)
if self.return_mol_freq:
if "inchikey" not in self.metadata.columns:
self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
self.return_identifier = return_identifier
self.dtype = dtype
def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
item = super().__getitem__(i, transform_spec, transform_mol = False)
mol = item['mol'] #smiles
# transform mol
if transform_mol:
if isinstance(self.mol_transform, dict):
for key, transform in self.mol_transform.items():
item[key] = transform(mol) if transform is not None else mol
else:
item["mol"] = self.mol_transform(mol)
return item
class ContrastiveDataset(Dataset):
def __init__(
self,
spec_mol_data,
):
super().__init__()
indices = spec_mol_data.indices
self.spec_mol_data = spec_mol_data
self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby('smiles').indices
self.smiles_to_spec_couter = defaultdict(int)
self.smiles_list = list(self.smiles_to_specmol_ids.keys())
def __len__(self) -> int:
return len(self.smiles_list)
def __getitem__(self, i:int) -> dict:
mol = self.smiles_list[i]
# select spectrum (iterate through list of spectra)
specmol_ids = self.smiles_to_specmol_ids[mol]
counter = self.smiles_to_spec_couter[mol]
specmol_id = specmol_ids[counter % len(specmol_ids)]
item = self.spec_mol_data.__getitem__(specmol_id)
self.smiles_to_spec_couter[mol] = counter+1
# item['smiles'] = mol
# item['spec_id'] = specmol_id
return item
@staticmethod
def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, mask_peak_ratio: float = 0.0, aug_cands: bool = False) -> dict:
mol_key = 'cand' if stage == Stage.TEST else 'mol'
non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
require_pad = False
if 'Formula' in spectra_view or 'Tokens' in spectra_view:
require_pad = True
padding_value=-5 if spec_enc in ('Transformer_Formula', 'Formula_BinnedSpec', 'Transformer_MzInt') else 0
non_standard_collate.append(spectra_view)
else:
non_standard_collate.remove('cons_spec')
non_standard_collate.remove('NL_spec')
collated_batch = {}
# standard collate
for k in batch[0].keys():
if k not in non_standard_collate:
collated_batch[k] = default_collate([item[k] for item in batch])
# batch graphs
batch_mol = []
batch_mol_nodes= []
for item in batch:
batch_mol.append(item[mol_key])
batch_mol_nodes.append(item[mol_key].num_nodes())
collated_batch[mol_key] = dgl.batch(batch_mol)
collated_batch['mol_n_nodes'] = batch_mol_nodes
# pad peaks/formulas
if require_pad:
peaks = []
n_peaks = []
for item in batch:
peaks.append(item[spectra_view])
n_peaks.append(len(item[spectra_view]))
collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
collated_batch['n_peaks'] = n_peaks
if 'cons_spec' in batch[0]:
peaks = []
n_peaks = []
for item in batch:
peaks.append(item['cons_spec'])
n_peaks.append(len(item['cons_spec']))
collated_batch['cons_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
collated_batch['cons_n_peaks'] = n_peaks
if 'NL_spec' in batch[0]:
peaks = []
n_peaks = []
for item in batch:
peaks.append(item['NL_spec'])
n_peaks.append(len(item['NL_spec']))
collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
collated_batch['NL_n_peaks'] = n_peaks
# mask peaks
if mask_peak_ratio > 0.0 and stage == Stage.TRAIN:
n_mask_peaks = [math.floor(n_peak* mask_peak_ratio) for n_peak in n_peaks]
mask_peak_idx = [np.random.choice(n_peak, n_mask, replace=False) for n_peak, n_mask in zip(n_peaks, n_mask_peaks)]
for i, peaks in enumerate(collated_batch[spectra_view]):
peaks[mask_peak_idx[i]] = -5.0
# batch candidates
if aug_cands:
candidates = \
sum([item["aug_cands"] for item in batch], start=[])
collated_batch['aug_cands'] = dgl.batch(candidates)
if 'aug_cands_fp' in batch[0]:
cand_fp = [item['aug_cands_fp'] for item in batch]
collated_batch['aug_cands_fp'] = torch.flatten(torch.Tensor(cand_fp), end_dim=1)
return collated_batch
class ExpandedRetrievalDataset:
'''Used for testing only
Assumes 'fold' column defines the split'''
def __init__(self,
use_formulas: bool = True,
mol_label_transform: MolTransform = MolToInChIKey(),
candidates_pth: T.Optional[T.Union[Path, str]] = None,
fp_size: int = None,
fp_radius: int = None,
external_test: bool = False,
**kwargs):
self.external_test = external_test
self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
# super().__init__(**kwargs)
if self.use_fp:
self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size)
self.candidates_pth = candidates_pth
self.mol_label_transform = mol_label_transform
# Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
with open(self.candidates_pth, "r") as file:
candidates = json.load(file)
self.candidates = {}
for s, cand in candidates.items():
self.candidates[s] = [c for c in cand if '.' not in c]
self.spec_cand = [] #(spec index, cand_smiles, true_label)
# use for external dataset where target smiles is not known
# self.candidates should be a dict of identifier to candidates
if self.external_test or 'smiles' not in self.metadata.columns:
if not isinstance(self.metadata.iloc[0]['identifier'], str):
self.metadata['smiles'] = self.metadata['identifier'].apply(str)
else:
self.metadata['smiles'] = self.metadata['identifier']
test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
for spec_id, s in zip(test_ms_id, test_smiles):
candidates = self.candidates[s]
# mol_label = self.mol_label_transform(s)
# labels = [self.mol_label_transform(c) == mol_label for c in candidates]
if not self.external_test:
labels = [c == s for c in candidates]
if len(candidates) == 0:
print(f"Skipping {spec_id}; empty candidate set")
continue
if not any(labels):
print(f"Target smiles not in candidate set")
else:
labels = [False] * len(candidates)
self.spec_cand.extend([(spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
def __getattr__(self, name):
return self.instance.__getattribute__(name)
def __len__(self):
return len(self.spec_cand)
def __getitem__(self, i):
spec_i = self.spec_cand[i][0]
cand_smiles = self.spec_cand[i][1]
label = self.spec_cand[i][2]
item = self.instance.__getitem__(spec_i, transform_mol=False)
item['cand'] = self.mol_transform(cand_smiles)
item['cand_smiles'] = cand_smiles
item['label'] = label
if self.use_fp:
item['fp'] = torch.Tensor(self.fpgen.GetFingerprint(Chem.MolFromSmiles(cand_smiles)).ToList())
return item