Spaces:
Runtime error
Runtime error
| from typing import Sequence | |
| from collections.abc import Sequence | |
| import numpy as np | |
| from tqdm import tqdm | |
| from matchms import Spectrum | |
| from torch.utils.data import Dataset | |
| from type import Peak, MetaData, TokenSequence | |
| SpecialToken = { | |
| "PAD": 0, | |
| } | |
| class TestDataset(Dataset): | |
| def __init__(self, sequences: list[TokenSequence]) -> None: | |
| super(TestDataset, self).__init__() | |
| self._sequences = sequences | |
| self.length = len(sequences) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index: int): | |
| sequence = self._sequences[index] | |
| return sequence["mz"], sequence["intensity"], sequence["mask"] | |
| class Tokenizer: | |
| def __init__(self, max_len: int, show_progress_bar: bool = True) -> None: | |
| """ | |
| Tokenization of mass spectrometry data | |
| Parameters: | |
| --- | |
| - max_len: Maximum number of peaks to extract | |
| - show_progress_bar: Whether to display a progress bar | |
| """ | |
| self.max_len = max_len | |
| self.show_progress_bar = show_progress_bar | |
| def tokenize(self, s: Spectrum): | |
| """ | |
| Tokenization of mass spectrometry data | |
| """ | |
| metadata = self.get_metadata(s) | |
| mz = [] | |
| intensity = [] | |
| for peak in metadata["peaks"]: | |
| mz.append(peak["mz"]) | |
| intensity.append(peak["intensity"]) | |
| mz = np.array(mz) | |
| intensity = np.array(intensity) | |
| mask = np.zeros((self.max_len, ), dtype=bool) | |
| if len(mz) < self.max_len: | |
| mask[len(mz):] = True | |
| mz = np.pad( | |
| mz, (0, self.max_len - len(mz)), | |
| mode='constant', constant_values=SpecialToken["PAD"] | |
| ) | |
| intensity = np.pad( | |
| intensity, (0, self.max_len - len(intensity)), | |
| mode='constant', constant_values=SpecialToken["PAD"] | |
| ) | |
| return TokenSequence( | |
| mz=np.array(mz, np.float32), | |
| intensity=np.array(intensity, np.float32), | |
| mask=mask, | |
| smiles=metadata["smiles"] | |
| ) | |
| def tokenize_sequence(self, spectra: Sequence[Spectrum]): | |
| sequences: list[TokenSequence] = [] | |
| pbar = spectra | |
| if self.show_progress_bar: | |
| pbar = tqdm(spectra, total=len(spectra), desc="tokenization") | |
| for s in pbar: | |
| sequences.append(self.tokenize(s)) | |
| return sequences | |
| def get_metadata(self, s: Spectrum): | |
| """ | |
| get the metadata from spectrum | |
| - smiles | |
| - precursor_mz | |
| - peaks | |
| """ | |
| precursor_mz = s.get("precursor_mz") | |
| smiles = s.get("smiles") | |
| peaks = np.array(s.peaks.to_numpy, np.float32) | |
| intensity = peaks[:, 1] | |
| argmaxsort_index = np.sort( | |
| np.argsort(intensity)[::-1][:self.max_len - 1] | |
| ) | |
| peaks = peaks[argmaxsort_index] | |
| peaks[:, 1] = peaks[:, 1] / max(peaks[:, 1]) | |
| packaged_peaks: list[Peak] = [ | |
| Peak( | |
| mz=np.array(precursor_mz, np.float32), | |
| intensity=2 | |
| ) | |
| ] | |
| for mz, intensity in peaks: | |
| packaged_peaks.append( | |
| Peak( | |
| mz=mz, | |
| intensity=intensity | |
| ) | |
| ) | |
| metadata = MetaData( | |
| smiles=smiles, | |
| peaks=packaged_peaks | |
| ) | |
| return metadata | |