Spaces:
Sleeping
Sleeping
File size: 8,646 Bytes
94aa6f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import pandas as pd
import json
import typing as T
import numpy as np
import torch
import matchms
import massspecgym.utils as utils
from pathlib import Path
from rdkit import Chem
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import default_collate
from matchms.importing import load_from_mgf
from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
class MassSpecDataset(Dataset):
"""
Dataset containing mass spectra and their corresponding molecular structures. This class is
responsible for loading the data from disk and applying transformation steps to the spectra and
molecules.
"""
def __init__(
self,
spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]] = None,
mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]] = None,
pth: T.Optional[Path] = None,
return_mol_freq: bool = True,
return_identifier: bool = True,
dtype: T.Type = torch.float32
):
"""
Args:
pth (Optional[Path], optional): Path to the .tsv or .mgf file containing the mass spectra.
Default is None, in which case the MassSpecGym dataset is downloaded from HuggingFace Hub.
"""
self.pth = pth
self.spec_transform = spec_transform
self.mol_transform = mol_transform
self.return_mol_freq = return_mol_freq
if self.pth is None:
self.pth = utils.hugging_face_download("MassSpecGym.tsv")
if isinstance(self.pth, str):
self.pth = Path(self.pth)
if self.pth.suffix == ".tsv":
self.metadata = pd.read_csv(self.pth, sep="\t")
self.spectra = self.metadata.apply(
lambda row: matchms.Spectrum(
mz=np.array([float(m) for m in row["mzs"].split(",")]),
intensities=np.array(
[float(i) for i in row["intensities"].split(",")]
),
metadata={"precursor_mz": row["precursor_mz"]},
),
axis=1,
)
self.metadata = self.metadata.drop(columns=["mzs", "intensities"])
elif self.pth.suffix == ".mgf":
self.spectra = list(load_from_mgf(str(self.pth)))
self.metadata = pd.DataFrame([s.metadata for s in self.spectra])
else:
raise ValueError(f"{self.pth.suffix} file format not supported.")
if self.return_mol_freq:
if "inchikey" not in self.metadata.columns:
self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
self.return_identifier = return_identifier
self.dtype = dtype
def __len__(self) -> int:
return len(self.spectra)
def __getitem__(
self, i: int, transform_spec: bool = True, transform_mol: bool = True
) -> dict:
spec = self.spectra[i]
metadata = self.metadata.iloc[i]
mol = metadata["smiles"]
# Apply all transformations to the spectrum
item = {}
if transform_spec and self.spec_transform:
if isinstance(self.spec_transform, dict):
for key, transform in self.spec_transform.items():
item[key] = transform(spec) if transform is not None else spec
else:
item["spec"] = self.spec_transform(spec)
else:
item["spec"] = spec
# Apply all transformations to the molecule
if transform_mol and self.mol_transform:
if isinstance(self.mol_transform, dict):
for key, transform in self.mol_transform.items():
item[key] = transform(mol) if transform is not None else mol
else:
item["mol"] = self.mol_transform(mol)
else:
item["mol"] = mol
# Add other metadata to the item
# item.update({
# k: metadata[k] for k in ["precursor_mz", "adduct"]
# })
if self.return_mol_freq:
item["mol_freq"] = metadata["mol_freq"]
if self.return_identifier:
item["identifier"] = metadata["identifier"]
# TODO: this should be refactored
for k, v in item.items():
if not isinstance(v, str):
try:
item[k] = torch.as_tensor(v, dtype=self.dtype)
except:
continue
return item
@staticmethod
def collate_fn(batch: T.Iterable[dict]) -> dict:
"""
Custom collate function to handle the outputs of __getitem__.
"""
return default_collate(batch)
class RetrievalDataset(MassSpecDataset):
"""
Dataset containing mass spectra and their corresponding molecular structures, with additional
candidates of molecules for retrieval based on spectral similarity.
"""
def __init__(
self,
mol_label_transform: MolTransform = MolToInChIKey(),
candidates_pth: T.Optional[T.Union[Path, str]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.candidates_pth = candidates_pth
self.mol_label_transform = mol_label_transform
# Download candidates from HuggigFace Hub if not a path to exisiting file is passed
if self.candidates_pth is None:
self.candidates_pth = utils.hugging_face_download(
"molecules/MassSpecGym_retrieval_candidates_mass.json"
)
elif isinstance(self.candidates_pth, str):
if Path(self.candidates_pth).is_file():
self.candidates_pth = Path(self.candidates_pth)
else:
self.candidates_pth = utils.hugging_face_download(candidates_pth)
# Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
with open(self.candidates_pth, "r") as file:
self.candidates = json.load(file)
def __getitem__(self, i) -> dict:
item = super().__getitem__(i, transform_mol=False)
# Save the original SMILES representation of the query molecule (for evaluation)
item["smiles"] = item["mol"]
# Get candidates
if item["mol"] not in self.candidates:
raise ValueError(f'No candidates for the query molecule {item["mol"]}.')
item["candidates"] = self.candidates[item["mol"]]
# Save the original SMILES representations of the canidates (for evaluation)
item["candidates_smiles"] = item["candidates"]
# Create neg/pos label mask by matching the query molecule with the candidates
item_label = self.mol_label_transform(item["mol"])
item["labels"] = [
self.mol_label_transform(c) == item_label for c in item["candidates"]
]
if not any(item["labels"]):
raise ValueError(
f'Query molecule {item["mol"]} not found in the candidates list.'
)
# Transform the query and candidate molecules
item["mol"] = self.mol_transform(item["mol"])
item["candidates"] = [self.mol_transform(c) for c in item["candidates"]]
if isinstance(item["mol"], np.ndarray):
item["mol"] = torch.as_tensor(item["mol"], dtype=self.dtype)
# item["candidates"] = [torch.as_tensor(c, dtype=self.dtype) for c in item["candidates"]]
return item
@staticmethod
def collate_fn(batch: T.Iterable[dict]) -> dict:
# Standard collate for everything except candidates and their labels (which may have different length per sample)
collated_batch = {}
for k in batch[0].keys():
if k not in ["candidates", "labels", "candidates_smiles"]:
collated_batch[k] = default_collate([item[k] for item in batch])
# Collate candidates and labels by concatenating and storing sizes of each list
collated_batch["candidates"] = torch.as_tensor(
np.concatenate([item["candidates"] for item in batch])
)
collated_batch["labels"] = torch.as_tensor(
sum([item["labels"] for item in batch], start=[])
)
collated_batch["batch_ptr"] = torch.as_tensor(
[len(item["candidates"]) for item in batch]
)
collated_batch["candidates_smiles"] = \
sum([item["candidates_smiles"] for item in batch], start=[])
return collated_batch
# TODO: Datasets for unlabeled data.
|