Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import time | |
| import numpy as np | |
| import sys | |
| import os | |
| import json | |
| import torch | |
| sys.path.append(os.path.dirname(__file__)) | |
| # | |
| # import pdb;pdb.set_trace() | |
| from server_infer import retrieve_topk, PATH_CONFIG | |
| from server_train import ( | |
| load_dual_tower_model, | |
| recompute_candidate_vectors, | |
| continuous_train | |
| ) | |
| # 1.Core Model Interfaces | |
| class DualTowerModel: | |
| def __init__(self): | |
| """ | |
| Initialize Dual tower model | |
| """ | |
| # check conifg | |
| if 'checkpoint_path' not in st.session_state: | |
| st.session_state['checkpoint_path'] = None | |
| if 'protein_model_path' not in st.session_state: | |
| st.session_state['protein_model_path'] = './SaProt_650M_AF2' | |
| if 'molecule_model_path' not in st.session_state: | |
| st.session_state['molecule_model_path'] = './ChemBERTa-zinc-base-v1' | |
| if 'vectors_ready' not in st.session_state: | |
| st.session_state['vectors_ready'] = self._check_vectors_exist() | |
| st.session_state['model_loaded'] = True | |
| def _check_vectors_exist(self): | |
| """Check vector""" | |
| protein_pt = PATH_CONFIG['protein']['pt'] | |
| molecule_pt = PATH_CONFIG['molecule']['pt'] | |
| return os.path.exists(protein_pt) and os.path.exists(molecule_pt) | |
| def retrieve_molecules_for_protein(self, protein_query, top_k=5): | |
| """ | |
| Task 1: Protein -> Molecule | |
| Input: Protein sequence/ID/Gene symbol | |
| Output: Recommended molecule list | |
| """ | |
| try: | |
| # Call retrieve_topk from server_infer | |
| results = retrieve_topk( | |
| input_type="protein", | |
| input_query=protein_query, | |
| topk=top_k, | |
| device='cpu' # Using CPU for web interface | |
| ) | |
| # Format results for frontend | |
| formatted_results = [] | |
| for res in results: | |
| formatted_results.append({ | |
| "smiles": res['id'], | |
| "score": res['score'], | |
| "name": res['info'].get('compound__name', 'Unknown'), | |
| "drugbank_id": res['info'].get('compound__drugbank_id', 'N/A'), | |
| "cas": res['info'].get('compound__cas', 'N/A') | |
| }) | |
| return formatted_results | |
| except Exception as e: | |
| st.error(f"Retrieval failed: {str(e)}") | |
| return [] | |
| def retrieve_proteins_for_molecule(self, molecule_query, top_k=5): | |
| """ | |
| Task 2: Molecule -> Protein | |
| Input: Molecule SMILES/Name/ID | |
| Output: Recommended protein list | |
| """ | |
| try: | |
| results = retrieve_topk( | |
| input_type="molecule", | |
| input_query=molecule_query, | |
| topk=top_k, | |
| device='cpu' | |
| ) | |
| # Format results | |
| formatted_results = [] | |
| for res in results: | |
| formatted_results.append({ | |
| "protein_id": res['id'], # Truncated display | |
| "desc": res['info'].get('target__gene', 'Unknown'), | |
| "score": res['score'], | |
| "class": res['info'].get('target__class', 'N/A'), | |
| "full_id": res['id'] | |
| }) | |
| return formatted_results | |
| except Exception as e: | |
| st.error(f"Retrieval failed: {str(e)}") | |
| return [] | |
| def fine_tune(self, uploaded_file, epochs, learning_rate, batch_size): | |
| """ | |
| Training Interface: Fine-tune model with uploaded data | |
| """ | |
| import tempfile | |
| try: | |
| # use temfile | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp: | |
| tmp.write(uploaded_file.getbuffer()) | |
| temp_dataset_path = tmp.name | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| status_text.text("Initializing training environment...") | |
| # check exist | |
| model_save_dir = 'Dual_Tower_Model/customized_checkpoints' | |
| if not os.path.exists(model_save_dir): | |
| os.makedirs(model_save_dir, exist_ok=True) | |
| checkpoint_path = st.session_state.get('checkpoint_path', None) | |
| status_text.text("Starting actual training process...") | |
| # call the actual training function | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| if device == 'cpu': | |
| st.warning("β οΈ Detected CPU environment. Training might be very slow.") | |
| success = continuous_train( | |
| dataset_path=temp_dataset_path, | |
| model_save_dir=model_save_dir, | |
| best_model_path=checkpoint_path, | |
| epochs=epochs, | |
| lr=learning_rate, | |
| batch_size=batch_size, | |
| device=device | |
| ) | |
| if success: | |
| status_text.success("β Training completed successfully!") | |
| # Clean up temporary file | |
| if os.path.exists(temp_dataset_path): | |
| os.remove(temp_dataset_path) | |
| return success | |
| except Exception as e: | |
| st.error(f"Training failed: {str(e)}") | |
| return False | |
| def update_vectors(self, protein_json, molecule_json, output_dir, checkpoint_path=None): | |
| """ | |
| Update candidate vector library | |
| """ | |
| try: | |
| recompute_candidate_vectors( | |
| protein_json_path=protein_json, | |
| molecule_json_path=molecule_json, | |
| output_dir=output_dir, | |
| pt_path=checkpoint_path, | |
| batch_size=32, | |
| device='cpu' | |
| ) | |
| st.session_state['vectors_ready'] = True | |
| return True | |
| except Exception as e: | |
| st.error(f"Vector update failed: {str(e)}") | |
| return False | |
| # Initialize model (Singleton pattern) | |
| if 'model' not in st.session_state: | |
| with st.spinner('Initializing model...'): | |
| st.session_state['model'] = DualTowerModel() | |
| model = st.session_state['model'] | |
| # 2. ι‘΅ι’εΈε± (UI Layout) | |
| st.set_page_config(page_title="Bio-Retrieval Service", page_icon="π§¬", layout="wide") | |
| # Sidebar Configuration | |
| with st.sidebar: | |
| st.title("𧬠Dual-Tower Service") | |
| st.markdown("---") | |
| mode = st.radio( | |
| "Select Mode:", | |
| ["π Inference (Retrieval)", "βοΈ Training (Fine-tune)", "π Configuration", "π₯ Team Info"] | |
| ) | |
| st.markdown("---") | |
| st.info("**Project Info**\n\nBidirectional Protein-Molecule Retrieval.\nBased on SaProt & Molecular Encoder.") | |
| # Header Section | |
| st.title("𧬠Bidirectional Protein-Molecule Retrieval") | |
| st.markdown(""" | |
| This service implements a **Dual-Tower Architecture** to bridge the gap between Proteins and molecules. | |
| """) | |
| # ========================================== | |
| # 3. Inference Mode | |
| # ========================================== | |
| if mode == "π Inference (Retrieval)": | |
| st.subheader("Interactive Retrieval System") | |
| # Check if vectors are ready | |
| if not st.session_state.get('vectors_ready', False): | |
| st.error("β Vectors not ready. Please generate them in 'Configuration' first.") | |
| st.stop() | |
| # Use Tabs for tasks | |
| tab1, tab2 = st.tabs(["Protein β Molecule", "Molecule β Protein"]) | |
| # --- Task 1: Protein to Molecule --- | |
| with tab1: | |
| st.markdown("### Recommend molecules for a specific protein target") | |
| # Quick Fill for Protein | |
| st.markdown("**Quick Fill Examples:**") | |
| q_col1, q_col2, q_col3 = st.columns([1, 1, 2]) | |
| if q_col1.button("Example: ROS1", key="btn_ros1"): | |
| st.session_state['p_input_val'] = "ROS1" | |
| if q_col2.button("Example: P08922", key="btn_p08922"): | |
| st.session_state['p_input_val'] = "P08922" | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| protein_input = st.text_area( | |
| "Enter Protein Sequence (FASTA format or raw sequence):", | |
| height=150, | |
| value=st.session_state.get('p_input_val', ""), | |
| placeholder="e.g., ROS1, P08922, or full sequence...", | |
| help="Supports Gene Symbols, UniProt IDs, or raw sequences" | |
| ) | |
| top_k_p2m = st.slider("Top K Results", 1, 20, 5, key="slider_p2m") | |
| with col2: | |
| st.markdown("#### Configuration") | |
| st.write(f"Encoder: **SaProt (650M)**") | |
| st.write("Search Space: **ChEMBL / ZINC15**") | |
| search_btn_p2m = st.button("π Retrieve Molecules", type="primary") | |
| if search_btn_p2m and protein_input: | |
| with st.spinner("Encoding protein and searching vector database..."): | |
| results = model.retrieve_molecules_for_protein(protein_input, top_k_p2m) | |
| if results: | |
| st.success(f"Found {len(results)} potential molecules.") | |
| for idx, res in enumerate(results): | |
| with st.expander( | |
| f"π Rank {idx+1}: {res['name']} | Similarity: {res['score']:.4f}", | |
| expanded=(idx < 1) | |
| ): | |
| col_a, col_b = st.columns([3, 2]) | |
| with col_a: | |
| st.markdown("**SMILES Structure:**") | |
| st.code(res['smiles'], language="text") | |
| with col_b: | |
| st.markdown("**Compound Info:**") | |
| st.markdown(f"- **Name:** {res['name']}") | |
| st.markdown(f"- **DrugBank ID:** {res.get('drugbank_id', 'N/A')}") | |
| st.markdown(f"- **CAS No.:** {res.get('cas', 'N/A')}") | |
| st.markdown(f"**Similarity Score:**") | |
| # Ensure score is within [0, 1] for st.progress | |
| display_score = max(0.0, min(1.0, float(res['score']))) | |
| st.progress(display_score) | |
| else: | |
| st.warning("No matches found. Please check your input.") | |
| # --- Task 2: Molecule to Protein --- | |
| with tab2: | |
| st.markdown("### Retrieve potential protein targets for a specific molecule") | |
| # Quick Fill for Molecule | |
| st.markdown("**Quick Fill Examples:**") | |
| m_q_col1, m_q_col2 = st.columns([1, 3]) | |
| if m_q_col1.button("Example: Lorlatinib", key="btn_lor"): | |
| st.session_state['m_input_val'] = "Lorlatinib" | |
| if m_q_col2.button("Example: SMILES", key="btn_smiles"): | |
| st.session_state['m_input_val'] = "C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2" | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| mol_input = st.text_input( | |
| "Enter Molecule SMILES:", | |
| value=st.session_state.get('m_input_val', ""), | |
| placeholder="e.g., Lorlatinib, DB12130, or SMILES string", | |
| help="Supports Compound Names, DrugBank IDs, or SMILES strings" | |
| ) | |
| top_k_m2p = st.slider("Top K Results", 1, 20, 5, key="slider_m2p") | |
| with col2: | |
| st.markdown("#### Configuration") | |
| st.write(f"Encoder: **ChemBERTa**") | |
| st.write("Search Space: **UniProt / PDB**") | |
| search_btn_m2p = st.button("π Retrieve Proteins", type="primary") | |
| if search_btn_m2p and mol_input: | |
| with st.spinner("Encoding molecule and searching vector database..."): | |
| results = model.retrieve_proteins_for_molecule(mol_input, top_k_m2p) | |
| if results: | |
| st.success(f"Found {len(results)} potential protein targets.") | |
| st.markdown("---") | |
| st.markdown("#### π Retrieval Results") | |
| # import pdb;pdb.set_trace() | |
| for idx, res in enumerate(results): | |
| with st.expander( | |
| f"π Rank {idx+1}: {res['desc']} | Similarity: {res['score']:.4f}", | |
| expanded=(idx < 1) | |
| ): | |
| st.markdown(f"**Full ID:** `{res['full_id']}`") | |
| st.markdown(f"**Gene Symbol:** {res['desc']}") | |
| st.markdown(f"**Class:** {res.get('class', 'N/A')}") | |
| st.markdown(f"**Similarity Score:**") | |
| # Ensure score is within [0, 1] for st.progress | |
| display_score = max(0.0, min(1.0, float(res['score']))) | |
| st.progress(display_score) | |
| else: | |
| st.warning("No matches found. Please check your input.") | |
| # 4. Training Mode | |
| elif mode == "βοΈ Training (Fine-tune)": | |
| st.subheader("Fine-tune the Dual-Tower Model") | |
| st.warning("This module allows you to input custom data to fine-tune the encoders.") | |
| # File Upload | |
| st.markdown("### π Upload Training Data") | |
| uploaded_file = st.file_uploader( | |
| "Upload dataset (CSV/JSON)", | |
| type=["csv", "json"], | |
| help="Dataset should include: compound__smiles, target__foldseek_seq, outcome_potency_pxc50" | |
| ) | |
| if uploaded_file: | |
| st.success("β File uploaded successfully") | |
| st.markdown("**Data Preview:**") | |
| try: | |
| df = pd.read_csv(uploaded_file) | |
| st.dataframe(df.head(10), use_container_width=True) | |
| st.info(f"π Dataset Size: {len(df)} records, {len(df.columns)} fields") | |
| except Exception as e: | |
| st.error(f"β οΈ Parsing failed: {str(e)}. Ensure format is correct.") | |
| # Hyperparameters | |
| st.markdown("---") | |
| st.markdown("### βοΈ Hyperparameters") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| epochs = st.number_input("Epochs", min_value=1, max_value=100, value=5) | |
| with col2: | |
| lr = st.number_input("Learning Rate", min_value=0.0001, max_value=0.1, value=0.001, format="%.4f") | |
| with col3: | |
| batch_size = st.selectbox("Batch Size", [16, 32, 64, 128], index=1) | |
| # Checkpoint Config | |
| st.markdown("---") | |
| st.markdown("### π¦ Model Checkpoint") | |
| checkpoint_option = st.radio( | |
| "Start training from:", | |
| ["Pre-trained Backbones", "Existing Checkpoint"], | |
| help="Use backbones for first run, or checkpoint for incremental training" | |
| ) | |
| if checkpoint_option == "Existing Checkpoint": | |
| checkpoint_path = st.text_input( | |
| "Checkpoint Path:", | |
| value="Dual_Tower_Model/output_checkpoints_ddp/model_epoch_7_acc_0.3259.pt", | |
| help="Full path to .pt file" | |
| ) | |
| st.session_state['checkpoint_path'] = checkpoint_path | |
| else: | |
| st.session_state['checkpoint_path'] = None | |
| st.info("Using default SaProt & ChemBERTa backbones.") | |
| st.markdown("---") | |
| start_train = st.button("π₯ Start Fine-tuning", type="primary", use_container_width=True) | |
| if start_train: | |
| if uploaded_file is None: | |
| st.error("β Please upload a dataset first.") | |
| else: | |
| st.markdown("---") | |
| st.markdown("### π Training Progress") | |
| success = model.fine_tune(uploaded_file, epochs, lr, batch_size) | |
| if success: | |
| st.balloons() | |
| st.success("π Training pipeline initialized!") | |
| st.info(""" | |
| **Notes:** | |
| - Real training requires a GPU-enabled environment. | |
| - Remember to update vector libraries after training. | |
| - Manage models in 'Configuration' section. | |
| """) | |
| # 5. Configuration Mode | |
| elif mode == "π Configuration": | |
| st.subheader("System Configuration") | |
| # Vector Library Management | |
| st.markdown("### ποΈ Vector Management") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("#### Protein Vector Library") | |
| protein_json = PATH_CONFIG['protein']['json'] | |
| protein_pt = PATH_CONFIG['protein']['pt'] | |
| if os.path.exists(protein_pt): | |
| st.success("β Ready") | |
| data = torch.load(protein_pt, map_location='cpu') | |
| st.metric("Total Vectors", len(data['ids'])) | |
| else: | |
| st.warning("β οΈ Missing") | |
| with col2: | |
| st.markdown("#### Molecule Vector Library") | |
| molecule_json = PATH_CONFIG['molecule']['json'] | |
| molecule_pt = PATH_CONFIG['molecule']['pt'] | |
| if os.path.exists(molecule_pt): | |
| st.success("β Ready") | |
| data = torch.load(molecule_pt, map_location='cpu') | |
| st.metric("Total Vectors", len(data['ids'])) | |
| else: | |
| st.warning("β οΈ Missing") | |
| st.markdown("---") | |
| # Recompute Vectors | |
| st.markdown("### π Recompute Vector Library") | |
| st.warning("β οΈ This will overwrite existing vector files.") | |
| checkpoint_for_vector = st.text_input( | |
| "Model Checkpoint (optional):", | |
| placeholder="Leave empty for default backbones", | |
| help="Specify path to a trained checkpoint if available" | |
| ) | |
| if st.button("π Start Vector Recomputation", type="primary"): | |
| with st.spinner("Recomputing vectors, this may take a few minutes..."): | |
| try: | |
| success = model.update_vectors( | |
| protein_json=protein_json, | |
| molecule_json=molecule_json, | |
| output_dir='drug_target_activity/candidates', | |
| checkpoint_path=checkpoint_for_vector if checkpoint_for_vector else None | |
| ) | |
| if success: | |
| st.success("β Vector library updated!") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"β Recomputation failed: {str(e)}") | |
| st.markdown("---") | |
| # Model Info | |
| st.markdown("### π€ Model Pipeline Info") | |
| info_col1, info_col2 = st.columns(2) | |
| with info_col1: | |
| st.markdown("**Protein Encoder**") | |
| st.markdown("- Model: SaProt 650M") | |
| st.markdown(f"- Path: `{st.session_state.get('protein_model_path', 'N/A')}`") | |
| with info_col2: | |
| st.markdown("**Molecule Encoder**") | |
| st.markdown("- Model: ChemBERTa ZINC Base") | |
| st.markdown(f"- Path: `{st.session_state.get('molecule_model_path', 'N/A')}`") | |
| # 6. Team Info Mode | |
| elif mode == "π₯ Team Info": | |
| st.subheader("π Project Information") | |
| st.markdown(""" | |
| **Course:** Recommender System | |
| **Final Project:** Bidirectional Protein-Molecule Retrieval (Choice Two) | |
| **Date:** January 2026 | |
| --- | |
| #### π¨βπ¬ Members: | |
| - **Zhiyun Jiang** | |
| - **Xinyang Tong** | |
| --- | |
| #### π οΈ Technical Contributions: | |
| - **Architecture Design:** Dual-Tower Encoder Integration | |
| - **Model Training:** SaProt & ChemBERTa Fine-tuning | |
| - **Frontend:** Streamlit Academic UI Development | |
| """) | |
| # Footer | |
| st.markdown("---") | |
| st.caption("Final Project | Dual-Tower Architecture Implementation | Zhiyun Jiang, Xinyang Tong") | |