BroteinShake / app.py
42Cummer's picture
Upload app.py
178c45c verified
import gradio as gr # type: ignore
import os
from gradio_molecule3d import Molecule3D #type: ignore
from Bio.PDB import PDBParser #type: ignore
import time
# Import your custom modules from the /scripts folder
from scripts.download import download_and_clean_pdb
from scripts.generator import run_broteinshake_generator
from scripts.refine import polish_design, process_results
from scripts.visualize import create_design_plot
from scripts.foldprotein import fold_protein_sequence
# --- HELPER FUNCTIONS ---
def get_pdb_chains(pdb_file):
"""Extracts unique chain IDs from a PDB file."""
if not pdb_file or not os.path.exists(pdb_file):
return []
try:
parser = PDBParser(QUIET=True)
structure = parser.get_structure("temp", pdb_file)
chains = [chain.id for model in structure for chain in model]
return sorted(list(set(chains)))
except Exception as e:
print(f"Error extracting chains: {e}")
return []
def load_pdb_and_extract_chains(pdb_id):
"""Download PDB and extract chains for selection."""
if not pdb_id or not pdb_id.strip():
return gr.update(choices=[], value=[]), "⚠️ Please enter a PDB ID", gr.update(interactive=False), []
try:
# Download the PDB
pdb_path = download_and_clean_pdb(pdb_id.strip(), data_dir="data")
# Extract chains
chains = get_pdb_chains(pdb_path)
if not chains:
return gr.update(choices=[], value=[]), f"⚠️ No chains found in {pdb_id.upper()}", gr.update(interactive=False), []
# Single-chain proteins are supported (will use different ProteinMPNN command)
# For single-chain, the only chain will be automatically selected for redesign
if len(chains) == 1:
status_msg = f"Loaded {pdb_id.upper()}: Single-chain protein - will redesign chain {chains[0]}"
# Auto-select the chain for single-chain proteins
return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=True), chains
status_msg = f"Loaded {pdb_id.upper()}: Found {len(chains)} chain(s) - {', '.join(chains)}"
# Initially disable button - user must select at least one chain
return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=False), chains
except Exception as e:
error_msg = f"Error loading {pdb_id.upper()}: {str(e)}"
print(error_msg)
return gr.update(choices=[], value=[]), error_msg, gr.update(interactive=False), []
def validate_chain_selection(selected_chains, available_chains_state):
"""Validate that at least one chain is selected and at least one remains fixed (for multi-chain)."""
if not selected_chains or len(selected_chains) == 0:
warning = "Please select at least one chain to redesign"
return gr.update(interactive=False), warning, available_chains_state
# Get available chains from state
available_chains = available_chains_state if available_chains_state else []
# For single-chain proteins, allow selecting the only chain
if len(available_chains) == 1:
warning = f"Single-chain protein: Will redesign chain {available_chains[0]}"
return gr.update(interactive=True), warning, available_chains_state
# For multi-chain: Check if all chains are selected (would leave no fixed chains)
if available_chains and len(selected_chains) >= len(available_chains):
warning = f"Cannot select all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}"
return gr.update(interactive=False), warning, available_chains_state
warning = f"{len(selected_chains)} chain(s) selected for redesign: {', '.join(selected_chains)}"
return gr.update(interactive=True), warning, available_chains_state
def get_all_sequences(fasta_file: str) -> str:
"""Get all designed sequences from FASTA file."""
sequences = []
with open(fasta_file, 'r') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
for i in range(0, len(lines), 2):
if i + 1 >= len(lines):
break
header = lines[i]
sequence = lines[i+1]
# Skip the original native sequence (first entry)
if "sample" not in header:
continue
sequences.append(f"{header}\n{sequence}")
if sequences:
return "\n\n".join(sequences)
else:
raise ValueError(f"No valid designs found in {fasta_file}")
def extract_best_sequence(fasta_file: str) -> str:
"""Extract the best sequence (lowest score) from FASTA file."""
best_score = float('inf')
best_header = ""
best_seq = ""
with open(fasta_file, 'r') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
for i in range(0, len(lines), 2):
if i + 1 >= len(lines):
break
header = lines[i]
sequence = lines[i+1]
# Skip the original native sequence (first entry)
if "sample" not in header:
continue
# Parse the score: "score=0.7647"
try:
score_part = [p for p in header.split(',') if 'score' in p][0]
score = float(score_part.split('=')[1])
if score < best_score:
best_score = score
best_header = header
best_seq = sequence
except (IndexError, ValueError):
continue
if best_seq:
return f"{best_header}\n{best_seq}"
else:
raise ValueError(f"No valid designs found in {fasta_file}")
def run_part1(pdb_id, fixed_chains, variable_chains, temperature=0.1, selected_chains=None):
"""Downloads the PDB and runs ProteinMPNN design."""
try:
# Step 1: Secure the template
pdb_path = download_and_clean_pdb(pdb_id, data_dir="data")
# Handle chain selection logic
# If chains are selected via checkbox, use those as variable chains
# Otherwise, use the text input (backward compatibility)
all_chains = get_pdb_chains(pdb_path)
# Check if single-chain protein
is_single_chain = len(all_chains) == 1
if selected_chains and len(selected_chains) > 0:
# Selected chains = variable chains, rest = fixed
variable_chains = "".join(selected_chains)
fixed_chains = "".join([c for c in all_chains if c not in selected_chains])
# For single-chain: no fixed chains (will use different ProteinMPNN command)
# For multi-chain: Validate must have at least one fixed chain
if not is_single_chain and (not fixed_chains or len(fixed_chains) == 0):
raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}, Available: {', '.join(all_chains)}")
if is_single_chain:
print(f"Single-chain mode: Redesigning chain {variable_chains}")
else:
print(f"Using chain selector: Fixed={fixed_chains}, Variable={variable_chains}")
else:
# If no chains selected, use text inputs (default behavior)
# For single-chain, if variable_chains is empty, use the only chain
if is_single_chain and not variable_chains:
variable_chains = all_chains[0]
fixed_chains = ""
# For multi-chain: Validate text inputs don't select all chains
elif not is_single_chain and fixed_chains and variable_chains:
all_selected = set(fixed_chains + variable_chains)
if len(all_selected) >= len(all_chains):
raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed.")
print(f"Using text inputs: Fixed={fixed_chains}, Variable={variable_chains}")
# Step 2: Generate Optimized Sequences
# This creates the .fa files you need for the ESM Atlas
print(f"Temperature: {temperature}")
print(f"Parameters: Fixed chains={fixed_chains}, Variable chains={variable_chains}, Temp={temperature}")
run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=temperature)
# Get all sequences and the best one
fa_file = os.path.join("generated", pdb_id.lower(), "seqs", f"{pdb_id.lower()}_clones.fa")
all_sequences = get_all_sequences(fa_file)
best_sequence = extract_best_sequence(fa_file)
# Generate the dashboard plot
evolution_plot = create_design_plot(fa_file)
# Parse score from header for status message
score_part = [p for p in best_sequence.split('\n')[0].split(',') if 'score' in p][0]
best_score = float(score_part.split('=')[1])
# Count number of designs
num_designs = len([s for s in all_sequences.split('\n\n') if s.strip()])
# Format status with best sequence
status_message = (
f"Design Complete! {num_designs} designs generated.\n\n"
f"Lead Candidate (Best Score: {best_score:.4f}):\n"
f"{best_sequence}\n\n"
)
return all_sequences, evolution_plot, status_message
except Exception as e:
return "", None, f"Error in Part 1: {str(e)}"
def run_part2(pdb_id, sequence):
"""
1. Folds the input sequence using ESM Atlas (API).
2. Aligns the folded structure to the target PDB (polish_design).
"""
try:
# --- 1. Validate Inputs ---
if not sequence or not sequence.strip():
return None, "Error: Please enter a protein sequence."
if not pdb_id or not pdb_id.strip():
return None, "Error: PDB ID is required."
print(f"Starting Pipeline for {pdb_id}...")
print(f" - Sequence Length: {len(sequence)} residues")
# --- 2. Fold Sequence (Automated) ---
# Calls the script in scripts/foldprotein.py
pdb_content = fold_protein_sequence(sequence)
if not pdb_content:
return None, "Error: Protein folding failed. The API might be down or the sequence is invalid."
# Save the raw folded structure to a temp file
raw_fold_path = f"temp_fold_{pdb_id}_{int(time.time())}.pdb"
with open(raw_fold_path, "w") as f:
f.write(pdb_content)
if not os.path.exists(raw_fold_path):
return None, f"Error: Could not save folded PDB to {raw_fold_path}"
# --- 3. Align & Polish (Existing Logic) ---
print(f"Aligning folded structure to target {pdb_id}...")
# Pass the generated file path to the existing polish_design function
final_pdb_path, global_rmsd, core_rmsd, high_conf_rmsd = polish_design(pdb_id, raw_fold_path)
# --- 4. Validate Alignment Output ---
if not final_pdb_path or not os.path.exists(final_pdb_path):
return None, f"Error: Alignment failed - output file not created: {final_pdb_path}"
if high_conf_rmsd is None:
return None, "Error: Alignment failed - RMSD calculation returned None"
print(f"Success: Global RMSD={global_rmsd:.3f}A | Core RMSD={core_rmsd:.3f}A")
# --- 5. Generate Report ---
report = process_results(pdb_id, final_pdb_path, global_rmsd, high_conf_rmsd)
# Clean up the raw unaligned fold to save space
os.remove(raw_fold_path)
return final_pdb_path, report
except Exception as e:
error_msg = f"Unexpected Error: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, error_msg
# --- GRADIO INTERFACE ---
# 1. Simple Dark Theme with Blue and Emerald Accents
dark_biohub = gr.themes.Base(
primary_hue="blue",
secondary_hue="emerald",
neutral_hue="slate",
).set(
body_background_fill="#0f172a",
block_background_fill="#1e293b",
body_text_color="#f1f5f9",
button_primary_background_fill="#10b981",
button_primary_text_color="#ffffff",
)
# 2. Targeted CSS for the 3D Viewer & Header
biohub_css = """
/* Remove the footer for a clean portfolio look */
footer {display: none !important;}
/* Fix the 3D viewer background to match the dark theme */
#molecule-viewer {
background-color: #111827 !important;
border: 1px solid #374151 !important;
border-radius: 12px;
}
/* Header Styling */
#biohub-header {
background: linear-gradient(135deg, #064e3b 0%, #1e40af 100%);
padding: 1.5rem;
border-radius: 12px;
border: 1px solid #10b981;
margin-bottom: 1rem;
}
"""
with gr.Blocks(theme=dark_biohub, css=biohub_css) as demo:
# Header
gr.HTML("""
<div id='biohub-header'>
<h1 style='color: white; margin: 0;'>BroteinShake</h1>
</div>
""")
with gr.Tabs():
# TAB 1: GENERATIVE DESIGN
with gr.Tab("1. Sequence Generation"):
gr.Markdown("Enter a PDB ID to 'repaint' its binder interface using ProteinMPNN.")
pdb_input = gr.Textbox(label="Target PDB ID", placeholder="e.g., 3kas", value="")
load_pdb_btn = gr.Button("Load PDB", variant="secondary")
pdb_status = gr.Markdown("Enter a PDB ID and click 'Load PDB' to begin")
with gr.Column():
gr.Markdown("### Design Parameters")
# Temperature (T) is the most critical knob for sequence recovery
sampling_temp = gr.Slider(
minimum=0.05, maximum=1.0, value=0.1, step=0.05,
label="Sampling Temperature (T)",
info="T=0.1 for high-fidelity; T=0.3 for natural diversity"
)
# Dynamic Chain Handling
chain_options = gr.CheckboxGroup(
choices=[],
label="Chains to Redesign",
info="Identify which chains ProteinMPNN should modify (will populate after loading PDB)"
)
chain_warning = gr.Markdown("Select at least one chain to enable generation", visible=True)
# Hidden state to track if we've successfully parsed the PDB
pdb_state = gr.State()
# Legacy text inputs (hidden but kept for backward compatibility)
with gr.Row(visible=False):
f_chains = gr.Textbox(label="Fixed Chains (Lock)", value="A")
v_chains = gr.Textbox(label="Variable Chains (Key)", value="B")
# Generate button (initially disabled)
gen_btn = gr.Button("Generate Optimized Sequences", variant="primary", interactive=False)
# Load PDB and extract chains when button is clicked
load_pdb_btn.click(
fn=load_pdb_and_extract_chains,
inputs=[pdb_input],
outputs=[chain_options, pdb_status, gen_btn, pdb_state]
)
# Validate chain selection and update button state
chain_options.change(
fn=validate_chain_selection,
inputs=[chain_options, pdb_state],
outputs=[gen_btn, chain_warning, pdb_state]
)
# Stack components vertically
fa_output = gr.Code(
label="Designed Sequences",
language="markdown",
lines=10,
max_lines=20 # Prevents infinite growth when showing all 20 candidates
)
plot_output = gr.Plot(
label="Design Evolution Dashboard"
)
gr.Markdown("### System Status")
status1 = gr.Markdown()
gen_btn.click(run_part1, inputs=[pdb_input, f_chains, v_chains, sampling_temp, chain_options],
outputs=[fa_output, plot_output, status1])
# TAB 2: STRUCTURAL VALIDATION
with gr.Tab("2. Structural Validation"):
gr.Markdown("### Final Structure Preview")
# Updated REPS for local dev visibility
REPS = [
{
"model": 0,
"style": "cartoon",
"color": "spectrum",
"opacity": 1.0
}
]
# 3D Viewer component
protein_view = Molecule3D(label="3D Structure Viewer (Refined Shuttle)", reps=REPS, elem_id="molecule-viewer")
with gr.Row():
# REPLACEMENT: Text Area instead of File Upload
sequence_input = gr.Textbox(
label="Paste Protein Sequence",
placeholder="Paste the sequence generated in Tab 1 here (e.g., MKTII...)",
lines=4,
max_lines=8
)
refined_download = gr.File(label="Download Aligned Lead (.pdb)")
validate_btn = gr.Button("Run Structural Alignment", variant="primary")
status2 = gr.Textbox(label="Validation Report", interactive=False, lines=5)
# Wrapper function now accepts 'sequence' instead of 'file'
def run_validation_with_view(pdb_id, sequence):
try:
# Call the updated run_part2 logic (API Fold -> Align)
final_pdb, report = run_part2(pdb_id, sequence)
if final_pdb is not None and os.path.exists(final_pdb):
# Verify we're using the refined shuttle
if "Refined_Shuttle.pdb" in final_pdb or os.path.basename(final_pdb) == "Refined_Shuttle.pdb":
print(f"Visualizing refined shuttle: {final_pdb}")
else:
print(f"Warning: Expected Refined_Shuttle.pdb but got: {final_pdb}")
# Return path for download, report text, and path for 3D viewer
return final_pdb, report, final_pdb
except Exception as e:
error_msg = f"Error generating 3D view: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, error_msg, None
# Updated inputs to include sequence_input
validate_btn.click(
run_validation_with_view,
inputs=[pdb_input, sequence_input],
outputs=[refined_download, status2, protein_view]
)
# Launch the app
if __name__ == "__main__":
# Docker deployment for HuggingFace Spaces
demo.launch(server_name="0.0.0.0", server_port=7860)