tong
revise streamlit temp files
aab3f3d
import streamlit as st
import pandas as pd
import time
import numpy as np
import sys
import os
import json
import torch
sys.path.append(os.path.dirname(__file__))
#
# import pdb;pdb.set_trace()
from server_infer import retrieve_topk, PATH_CONFIG
from server_train import (
load_dual_tower_model,
recompute_candidate_vectors,
continuous_train
)
# 1.Core Model Interfaces
class DualTowerModel:
def __init__(self):
"""
Initialize Dual tower model
"""
# check conifg
if 'checkpoint_path' not in st.session_state:
st.session_state['checkpoint_path'] = None
if 'protein_model_path' not in st.session_state:
st.session_state['protein_model_path'] = './SaProt_650M_AF2'
if 'molecule_model_path' not in st.session_state:
st.session_state['molecule_model_path'] = './ChemBERTa-zinc-base-v1'
if 'vectors_ready' not in st.session_state:
st.session_state['vectors_ready'] = self._check_vectors_exist()
st.session_state['model_loaded'] = True
def _check_vectors_exist(self):
"""Check vector"""
protein_pt = PATH_CONFIG['protein']['pt']
molecule_pt = PATH_CONFIG['molecule']['pt']
return os.path.exists(protein_pt) and os.path.exists(molecule_pt)
def retrieve_molecules_for_protein(self, protein_query, top_k=5):
"""
Task 1: Protein -> Molecule
Input: Protein sequence/ID/Gene symbol
Output: Recommended molecule list
"""
try:
# Call retrieve_topk from server_infer
results = retrieve_topk(
input_type="protein",
input_query=protein_query,
topk=top_k,
device='cpu' # Using CPU for web interface
)
# Format results for frontend
formatted_results = []
for res in results:
formatted_results.append({
"smiles": res['id'],
"score": res['score'],
"name": res['info'].get('compound__name', 'Unknown'),
"drugbank_id": res['info'].get('compound__drugbank_id', 'N/A'),
"cas": res['info'].get('compound__cas', 'N/A')
})
return formatted_results
except Exception as e:
st.error(f"Retrieval failed: {str(e)}")
return []
def retrieve_proteins_for_molecule(self, molecule_query, top_k=5):
"""
Task 2: Molecule -> Protein
Input: Molecule SMILES/Name/ID
Output: Recommended protein list
"""
try:
results = retrieve_topk(
input_type="molecule",
input_query=molecule_query,
topk=top_k,
device='cpu'
)
# Format results
formatted_results = []
for res in results:
formatted_results.append({
"protein_id": res['id'], # Truncated display
"desc": res['info'].get('target__gene', 'Unknown'),
"score": res['score'],
"class": res['info'].get('target__class', 'N/A'),
"full_id": res['id']
})
return formatted_results
except Exception as e:
st.error(f"Retrieval failed: {str(e)}")
return []
def fine_tune(self, uploaded_file, epochs, learning_rate, batch_size):
"""
Training Interface: Fine-tune model with uploaded data
"""
import tempfile
try:
# use temfile
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp:
tmp.write(uploaded_file.getbuffer())
temp_dataset_path = tmp.name
progress_bar = st.progress(0)
status_text = st.empty()
status_text.text("Initializing training environment...")
# check exist
model_save_dir = 'Dual_Tower_Model/customized_checkpoints'
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir, exist_ok=True)
checkpoint_path = st.session_state.get('checkpoint_path', None)
status_text.text("Starting actual training process...")
# call the actual training function
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cpu':
st.warning("⚠️ Detected CPU environment. Training might be very slow.")
success = continuous_train(
dataset_path=temp_dataset_path,
model_save_dir=model_save_dir,
best_model_path=checkpoint_path,
epochs=epochs,
lr=learning_rate,
batch_size=batch_size,
device=device
)
if success:
status_text.success("βœ… Training completed successfully!")
# Clean up temporary file
if os.path.exists(temp_dataset_path):
os.remove(temp_dataset_path)
return success
except Exception as e:
st.error(f"Training failed: {str(e)}")
return False
def update_vectors(self, protein_json, molecule_json, output_dir, checkpoint_path=None):
"""
Update candidate vector library
"""
try:
recompute_candidate_vectors(
protein_json_path=protein_json,
molecule_json_path=molecule_json,
output_dir=output_dir,
pt_path=checkpoint_path,
batch_size=32,
device='cpu'
)
st.session_state['vectors_ready'] = True
return True
except Exception as e:
st.error(f"Vector update failed: {str(e)}")
return False
# Initialize model (Singleton pattern)
if 'model' not in st.session_state:
with st.spinner('Initializing model...'):
st.session_state['model'] = DualTowerModel()
model = st.session_state['model']
# 2. 鑡青布局 (UI Layout)
st.set_page_config(page_title="Bio-Retrieval Service", page_icon="🧬", layout="wide")
# Sidebar Configuration
with st.sidebar:
st.title("🧬 Dual-Tower Service")
st.markdown("---")
mode = st.radio(
"Select Mode:",
["πŸ” Inference (Retrieval)", "βš™οΈ Training (Fine-tune)", "πŸ“Š Configuration", "πŸ‘₯ Team Info"]
)
st.markdown("---")
st.info("**Project Info**\n\nBidirectional Protein-Molecule Retrieval.\nBased on SaProt & Molecular Encoder.")
# Header Section
st.title("🧬 Bidirectional Protein-Molecule Retrieval")
st.markdown("""
This service implements a **Dual-Tower Architecture** to bridge the gap between Proteins and molecules.
""")
# ==========================================
# 3. Inference Mode
# ==========================================
if mode == "πŸ” Inference (Retrieval)":
st.subheader("Interactive Retrieval System")
# Check if vectors are ready
if not st.session_state.get('vectors_ready', False):
st.error("❌ Vectors not ready. Please generate them in 'Configuration' first.")
st.stop()
# Use Tabs for tasks
tab1, tab2 = st.tabs(["Protein β†’ Molecule", "Molecule β†’ Protein"])
# --- Task 1: Protein to Molecule ---
with tab1:
st.markdown("### Recommend molecules for a specific protein target")
# Quick Fill for Protein
st.markdown("**Quick Fill Examples:**")
q_col1, q_col2, q_col3 = st.columns([1, 1, 2])
if q_col1.button("Example: ROS1", key="btn_ros1"):
st.session_state['p_input_val'] = "ROS1"
if q_col2.button("Example: P08922", key="btn_p08922"):
st.session_state['p_input_val'] = "P08922"
col1, col2 = st.columns([2, 1])
with col1:
protein_input = st.text_area(
"Enter Protein Sequence (FASTA format or raw sequence):",
height=150,
value=st.session_state.get('p_input_val', ""),
placeholder="e.g., ROS1, P08922, or full sequence...",
help="Supports Gene Symbols, UniProt IDs, or raw sequences"
)
top_k_p2m = st.slider("Top K Results", 1, 20, 5, key="slider_p2m")
with col2:
st.markdown("#### Configuration")
st.write(f"Encoder: **SaProt (650M)**")
st.write("Search Space: **ChEMBL / ZINC15**")
search_btn_p2m = st.button("πŸš€ Retrieve Molecules", type="primary")
if search_btn_p2m and protein_input:
with st.spinner("Encoding protein and searching vector database..."):
results = model.retrieve_molecules_for_protein(protein_input, top_k_p2m)
if results:
st.success(f"Found {len(results)} potential molecules.")
for idx, res in enumerate(results):
with st.expander(
f"πŸ† Rank {idx+1}: {res['name']} | Similarity: {res['score']:.4f}",
expanded=(idx < 1)
):
col_a, col_b = st.columns([3, 2])
with col_a:
st.markdown("**SMILES Structure:**")
st.code(res['smiles'], language="text")
with col_b:
st.markdown("**Compound Info:**")
st.markdown(f"- **Name:** {res['name']}")
st.markdown(f"- **DrugBank ID:** {res.get('drugbank_id', 'N/A')}")
st.markdown(f"- **CAS No.:** {res.get('cas', 'N/A')}")
st.markdown(f"**Similarity Score:**")
# Ensure score is within [0, 1] for st.progress
display_score = max(0.0, min(1.0, float(res['score'])))
st.progress(display_score)
else:
st.warning("No matches found. Please check your input.")
# --- Task 2: Molecule to Protein ---
with tab2:
st.markdown("### Retrieve potential protein targets for a specific molecule")
# Quick Fill for Molecule
st.markdown("**Quick Fill Examples:**")
m_q_col1, m_q_col2 = st.columns([1, 3])
if m_q_col1.button("Example: Lorlatinib", key="btn_lor"):
st.session_state['m_input_val'] = "Lorlatinib"
if m_q_col2.button("Example: SMILES", key="btn_smiles"):
st.session_state['m_input_val'] = "C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2"
col1, col2 = st.columns([2, 1])
with col1:
mol_input = st.text_input(
"Enter Molecule SMILES:",
value=st.session_state.get('m_input_val', ""),
placeholder="e.g., Lorlatinib, DB12130, or SMILES string",
help="Supports Compound Names, DrugBank IDs, or SMILES strings"
)
top_k_m2p = st.slider("Top K Results", 1, 20, 5, key="slider_m2p")
with col2:
st.markdown("#### Configuration")
st.write(f"Encoder: **ChemBERTa**")
st.write("Search Space: **UniProt / PDB**")
search_btn_m2p = st.button("πŸš€ Retrieve Proteins", type="primary")
if search_btn_m2p and mol_input:
with st.spinner("Encoding molecule and searching vector database..."):
results = model.retrieve_proteins_for_molecule(mol_input, top_k_m2p)
if results:
st.success(f"Found {len(results)} potential protein targets.")
st.markdown("---")
st.markdown("#### πŸ† Retrieval Results")
# import pdb;pdb.set_trace()
for idx, res in enumerate(results):
with st.expander(
f"πŸ† Rank {idx+1}: {res['desc']} | Similarity: {res['score']:.4f}",
expanded=(idx < 1)
):
st.markdown(f"**Full ID:** `{res['full_id']}`")
st.markdown(f"**Gene Symbol:** {res['desc']}")
st.markdown(f"**Class:** {res.get('class', 'N/A')}")
st.markdown(f"**Similarity Score:**")
# Ensure score is within [0, 1] for st.progress
display_score = max(0.0, min(1.0, float(res['score'])))
st.progress(display_score)
else:
st.warning("No matches found. Please check your input.")
# 4. Training Mode
elif mode == "βš™οΈ Training (Fine-tune)":
st.subheader("Fine-tune the Dual-Tower Model")
st.warning("This module allows you to input custom data to fine-tune the encoders.")
# File Upload
st.markdown("### πŸ“ Upload Training Data")
uploaded_file = st.file_uploader(
"Upload dataset (CSV/JSON)",
type=["csv", "json"],
help="Dataset should include: compound__smiles, target__foldseek_seq, outcome_potency_pxc50"
)
if uploaded_file:
st.success("βœ… File uploaded successfully")
st.markdown("**Data Preview:**")
try:
df = pd.read_csv(uploaded_file)
st.dataframe(df.head(10), use_container_width=True)
st.info(f"πŸ“Š Dataset Size: {len(df)} records, {len(df.columns)} fields")
except Exception as e:
st.error(f"⚠️ Parsing failed: {str(e)}. Ensure format is correct.")
# Hyperparameters
st.markdown("---")
st.markdown("### βš™οΈ Hyperparameters")
col1, col2, col3 = st.columns(3)
with col1:
epochs = st.number_input("Epochs", min_value=1, max_value=100, value=5)
with col2:
lr = st.number_input("Learning Rate", min_value=0.0001, max_value=0.1, value=0.001, format="%.4f")
with col3:
batch_size = st.selectbox("Batch Size", [16, 32, 64, 128], index=1)
# Checkpoint Config
st.markdown("---")
st.markdown("### πŸ“¦ Model Checkpoint")
checkpoint_option = st.radio(
"Start training from:",
["Pre-trained Backbones", "Existing Checkpoint"],
help="Use backbones for first run, or checkpoint for incremental training"
)
if checkpoint_option == "Existing Checkpoint":
checkpoint_path = st.text_input(
"Checkpoint Path:",
value="Dual_Tower_Model/output_checkpoints_ddp/model_epoch_7_acc_0.3259.pt",
help="Full path to .pt file"
)
st.session_state['checkpoint_path'] = checkpoint_path
else:
st.session_state['checkpoint_path'] = None
st.info("Using default SaProt & ChemBERTa backbones.")
st.markdown("---")
start_train = st.button("πŸ”₯ Start Fine-tuning", type="primary", use_container_width=True)
if start_train:
if uploaded_file is None:
st.error("❌ Please upload a dataset first.")
else:
st.markdown("---")
st.markdown("### πŸ“ˆ Training Progress")
success = model.fine_tune(uploaded_file, epochs, lr, batch_size)
if success:
st.balloons()
st.success("πŸŽ‰ Training pipeline initialized!")
st.info("""
**Notes:**
- Real training requires a GPU-enabled environment.
- Remember to update vector libraries after training.
- Manage models in 'Configuration' section.
""")
# 5. Configuration Mode
elif mode == "πŸ“Š Configuration":
st.subheader("System Configuration")
# Vector Library Management
st.markdown("### πŸ—„οΈ Vector Management")
col1, col2 = st.columns(2)
with col1:
st.markdown("#### Protein Vector Library")
protein_json = PATH_CONFIG['protein']['json']
protein_pt = PATH_CONFIG['protein']['pt']
if os.path.exists(protein_pt):
st.success("βœ… Ready")
data = torch.load(protein_pt, map_location='cpu')
st.metric("Total Vectors", len(data['ids']))
else:
st.warning("⚠️ Missing")
with col2:
st.markdown("#### Molecule Vector Library")
molecule_json = PATH_CONFIG['molecule']['json']
molecule_pt = PATH_CONFIG['molecule']['pt']
if os.path.exists(molecule_pt):
st.success("βœ… Ready")
data = torch.load(molecule_pt, map_location='cpu')
st.metric("Total Vectors", len(data['ids']))
else:
st.warning("⚠️ Missing")
st.markdown("---")
# Recompute Vectors
st.markdown("### πŸ”„ Recompute Vector Library")
st.warning("⚠️ This will overwrite existing vector files.")
checkpoint_for_vector = st.text_input(
"Model Checkpoint (optional):",
placeholder="Leave empty for default backbones",
help="Specify path to a trained checkpoint if available"
)
if st.button("πŸš€ Start Vector Recomputation", type="primary"):
with st.spinner("Recomputing vectors, this may take a few minutes..."):
try:
success = model.update_vectors(
protein_json=protein_json,
molecule_json=molecule_json,
output_dir='drug_target_activity/candidates',
checkpoint_path=checkpoint_for_vector if checkpoint_for_vector else None
)
if success:
st.success("βœ… Vector library updated!")
st.rerun()
except Exception as e:
st.error(f"❌ Recomputation failed: {str(e)}")
st.markdown("---")
# Model Info
st.markdown("### πŸ€– Model Pipeline Info")
info_col1, info_col2 = st.columns(2)
with info_col1:
st.markdown("**Protein Encoder**")
st.markdown("- Model: SaProt 650M")
st.markdown(f"- Path: `{st.session_state.get('protein_model_path', 'N/A')}`")
with info_col2:
st.markdown("**Molecule Encoder**")
st.markdown("- Model: ChemBERTa ZINC Base")
st.markdown(f"- Path: `{st.session_state.get('molecule_model_path', 'N/A')}`")
# 6. Team Info Mode
elif mode == "πŸ‘₯ Team Info":
st.subheader("πŸŽ“ Project Information")
st.markdown("""
**Course:** Recommender System
**Final Project:** Bidirectional Protein-Molecule Retrieval (Choice Two)
**Date:** January 2026
---
#### πŸ‘¨β€πŸ”¬ Members:
- **Zhiyun Jiang**
- **Xinyang Tong**
---
#### πŸ› οΈ Technical Contributions:
- **Architecture Design:** Dual-Tower Encoder Integration
- **Model Training:** SaProt & ChemBERTa Fine-tuning
- **Frontend:** Streamlit Academic UI Development
""")
# Footer
st.markdown("---")
st.caption("Final Project | Dual-Tower Architecture Implementation | Zhiyun Jiang, Xinyang Tong")