Spaces:
Runtime error
Runtime error
xp
commited on
Commit
·
6039b52
1
Parent(s):
748dc69
init commit
Browse files- Dockerfile +2 -2
- requirements.txt +9 -3
- src/app.py +372 -0
- src/data.py +118 -0
- src/model.ckpt +3 -0
- src/model.py +274 -0
- src/streamlit_app.py +0 -40
- src/tester.py +30 -0
- src/type.py +34 -0
- src/utils.py +41 -0
Dockerfile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
FROM python:3.
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
@@ -18,4 +18,4 @@ EXPOSE 8501
|
|
| 18 |
|
| 19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 20 |
|
| 21 |
-
ENTRYPOINT ["streamlit", "run", "src/
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
|
|
| 18 |
|
| 19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 20 |
|
| 21 |
+
ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
requirements.txt
CHANGED
|
@@ -1,3 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
pandas
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matchms==0.27.0
|
| 2 |
+
pandas==2.2.3
|
| 3 |
+
matplotlib==3.7.2
|
| 4 |
+
numba==0.59.1
|
| 5 |
+
numpy==1.26.4
|
| 6 |
+
rdkit==2024.9.6
|
| 7 |
+
seaborn==0.13.2
|
| 8 |
+
streamlit==1.44.1
|
| 9 |
+
torch==2.2.0
|
src/app.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
from typing import Sequence
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import seaborn as sns
|
| 10 |
+
sns.set_style("whitegrid")
|
| 11 |
+
sns.set_palette("deep")
|
| 12 |
+
import streamlit as st
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from matplotlib.container import StemContainer
|
| 15 |
+
from matchms import Spectrum
|
| 16 |
+
from rdkit import Chem
|
| 17 |
+
from rdkit.Chem import Draw
|
| 18 |
+
|
| 19 |
+
from type import TokenizerConfig
|
| 20 |
+
from data import Tokenizer, TestDataset
|
| 21 |
+
from model import SiameseModel
|
| 22 |
+
from tester import ModelTester
|
| 23 |
+
from utils import top_k_indices, cosine_similarity, read_raw_spectra
|
| 24 |
+
|
| 25 |
+
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
|
| 26 |
+
|
| 27 |
+
PAGE_SIZE = 5
|
| 28 |
+
BATCH_SIZE = 64
|
| 29 |
+
LOADER_BATCH_SIZE = 32
|
| 30 |
+
CANDIDATE_PAGE = [2, 5, 10, 20]
|
| 31 |
+
SHOW_PROGRESS_BAR = False
|
| 32 |
+
|
| 33 |
+
device = torch.device("cpu")
|
| 34 |
+
tokenizer_config = TokenizerConfig(
|
| 35 |
+
max_len=100,
|
| 36 |
+
show_progress_bar=SHOW_PROGRESS_BAR
|
| 37 |
+
)
|
| 38 |
+
tokenizer = Tokenizer(100, SHOW_PROGRESS_BAR)
|
| 39 |
+
model = SiameseModel(
|
| 40 |
+
embedding_dim=512,
|
| 41 |
+
n_head=16,
|
| 42 |
+
n_layer=4,
|
| 43 |
+
dim_feedward=512,
|
| 44 |
+
dim_target=512,
|
| 45 |
+
feedward_activation="selu"
|
| 46 |
+
)
|
| 47 |
+
model_state = torch.load("model.ckpt", map_location=device)
|
| 48 |
+
model.load_state_dict(model_state)
|
| 49 |
+
tester = ModelTester(model, device, SHOW_PROGRESS_BAR)
|
| 50 |
+
|
| 51 |
+
def custom_stemcontainer(stem_container: StemContainer):
|
| 52 |
+
stem_container.markerline.set_marker("")
|
| 53 |
+
stem_container.baseline.set_color("none")
|
| 54 |
+
stem_container.baseline.set_alpha(0.5)
|
| 55 |
+
|
| 56 |
+
def draw_mol(smiles: str):
|
| 57 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 58 |
+
image = Draw.MolToImage(mol)
|
| 59 |
+
return image
|
| 60 |
+
|
| 61 |
+
def plot_pair(q: Spectrum, r: Spectrum):
|
| 62 |
+
q_peaks = q.peaks.to_numpy
|
| 63 |
+
r_peaks = r.peaks.to_numpy
|
| 64 |
+
fig, ax = plt.subplots(1, 1, figsize=(5, 2.7), dpi=300)
|
| 65 |
+
ax.text(0.8, 0.8, "query", transform=ax.transAxes)
|
| 66 |
+
ax.text(0.8, 0.2, "reference", transform=ax.transAxes)
|
| 67 |
+
container1 = ax.stem(q_peaks[:, 0], q_peaks[:, 1])
|
| 68 |
+
custom_stemcontainer(container1)
|
| 69 |
+
container2 = ax.stem(r_peaks[:, 0], -r_peaks[:, 1])
|
| 70 |
+
custom_stemcontainer(container2)
|
| 71 |
+
return fig
|
| 72 |
+
|
| 73 |
+
def generate_result():
|
| 74 |
+
ref_smiles = st.session_state.ref_smiles
|
| 75 |
+
match_indices = st.session_state.match_indices
|
| 76 |
+
df = pd.DataFrame(columns=["ID", "Smiles"])
|
| 77 |
+
for i, index in enumerate(match_indices):
|
| 78 |
+
df.loc[len(df)] = [i + 1, ref_smiles[index]]
|
| 79 |
+
st.session_state.result = df.to_csv(index=False).encode("utf8")
|
| 80 |
+
|
| 81 |
+
def get_smiles(spectra: Sequence[Spectrum]):
|
| 82 |
+
smiles_seq = [
|
| 83 |
+
s.get("smiles", "")
|
| 84 |
+
for s in spectra
|
| 85 |
+
]
|
| 86 |
+
return np.array(smiles_seq)
|
| 87 |
+
|
| 88 |
+
def batch_match(
|
| 89 |
+
progress_bar,
|
| 90 |
+
query_embedding,
|
| 91 |
+
ref_embedding
|
| 92 |
+
):
|
| 93 |
+
length = len(query_embedding)
|
| 94 |
+
start_seq, end_seq = gen_start_end_seq(length)
|
| 95 |
+
indices = []
|
| 96 |
+
|
| 97 |
+
progress = 0
|
| 98 |
+
for start, end in zip(start_seq, end_seq):
|
| 99 |
+
batch_embedding = query_embedding[start:end]
|
| 100 |
+
cosine_scores = cosine_similarity(batch_embedding, ref_embedding)
|
| 101 |
+
batch_indices = top_k_indices(cosine_scores, 1)
|
| 102 |
+
indices.append(batch_indices)
|
| 103 |
+
if progress + BATCH_SIZE >= length:
|
| 104 |
+
progress = length - 1
|
| 105 |
+
else:
|
| 106 |
+
progress += BATCH_SIZE
|
| 107 |
+
progress_bar.progress((progress + 1) / length)
|
| 108 |
+
|
| 109 |
+
return np.concatenate(indices, axis=0)[:, 0]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def init_session_state():
|
| 113 |
+
if "query_path" not in st.session_state:
|
| 114 |
+
st.session_state.query_path = None
|
| 115 |
+
|
| 116 |
+
if "ref_path" not in st.session_state:
|
| 117 |
+
st.session_state.ref_path = None
|
| 118 |
+
|
| 119 |
+
if "data_len" not in st.session_state:
|
| 120 |
+
st.session_state.data_len = None
|
| 121 |
+
|
| 122 |
+
if "query_embedding" not in st.session_state:
|
| 123 |
+
st.session_state.query_embedding = None
|
| 124 |
+
|
| 125 |
+
if "ref_embedding" not in st.session_state:
|
| 126 |
+
st.session_state.ref_embedding = None
|
| 127 |
+
|
| 128 |
+
if "query_smiles" not in st.session_state:
|
| 129 |
+
st.session_state.query_smiles = None
|
| 130 |
+
|
| 131 |
+
if "ref_smiles" not in st.session_state:
|
| 132 |
+
st.session_state.ref_smiles = None
|
| 133 |
+
|
| 134 |
+
if "query_spectra" not in st.session_state:
|
| 135 |
+
st.session_state.query_spectra = None
|
| 136 |
+
|
| 137 |
+
if "ref_spectra" not in st.session_state:
|
| 138 |
+
st.session_state.ref_spectra = None
|
| 139 |
+
|
| 140 |
+
if "match_indices" not in st.session_state:
|
| 141 |
+
st.session_state.match_indices = None
|
| 142 |
+
|
| 143 |
+
if "current_page" not in st.session_state:
|
| 144 |
+
st.session_state.current_page = None
|
| 145 |
+
|
| 146 |
+
if "last_page" not in st.session_state:
|
| 147 |
+
st.session_state.last_page = None
|
| 148 |
+
|
| 149 |
+
if "page_size" not in st.session_state:
|
| 150 |
+
st.session_state.page_size = PAGE_SIZE
|
| 151 |
+
|
| 152 |
+
def previous_page():
|
| 153 |
+
current_page = st.session_state.current_page
|
| 154 |
+
if current_page != 1:
|
| 155 |
+
st.session_state.current_page -= 1
|
| 156 |
+
|
| 157 |
+
def next_page():
|
| 158 |
+
current_page = st.session_state.current_page
|
| 159 |
+
last_page = st.session_state.last_page
|
| 160 |
+
if current_page != last_page:
|
| 161 |
+
st.session_state.current_page += 1
|
| 162 |
+
|
| 163 |
+
def select_page():
|
| 164 |
+
st.session_state.current_page = int(st.session_state.page_selector)
|
| 165 |
+
|
| 166 |
+
def set_page_size():
|
| 167 |
+
st.session_state.current_page = 1
|
| 168 |
+
page_size = int(st.session_state.page_size_selector)
|
| 169 |
+
st.session_state.page_size = page_size
|
| 170 |
+
cal_page_num(st.session_state.data_len, page_size)
|
| 171 |
+
|
| 172 |
+
def cal_page_num(
|
| 173 |
+
length: int,
|
| 174 |
+
page_size: int
|
| 175 |
+
):
|
| 176 |
+
page_num, rest = divmod(length, page_size)
|
| 177 |
+
if rest != 0:
|
| 178 |
+
page_num += 1
|
| 179 |
+
st.session_state.last_page = page_num
|
| 180 |
+
|
| 181 |
+
def gen_start_end_seq(
|
| 182 |
+
length: int,
|
| 183 |
+
):
|
| 184 |
+
start_seq = range(0, length, BATCH_SIZE)
|
| 185 |
+
end_seq = range(BATCH_SIZE, length + BATCH_SIZE, BATCH_SIZE)
|
| 186 |
+
return start_seq, end_seq
|
| 187 |
+
|
| 188 |
+
def embedding(
|
| 189 |
+
progress_bar,
|
| 190 |
+
tester: ModelTester,
|
| 191 |
+
tokenizer: Tokenizer,
|
| 192 |
+
spectra: Sequence[Spectrum],
|
| 193 |
+
):
|
| 194 |
+
sequences = tokenizer.tokenize_sequence(spectra)
|
| 195 |
+
start_seq, end_seq = gen_start_end_seq(len(spectra))
|
| 196 |
+
progress = 0
|
| 197 |
+
embedding = []
|
| 198 |
+
for start, end in zip(start_seq, end_seq):
|
| 199 |
+
test_dataset = TestDataset(sequences[start:end])
|
| 200 |
+
test_dataloader = DataLoader(
|
| 201 |
+
test_dataset,
|
| 202 |
+
LOADER_BATCH_SIZE,
|
| 203 |
+
False
|
| 204 |
+
)
|
| 205 |
+
step_embedding = tester.test(test_dataloader)
|
| 206 |
+
if progress + BATCH_SIZE >= len(spectra):
|
| 207 |
+
progress = len(spectra) - 1
|
| 208 |
+
else:
|
| 209 |
+
progress += BATCH_SIZE
|
| 210 |
+
|
| 211 |
+
embedding.append(step_embedding)
|
| 212 |
+
progress_bar.progress((progress + 1) / len(spectra))
|
| 213 |
+
|
| 214 |
+
embedding = np.concatenate(embedding, axis=0)
|
| 215 |
+
return embedding
|
| 216 |
+
|
| 217 |
+
def main():
|
| 218 |
+
st.set_page_config(layout="wide")
|
| 219 |
+
st.title("SpecEmbedding")
|
| 220 |
+
tab1, tab2, tab3 = st.tabs(["upload query file", "upload reference/library file", "library match"])
|
| 221 |
+
|
| 222 |
+
with tab1:
|
| 223 |
+
st.header("Upload query spectra file(positive mode)")
|
| 224 |
+
query_file = st.file_uploader(
|
| 225 |
+
"upload the query spectra file",
|
| 226 |
+
type=["msp", "mgf", "mzxml"],
|
| 227 |
+
key="query_file",
|
| 228 |
+
accept_multiple_files=False
|
| 229 |
+
)
|
| 230 |
+
query_embedding_btn = st.button("Embedding", "query_embedding_btn")
|
| 231 |
+
query_status_box = st.empty()
|
| 232 |
+
if query_embedding_btn:
|
| 233 |
+
if query_file is not None:
|
| 234 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix="." + query_file.name.split(".")[-1]) as tmp_file:
|
| 235 |
+
tmp_file.write(query_file.getvalue())
|
| 236 |
+
query_spectra = read_raw_spectra(tmp_file.name)
|
| 237 |
+
|
| 238 |
+
progress_bar = st.progress(0, text="Embedding...")
|
| 239 |
+
st.session_state.data_len = len(query_spectra)
|
| 240 |
+
st.session_state.query_spectra = query_spectra
|
| 241 |
+
st.session_state.query_smiles = get_smiles(query_spectra)
|
| 242 |
+
query_embedding = embedding(
|
| 243 |
+
progress_bar,
|
| 244 |
+
tester,
|
| 245 |
+
tokenizer,
|
| 246 |
+
query_spectra,
|
| 247 |
+
)
|
| 248 |
+
st.session_state.query_embedding = query_embedding
|
| 249 |
+
query_status_box.success("Embedding Success ✅")
|
| 250 |
+
else:
|
| 251 |
+
query_status_box.error("Please upload the spectra file")
|
| 252 |
+
|
| 253 |
+
with tab2:
|
| 254 |
+
st.header("Upload reference/library spectra file(positive mode)")
|
| 255 |
+
ref_file = st.file_uploader(
|
| 256 |
+
"upload the reference/library spectra file",
|
| 257 |
+
type=["msp", "mgf", "mzxml"],
|
| 258 |
+
key="ref_file",
|
| 259 |
+
accept_multiple_files=False
|
| 260 |
+
)
|
| 261 |
+
ref_embedding_btn = st.button("Embedding", "ref_embedding_btn")
|
| 262 |
+
ref_status_box = st.empty()
|
| 263 |
+
if ref_embedding_btn:
|
| 264 |
+
if ref_file is not None:
|
| 265 |
+
progress_bar = st.progress(0, text="Embedding...")
|
| 266 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix="." + ref_file.name.split(".")[-1]) as tmp_file:
|
| 267 |
+
tmp_file.write(ref_file.getvalue())
|
| 268 |
+
ref_spectra = read_raw_spectra(tmp_file.name)
|
| 269 |
+
|
| 270 |
+
st.session_state.ref_spectra = ref_spectra
|
| 271 |
+
st.session_state.ref_smiles = get_smiles(ref_spectra)
|
| 272 |
+
ref_embedding = embedding(
|
| 273 |
+
progress_bar,
|
| 274 |
+
tester,
|
| 275 |
+
tokenizer,
|
| 276 |
+
ref_spectra,
|
| 277 |
+
)
|
| 278 |
+
st.session_state.ref_embedding = ref_embedding
|
| 279 |
+
ref_status_box.success("Embedding Success ✅")
|
| 280 |
+
else:
|
| 281 |
+
ref_status_box.error("Please upload the spectra file")
|
| 282 |
+
|
| 283 |
+
with tab3:
|
| 284 |
+
st.header("Start to match")
|
| 285 |
+
launch_btn = st.button("Launch", key="launch_btn")
|
| 286 |
+
match_status_box = st.empty()
|
| 287 |
+
if launch_btn:
|
| 288 |
+
query_embedding = st.session_state.query_embedding
|
| 289 |
+
ref_embedding = st.session_state.ref_embedding
|
| 290 |
+
if query_embedding is None:
|
| 291 |
+
match_status_box.error("No query embedding")
|
| 292 |
+
elif ref_embedding is None:
|
| 293 |
+
match_status_box.error("No reference embedding")
|
| 294 |
+
else:
|
| 295 |
+
progress_bar = st.progress(0, "Match...")
|
| 296 |
+
match_indices = batch_match(progress_bar, query_embedding, ref_embedding)
|
| 297 |
+
st.session_state.match_indices = match_indices
|
| 298 |
+
st.session_state.current_page = 1
|
| 299 |
+
generate_result()
|
| 300 |
+
cal_page_num(st.session_state.data_len, st.session_state.page_size)
|
| 301 |
+
match_status_box.success("match success")
|
| 302 |
+
|
| 303 |
+
if st.session_state.match_indices is not None:
|
| 304 |
+
st.subheader(f"Match Result")
|
| 305 |
+
current_page = st.session_state.current_page
|
| 306 |
+
last_page = st.session_state.last_page
|
| 307 |
+
|
| 308 |
+
ref_smiles = st.session_state.ref_smiles
|
| 309 |
+
query_spectra = st.session_state.query_spectra
|
| 310 |
+
ref_spectra = st.session_state.ref_spectra
|
| 311 |
+
page_size = st.session_state.page_size
|
| 312 |
+
|
| 313 |
+
indices = st.session_state.match_indices
|
| 314 |
+
start = (current_page - 1) * page_size
|
| 315 |
+
end = start + page_size
|
| 316 |
+
|
| 317 |
+
if current_page == last_page:
|
| 318 |
+
end = indices.shape[0]
|
| 319 |
+
|
| 320 |
+
col1, col2, _ = st.columns([1, 1, 5])
|
| 321 |
+
|
| 322 |
+
col1.selectbox(
|
| 323 |
+
"page size",
|
| 324 |
+
CANDIDATE_PAGE,
|
| 325 |
+
key="page_size_selector",
|
| 326 |
+
disabled=False,
|
| 327 |
+
label_visibility="collapsed",
|
| 328 |
+
index=CANDIDATE_PAGE.index(page_size),
|
| 329 |
+
on_change=set_page_size,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
col2.download_button(
|
| 333 |
+
label="download result",
|
| 334 |
+
data=st.session_state.result,
|
| 335 |
+
file_name="data.csv",
|
| 336 |
+
mime="text/csv"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
pre_btn, current, next_btn, page_selector, _ = st.columns([1, 1, 1, 1, 2])
|
| 340 |
+
pre_btn.button("previous page", key="pre_btn", on_click=previous_page)
|
| 341 |
+
current.subheader(f"current page: {current_page}")
|
| 342 |
+
next_btn.button("next page", key="next_btn", on_click=next_page)
|
| 343 |
+
page_selector.selectbox(
|
| 344 |
+
label="target page",
|
| 345 |
+
key="page_selector",
|
| 346 |
+
options=range(1, last_page + 1),
|
| 347 |
+
disabled=False,
|
| 348 |
+
index=current_page - 1,
|
| 349 |
+
label_visibility="collapsed",
|
| 350 |
+
on_change=select_page,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
col1, col2, col3, col4 = st.columns([1, 4, 6, 4])
|
| 354 |
+
col1.subheader("Index")
|
| 355 |
+
col2.subheader("Smiles")
|
| 356 |
+
col3.subheader("MS/MS Spectra Pair")
|
| 357 |
+
col4.subheader("Molecular Structure")
|
| 358 |
+
|
| 359 |
+
for i in range(start, end):
|
| 360 |
+
query_index = i
|
| 361 |
+
ref_index = indices[i]
|
| 362 |
+
id_label, smiles_label, pair_viewer, mol_viewer = st.columns([2, 4, 6, 4])
|
| 363 |
+
id_label.subheader(i + 1)
|
| 364 |
+
smiles_label.text(ref_smiles[ref_index])
|
| 365 |
+
pair_fig = plot_pair(query_spectra[query_index], ref_spectra[ref_index])
|
| 366 |
+
pair_viewer.pyplot(pair_fig, use_container_width=True)
|
| 367 |
+
mol_image = draw_mol(ref_smiles[ref_index])
|
| 368 |
+
mol_viewer.image(mol_image, use_container_width=True)
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
init_session_state()
|
| 372 |
+
main()
|
src/data.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Sequence
|
| 2 |
+
from collections.abc import Sequence
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from matchms import Spectrum
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
from type import Peak, MetaData, TokenSequence
|
| 10 |
+
|
| 11 |
+
SpecialToken = {
|
| 12 |
+
"PAD": 0,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
class TestDataset(Dataset):
|
| 16 |
+
def __init__(self, sequences: list[TokenSequence]) -> None:
|
| 17 |
+
super(TestDataset, self).__init__()
|
| 18 |
+
self._sequences = sequences
|
| 19 |
+
self.length = len(sequences)
|
| 20 |
+
|
| 21 |
+
def __len__(self):
|
| 22 |
+
return self.length
|
| 23 |
+
|
| 24 |
+
def __getitem__(self, index: int):
|
| 25 |
+
sequence = self._sequences[index]
|
| 26 |
+
return sequence["mz"], sequence["intensity"], sequence["mask"]
|
| 27 |
+
|
| 28 |
+
class Tokenizer:
|
| 29 |
+
def __init__(self, max_len: int, show_progress_bar: bool = True) -> None:
|
| 30 |
+
"""
|
| 31 |
+
Tokenization of mass spectrometry data
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
---
|
| 35 |
+
- max_len: Maximum number of peaks to extract
|
| 36 |
+
- show_progress_bar: Whether to display a progress bar
|
| 37 |
+
"""
|
| 38 |
+
self.max_len = max_len
|
| 39 |
+
self.show_progress_bar = show_progress_bar
|
| 40 |
+
|
| 41 |
+
def tokenize(self, s: Spectrum):
|
| 42 |
+
"""
|
| 43 |
+
Tokenization of mass spectrometry data
|
| 44 |
+
"""
|
| 45 |
+
metadata = self.get_metadata(s)
|
| 46 |
+
mz = []
|
| 47 |
+
intensity = []
|
| 48 |
+
for peak in metadata["peaks"]:
|
| 49 |
+
mz.append(peak["mz"])
|
| 50 |
+
intensity.append(peak["intensity"])
|
| 51 |
+
|
| 52 |
+
mz = np.array(mz)
|
| 53 |
+
intensity = np.array(intensity)
|
| 54 |
+
mask = np.zeros((self.max_len, ), dtype=bool)
|
| 55 |
+
if len(mz) < self.max_len:
|
| 56 |
+
mask[len(mz):] = True
|
| 57 |
+
mz = np.pad(
|
| 58 |
+
mz, (0, self.max_len - len(mz)),
|
| 59 |
+
mode='constant', constant_values=SpecialToken["PAD"]
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
intensity = np.pad(
|
| 63 |
+
intensity, (0, self.max_len - len(intensity)),
|
| 64 |
+
mode='constant', constant_values=SpecialToken["PAD"]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
return TokenSequence(
|
| 68 |
+
mz=np.array(mz, np.float32),
|
| 69 |
+
intensity=np.array(intensity, np.float32),
|
| 70 |
+
mask=mask,
|
| 71 |
+
smiles=metadata["smiles"]
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def tokenize_sequence(self, spectra: Sequence[Spectrum]):
|
| 75 |
+
sequences: list[TokenSequence] = []
|
| 76 |
+
pbar = spectra
|
| 77 |
+
if self.show_progress_bar:
|
| 78 |
+
pbar = tqdm(spectra, total=len(spectra), desc="tokenization")
|
| 79 |
+
for s in pbar:
|
| 80 |
+
sequences.append(self.tokenize(s))
|
| 81 |
+
|
| 82 |
+
return sequences
|
| 83 |
+
|
| 84 |
+
def get_metadata(self, s: Spectrum):
|
| 85 |
+
"""
|
| 86 |
+
get the metadata from spectrum
|
| 87 |
+
|
| 88 |
+
- smiles
|
| 89 |
+
- precursor_mz
|
| 90 |
+
- peaks
|
| 91 |
+
"""
|
| 92 |
+
precursor_mz = s.get("precursor_mz")
|
| 93 |
+
smiles = s.get("smiles")
|
| 94 |
+
peaks = np.array(s.peaks.to_numpy, np.float32)
|
| 95 |
+
intensity = peaks[:, 1]
|
| 96 |
+
argmaxsort_index = np.sort(
|
| 97 |
+
np.argsort(intensity)[::-1][:self.max_len - 1]
|
| 98 |
+
)
|
| 99 |
+
peaks = peaks[argmaxsort_index]
|
| 100 |
+
peaks[:, 1] = peaks[:, 1] / max(peaks[:, 1])
|
| 101 |
+
packaged_peaks: list[Peak] = [
|
| 102 |
+
Peak(
|
| 103 |
+
mz=np.array(precursor_mz, np.float32),
|
| 104 |
+
intensity=2
|
| 105 |
+
)
|
| 106 |
+
]
|
| 107 |
+
for mz, intensity in peaks:
|
| 108 |
+
packaged_peaks.append(
|
| 109 |
+
Peak(
|
| 110 |
+
mz=mz,
|
| 111 |
+
intensity=intensity
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
metadata = MetaData(
|
| 115 |
+
smiles=smiles,
|
| 116 |
+
peaks=packaged_peaks
|
| 117 |
+
)
|
| 118 |
+
return metadata
|
src/model.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ca0aa002a0d061a95410f7a4055e82c7fcb428d0ba04b5714ac3a4e7f0f5cca
|
| 3 |
+
size 31572706
|
src/model.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Literal, Union, Iterable, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
| 7 |
+
|
| 8 |
+
LAMBDA_MIN = math.pow(10, -3.0)
|
| 9 |
+
LAMBDA_MAX = math.pow(10, 3.0)
|
| 10 |
+
|
| 11 |
+
class MultiFeedForwardModule(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
input_size: int,
|
| 15 |
+
hidden_size: Union[int, Iterable[int]],
|
| 16 |
+
output_size: int,
|
| 17 |
+
*,
|
| 18 |
+
activation: Literal['relu', 'selu', 'gelu'] = 'relu',
|
| 19 |
+
dropout: float = 0.1,
|
| 20 |
+
dropout_last_layer: bool = True
|
| 21 |
+
):
|
| 22 |
+
super(MultiFeedForwardModule, self).__init__()
|
| 23 |
+
if activation == 'relu':
|
| 24 |
+
self._activation = nn.ReLU()
|
| 25 |
+
elif activation == 'selu':
|
| 26 |
+
self._activation = nn.SELU()
|
| 27 |
+
elif activation == 'gelu':
|
| 28 |
+
self._activation = nn.GELU()
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError('activation must be relu or selu')
|
| 31 |
+
|
| 32 |
+
if not hasattr(hidden_size, '__iter__'):
|
| 33 |
+
if hidden_size is None:
|
| 34 |
+
hidden_size = [output_size]
|
| 35 |
+
else:
|
| 36 |
+
hidden_size = [hidden_size]
|
| 37 |
+
|
| 38 |
+
self._layers = []
|
| 39 |
+
layer_dims = [input_size] + hidden_size + [output_size]
|
| 40 |
+
|
| 41 |
+
for i in range(1, len(layer_dims) - 1):
|
| 42 |
+
self._layers.append(nn.Linear(layer_dims[i - 1], layer_dims[i]))
|
| 43 |
+
self._layers.append(self._activation)
|
| 44 |
+
self._layers.append(nn.Dropout(dropout))
|
| 45 |
+
|
| 46 |
+
self._layers.append(nn.Linear(layer_dims[-2], layer_dims[-1]))
|
| 47 |
+
|
| 48 |
+
if dropout_last_layer:
|
| 49 |
+
self._layers.append(nn.Dropout(dropout))
|
| 50 |
+
self._layers = nn.Sequential(*self._layers)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self._layers(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SinusodialMz(nn.Module):
|
| 57 |
+
def __init__(self, embedding_dim: int, *, lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX)) -> None:
|
| 58 |
+
super(SinusodialMz, self).__init__()
|
| 59 |
+
self.lambda_min, self.lambda_max = lambda_params
|
| 60 |
+
self.lambda_div_value = self.lambda_max / self.lambda_min
|
| 61 |
+
self.x = torch.arange(0, embedding_dim, 2)
|
| 62 |
+
self.x = (
|
| 63 |
+
2 * math.pi *
|
| 64 |
+
(
|
| 65 |
+
self.lambda_min *
|
| 66 |
+
self.lambda_div_value ** (self.x / (embedding_dim - 2))
|
| 67 |
+
) ** -1
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, mz: torch.Tensor):
|
| 71 |
+
self.x = self.x.to(mz.device)
|
| 72 |
+
x = torch.einsum('bl,d->bld', mz, self.x)
|
| 73 |
+
sin_embedding = torch.sin(x)
|
| 74 |
+
cos_embedding = torch.cos(x)
|
| 75 |
+
b, l, d = sin_embedding.shape
|
| 76 |
+
x = torch.zeros(b, l, 2 * d, dtype=mz.dtype, device=mz.device)
|
| 77 |
+
x[:, :, ::2] = sin_embedding
|
| 78 |
+
x[:, :, 1::2] = cos_embedding
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class SinusodialMzEmbedding(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
embedding_dim: int,
|
| 86 |
+
*,
|
| 87 |
+
lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
|
| 88 |
+
feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
|
| 89 |
+
dropout: float = 0.1,
|
| 90 |
+
dropout_last_layer: bool = True
|
| 91 |
+
):
|
| 92 |
+
super(SinusodialMzEmbedding, self).__init__()
|
| 93 |
+
if embedding_dim % 2 != 0:
|
| 94 |
+
raise ValueError('embedding_dim must be even')
|
| 95 |
+
self.embedding = SinusodialMz(
|
| 96 |
+
embedding_dim, lambda_params=lambda_params)
|
| 97 |
+
self.feedward_layers = MultiFeedForwardModule(
|
| 98 |
+
embedding_dim, embedding_dim, embedding_dim,
|
| 99 |
+
activation=feedward_activation, dropout=dropout, dropout_last_layer=dropout_last_layer
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, mz: torch.Tensor):
|
| 103 |
+
x = self.embedding(mz)
|
| 104 |
+
x = self.feedward_layers(x)
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class PeaksEmbedding(nn.Module):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
embedding_dim: int,
|
| 112 |
+
*,
|
| 113 |
+
lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
|
| 114 |
+
feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
|
| 115 |
+
dropout: float = 0.1,
|
| 116 |
+
dropout_last_layer: bool = False
|
| 117 |
+
) -> None:
|
| 118 |
+
super(PeaksEmbedding, self).__init__()
|
| 119 |
+
self.mz_embedding = SinusodialMzEmbedding(
|
| 120 |
+
embedding_dim,
|
| 121 |
+
lambda_params=lambda_params,
|
| 122 |
+
feedward_activation=feedward_activation,
|
| 123 |
+
dropout=dropout,
|
| 124 |
+
dropout_last_layer=dropout_last_layer
|
| 125 |
+
)
|
| 126 |
+
self.intensity_embedding = MultiFeedForwardModule(
|
| 127 |
+
embedding_dim + 1, embedding_dim, embedding_dim,
|
| 128 |
+
activation=feedward_activation,
|
| 129 |
+
dropout=dropout,
|
| 130 |
+
dropout_last_layer=dropout_last_layer
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def forward(self, mz: torch.Tensor, intensity: torch.Tensor):
|
| 134 |
+
mz_tensor = self.mz_embedding(mz)
|
| 135 |
+
intensity_tensor = torch.unsqueeze(intensity, dim=-1)
|
| 136 |
+
x = self.intensity_embedding(
|
| 137 |
+
torch.cat([mz_tensor, intensity_tensor], dim=-1))
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class SiameseModel(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
embedding_dim: int,
|
| 145 |
+
n_head: int,
|
| 146 |
+
n_layer: int,
|
| 147 |
+
dim_feedward: int,
|
| 148 |
+
dim_target: int,
|
| 149 |
+
*,
|
| 150 |
+
lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
|
| 151 |
+
feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
|
| 152 |
+
dropout: float = 0.1,
|
| 153 |
+
dropout_last_layer: bool = False,
|
| 154 |
+
norm_first: bool = True
|
| 155 |
+
) -> None:
|
| 156 |
+
super(SiameseModel, self).__init__()
|
| 157 |
+
if embedding_dim % n_head != 0:
|
| 158 |
+
raise ValueError('embedding must be divisible by n_head')
|
| 159 |
+
|
| 160 |
+
self.embedding = PeaksEmbedding(
|
| 161 |
+
embedding_dim,
|
| 162 |
+
lambda_params=lambda_params,
|
| 163 |
+
feedward_activation=feedward_activation,
|
| 164 |
+
dropout=dropout,
|
| 165 |
+
dropout_last_layer=dropout_last_layer
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if feedward_activation == 'selu':
|
| 169 |
+
# transformer encoder activation
|
| 170 |
+
# only gelu or relu
|
| 171 |
+
self.activation = 'gelu'
|
| 172 |
+
else:
|
| 173 |
+
self.activation = feedward_activation
|
| 174 |
+
|
| 175 |
+
if feedward_activation == 'relu':
|
| 176 |
+
self._activation = nn.ReLU()
|
| 177 |
+
elif feedward_activation == 'selu':
|
| 178 |
+
self._activation = nn.SELU()
|
| 179 |
+
elif feedward_activation == 'gelu':
|
| 180 |
+
self._activation = nn.GELU()
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError('activation must be relu or selu or gelu')
|
| 183 |
+
|
| 184 |
+
encoder_layer = TransformerEncoderLayer(
|
| 185 |
+
embedding_dim,
|
| 186 |
+
n_head,
|
| 187 |
+
dim_feedforward=dim_feedward,
|
| 188 |
+
dropout=dropout,
|
| 189 |
+
activation=self.activation,
|
| 190 |
+
batch_first=True,
|
| 191 |
+
norm_first=norm_first
|
| 192 |
+
)
|
| 193 |
+
self._encoder = TransformerEncoder(
|
| 194 |
+
encoder_layer,
|
| 195 |
+
n_layer,
|
| 196 |
+
enable_nested_tensor=False
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
self._decoder = MultiFeedForwardModule(
|
| 200 |
+
embedding_dim,
|
| 201 |
+
dim_feedward,
|
| 202 |
+
dim_target,
|
| 203 |
+
activation=feedward_activation,
|
| 204 |
+
dropout=dropout,
|
| 205 |
+
dropout_last_layer=dropout_last_layer
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor):
|
| 209 |
+
x = self.embedding(mz, intensity)
|
| 210 |
+
x = self._encoder(x, src_key_padding_mask=mask)
|
| 211 |
+
# mean pooling or cls position vector
|
| 212 |
+
x = torch.mean(x, dim=1)
|
| 213 |
+
x = self._activation(self._decoder(x))
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# class MambaSiameseModel(nn.Module):
|
| 218 |
+
# def __init__(
|
| 219 |
+
# self,
|
| 220 |
+
# embedding_dim: int,
|
| 221 |
+
# n_layer: int,
|
| 222 |
+
# dim_feedward: int,
|
| 223 |
+
# dim_target: int,
|
| 224 |
+
# *,
|
| 225 |
+
# lambda_params: Tuple[float, float] = (LAMBDA_MIN, LAMBDA_MAX),
|
| 226 |
+
# feedward_activation: Literal['relu', 'selu', 'gelu'] = 'relu',
|
| 227 |
+
# dropout: float = 0.1,
|
| 228 |
+
# dropout_last_layer: bool = False,
|
| 229 |
+
# ):
|
| 230 |
+
# super(MambaSiameseModel, self).__init__()
|
| 231 |
+
|
| 232 |
+
# self.embedding = PeaksEmbedding(
|
| 233 |
+
# embedding_dim,
|
| 234 |
+
# lambda_params=lambda_params,
|
| 235 |
+
# feedward_activation=feedward_activation,
|
| 236 |
+
# dropout=dropout,
|
| 237 |
+
# dropout_last_layer=dropout_last_layer
|
| 238 |
+
# )
|
| 239 |
+
|
| 240 |
+
# if feedward_activation == 'relu':
|
| 241 |
+
# self._activation = nn.ReLU()
|
| 242 |
+
# elif feedward_activation == 'selu':
|
| 243 |
+
# self._activation = nn.SELU()
|
| 244 |
+
# elif feedward_activation == 'gelu':
|
| 245 |
+
# self._activation = nn.GELU()
|
| 246 |
+
# else:
|
| 247 |
+
# raise ValueError('activation must be relu or selu or gelu')
|
| 248 |
+
|
| 249 |
+
# self._encoder = nn.Sequential(*[
|
| 250 |
+
# Mamba2(
|
| 251 |
+
# d_model=embedding_dim,
|
| 252 |
+
# d_state=64,
|
| 253 |
+
# d_conv=4,
|
| 254 |
+
# expand=2
|
| 255 |
+
# )
|
| 256 |
+
# for _ in range(n_layer)
|
| 257 |
+
# ])
|
| 258 |
+
|
| 259 |
+
# self._decoder = MultiFeedForwardModule(
|
| 260 |
+
# embedding_dim,
|
| 261 |
+
# dim_feedward,
|
| 262 |
+
# dim_target,
|
| 263 |
+
# activation=feedward_activation,
|
| 264 |
+
# dropout=dropout,
|
| 265 |
+
# dropout_last_layer=dropout_last_layer
|
| 266 |
+
# )
|
| 267 |
+
|
| 268 |
+
# def forward(self, mz: torch.Tensor, intensity: torch.Tensor, mask: torch.Tensor):
|
| 269 |
+
# x = self.embedding(mz, intensity)
|
| 270 |
+
# x = self._encoder(x)
|
| 271 |
+
# # mean pooling or cls position vector
|
| 272 |
+
# x = torch.mean(x, dim=1)
|
| 273 |
+
# x = self._activation(self._decoder(x))
|
| 274 |
+
# return x
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/tester.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import Module
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class ModelTester:
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
model: Module,
|
| 11 |
+
device: torch.device,
|
| 12 |
+
show_prgress_bar: bool = True
|
| 13 |
+
) -> None:
|
| 14 |
+
self.model = model
|
| 15 |
+
self.device = device
|
| 16 |
+
self.show_prgress_bar = show_prgress_bar
|
| 17 |
+
|
| 18 |
+
def test(self, dataloader: DataLoader):
|
| 19 |
+
self.model.eval()
|
| 20 |
+
result = []
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
pbar = dataloader
|
| 23 |
+
if self.show_prgress_bar:
|
| 24 |
+
pbar = tqdm(dataloader, total=len(
|
| 25 |
+
dataloader), desc="embedding")
|
| 26 |
+
for x in pbar:
|
| 27 |
+
x = [d.to(self.device) for d in x]
|
| 28 |
+
pred: torch.Tensor = self.model(*x)
|
| 29 |
+
result.append(pred.cpu().numpy())
|
| 30 |
+
return np.concatenate(result, axis=0)
|
src/type.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, Sequence, Callable, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch import device
|
| 6 |
+
import numpy as np
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
|
| 9 |
+
BatchType = Sequence[torch.Tensor]
|
| 10 |
+
StepTrain = Callable[[nn.Module, nn.Module, device,
|
| 11 |
+
BatchType, Optional[Callable[..., int]]], Sequence[torch.Tensor]]
|
| 12 |
+
StepVal = Callable[[nn.Module, nn.Module, device,
|
| 13 |
+
BatchType, Optional[Callable[..., int]]], Sequence[torch.Tensor]]
|
| 14 |
+
|
| 15 |
+
class Peak(TypedDict):
|
| 16 |
+
mz: str
|
| 17 |
+
intensity: npt.NDArray
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MetaData(TypedDict):
|
| 21 |
+
peaks: Sequence[Peak]
|
| 22 |
+
smiles: str
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TokenSequence(TypedDict):
|
| 26 |
+
mz: npt.NDArray[np.int32]
|
| 27 |
+
intensity: npt.NDArray[np.float32]
|
| 28 |
+
mask: npt.NDArray[np.bool_]
|
| 29 |
+
smiles: str
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TokenizerConfig(TypedDict):
|
| 33 |
+
max_len: int
|
| 34 |
+
show_progress_bar: bool
|
src/utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import numpy.typing as npt
|
| 5 |
+
from numba import prange, njit
|
| 6 |
+
from matchms.importing import load_from_mgf, load_from_msp, load_from_mzxml
|
| 7 |
+
from matchms.filtering import default_filters, normalize_intensities
|
| 8 |
+
|
| 9 |
+
def read_raw_spectra(path: str):
|
| 10 |
+
suffix = Path(path).suffix
|
| 11 |
+
if suffix == ".mgf":
|
| 12 |
+
spectra = list(load_from_mgf(path))
|
| 13 |
+
elif suffix == ".msp":
|
| 14 |
+
spectra = list(load_from_msp(path))
|
| 15 |
+
elif suffix == ".mzxml":
|
| 16 |
+
spectra = list(load_from_mzxml(path))
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError(f"Not support the {suffix} format")
|
| 19 |
+
|
| 20 |
+
spectra = [default_filters(s) for s in spectra]
|
| 21 |
+
spectra = [normalize_intensities(s) for s in spectra]
|
| 22 |
+
return spectra
|
| 23 |
+
|
| 24 |
+
@njit
|
| 25 |
+
def cosine_similarity(A: npt.NDArray, B: npt.NDArray):
|
| 26 |
+
norm_A = np.sqrt(np.sum(A ** 2, axis=1)) + 1e-8
|
| 27 |
+
norm_B = np.sqrt(np.sum(B ** 2, axis=1)) + 1e-8
|
| 28 |
+
normalize_A = A / norm_A[:, np.newaxis]
|
| 29 |
+
normalize_B = B / norm_B[:, np.newaxis]
|
| 30 |
+
scores = np.dot(normalize_A, normalize_B.T)
|
| 31 |
+
return scores
|
| 32 |
+
|
| 33 |
+
@njit(parallel=True)
|
| 34 |
+
def top_k_indices(score, top_k):
|
| 35 |
+
rows, cols = score.shape
|
| 36 |
+
indices = np.empty((rows, top_k), dtype=np.int64)
|
| 37 |
+
for i in prange(rows):
|
| 38 |
+
row = score[i]
|
| 39 |
+
sorted_idx = np.argsort(row)[::-1]
|
| 40 |
+
indices[i] = sorted_idx[:top_k]
|
| 41 |
+
return indices
|