SpecEmbedding / src /data.py
xp
init commit
6039b52
from typing import Sequence
from collections.abc import Sequence
import numpy as np
from tqdm import tqdm
from matchms import Spectrum
from torch.utils.data import Dataset
from type import Peak, MetaData, TokenSequence
SpecialToken = {
"PAD": 0,
}
class TestDataset(Dataset):
def __init__(self, sequences: list[TokenSequence]) -> None:
super(TestDataset, self).__init__()
self._sequences = sequences
self.length = len(sequences)
def __len__(self):
return self.length
def __getitem__(self, index: int):
sequence = self._sequences[index]
return sequence["mz"], sequence["intensity"], sequence["mask"]
class Tokenizer:
def __init__(self, max_len: int, show_progress_bar: bool = True) -> None:
"""
Tokenization of mass spectrometry data
Parameters:
---
- max_len: Maximum number of peaks to extract
- show_progress_bar: Whether to display a progress bar
"""
self.max_len = max_len
self.show_progress_bar = show_progress_bar
def tokenize(self, s: Spectrum):
"""
Tokenization of mass spectrometry data
"""
metadata = self.get_metadata(s)
mz = []
intensity = []
for peak in metadata["peaks"]:
mz.append(peak["mz"])
intensity.append(peak["intensity"])
mz = np.array(mz)
intensity = np.array(intensity)
mask = np.zeros((self.max_len, ), dtype=bool)
if len(mz) < self.max_len:
mask[len(mz):] = True
mz = np.pad(
mz, (0, self.max_len - len(mz)),
mode='constant', constant_values=SpecialToken["PAD"]
)
intensity = np.pad(
intensity, (0, self.max_len - len(intensity)),
mode='constant', constant_values=SpecialToken["PAD"]
)
return TokenSequence(
mz=np.array(mz, np.float32),
intensity=np.array(intensity, np.float32),
mask=mask,
smiles=metadata["smiles"]
)
def tokenize_sequence(self, spectra: Sequence[Spectrum]):
sequences: list[TokenSequence] = []
pbar = spectra
if self.show_progress_bar:
pbar = tqdm(spectra, total=len(spectra), desc="tokenization")
for s in pbar:
sequences.append(self.tokenize(s))
return sequences
def get_metadata(self, s: Spectrum):
"""
get the metadata from spectrum
- smiles
- precursor_mz
- peaks
"""
precursor_mz = s.get("precursor_mz")
smiles = s.get("smiles")
peaks = np.array(s.peaks.to_numpy, np.float32)
intensity = peaks[:, 1]
argmaxsort_index = np.sort(
np.argsort(intensity)[::-1][:self.max_len - 1]
)
peaks = peaks[argmaxsort_index]
peaks[:, 1] = peaks[:, 1] / max(peaks[:, 1])
packaged_peaks: list[Peak] = [
Peak(
mz=np.array(precursor_mz, np.float32),
intensity=2
)
]
for mz, intensity in peaks:
packaged_peaks.append(
Peak(
mz=mz,
intensity=intensity
)
)
metadata = MetaData(
smiles=smiles,
peaks=packaged_peaks
)
return metadata