File size: 3,490 Bytes
6039b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Sequence
from collections.abc import Sequence

import numpy as np
from tqdm import tqdm
from matchms import Spectrum
from torch.utils.data import Dataset

from type import Peak, MetaData, TokenSequence

SpecialToken = {
    "PAD": 0,
}

class TestDataset(Dataset):
    def __init__(self, sequences: list[TokenSequence]) -> None:
        super(TestDataset, self).__init__()
        self._sequences = sequences
        self.length = len(sequences)

    def __len__(self):
        return self.length

    def __getitem__(self, index: int):
        sequence = self._sequences[index]
        return sequence["mz"], sequence["intensity"], sequence["mask"]

class Tokenizer:
    def __init__(self, max_len: int, show_progress_bar: bool = True) -> None:
        """
            Tokenization of mass spectrometry data

            Parameters:
            ---
            -   max_len: Maximum number of peaks to extract
            -   show_progress_bar: Whether to display a progress bar
        """
        self.max_len = max_len
        self.show_progress_bar = show_progress_bar

    def tokenize(self, s: Spectrum):
        """
            Tokenization of mass spectrometry data
        """
        metadata = self.get_metadata(s)
        mz = []
        intensity = []
        for peak in metadata["peaks"]:
            mz.append(peak["mz"])
            intensity.append(peak["intensity"])

        mz = np.array(mz)
        intensity = np.array(intensity)
        mask = np.zeros((self.max_len, ), dtype=bool)
        if len(mz) < self.max_len:
            mask[len(mz):] = True
            mz = np.pad(
                mz, (0, self.max_len - len(mz)),
                mode='constant', constant_values=SpecialToken["PAD"]
            )

            intensity = np.pad(
                intensity, (0, self.max_len - len(intensity)),
                mode='constant', constant_values=SpecialToken["PAD"]
            )

        return TokenSequence(
            mz=np.array(mz, np.float32),
            intensity=np.array(intensity, np.float32),
            mask=mask,
            smiles=metadata["smiles"]
        )

    def tokenize_sequence(self, spectra: Sequence[Spectrum]):
        sequences: list[TokenSequence] = []
        pbar = spectra
        if self.show_progress_bar:
            pbar = tqdm(spectra, total=len(spectra), desc="tokenization")
        for s in pbar:
            sequences.append(self.tokenize(s))

        return sequences

    def get_metadata(self, s: Spectrum):
        """
            get the metadata from spectrum

            -   smiles
            -   precursor_mz
            -   peaks
        """
        precursor_mz = s.get("precursor_mz")
        smiles = s.get("smiles")
        peaks = np.array(s.peaks.to_numpy, np.float32)
        intensity = peaks[:, 1]
        argmaxsort_index = np.sort(
            np.argsort(intensity)[::-1][:self.max_len - 1]
        )
        peaks = peaks[argmaxsort_index]
        peaks[:, 1] = peaks[:, 1] / max(peaks[:, 1])
        packaged_peaks: list[Peak] = [
            Peak(
                mz=np.array(precursor_mz, np.float32),
                intensity=2
            )
        ]
        for mz, intensity in peaks:
            packaged_peaks.append(
                Peak(
                    mz=mz,
                    intensity=intensity
                )
            )
        metadata = MetaData(
            smiles=smiles,
            peaks=packaged_peaks
        )
        return metadata