Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| from typing import Sequence | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| sns.set_style("whitegrid") | |
| sns.set_palette("deep") | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| from matplotlib.container import StemContainer | |
| from matchms import Spectrum | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from type import TokenizerConfig | |
| from data import Tokenizer, TestDataset | |
| from model import SiameseModel | |
| from tester import ModelTester | |
| from utils import top_k_indices, cosine_similarity, read_raw_spectra | |
| torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)] | |
| PAGE_SIZE = 5 | |
| BATCH_SIZE = 64 | |
| LOADER_BATCH_SIZE = 32 | |
| CANDIDATE_PAGE = [2, 5, 10, 20] | |
| SHOW_PROGRESS_BAR = False | |
| device = torch.device("cpu") | |
| tokenizer_config = TokenizerConfig( | |
| max_len=100, | |
| show_progress_bar=SHOW_PROGRESS_BAR | |
| ) | |
| tokenizer = Tokenizer(100, SHOW_PROGRESS_BAR) | |
| model = SiameseModel( | |
| embedding_dim=512, | |
| n_head=16, | |
| n_layer=4, | |
| dim_feedward=512, | |
| dim_target=512, | |
| feedward_activation="selu" | |
| ) | |
| model_state = torch.load("model.ckpt", map_location=device) | |
| model.load_state_dict(model_state) | |
| tester = ModelTester(model, device, SHOW_PROGRESS_BAR) | |
| def custom_stemcontainer(stem_container: StemContainer): | |
| stem_container.markerline.set_marker("") | |
| stem_container.baseline.set_color("none") | |
| stem_container.baseline.set_alpha(0.5) | |
| def draw_mol(smiles: str): | |
| mol = Chem.MolFromSmiles(smiles) | |
| image = Draw.MolToImage(mol) | |
| return image | |
| def plot_pair(q: Spectrum, r: Spectrum): | |
| q_peaks = q.peaks.to_numpy | |
| r_peaks = r.peaks.to_numpy | |
| fig, ax = plt.subplots(1, 1, figsize=(5, 2.7), dpi=300) | |
| ax.text(0.8, 0.8, "query", transform=ax.transAxes) | |
| ax.text(0.8, 0.2, "reference", transform=ax.transAxes) | |
| container1 = ax.stem(q_peaks[:, 0], q_peaks[:, 1]) | |
| custom_stemcontainer(container1) | |
| container2 = ax.stem(r_peaks[:, 0], -r_peaks[:, 1]) | |
| custom_stemcontainer(container2) | |
| return fig | |
| def generate_result(): | |
| ref_smiles = st.session_state.ref_smiles | |
| match_indices = st.session_state.match_indices | |
| df = pd.DataFrame(columns=["ID", "Smiles"]) | |
| for i, index in enumerate(match_indices): | |
| df.loc[len(df)] = [i + 1, ref_smiles[index]] | |
| st.session_state.result = df.to_csv(index=False).encode("utf8") | |
| def get_smiles(spectra: Sequence[Spectrum]): | |
| smiles_seq = [ | |
| s.get("smiles", "") | |
| for s in spectra | |
| ] | |
| return np.array(smiles_seq) | |
| def batch_match( | |
| progress_bar, | |
| query_embedding, | |
| ref_embedding | |
| ): | |
| length = len(query_embedding) | |
| start_seq, end_seq = gen_start_end_seq(length) | |
| indices = [] | |
| progress = 0 | |
| for start, end in zip(start_seq, end_seq): | |
| batch_embedding = query_embedding[start:end] | |
| cosine_scores = cosine_similarity(batch_embedding, ref_embedding) | |
| batch_indices = top_k_indices(cosine_scores, 1) | |
| indices.append(batch_indices) | |
| if progress + BATCH_SIZE >= length: | |
| progress = length - 1 | |
| else: | |
| progress += BATCH_SIZE | |
| progress_bar.progress((progress + 1) / length) | |
| return np.concatenate(indices, axis=0)[:, 0] | |
| def init_session_state(): | |
| if "query_path" not in st.session_state: | |
| st.session_state.query_path = None | |
| if "ref_path" not in st.session_state: | |
| st.session_state.ref_path = None | |
| if "data_len" not in st.session_state: | |
| st.session_state.data_len = None | |
| if "query_embedding" not in st.session_state: | |
| st.session_state.query_embedding = None | |
| if "ref_embedding" not in st.session_state: | |
| st.session_state.ref_embedding = None | |
| if "query_smiles" not in st.session_state: | |
| st.session_state.query_smiles = None | |
| if "ref_smiles" not in st.session_state: | |
| st.session_state.ref_smiles = None | |
| if "query_spectra" not in st.session_state: | |
| st.session_state.query_spectra = None | |
| if "ref_spectra" not in st.session_state: | |
| st.session_state.ref_spectra = None | |
| if "match_indices" not in st.session_state: | |
| st.session_state.match_indices = None | |
| if "current_page" not in st.session_state: | |
| st.session_state.current_page = None | |
| if "last_page" not in st.session_state: | |
| st.session_state.last_page = None | |
| if "page_size" not in st.session_state: | |
| st.session_state.page_size = PAGE_SIZE | |
| def previous_page(): | |
| current_page = st.session_state.current_page | |
| if current_page != 1: | |
| st.session_state.current_page -= 1 | |
| def next_page(): | |
| current_page = st.session_state.current_page | |
| last_page = st.session_state.last_page | |
| if current_page != last_page: | |
| st.session_state.current_page += 1 | |
| def select_page(): | |
| st.session_state.current_page = int(st.session_state.page_selector) | |
| def set_page_size(): | |
| st.session_state.current_page = 1 | |
| page_size = int(st.session_state.page_size_selector) | |
| st.session_state.page_size = page_size | |
| cal_page_num(st.session_state.data_len, page_size) | |
| def cal_page_num( | |
| length: int, | |
| page_size: int | |
| ): | |
| page_num, rest = divmod(length, page_size) | |
| if rest != 0: | |
| page_num += 1 | |
| st.session_state.last_page = page_num | |
| def gen_start_end_seq( | |
| length: int, | |
| ): | |
| start_seq = range(0, length, BATCH_SIZE) | |
| end_seq = range(BATCH_SIZE, length + BATCH_SIZE, BATCH_SIZE) | |
| return start_seq, end_seq | |
| def embedding( | |
| progress_bar, | |
| tester: ModelTester, | |
| tokenizer: Tokenizer, | |
| spectra: Sequence[Spectrum], | |
| ): | |
| sequences = tokenizer.tokenize_sequence(spectra) | |
| start_seq, end_seq = gen_start_end_seq(len(spectra)) | |
| progress = 0 | |
| embedding = [] | |
| for start, end in zip(start_seq, end_seq): | |
| test_dataset = TestDataset(sequences[start:end]) | |
| test_dataloader = DataLoader( | |
| test_dataset, | |
| LOADER_BATCH_SIZE, | |
| False | |
| ) | |
| step_embedding = tester.test(test_dataloader) | |
| if progress + BATCH_SIZE >= len(spectra): | |
| progress = len(spectra) - 1 | |
| else: | |
| progress += BATCH_SIZE | |
| embedding.append(step_embedding) | |
| progress_bar.progress((progress + 1) / len(spectra)) | |
| embedding = np.concatenate(embedding, axis=0) | |
| return embedding | |
| def main(): | |
| st.set_page_config(layout="wide") | |
| st.title("SpecEmbedding") | |
| tab1, tab2, tab3 = st.tabs(["upload query file", "upload reference/library file", "library match"]) | |
| with tab1: | |
| st.header("Upload query spectra file(positive mode)") | |
| query_file = st.file_uploader( | |
| "upload the query spectra file", | |
| type=["msp", "mgf", "mzxml"], | |
| key="query_file", | |
| accept_multiple_files=False | |
| ) | |
| query_embedding_btn = st.button("Embedding", "query_embedding_btn") | |
| query_status_box = st.empty() | |
| if query_embedding_btn: | |
| if query_file is not None: | |
| with tempfile.NamedTemporaryFile(delete=True, suffix="." + query_file.name.split(".")[-1]) as tmp_file: | |
| tmp_file.write(query_file.getvalue()) | |
| query_spectra = read_raw_spectra(tmp_file.name) | |
| progress_bar = st.progress(0, text="Embedding...") | |
| st.session_state.data_len = len(query_spectra) | |
| st.session_state.query_spectra = query_spectra | |
| st.session_state.query_smiles = get_smiles(query_spectra) | |
| query_embedding = embedding( | |
| progress_bar, | |
| tester, | |
| tokenizer, | |
| query_spectra, | |
| ) | |
| st.session_state.query_embedding = query_embedding | |
| query_status_box.success("Embedding Success ✅") | |
| else: | |
| query_status_box.error("Please upload the spectra file") | |
| with tab2: | |
| st.header("Upload reference/library spectra file(positive mode)") | |
| ref_file = st.file_uploader( | |
| "upload the reference/library spectra file", | |
| type=["msp", "mgf", "mzxml"], | |
| key="ref_file", | |
| accept_multiple_files=False | |
| ) | |
| ref_embedding_btn = st.button("Embedding", "ref_embedding_btn") | |
| ref_status_box = st.empty() | |
| if ref_embedding_btn: | |
| if ref_file is not None: | |
| progress_bar = st.progress(0, text="Embedding...") | |
| with tempfile.NamedTemporaryFile(delete=True, suffix="." + ref_file.name.split(".")[-1]) as tmp_file: | |
| tmp_file.write(ref_file.getvalue()) | |
| ref_spectra = read_raw_spectra(tmp_file.name) | |
| st.session_state.ref_spectra = ref_spectra | |
| st.session_state.ref_smiles = get_smiles(ref_spectra) | |
| ref_embedding = embedding( | |
| progress_bar, | |
| tester, | |
| tokenizer, | |
| ref_spectra, | |
| ) | |
| st.session_state.ref_embedding = ref_embedding | |
| ref_status_box.success("Embedding Success ✅") | |
| else: | |
| ref_status_box.error("Please upload the spectra file") | |
| with tab3: | |
| st.header("Start to match") | |
| launch_btn = st.button("Launch", key="launch_btn") | |
| match_status_box = st.empty() | |
| if launch_btn: | |
| query_embedding = st.session_state.query_embedding | |
| ref_embedding = st.session_state.ref_embedding | |
| if query_embedding is None: | |
| match_status_box.error("No query embedding") | |
| elif ref_embedding is None: | |
| match_status_box.error("No reference embedding") | |
| else: | |
| progress_bar = st.progress(0, "Match...") | |
| match_indices = batch_match(progress_bar, query_embedding, ref_embedding) | |
| st.session_state.match_indices = match_indices | |
| st.session_state.current_page = 1 | |
| generate_result() | |
| cal_page_num(st.session_state.data_len, st.session_state.page_size) | |
| match_status_box.success("match success") | |
| if st.session_state.match_indices is not None: | |
| st.subheader(f"Match Result") | |
| current_page = st.session_state.current_page | |
| last_page = st.session_state.last_page | |
| ref_smiles = st.session_state.ref_smiles | |
| query_spectra = st.session_state.query_spectra | |
| ref_spectra = st.session_state.ref_spectra | |
| page_size = st.session_state.page_size | |
| indices = st.session_state.match_indices | |
| start = (current_page - 1) * page_size | |
| end = start + page_size | |
| if current_page == last_page: | |
| end = indices.shape[0] | |
| col1, col2, _ = st.columns([1, 1, 5]) | |
| col1.selectbox( | |
| "page size", | |
| CANDIDATE_PAGE, | |
| key="page_size_selector", | |
| disabled=False, | |
| label_visibility="collapsed", | |
| index=CANDIDATE_PAGE.index(page_size), | |
| on_change=set_page_size, | |
| ) | |
| col2.download_button( | |
| label="download result", | |
| data=st.session_state.result, | |
| file_name="data.csv", | |
| mime="text/csv" | |
| ) | |
| pre_btn, current, next_btn, page_selector, _ = st.columns([1, 1, 1, 1, 2]) | |
| pre_btn.button("previous page", key="pre_btn", on_click=previous_page) | |
| current.subheader(f"current page: {current_page}") | |
| next_btn.button("next page", key="next_btn", on_click=next_page) | |
| page_selector.selectbox( | |
| label="target page", | |
| key="page_selector", | |
| options=range(1, last_page + 1), | |
| disabled=False, | |
| index=current_page - 1, | |
| label_visibility="collapsed", | |
| on_change=select_page, | |
| ) | |
| col1, col2, col3, col4 = st.columns([1, 4, 6, 4]) | |
| col1.subheader("Index") | |
| col2.subheader("Smiles") | |
| col3.subheader("MS/MS Spectra Pair") | |
| col4.subheader("Molecular Structure") | |
| for i in range(start, end): | |
| query_index = i | |
| ref_index = indices[i] | |
| id_label, smiles_label, pair_viewer, mol_viewer = st.columns([2, 4, 6, 4]) | |
| id_label.subheader(i + 1) | |
| smiles_label.text(ref_smiles[ref_index]) | |
| pair_fig = plot_pair(query_spectra[query_index], ref_spectra[ref_index]) | |
| pair_viewer.pyplot(pair_fig, use_container_width=True) | |
| mol_image = draw_mol(ref_smiles[ref_index]) | |
| mol_viewer.image(mol_image, use_container_width=True) | |
| if __name__ == "__main__": | |
| init_session_state() | |
| main() |