xp commited on
Commit
6039b52
·
1 Parent(s): 748dc69

init commit

Browse files
Files changed (10) hide show
  1. Dockerfile +2 -2
  2. requirements.txt +9 -3
  3. src/app.py +372 -0
  4. src/data.py +118 -0
  5. src/model.ckpt +3 -0
  6. src/model.py +274 -0
  7. src/streamlit_app.py +0 -40
  8. src/tester.py +30 -0
  9. src/type.py +34 -0
  10. src/utils.py +41 -0
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
@@ -18,4 +18,4 @@ EXPOSE 8501
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
 
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
1
+ matchms==0.27.0
2
+ pandas==2.2.3
3
+ matplotlib==3.7.2
4
+ numba==0.59.1
5
+ numpy==1.26.4
6
+ rdkit==2024.9.6
7
+ seaborn==0.13.2
8
+ streamlit==1.44.1
9
+ torch==2.2.0
src/app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import Sequence
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ import numpy as np
8
+ import pandas as pd
9
+ import seaborn as sns
10
+ sns.set_style("whitegrid")
11
+ sns.set_palette("deep")
12
+ import streamlit as st
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib.container import StemContainer
15
+ from matchms import Spectrum
16
+ from rdkit import Chem
17
+ from rdkit.Chem import Draw
18
+
19
+ from type import TokenizerConfig
20
+ from data import Tokenizer, TestDataset
21
+ from model import SiameseModel
22
+ from tester import ModelTester
23
+ from utils import top_k_indices, cosine_similarity, read_raw_spectra
24
+
25
+ torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
26
+
27
+ PAGE_SIZE = 5
28
+ BATCH_SIZE = 64
29
+ LOADER_BATCH_SIZE = 32
30
+ CANDIDATE_PAGE = [2, 5, 10, 20]
31
+ SHOW_PROGRESS_BAR = False
32
+
33
+ device = torch.device("cpu")
34
+ tokenizer_config = TokenizerConfig(
35
+ max_len=100,
36
+ show_progress_bar=SHOW_PROGRESS_BAR
37
+ )
38
+ tokenizer = Tokenizer(100, SHOW_PROGRESS_BAR)
39
+ model = SiameseModel(
40
+ embedding_dim=512,
41
+ n_head=16,
42
+ n_layer=4,
43
+ dim_feedward=512,
44
+ dim_target=512,
45
+ feedward_activation="selu"
46
+ )
47
+ model_state = torch.load("model.ckpt", map_location=device)
48
+ model.load_state_dict(model_state)
49
+ tester = ModelTester(model, device, SHOW_PROGRESS_BAR)
50
+
51
+ def custom_stemcontainer(stem_container: StemContainer):
52
+ stem_container.markerline.set_marker("")
53
+ stem_container.baseline.set_color("none")
54
+ stem_container.baseline.set_alpha(0.5)
55
+
56
+ def draw_mol(smiles: str):
57
+ mol = Chem.MolFromSmiles(smiles)
58
+ image = Draw.MolToImage(mol)
59
+ return image
60
+
61
+ def plot_pair(q: Spectrum, r: Spectrum):
62
+ q_peaks = q.peaks.to_numpy
63
+ r_peaks = r.peaks.to_numpy
64
+ fig, ax = plt.subplots(1, 1, figsize=(5, 2.7), dpi=300)
65
+ ax.text(0.8, 0.8, "query", transform=ax.transAxes)
66
+ ax.text(0.8, 0.2, "reference", transform=ax.transAxes)
67
+ container1 = ax.stem(q_peaks[:, 0], q_peaks[:, 1])
68
+ custom_stemcontainer(container1)
69
+ container2 = ax.stem(r_peaks[:, 0], -r_peaks[:, 1])
70
+ custom_stemcontainer(container2)
71
+ return fig
72
+
73
+ def generate_result():
74
+ ref_smiles = st.session_state.ref_smiles
75
+ match_indices = st.session_state.match_indices
76
+ df = pd.DataFrame(columns=["ID", "Smiles"])
77
+ for i, index in enumerate(match_indices):
78
+ df.loc[len(df)] = [i + 1, ref_smiles[index]]
79
+ st.session_state.result = df.to_csv(index=False).encode("utf8")
80
+
81
+ def get_smiles(spectra: Sequence[Spectrum]):
82
+ smiles_seq = [
83
+ s.get("smiles", "")
84
+ for s in spectra
85
+ ]
86
+ return np.array(smiles_seq)
87
+
88
+ def batch_match(
89
+ progress_bar,
90
+ query_embedding,
91
+ ref_embedding
92
+ ):
93
+ length = len(query_embedding)
94
+ start_seq, end_seq = gen_start_end_seq(length)
95
+ indices = []
96
+
97
+ progress = 0
98
+ for start, end in zip(start_seq, end_seq):
99
+ batch_embedding = query_embedding[start:end]
100
+ cosine_scores = cosine_similarity(batch_embedding, ref_embedding)
101
+ batch_indices = top_k_indices(cosine_scores, 1)
102
+ indices.append(batch_indices)
103
+ if progress + BATCH_SIZE >= length:
104
+ progress = length - 1
105
+ else:
106
+ progress += BATCH_SIZE
107
+ progress_bar.progress((progress + 1) / length)
108
+
109
+ return np.concatenate(indices, axis=0)[:, 0]
110
+
111
+
112
+ def init_session_state():
113
+ if "query_path" not in st.session_state:
114
+ st.session_state.query_path = None
115
+
116
+ if "ref_path" not in st.session_state:
117
+ st.session_state.ref_path = None
118
+
119
+ if "data_len" not in st.session_state:
120
+ st.session_state.data_len = None
121
+
122
+ if "query_embedding" not in st.session_state:
123
+ st.session_state.query_embedding = None
124
+
125
+ if "ref_embedding" not in st.session_state:
126
+ st.session_state.ref_embedding = None
127
+
128
+ if "query_smiles" not in st.session_state:
129
+ st.session_state.query_smiles = None
130
+
131
+ if "ref_smiles" not in st.session_state:
132
+ st.session_state.ref_smiles = None
133
+
134
+ if "query_spectra" not in st.session_state:
135
+ st.session_state.query_spectra = None
136
+
137
+ if "ref_spectra" not in st.session_state:
138
+ st.session_state.ref_spectra = None
139
+
140
+ if "match_indices" not in st.session_state:
141
+ st.session_state.match_indices = None
142
+
143
+ if "current_page" not in st.session_state:
144
+ st.session_state.current_page = None
145
+
146
+ if "last_page" not in st.session_state:
147
+ st.session_state.last_page = None
148
+
149
+ if "page_size" not in st.session_state:
150
+ st.session_state.page_size = PAGE_SIZE
151
+
152
+ def previous_page():
153
+ current_page = st.session_state.current_page
154
+ if current_page != 1:
155
+ st.session_state.current_page -= 1
156
+
157
+ def next_page():
158
+ current_page = st.session_state.current_page
159
+ last_page = st.session_state.last_page
160
+ if current_page != last_page:
161
+ st.session_state.current_page += 1
162
+
163
+ def select_page():
164
+ st.session_state.current_page = int(st.session_state.page_selector)
165
+
166
+ def set_page_size():
167
+ st.session_state.current_page = 1
168
+ page_size = int(st.session_state.page_size_selector)
169
+ st.session_state.page_size = page_size
170
+ cal_page_num(st.session_state.data_len, page_size)
171
+
172
+ def cal_page_num(
173
+ length: int,
174
+ page_size: int
175
+ ):
176
+ page_num, rest = divmod(length, page_size)
177
+ if rest != 0:
178
+ page_num += 1
179
+ st.session_state.last_page = page_num
180
+
181
+ def gen_start_end_seq(
182
+ length: int,
183
+ ):
184
+ start_seq = range(0, length, BATCH_SIZE)
185
+ end_seq = range(BATCH_SIZE, length + BATCH_SIZE, BATCH_SIZE)
186
+ return start_seq, end_seq
187
+
188
+ def embedding(
189
+ progress_bar,
190
+ tester: ModelTester,
191
+ tokenizer: Tokenizer,
192
+ spectra: Sequence[Spectrum],
193
+ ):
194
+ sequences = tokenizer.tokenize_sequence(spectra)
195
+ start_seq, end_seq = gen_start_end_seq(len(spectra))
196
+ progress = 0
197
+ embedding = []
198
+ for start, end in zip(start_seq, end_seq):
199
+ test_dataset = TestDataset(sequences[start:end])
200
+ test_dataloader = DataLoader(
201
+ test_dataset,
202
+ LOADER_BATCH_SIZE,
203
+ False
204
+ )
205
+ step_embedding = tester.test(test_dataloader)
206
+ if progress + BATCH_SIZE >= len(spectra):
207
+ progress = len(spectra) - 1
208
+ else:
209
+ progress += BATCH_SIZE
210
+
211
+ embedding.append(step_embedding)
212
+ progress_bar.progress((progress + 1) / len(spectra))
213
+
214
+ embedding = np.concatenate(embedding, axis=0)
215
+ return embedding
216
+
217
+ def main():
218
+ st.set_page_config(layout="wide")
219
+ st.title("SpecEmbedding")
220
+ tab1, tab2, tab3 = st.tabs(["upload query file", "upload reference/library file", "library match"])
221
+
222
+ with tab1:
223
+ st.header("Upload query spectra file(positive mode)")
224
+ query_file = st.file_uploader(
225
+ "upload the query spectra file",
226
+ type=["msp", "mgf", "mzxml"],
227
+ key="query_file",
228
+ accept_multiple_files=False
229
+ )
230
+ query_embedding_btn = st.button("Embedding", "query_embedding_btn")
231
+ query_status_box = st.empty()
232
+ if query_embedding_btn:
233
+ if query_file is not None:
234
+ with tempfile.NamedTemporaryFile(delete=True, suffix="." + query_file.name.split(".")[-1]) as tmp_file:
235
+ tmp_file.write(query_file.getvalue())
236
+ query_spectra = read_raw_spectra(tmp_file.name)
237
+
238
+ progress_bar = st.progress(0, text="Embedding...")
239
+ st.session_state.data_len = len(query_spectra)
240
+ st.session_state.query_spectra = query_spectra
241
+ st.session_state.query_smiles = get_smiles(query_spectra)
242
+ query_embedding = embedding(
243
+ progress_bar,
244
+ tester,
245
+ tokenizer,
246
+ query_spectra,
247
+ )
248
+ st.session_state.query_embedding = query_embedding
249
+ query_status_box.success("Embedding Success ✅")
250
+ else:
251
+ query_status_box.error("Please upload the spectra file")
252
+
253
+ with tab2:
254
+ st.header("Upload reference/library spectra file(positive mode)")
255
+ ref_file = st.file_uploader(
256
+ "upload the reference/library spectra file",
257
+ type=["msp", "mgf", "mzxml"],
258
+ key="ref_file",
259
+ accept_multiple_files=False
260
+ )
261
+ ref_embedding_btn = st.button("Embedding", "ref_embedding_btn")
262
+ ref_status_box = st.empty()
263
+ if ref_embedding_btn:
264
+ if ref_file is not None:
265
+ progress_bar = st.progress(0, text="Embedding...")
266
+ with tempfile.NamedTemporaryFile(delete=True, suffix="." + ref_file.name.split(".")[-1]) as tmp_file:
267
+ tmp_file.write(ref_file.getvalue())
268
+ ref_spectra = read_raw_spectra(tmp_file.name)
269
+
270
+ st.session_state.ref_spectra = ref_spectra
271
+ st.session_state.ref_smiles = get_smiles(ref_spectra)
272
+ ref_embedding = embedding(
273
+ progress_bar,
274
+ tester,
275
+ tokenizer,
276
+ ref_spectra,
277
+ )
278
+ st.session_state.ref_embedding = ref_embedding
279
+ ref_status_box.success("Embedding Success ✅")
280
+ else:
281
+ ref_status_box.error("Please upload the spectra file")
282
+
283
+ with tab3:
284
+ st.header("Start to match")
285
+ launch_btn = st.button("Launch", key="launch_btn")
286
+ match_status_box = st.empty()
287
+ if launch_btn:
288
+ query_embedding = st.session_state.query_embedding
289
+ ref_embedding = st.session_state.ref_embedding
290
+ if query_embedding is None:
291
+ match_status_box.error("No query embedding")
292
+ elif ref_embedding is None:
293
+ match_status_box.error("No reference embedding")
294
+ else:
295
+ progress_bar = st.progress(0, "Match...")
296
+ match_indices = batch_match(progress_bar, query_embedding, ref_embedding)
297
+ st.session_state.match_indices = match_indices
298
+ st.session_state.current_page = 1
299
+ generate_result()
300
+ cal_page_num(st.session_state.data_len, st.session_state.page_size)
301
+ match_status_box.success("match success")
302
+
303
+ if st.session_state.match_indices is not None:
304
+ st.subheader(f"Match Result")
305
+ current_page = st.session_state.current_page
306
+ last_page = st.session_state.last_page
307
+
308
+ ref_smiles = st.session_state.ref_smiles
309
+ query_spectra = st.session_state.query_spectra
310
+ ref_spectra = st.session_state.ref_spectra
311
+ page_size = st.session_state.page_size
312
+
313
+ indices = st.session_state.match_indices
314
+ start = (current_page - 1) * page_size
315
+ end = start + page_size
316
+
317
+ if current_page == last_page:
318
+ end = indices.shape[0]
319
+
320
+ col1, col2, _ = st.columns([1, 1, 5])
321
+
322
+ col1.selectbox(
323
+ "page size",
324
+ CANDIDATE_PAGE,
325
+ key="page_size_selector",
326
+ disabled=False,
327
+ label_visibility="collapsed",
328
+ index=CANDIDATE_PAGE.index(page_size),
329
+ on_change=set_page_size,
330
+ )
331
+
332
+ col2.download_button(
333
+ label="download result",
334
+ data=st.session_state.result,
335
+ file_name="data.csv",
336
+ mime="text/csv"
337
+ )
338
+
339
+ pre_btn, current, next_btn, page_selector, _ = st.columns([1, 1, 1, 1, 2])
340
+ pre_btn.button("previous page", key="pre_btn", on_click=previous_page)
341
+ current.subheader(f"current page: {current_page}")
342
+ next_btn.button("next page", key="next_btn", on_click=next_page)
343
+ page_selector.selectbox(
344
+ label="target page",
345
+ key="page_selector",
346
+ options=range(1, last_page + 1),
347
+ disabled=False,
348
+ index=current_page - 1,
349
+ label_visibility="collapsed",
350
+ on_change=select_page,
351
+ )
352
+
353
+ col1, col2, col3, col4 = st.columns([1, 4, 6, 4])
354
+ col1.subheader("Index")
355
+ col2.subheader("Smiles")
356
+ col3.subheader("MS/MS Spectra Pair")
357
+ col4.subheader("Molecular Structure")
358
+
359
+ for i in range(start, end):
360
+ query_index = i
361
+ ref_index = indices[i]
362
+ id_label, smiles_label, pair_viewer, mol_viewer = st.columns([2, 4, 6, 4])
363
+ id_label.subheader(i + 1)
364
+ smiles_label.text(ref_smiles[ref_index])
365
+ pair_fig = plot_pair(query_spectra[query_index], ref_spectra[ref_index])
366
+ pair_viewer.pyplot(pair_fig, use_container_width=True)
367
+ mol_image = draw_mol(ref_smiles[ref_index])
368
+ mol_viewer.image(mol_image, use_container_width=True)
369
+
370
+ if __name__ == "__main__":
371
+ init_session_state()
372
+ main()
src/data.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ from collections.abc import Sequence
3
+
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from matchms import Spectrum
7
+ from torch.utils.data import Dataset
8
+
9
+ from type import Peak, MetaData, TokenSequence
10
+
11
+ SpecialToken = {
12
+ "PAD": 0,
13
+ }
14
+
15
+ class TestDataset(Dataset):
16
+ def __init__(self, sequences: list[TokenSequence]) -> None:
17
+ super(TestDataset, self).__init__()
18
+ self._sequences = sequences
19
+ self.length = len(sequences)
20
+
21
+ def __len__(self):
22
+ return self.length
23
+
24
+ def __getitem__(self, index: int):
25
+ sequence = self._sequences[index]
26
+ return sequence["mz"], sequence["intensity"], sequence["mask"]
27
+
28
+ class Tokenizer:
29
+ def __init__(self, max_len: int, show_progress_bar: bool = True) -> None:
30
+ """
31
+ Tokenization of mass spectrometry data
32
+
33
+ Parameters:
34
+ ---
35
+ - max_len: Maximum number of peaks to extract
36
+ - show_progress_bar: Whether to display a progress bar
37
+ """
38
+ self.max_len = max_len
39
+ self.show_progress_bar = show_progress_bar
40
+
41
+ def tokenize(self, s: Spectrum):
42
+ """
43
+ Tokenization of mass spectrometry data
44
+ """
45
+ metadata = self.get_metadata(s)
46
+ mz = []
47
+ intensity = []
48
+ for peak in metadata["peaks"]:
49
+ mz.append(peak["mz"])
50
+ intensity.append(peak["intensity"])
51
+
52
+ mz = np.array(mz)
53
+ intensity = np.array(intensity)
54
+ mask = np.zeros((self.max_len, ), dtype=bool)
55
+ if len(mz) < self.max_len:
56
+ mask[len(mz):] = True
57
+ mz = np.pad(
58
+ mz, (0, self.max_len - len(mz)),
59
+ mode='constant', constant_values=SpecialToken["PAD"]
60
+ )
61
+
62
+ intensity = np.pad(
63
+ intensity, (0, self.max_len - len(intensity)),
64
+ mode='constant', constant_values=SpecialToken["PAD"]
65
+ )
66
+
67
+ return TokenSequence(
68
+ mz=np.array(mz, np.float32),
69
+ intensity=np.array(intensity, np.float32),
70
+ mask=mask,
71
+ smiles=metadata["smiles"]
72
+ )
73
+
74
+ def tokenize_sequence(self, spectra: Sequence[Spectrum]):
75
+ sequences: list[TokenSequence] = []
76
+ pbar = spectra
77
+ if self.show_progress_bar:
78
+ pbar = tqdm(spectra, total=len(spectra), desc="tokenization")
79
+ for s in pbar:
80
+ sequences.append(self.tokenize(s))
81
+
82
+ return sequences
83
+
84
+ def get_metadata(self, s: Spectrum):
85
+ """
86
+ get the metadata from spectrum
87
+
88
+ - smiles
89
+ - precursor_mz
90
+ - peaks
91
+ """
92
+ precursor_mz = s.get("precursor_mz")
93
+ smiles = s.get("smiles")
94
+ peaks = np.array(s.peaks.to_numpy, np.float32)
95
+ intensity = peaks[:, 1]
96
+ argmaxsort_index = np.sort(
97
+ np.argsort(intensity)[::-1][:self.max_len - 1]
98
+ )
99
+ peaks = peaks[argmaxsort_index]
100
+ peaks[:, 1] = peaks[:, 1] / max(peaks[:, 1])
101
+ packaged_peaks: list[Peak] = [
102
+ Peak(
103
+ mz=np.array(precursor_mz, np.float32),
104
+ intensity=2
105
+ )
106
+ ]
107
+ for mz, intensity in peaks:
108
+ packaged_peaks.append(
109
+ Peak(
110
+ mz=mz,
111
+ intensity=intensity
112
+ )
113
+ )
114
+ metadata = MetaData(
115
+ smiles=smiles,
116
+ peaks=packaged_peaks
117
+ )
118
+ return metadata
src/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ca0aa002a0d061a95410f7a4055e82c7fcb428d0ba04b5714ac3a4e7f0f5cca
3
+ size 31572706
src/model.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Literal, Union, Iterable, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
7
+
8
+ LAMBDA_MIN = math.pow(10, -3.0)
9
+ LAMBDA_MAX = math.pow(10, 3.0)
10
+
11
+ class MultiFeedForwardModule(nn.Module):
12
+ def __init__(
13
+ self,
14
+ input_size: int,
15
+ hidden_size: Union[int, Iterable[int]],
16
+ output_size: int,
17
+ *,
18
+ activation: Literal['relu', 'selu', 'gelu'] = 'relu',
19
+ dropout: float = 0.1,
20
+ dropout_last_layer: bool = True
21
+ ):
22
+ super(MultiFeedForwardModule, self).__init__()
23
+ if activation == 'relu':
24
+ self._activation = nn.ReLU()
25
+ elif activation == 'selu':
26
+ self._activation = nn.SELU()
27
+ elif activation == 'gelu':
28
+ self._activation = nn.GELU()
29
+ else:
30
+ raise ValueError('activation must be relu or selu')
31
+
32
+ if not hasattr(hidden_size, '__iter__'):
33
+ if hidden_size is None:
34
+ hidden_size = [output_size]
35
+ else:
36
+ hidden_size = [hidden_size]
37
+
38
+ self._layers = []
39
+ layer_dims = [input_size] + hidden_size + [output_size]
40
+
41
+ for i in range(1, len(layer_dims) - 1):
42
+ self._layers.append(nn.Linear(layer_dims[i - 1], layer_dims[i]))
43
+ self._layers.append(self._activation)
44
+ self._layers.append(nn.Dropout(dropout))
45
+
46
+ self._layers.append(nn.Linear(layer_dims[-2], layer_dims[-1]))
47
+
48
+ if dropout_last_layer:
49
+ self._layers.append(nn.Dropout(dropout))
50
+ self._layers = nn.Sequential(*self._layers)
51
+
52
+ def forward(self, x):
53
+ return self._layers(x)
54
+
55
+
56
+ class SinusodialMz(nn.Module):
57
+ def __init__(self, embedding_dim: int, *, lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX)) -> None:
58
+ super(SinusodialMz, self).__init__()
59
+ self.lambda_min, self.lambda_max = lambda_params
60
+ self.lambda_div_value = self.lambda_max / self.lambda_min
61
+ self.x = torch.arange(0, embedding_dim, 2)
62
+ self.x = (
63
+ 2 * math.pi *
64
+ (
65
+ self.lambda_min *
66
+ self.lambda_div_value ** (self.x / (embedding_dim - 2))
67
+ ) ** -1
68
+ )
69
+
70
+ def forward(self, mz: torch.Tensor):
71
+ self.x = self.x.to(mz.device)
72
+ x = torch.einsum('bl,d->bld', mz, self.x)
73
+ sin_embedding = torch.sin(x)
74
+ cos_embedding = torch.cos(x)
75
+ b, l, d = sin_embedding.shape
76
+ x = torch.zeros(b, l, 2 * d, dtype=mz.dtype, device=mz.device)
77
+ x[:, :, ::2] = sin_embedding
78
+ x[:, :, 1::2] = cos_embedding
79
+ return x
80
+
81
+
82
+ class SinusodialMzEmbedding(nn.Module):
83
+ def __init__(
84
+ self,
85
+ embedding_dim: int,
86
+ *,
87
+ lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
88
+ feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
89
+ dropout: float = 0.1,
90
+ dropout_last_layer: bool = True
91
+ ):
92
+ super(SinusodialMzEmbedding, self).__init__()
93
+ if embedding_dim % 2 != 0:
94
+ raise ValueError('embedding_dim must be even')
95
+ self.embedding = SinusodialMz(
96
+ embedding_dim, lambda_params=lambda_params)
97
+ self.feedward_layers = MultiFeedForwardModule(
98
+ embedding_dim, embedding_dim, embedding_dim,
99
+ activation=feedward_activation, dropout=dropout, dropout_last_layer=dropout_last_layer
100
+ )
101
+
102
+ def forward(self, mz: torch.Tensor):
103
+ x = self.embedding(mz)
104
+ x = self.feedward_layers(x)
105
+ return x
106
+
107
+
108
+ class PeaksEmbedding(nn.Module):
109
+ def __init__(
110
+ self,
111
+ embedding_dim: int,
112
+ *,
113
+ lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
114
+ feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
115
+ dropout: float = 0.1,
116
+ dropout_last_layer: bool = False
117
+ ) -> None:
118
+ super(PeaksEmbedding, self).__init__()
119
+ self.mz_embedding = SinusodialMzEmbedding(
120
+ embedding_dim,
121
+ lambda_params=lambda_params,
122
+ feedward_activation=feedward_activation,
123
+ dropout=dropout,
124
+ dropout_last_layer=dropout_last_layer
125
+ )
126
+ self.intensity_embedding = MultiFeedForwardModule(
127
+ embedding_dim + 1, embedding_dim, embedding_dim,
128
+ activation=feedward_activation,
129
+ dropout=dropout,
130
+ dropout_last_layer=dropout_last_layer
131
+ )
132
+
133
+ def forward(self, mz: torch.Tensor, intensity: torch.Tensor):
134
+ mz_tensor = self.mz_embedding(mz)
135
+ intensity_tensor = torch.unsqueeze(intensity, dim=-1)
136
+ x = self.intensity_embedding(
137
+ torch.cat([mz_tensor, intensity_tensor], dim=-1))
138
+ return x
139
+
140
+
141
+ class SiameseModel(nn.Module):
142
+ def __init__(
143
+ self,
144
+ embedding_dim: int,
145
+ n_head: int,
146
+ n_layer: int,
147
+ dim_feedward: int,
148
+ dim_target: int,
149
+ *,
150
+ lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
151
+ feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
152
+ dropout: float = 0.1,
153
+ dropout_last_layer: bool = False,
154
+ norm_first: bool = True
155
+ ) -> None:
156
+ super(SiameseModel, self).__init__()
157
+ if embedding_dim % n_head != 0:
158
+ raise ValueError('embedding must be divisible by n_head')
159
+
160
+ self.embedding = PeaksEmbedding(
161
+ embedding_dim,
162
+ lambda_params=lambda_params,
163
+ feedward_activation=feedward_activation,
164
+ dropout=dropout,
165
+ dropout_last_layer=dropout_last_layer
166
+ )
167
+
168
+ if feedward_activation == 'selu':
169
+ # transformer encoder activation
170
+ # only gelu or relu
171
+ self.activation = 'gelu'
172
+ else:
173
+ self.activation = feedward_activation
174
+
175
+ if feedward_activation == 'relu':
176
+ self._activation = nn.ReLU()
177
+ elif feedward_activation == 'selu':
178
+ self._activation = nn.SELU()
179
+ elif feedward_activation == 'gelu':
180
+ self._activation = nn.GELU()
181
+ else:
182
+ raise ValueError('activation must be relu or selu or gelu')
183
+
184
+ encoder_layer = TransformerEncoderLayer(
185
+ embedding_dim,
186
+ n_head,
187
+ dim_feedforward=dim_feedward,
188
+ dropout=dropout,
189
+ activation=self.activation,
190
+ batch_first=True,
191
+ norm_first=norm_first
192
+ )
193
+ self._encoder = TransformerEncoder(
194
+ encoder_layer,
195
+ n_layer,
196
+ enable_nested_tensor=False
197
+ )
198
+
199
+ self._decoder = MultiFeedForwardModule(
200
+ embedding_dim,
201
+ dim_feedward,
202
+ dim_target,
203
+ activation=feedward_activation,
204
+ dropout=dropout,
205
+ dropout_last_layer=dropout_last_layer
206
+ )
207
+
208
+ def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor):
209
+ x = self.embedding(mz, intensity)
210
+ x = self._encoder(x, src_key_padding_mask=mask)
211
+ # mean pooling or cls position vector
212
+ x = torch.mean(x, dim=1)
213
+ x = self._activation(self._decoder(x))
214
+ return x
215
+
216
+
217
+ # class MambaSiameseModel(nn.Module):
218
+ # def __init__(
219
+ # self,
220
+ # embedding_dim: int,
221
+ # n_layer: int,
222
+ # dim_feedward: int,
223
+ # dim_target: int,
224
+ # *,
225
+ # lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
226
+ # feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
227
+ # dropout: float = 0.1,
228
+ # dropout_last_layer: bool = False,
229
+ # ):
230
+ # super(MambaSiameseModel, self).__init__()
231
+
232
+ # self.embedding = PeaksEmbedding(
233
+ # embedding_dim,
234
+ # lambda_params=lambda_params,
235
+ # feedward_activation=feedward_activation,
236
+ # dropout=dropout,
237
+ # dropout_last_layer=dropout_last_layer
238
+ # )
239
+
240
+ # if feedward_activation == 'relu':
241
+ # self._activation = nn.ReLU()
242
+ # elif feedward_activation == 'selu':
243
+ # self._activation = nn.SELU()
244
+ # elif feedward_activation == 'gelu':
245
+ # self._activation = nn.GELU()
246
+ # else:
247
+ # raise ValueError('activation must be relu or selu or gelu')
248
+
249
+ # self._encoder = nn.Sequential(*[
250
+ # Mamba2(
251
+ # d_model=embedding_dim,
252
+ # d_state=64,
253
+ # d_conv=4,
254
+ # expand=2
255
+ # )
256
+ # for _ in range(n_layer)
257
+ # ])
258
+
259
+ # self._decoder = MultiFeedForwardModule(
260
+ # embedding_dim,
261
+ # dim_feedward,
262
+ # dim_target,
263
+ # activation=feedward_activation,
264
+ # dropout=dropout,
265
+ # dropout_last_layer=dropout_last_layer
266
+ # )
267
+
268
+ # def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor):
269
+ # x = self.embedding(mz, intensity)
270
+ # x = self._encoder(x)
271
+ # # mean pooling or cls position vector
272
+ # x = torch.mean(x, dim=1)
273
+ # x = self._activation(self._decoder(x))
274
+ # return x
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/tester.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Module
3
+ from torch.utils.data import DataLoader
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+
7
+ class ModelTester:
8
+ def __init__(
9
+ self,
10
+ model: Module,
11
+ device: torch.device,
12
+ show_prgress_bar: bool = True
13
+ ) -> None:
14
+ self.model = model
15
+ self.device = device
16
+ self.show_prgress_bar = show_prgress_bar
17
+
18
+ def test(self, dataloader: DataLoader):
19
+ self.model.eval()
20
+ result = []
21
+ with torch.no_grad():
22
+ pbar = dataloader
23
+ if self.show_prgress_bar:
24
+ pbar = tqdm(dataloader, total=len(
25
+ dataloader), desc="embedding")
26
+ for x in pbar:
27
+ x = [d.to(self.device) for d in x]
28
+ pred: torch.Tensor = self.model(*x)
29
+ result.append(pred.cpu().numpy())
30
+ return np.concatenate(result, axis=0)
src/type.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Sequence, Callable, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch import device
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+ BatchType = Sequence[torch.Tensor]
10
+ StepTrain = Callable[[nn.Module, nn.Module, device,
11
+ BatchType, Optional[Callable[..., int]]], Sequence[torch.Tensor]]
12
+ StepVal = Callable[[nn.Module, nn.Module, device,
13
+ BatchType, Optional[Callable[..., int]]], Sequence[torch.Tensor]]
14
+
15
+ class Peak(TypedDict):
16
+ mz: str
17
+ intensity: npt.NDArray
18
+
19
+
20
+ class MetaData(TypedDict):
21
+ peaks: Sequence[Peak]
22
+ smiles: str
23
+
24
+
25
+ class TokenSequence(TypedDict):
26
+ mz: npt.NDArray[np.int32]
27
+ intensity: npt.NDArray[np.float32]
28
+ mask: npt.NDArray[np.bool_]
29
+ smiles: str
30
+
31
+
32
+ class TokenizerConfig(TypedDict):
33
+ max_len: int
34
+ show_progress_bar: bool
src/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ from numba import prange, njit
6
+ from matchms.importing import load_from_mgf, load_from_msp, load_from_mzxml
7
+ from matchms.filtering import default_filters, normalize_intensities
8
+
9
+ def read_raw_spectra(path: str):
10
+ suffix = Path(path).suffix
11
+ if suffix == ".mgf":
12
+ spectra = list(load_from_mgf(path))
13
+ elif suffix == ".msp":
14
+ spectra = list(load_from_msp(path))
15
+ elif suffix == ".mzxml":
16
+ spectra = list(load_from_mzxml(path))
17
+ else:
18
+ raise ValueError(f"Not support the {suffix} format")
19
+
20
+ spectra = [default_filters(s) for s in spectra]
21
+ spectra = [normalize_intensities(s) for s in spectra]
22
+ return spectra
23
+
24
+ @njit
25
+ def cosine_similarity(A: npt.NDArray, B: npt.NDArray):
26
+ norm_A = np.sqrt(np.sum(A ** 2, axis=1)) + 1e-8
27
+ norm_B = np.sqrt(np.sum(B ** 2, axis=1)) + 1e-8
28
+ normalize_A = A / norm_A[:, np.newaxis]
29
+ normalize_B = B / norm_B[:, np.newaxis]
30
+ scores = np.dot(normalize_A, normalize_B.T)
31
+ return scores
32
+
33
+ @njit(parallel=True)
34
+ def top_k_indices(score, top_k):
35
+ rows, cols = score.shape
36
+ indices = np.empty((rows, top_k), dtype=np.int64)
37
+ for i in prange(rows):
38
+ row = score[i]
39
+ sorted_idx = np.argsort(row)[::-1]
40
+ indices[i] = sorted_idx[:top_k]
41
+ return indices