import streamlit as st import pandas as pd import time import numpy as np import sys import os import json import torch sys.path.append(os.path.dirname(__file__)) # # import pdb;pdb.set_trace() from server_infer import retrieve_topk, PATH_CONFIG from server_train import ( load_dual_tower_model, recompute_candidate_vectors, continuous_train ) # 1.Core Model Interfaces class DualTowerModel: def __init__(self): """ Initialize Dual tower model """ # check conifg if 'checkpoint_path' not in st.session_state: st.session_state['checkpoint_path'] = None if 'protein_model_path' not in st.session_state: st.session_state['protein_model_path'] = './SaProt_650M_AF2' if 'molecule_model_path' not in st.session_state: st.session_state['molecule_model_path'] = './ChemBERTa-zinc-base-v1' if 'vectors_ready' not in st.session_state: st.session_state['vectors_ready'] = self._check_vectors_exist() st.session_state['model_loaded'] = True def _check_vectors_exist(self): """Check vector""" protein_pt = PATH_CONFIG['protein']['pt'] molecule_pt = PATH_CONFIG['molecule']['pt'] return os.path.exists(protein_pt) and os.path.exists(molecule_pt) def retrieve_molecules_for_protein(self, protein_query, top_k=5): """ Task 1: Protein -> Molecule Input: Protein sequence/ID/Gene symbol Output: Recommended molecule list """ try: # Call retrieve_topk from server_infer results = retrieve_topk( input_type="protein", input_query=protein_query, topk=top_k, device='cpu' # Using CPU for web interface ) # Format results for frontend formatted_results = [] for res in results: formatted_results.append({ "smiles": res['id'], "score": res['score'], "name": res['info'].get('compound__name', 'Unknown'), "drugbank_id": res['info'].get('compound__drugbank_id', 'N/A'), "cas": res['info'].get('compound__cas', 'N/A') }) return formatted_results except Exception as e: st.error(f"Retrieval failed: {str(e)}") return [] def retrieve_proteins_for_molecule(self, molecule_query, top_k=5): """ Task 2: Molecule -> Protein Input: Molecule SMILES/Name/ID Output: Recommended protein list """ try: results = retrieve_topk( input_type="molecule", input_query=molecule_query, topk=top_k, device='cpu' ) # Format results formatted_results = [] for res in results: formatted_results.append({ "protein_id": res['id'], # Truncated display "desc": res['info'].get('target__gene', 'Unknown'), "score": res['score'], "class": res['info'].get('target__class', 'N/A'), "full_id": res['id'] }) return formatted_results except Exception as e: st.error(f"Retrieval failed: {str(e)}") return [] def fine_tune(self, uploaded_file, epochs, learning_rate, batch_size): """ Training Interface: Fine-tune model with uploaded data """ import tempfile try: # use temfile with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp: tmp.write(uploaded_file.getbuffer()) temp_dataset_path = tmp.name progress_bar = st.progress(0) status_text = st.empty() status_text.text("Initializing training environment...") # check exist model_save_dir = 'Dual_Tower_Model/customized_checkpoints' if not os.path.exists(model_save_dir): os.makedirs(model_save_dir, exist_ok=True) checkpoint_path = st.session_state.get('checkpoint_path', None) status_text.text("Starting actual training process...") # call the actual training function device = 'cuda' if torch.cuda.is_available() else 'cpu' if device == 'cpu': st.warning("⚠️ Detected CPU environment. Training might be very slow.") success = continuous_train( dataset_path=temp_dataset_path, model_save_dir=model_save_dir, best_model_path=checkpoint_path, epochs=epochs, lr=learning_rate, batch_size=batch_size, device=device ) if success: status_text.success("✅ Training completed successfully!") # Clean up temporary file if os.path.exists(temp_dataset_path): os.remove(temp_dataset_path) return success except Exception as e: st.error(f"Training failed: {str(e)}") return False def update_vectors(self, protein_json, molecule_json, output_dir, checkpoint_path=None): """ Update candidate vector library """ try: recompute_candidate_vectors( protein_json_path=protein_json, molecule_json_path=molecule_json, output_dir=output_dir, pt_path=checkpoint_path, batch_size=32, device='cpu' ) st.session_state['vectors_ready'] = True return True except Exception as e: st.error(f"Vector update failed: {str(e)}") return False # Initialize model (Singleton pattern) if 'model' not in st.session_state: with st.spinner('Initializing model...'): st.session_state['model'] = DualTowerModel() model = st.session_state['model'] # 2. 页面布局 (UI Layout) st.set_page_config(page_title="Bio-Retrieval Service", page_icon="🧬", layout="wide") # Sidebar Configuration with st.sidebar: st.title("🧬 Dual-Tower Service") st.markdown("---") mode = st.radio( "Select Mode:", ["🔍 Inference (Retrieval)", "⚙️ Training (Fine-tune)", "📊 Configuration", "👥 Team Info"] ) st.markdown("---") st.info("**Project Info**\n\nBidirectional Protein-Molecule Retrieval.\nBased on SaProt & Molecular Encoder.") # Header Section st.title("🧬 Bidirectional Protein-Molecule Retrieval") st.markdown(""" This service implements a **Dual-Tower Architecture** to bridge the gap between Proteins and molecules. """) # ========================================== # 3. Inference Mode # ========================================== if mode == "🔍 Inference (Retrieval)": st.subheader("Interactive Retrieval System") # Check if vectors are ready if not st.session_state.get('vectors_ready', False): st.error("❌ Vectors not ready. Please generate them in 'Configuration' first.") st.stop() # Use Tabs for tasks tab1, tab2 = st.tabs(["Protein → Molecule", "Molecule → Protein"]) # --- Task 1: Protein to Molecule --- with tab1: st.markdown("### Recommend molecules for a specific protein target") # Quick Fill for Protein st.markdown("**Quick Fill Examples:**") q_col1, q_col2, q_col3 = st.columns([1, 1, 2]) if q_col1.button("Example: ROS1", key="btn_ros1"): st.session_state['p_input_val'] = "ROS1" if q_col2.button("Example: P08922", key="btn_p08922"): st.session_state['p_input_val'] = "P08922" col1, col2 = st.columns([2, 1]) with col1: protein_input = st.text_area( "Enter Protein Sequence (FASTA format or raw sequence):", height=150, value=st.session_state.get('p_input_val', ""), placeholder="e.g., ROS1, P08922, or full sequence...", help="Supports Gene Symbols, UniProt IDs, or raw sequences" ) top_k_p2m = st.slider("Top K Results", 1, 20, 5, key="slider_p2m") with col2: st.markdown("#### Configuration") st.write(f"Encoder: **SaProt (650M)**") st.write("Search Space: **ChEMBL / ZINC15**") search_btn_p2m = st.button("🚀 Retrieve Molecules", type="primary") if search_btn_p2m and protein_input: with st.spinner("Encoding protein and searching vector database..."): results = model.retrieve_molecules_for_protein(protein_input, top_k_p2m) if results: st.success(f"Found {len(results)} potential molecules.") for idx, res in enumerate(results): with st.expander( f"🏆 Rank {idx+1}: {res['name']} | Similarity: {res['score']:.4f}", expanded=(idx < 1) ): col_a, col_b = st.columns([3, 2]) with col_a: st.markdown("**SMILES Structure:**") st.code(res['smiles'], language="text") with col_b: st.markdown("**Compound Info:**") st.markdown(f"- **Name:** {res['name']}") st.markdown(f"- **DrugBank ID:** {res.get('drugbank_id', 'N/A')}") st.markdown(f"- **CAS No.:** {res.get('cas', 'N/A')}") st.markdown(f"**Similarity Score:**") # Ensure score is within [0, 1] for st.progress display_score = max(0.0, min(1.0, float(res['score']))) st.progress(display_score) else: st.warning("No matches found. Please check your input.") # --- Task 2: Molecule to Protein --- with tab2: st.markdown("### Retrieve potential protein targets for a specific molecule") # Quick Fill for Molecule st.markdown("**Quick Fill Examples:**") m_q_col1, m_q_col2 = st.columns([1, 3]) if m_q_col1.button("Example: Lorlatinib", key="btn_lor"): st.session_state['m_input_val'] = "Lorlatinib" if m_q_col2.button("Example: SMILES", key="btn_smiles"): st.session_state['m_input_val'] = "C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2" col1, col2 = st.columns([2, 1]) with col1: mol_input = st.text_input( "Enter Molecule SMILES:", value=st.session_state.get('m_input_val', ""), placeholder="e.g., Lorlatinib, DB12130, or SMILES string", help="Supports Compound Names, DrugBank IDs, or SMILES strings" ) top_k_m2p = st.slider("Top K Results", 1, 20, 5, key="slider_m2p") with col2: st.markdown("#### Configuration") st.write(f"Encoder: **ChemBERTa**") st.write("Search Space: **UniProt / PDB**") search_btn_m2p = st.button("🚀 Retrieve Proteins", type="primary") if search_btn_m2p and mol_input: with st.spinner("Encoding molecule and searching vector database..."): results = model.retrieve_proteins_for_molecule(mol_input, top_k_m2p) if results: st.success(f"Found {len(results)} potential protein targets.") st.markdown("---") st.markdown("#### 🏆 Retrieval Results") # import pdb;pdb.set_trace() for idx, res in enumerate(results): with st.expander( f"🏆 Rank {idx+1}: {res['desc']} | Similarity: {res['score']:.4f}", expanded=(idx < 1) ): st.markdown(f"**Full ID:** `{res['full_id']}`") st.markdown(f"**Gene Symbol:** {res['desc']}") st.markdown(f"**Class:** {res.get('class', 'N/A')}") st.markdown(f"**Similarity Score:**") # Ensure score is within [0, 1] for st.progress display_score = max(0.0, min(1.0, float(res['score']))) st.progress(display_score) else: st.warning("No matches found. Please check your input.") # 4. Training Mode elif mode == "⚙️ Training (Fine-tune)": st.subheader("Fine-tune the Dual-Tower Model") st.warning("This module allows you to input custom data to fine-tune the encoders.") # File Upload st.markdown("### 📁 Upload Training Data") uploaded_file = st.file_uploader( "Upload dataset (CSV/JSON)", type=["csv", "json"], help="Dataset should include: compound__smiles, target__foldseek_seq, outcome_potency_pxc50" ) if uploaded_file: st.success("✅ File uploaded successfully") st.markdown("**Data Preview:**") try: df = pd.read_csv(uploaded_file) st.dataframe(df.head(10), use_container_width=True) st.info(f"📊 Dataset Size: {len(df)} records, {len(df.columns)} fields") except Exception as e: st.error(f"⚠️ Parsing failed: {str(e)}. Ensure format is correct.") # Hyperparameters st.markdown("---") st.markdown("### ⚙️ Hyperparameters") col1, col2, col3 = st.columns(3) with col1: epochs = st.number_input("Epochs", min_value=1, max_value=100, value=5) with col2: lr = st.number_input("Learning Rate", min_value=0.0001, max_value=0.1, value=0.001, format="%.4f") with col3: batch_size = st.selectbox("Batch Size", [16, 32, 64, 128], index=1) # Checkpoint Config st.markdown("---") st.markdown("### 📦 Model Checkpoint") checkpoint_option = st.radio( "Start training from:", ["Pre-trained Backbones", "Existing Checkpoint"], help="Use backbones for first run, or checkpoint for incremental training" ) if checkpoint_option == "Existing Checkpoint": checkpoint_path = st.text_input( "Checkpoint Path:", value="Dual_Tower_Model/output_checkpoints_ddp/model_epoch_7_acc_0.3259.pt", help="Full path to .pt file" ) st.session_state['checkpoint_path'] = checkpoint_path else: st.session_state['checkpoint_path'] = None st.info("Using default SaProt & ChemBERTa backbones.") st.markdown("---") start_train = st.button("🔥 Start Fine-tuning", type="primary", use_container_width=True) if start_train: if uploaded_file is None: st.error("❌ Please upload a dataset first.") else: st.markdown("---") st.markdown("### 📈 Training Progress") success = model.fine_tune(uploaded_file, epochs, lr, batch_size) if success: st.balloons() st.success("🎉 Training pipeline initialized!") st.info(""" **Notes:** - Real training requires a GPU-enabled environment. - Remember to update vector libraries after training. - Manage models in 'Configuration' section. """) # 5. Configuration Mode elif mode == "📊 Configuration": st.subheader("System Configuration") # Vector Library Management st.markdown("### 🗄️ Vector Management") col1, col2 = st.columns(2) with col1: st.markdown("#### Protein Vector Library") protein_json = PATH_CONFIG['protein']['json'] protein_pt = PATH_CONFIG['protein']['pt'] if os.path.exists(protein_pt): st.success("✅ Ready") data = torch.load(protein_pt, map_location='cpu') st.metric("Total Vectors", len(data['ids'])) else: st.warning("⚠️ Missing") with col2: st.markdown("#### Molecule Vector Library") molecule_json = PATH_CONFIG['molecule']['json'] molecule_pt = PATH_CONFIG['molecule']['pt'] if os.path.exists(molecule_pt): st.success("✅ Ready") data = torch.load(molecule_pt, map_location='cpu') st.metric("Total Vectors", len(data['ids'])) else: st.warning("⚠️ Missing") st.markdown("---") # Recompute Vectors st.markdown("### 🔄 Recompute Vector Library") st.warning("⚠️ This will overwrite existing vector files.") checkpoint_for_vector = st.text_input( "Model Checkpoint (optional):", placeholder="Leave empty for default backbones", help="Specify path to a trained checkpoint if available" ) if st.button("🚀 Start Vector Recomputation", type="primary"): with st.spinner("Recomputing vectors, this may take a few minutes..."): try: success = model.update_vectors( protein_json=protein_json, molecule_json=molecule_json, output_dir='drug_target_activity/candidates', checkpoint_path=checkpoint_for_vector if checkpoint_for_vector else None ) if success: st.success("✅ Vector library updated!") st.rerun() except Exception as e: st.error(f"❌ Recomputation failed: {str(e)}") st.markdown("---") # Model Info st.markdown("### 🤖 Model Pipeline Info") info_col1, info_col2 = st.columns(2) with info_col1: st.markdown("**Protein Encoder**") st.markdown("- Model: SaProt 650M") st.markdown(f"- Path: `{st.session_state.get('protein_model_path', 'N/A')}`") with info_col2: st.markdown("**Molecule Encoder**") st.markdown("- Model: ChemBERTa ZINC Base") st.markdown(f"- Path: `{st.session_state.get('molecule_model_path', 'N/A')}`") # 6. Team Info Mode elif mode == "👥 Team Info": st.subheader("🎓 Project Information") st.markdown(""" **Course:** Recommender System **Final Project:** Bidirectional Protein-Molecule Retrieval (Choice Two) **Date:** January 2026 --- #### 👨‍🔬 Members: - **Zhiyun Jiang** - **Xinyang Tong** --- #### 🛠️ Technical Contributions: - **Architecture Design:** Dual-Tower Encoder Integration - **Model Training:** SaProt & ChemBERTa Fine-tuning - **Frontend:** Streamlit Academic UI Development """) # Footer st.markdown("---") st.caption("Final Project | Dual-Tower Architecture Implementation | Zhiyun Jiang, Xinyang Tong")