import os import tempfile from typing import Sequence import torch from torch.utils.data import DataLoader import numpy as np import pandas as pd import seaborn as sns sns.set_style("whitegrid") sns.set_palette("deep") import streamlit as st import matplotlib.pyplot as plt from matplotlib.container import StemContainer from matchms import Spectrum from rdkit import Chem from rdkit.Chem import Draw from type import TokenizerConfig from data import Tokenizer, TestDataset from model import SiameseModel from tester import ModelTester from utils import top_k_indices, cosine_similarity, read_raw_spectra torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)] PAGE_SIZE = 5 BATCH_SIZE = 64 LOADER_BATCH_SIZE = 32 CANDIDATE_PAGE = [2, 5, 10, 20] SHOW_PROGRESS_BAR = False device = torch.device("cpu") tokenizer_config = TokenizerConfig( max_len=100, show_progress_bar=SHOW_PROGRESS_BAR ) tokenizer = Tokenizer(100, SHOW_PROGRESS_BAR) model = SiameseModel( embedding_dim=512, n_head=16, n_layer=4, dim_feedward=512, dim_target=512, feedward_activation="selu" ) model_state = torch.load("model.ckpt", map_location=device) model.load_state_dict(model_state) tester = ModelTester(model, device, SHOW_PROGRESS_BAR) def custom_stemcontainer(stem_container: StemContainer): stem_container.markerline.set_marker("") stem_container.baseline.set_color("none") stem_container.baseline.set_alpha(0.5) def draw_mol(smiles: str): mol = Chem.MolFromSmiles(smiles) image = Draw.MolToImage(mol) return image def plot_pair(q: Spectrum, r: Spectrum): q_peaks = q.peaks.to_numpy r_peaks = r.peaks.to_numpy fig, ax = plt.subplots(1, 1, figsize=(5, 2.7), dpi=300) ax.text(0.8, 0.8, "query", transform=ax.transAxes) ax.text(0.8, 0.2, "reference", transform=ax.transAxes) container1 = ax.stem(q_peaks[:, 0], q_peaks[:, 1]) custom_stemcontainer(container1) container2 = ax.stem(r_peaks[:, 0], -r_peaks[:, 1]) custom_stemcontainer(container2) return fig def generate_result(): ref_smiles = st.session_state.ref_smiles match_indices = st.session_state.match_indices df = pd.DataFrame(columns=["ID", "Smiles"]) for i, index in enumerate(match_indices): df.loc[len(df)] = [i + 1, ref_smiles[index]] st.session_state.result = df.to_csv(index=False).encode("utf8") def get_smiles(spectra: Sequence[Spectrum]): smiles_seq = [ s.get("smiles", "") for s in spectra ] return np.array(smiles_seq) def batch_match( progress_bar, query_embedding, ref_embedding ): length = len(query_embedding) start_seq, end_seq = gen_start_end_seq(length) indices = [] progress = 0 for start, end in zip(start_seq, end_seq): batch_embedding = query_embedding[start:end] cosine_scores = cosine_similarity(batch_embedding, ref_embedding) batch_indices = top_k_indices(cosine_scores, 1) indices.append(batch_indices) if progress + BATCH_SIZE >= length: progress = length - 1 else: progress += BATCH_SIZE progress_bar.progress((progress + 1) / length) return np.concatenate(indices, axis=0)[:, 0] def init_session_state(): if "query_path" not in st.session_state: st.session_state.query_path = None if "ref_path" not in st.session_state: st.session_state.ref_path = None if "data_len" not in st.session_state: st.session_state.data_len = None if "query_embedding" not in st.session_state: st.session_state.query_embedding = None if "ref_embedding" not in st.session_state: st.session_state.ref_embedding = None if "query_smiles" not in st.session_state: st.session_state.query_smiles = None if "ref_smiles" not in st.session_state: st.session_state.ref_smiles = None if "query_spectra" not in st.session_state: st.session_state.query_spectra = None if "ref_spectra" not in st.session_state: st.session_state.ref_spectra = None if "match_indices" not in st.session_state: st.session_state.match_indices = None if "current_page" not in st.session_state: st.session_state.current_page = None if "last_page" not in st.session_state: st.session_state.last_page = None if "page_size" not in st.session_state: st.session_state.page_size = PAGE_SIZE def previous_page(): current_page = st.session_state.current_page if current_page != 1: st.session_state.current_page -= 1 def next_page(): current_page = st.session_state.current_page last_page = st.session_state.last_page if current_page != last_page: st.session_state.current_page += 1 def select_page(): st.session_state.current_page = int(st.session_state.page_selector) def set_page_size(): st.session_state.current_page = 1 page_size = int(st.session_state.page_size_selector) st.session_state.page_size = page_size cal_page_num(st.session_state.data_len, page_size) def cal_page_num( length: int, page_size: int ): page_num, rest = divmod(length, page_size) if rest != 0: page_num += 1 st.session_state.last_page = page_num def gen_start_end_seq( length: int, ): start_seq = range(0, length, BATCH_SIZE) end_seq = range(BATCH_SIZE, length + BATCH_SIZE, BATCH_SIZE) return start_seq, end_seq def embedding( progress_bar, tester: ModelTester, tokenizer: Tokenizer, spectra: Sequence[Spectrum], ): sequences = tokenizer.tokenize_sequence(spectra) start_seq, end_seq = gen_start_end_seq(len(spectra)) progress = 0 embedding = [] for start, end in zip(start_seq, end_seq): test_dataset = TestDataset(sequences[start:end]) test_dataloader = DataLoader( test_dataset, LOADER_BATCH_SIZE, False ) step_embedding = tester.test(test_dataloader) if progress + BATCH_SIZE >= len(spectra): progress = len(spectra) - 1 else: progress += BATCH_SIZE embedding.append(step_embedding) progress_bar.progress((progress + 1) / len(spectra)) embedding = np.concatenate(embedding, axis=0) return embedding def main(): st.set_page_config(layout="wide") st.title("SpecEmbedding") tab1, tab2, tab3 = st.tabs(["upload query file", "upload reference/library file", "library match"]) with tab1: st.header("Upload query spectra file(positive mode)") query_file = st.file_uploader( "upload the query spectra file", type=["msp", "mgf", "mzxml"], key="query_file", accept_multiple_files=False ) query_embedding_btn = st.button("Embedding", "query_embedding_btn") query_status_box = st.empty() if query_embedding_btn: if query_file is not None: with tempfile.NamedTemporaryFile(delete=True, suffix="." + query_file.name.split(".")[-1]) as tmp_file: tmp_file.write(query_file.getvalue()) query_spectra = read_raw_spectra(tmp_file.name) progress_bar = st.progress(0, text="Embedding...") st.session_state.data_len = len(query_spectra) st.session_state.query_spectra = query_spectra st.session_state.query_smiles = get_smiles(query_spectra) query_embedding = embedding( progress_bar, tester, tokenizer, query_spectra, ) st.session_state.query_embedding = query_embedding query_status_box.success("Embedding Success ✅") else: query_status_box.error("Please upload the spectra file") with tab2: st.header("Upload reference/library spectra file(positive mode)") ref_file = st.file_uploader( "upload the reference/library spectra file", type=["msp", "mgf", "mzxml"], key="ref_file", accept_multiple_files=False ) ref_embedding_btn = st.button("Embedding", "ref_embedding_btn") ref_status_box = st.empty() if ref_embedding_btn: if ref_file is not None: progress_bar = st.progress(0, text="Embedding...") with tempfile.NamedTemporaryFile(delete=True, suffix="." + ref_file.name.split(".")[-1]) as tmp_file: tmp_file.write(ref_file.getvalue()) ref_spectra = read_raw_spectra(tmp_file.name) st.session_state.ref_spectra = ref_spectra st.session_state.ref_smiles = get_smiles(ref_spectra) ref_embedding = embedding( progress_bar, tester, tokenizer, ref_spectra, ) st.session_state.ref_embedding = ref_embedding ref_status_box.success("Embedding Success ✅") else: ref_status_box.error("Please upload the spectra file") with tab3: st.header("Start to match") launch_btn = st.button("Launch", key="launch_btn") match_status_box = st.empty() if launch_btn: query_embedding = st.session_state.query_embedding ref_embedding = st.session_state.ref_embedding if query_embedding is None: match_status_box.error("No query embedding") elif ref_embedding is None: match_status_box.error("No reference embedding") else: progress_bar = st.progress(0, "Match...") match_indices = batch_match(progress_bar, query_embedding, ref_embedding) st.session_state.match_indices = match_indices st.session_state.current_page = 1 generate_result() cal_page_num(st.session_state.data_len, st.session_state.page_size) match_status_box.success("match success") if st.session_state.match_indices is not None: st.subheader(f"Match Result") current_page = st.session_state.current_page last_page = st.session_state.last_page ref_smiles = st.session_state.ref_smiles query_spectra = st.session_state.query_spectra ref_spectra = st.session_state.ref_spectra page_size = st.session_state.page_size indices = st.session_state.match_indices start = (current_page - 1) * page_size end = start + page_size if current_page == last_page: end = indices.shape[0] col1, col2, _ = st.columns([1, 1, 5]) col1.selectbox( "page size", CANDIDATE_PAGE, key="page_size_selector", disabled=False, label_visibility="collapsed", index=CANDIDATE_PAGE.index(page_size), on_change=set_page_size, ) col2.download_button( label="download result", data=st.session_state.result, file_name="data.csv", mime="text/csv" ) pre_btn, current, next_btn, page_selector, _ = st.columns([1, 1, 1, 1, 2]) pre_btn.button("previous page", key="pre_btn", on_click=previous_page) current.subheader(f"current page: {current_page}") next_btn.button("next page", key="next_btn", on_click=next_page) page_selector.selectbox( label="target page", key="page_selector", options=range(1, last_page + 1), disabled=False, index=current_page - 1, label_visibility="collapsed", on_change=select_page, ) col1, col2, col3, col4 = st.columns([1, 4, 6, 4]) col1.subheader("Index") col2.subheader("Smiles") col3.subheader("MS/MS Spectra Pair") col4.subheader("Molecular Structure") for i in range(start, end): query_index = i ref_index = indices[i] id_label, smiles_label, pair_viewer, mol_viewer = st.columns([2, 4, 6, 4]) id_label.subheader(i + 1) smiles_label.text(ref_smiles[ref_index]) pair_fig = plot_pair(query_spectra[query_index], ref_spectra[ref_index]) pair_viewer.pyplot(pair_fig, use_container_width=True) mol_image = draw_mol(ref_smiles[ref_index]) mol_viewer.image(mol_image, use_container_width=True) if __name__ == "__main__": init_session_state() main()