Spaces:
Sleeping
Sleeping
| import gradio as gr # type: ignore | |
| import os | |
| from gradio_molecule3d import Molecule3D #type: ignore | |
| from Bio.PDB import PDBParser #type: ignore | |
| import time | |
| # Import your custom modules from the /scripts folder | |
| from scripts.download import download_and_clean_pdb | |
| from scripts.generator import run_broteinshake_generator | |
| from scripts.refine import polish_design, process_results | |
| from scripts.visualize import create_design_plot | |
| from scripts.foldprotein import fold_protein_sequence | |
| # --- HELPER FUNCTIONS --- | |
| def get_pdb_chains(pdb_file): | |
| """Extracts unique chain IDs from a PDB file.""" | |
| if not pdb_file or not os.path.exists(pdb_file): | |
| return [] | |
| try: | |
| parser = PDBParser(QUIET=True) | |
| structure = parser.get_structure("temp", pdb_file) | |
| chains = [chain.id for model in structure for chain in model] | |
| return sorted(list(set(chains))) | |
| except Exception as e: | |
| print(f"Error extracting chains: {e}") | |
| return [] | |
| def load_pdb_and_extract_chains(pdb_id): | |
| """Download PDB and extract chains for selection.""" | |
| if not pdb_id or not pdb_id.strip(): | |
| return gr.update(choices=[], value=[]), "⚠️ Please enter a PDB ID", gr.update(interactive=False), [] | |
| try: | |
| # Download the PDB | |
| pdb_path = download_and_clean_pdb(pdb_id.strip(), data_dir="data") | |
| # Extract chains | |
| chains = get_pdb_chains(pdb_path) | |
| if not chains: | |
| return gr.update(choices=[], value=[]), f"⚠️ No chains found in {pdb_id.upper()}", gr.update(interactive=False), [] | |
| # Single-chain proteins are supported (will use different ProteinMPNN command) | |
| # For single-chain, the only chain will be automatically selected for redesign | |
| if len(chains) == 1: | |
| status_msg = f"Loaded {pdb_id.upper()}: Single-chain protein - will redesign chain {chains[0]}" | |
| # Auto-select the chain for single-chain proteins | |
| return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=True), chains | |
| status_msg = f"Loaded {pdb_id.upper()}: Found {len(chains)} chain(s) - {', '.join(chains)}" | |
| # Initially disable button - user must select at least one chain | |
| return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=False), chains | |
| except Exception as e: | |
| error_msg = f"Error loading {pdb_id.upper()}: {str(e)}" | |
| print(error_msg) | |
| return gr.update(choices=[], value=[]), error_msg, gr.update(interactive=False), [] | |
| def validate_chain_selection(selected_chains, available_chains_state): | |
| """Validate that at least one chain is selected and at least one remains fixed (for multi-chain).""" | |
| if not selected_chains or len(selected_chains) == 0: | |
| warning = "Please select at least one chain to redesign" | |
| return gr.update(interactive=False), warning, available_chains_state | |
| # Get available chains from state | |
| available_chains = available_chains_state if available_chains_state else [] | |
| # For single-chain proteins, allow selecting the only chain | |
| if len(available_chains) == 1: | |
| warning = f"Single-chain protein: Will redesign chain {available_chains[0]}" | |
| return gr.update(interactive=True), warning, available_chains_state | |
| # For multi-chain: Check if all chains are selected (would leave no fixed chains) | |
| if available_chains and len(selected_chains) >= len(available_chains): | |
| warning = f"Cannot select all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}" | |
| return gr.update(interactive=False), warning, available_chains_state | |
| warning = f"{len(selected_chains)} chain(s) selected for redesign: {', '.join(selected_chains)}" | |
| return gr.update(interactive=True), warning, available_chains_state | |
| def get_all_sequences(fasta_file: str) -> str: | |
| """Get all designed sequences from FASTA file.""" | |
| sequences = [] | |
| with open(fasta_file, 'r') as f: | |
| lines = [line.strip() for line in f.readlines() if line.strip()] | |
| for i in range(0, len(lines), 2): | |
| if i + 1 >= len(lines): | |
| break | |
| header = lines[i] | |
| sequence = lines[i+1] | |
| # Skip the original native sequence (first entry) | |
| if "sample" not in header: | |
| continue | |
| sequences.append(f"{header}\n{sequence}") | |
| if sequences: | |
| return "\n\n".join(sequences) | |
| else: | |
| raise ValueError(f"No valid designs found in {fasta_file}") | |
| def extract_best_sequence(fasta_file: str) -> str: | |
| """Extract the best sequence (lowest score) from FASTA file.""" | |
| best_score = float('inf') | |
| best_header = "" | |
| best_seq = "" | |
| with open(fasta_file, 'r') as f: | |
| lines = [line.strip() for line in f.readlines() if line.strip()] | |
| for i in range(0, len(lines), 2): | |
| if i + 1 >= len(lines): | |
| break | |
| header = lines[i] | |
| sequence = lines[i+1] | |
| # Skip the original native sequence (first entry) | |
| if "sample" not in header: | |
| continue | |
| # Parse the score: "score=0.7647" | |
| try: | |
| score_part = [p for p in header.split(',') if 'score' in p][0] | |
| score = float(score_part.split('=')[1]) | |
| if score < best_score: | |
| best_score = score | |
| best_header = header | |
| best_seq = sequence | |
| except (IndexError, ValueError): | |
| continue | |
| if best_seq: | |
| return f"{best_header}\n{best_seq}" | |
| else: | |
| raise ValueError(f"No valid designs found in {fasta_file}") | |
| def run_part1(pdb_id, fixed_chains, variable_chains, temperature=0.1, selected_chains=None): | |
| """Downloads the PDB and runs ProteinMPNN design.""" | |
| try: | |
| # Step 1: Secure the template | |
| pdb_path = download_and_clean_pdb(pdb_id, data_dir="data") | |
| # Handle chain selection logic | |
| # If chains are selected via checkbox, use those as variable chains | |
| # Otherwise, use the text input (backward compatibility) | |
| all_chains = get_pdb_chains(pdb_path) | |
| # Check if single-chain protein | |
| is_single_chain = len(all_chains) == 1 | |
| if selected_chains and len(selected_chains) > 0: | |
| # Selected chains = variable chains, rest = fixed | |
| variable_chains = "".join(selected_chains) | |
| fixed_chains = "".join([c for c in all_chains if c not in selected_chains]) | |
| # For single-chain: no fixed chains (will use different ProteinMPNN command) | |
| # For multi-chain: Validate must have at least one fixed chain | |
| if not is_single_chain and (not fixed_chains or len(fixed_chains) == 0): | |
| raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}, Available: {', '.join(all_chains)}") | |
| if is_single_chain: | |
| print(f"Single-chain mode: Redesigning chain {variable_chains}") | |
| else: | |
| print(f"Using chain selector: Fixed={fixed_chains}, Variable={variable_chains}") | |
| else: | |
| # If no chains selected, use text inputs (default behavior) | |
| # For single-chain, if variable_chains is empty, use the only chain | |
| if is_single_chain and not variable_chains: | |
| variable_chains = all_chains[0] | |
| fixed_chains = "" | |
| # For multi-chain: Validate text inputs don't select all chains | |
| elif not is_single_chain and fixed_chains and variable_chains: | |
| all_selected = set(fixed_chains + variable_chains) | |
| if len(all_selected) >= len(all_chains): | |
| raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed.") | |
| print(f"Using text inputs: Fixed={fixed_chains}, Variable={variable_chains}") | |
| # Step 2: Generate Optimized Sequences | |
| # This creates the .fa files you need for the ESM Atlas | |
| print(f"Temperature: {temperature}") | |
| print(f"Parameters: Fixed chains={fixed_chains}, Variable chains={variable_chains}, Temp={temperature}") | |
| run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=temperature) | |
| # Get all sequences and the best one | |
| fa_file = os.path.join("generated", pdb_id.lower(), "seqs", f"{pdb_id.lower()}_clones.fa") | |
| all_sequences = get_all_sequences(fa_file) | |
| best_sequence = extract_best_sequence(fa_file) | |
| # Generate the dashboard plot | |
| evolution_plot = create_design_plot(fa_file) | |
| # Parse score from header for status message | |
| score_part = [p for p in best_sequence.split('\n')[0].split(',') if 'score' in p][0] | |
| best_score = float(score_part.split('=')[1]) | |
| # Count number of designs | |
| num_designs = len([s for s in all_sequences.split('\n\n') if s.strip()]) | |
| # Format status with best sequence | |
| status_message = ( | |
| f"Design Complete! {num_designs} designs generated.\n\n" | |
| f"Lead Candidate (Best Score: {best_score:.4f}):\n" | |
| f"{best_sequence}\n\n" | |
| ) | |
| return all_sequences, evolution_plot, status_message | |
| except Exception as e: | |
| return "", None, f"Error in Part 1: {str(e)}" | |
| def run_part2(pdb_id, sequence): | |
| """ | |
| 1. Folds the input sequence using ESM Atlas (API). | |
| 2. Aligns the folded structure to the target PDB (polish_design). | |
| """ | |
| try: | |
| # --- 1. Validate Inputs --- | |
| if not sequence or not sequence.strip(): | |
| return None, "Error: Please enter a protein sequence." | |
| if not pdb_id or not pdb_id.strip(): | |
| return None, "Error: PDB ID is required." | |
| print(f"Starting Pipeline for {pdb_id}...") | |
| print(f" - Sequence Length: {len(sequence)} residues") | |
| # --- 2. Fold Sequence (Automated) --- | |
| # Calls the script in scripts/foldprotein.py | |
| pdb_content = fold_protein_sequence(sequence) | |
| if not pdb_content: | |
| return None, "Error: Protein folding failed. The API might be down or the sequence is invalid." | |
| # Save the raw folded structure to a temp file | |
| raw_fold_path = f"temp_fold_{pdb_id}_{int(time.time())}.pdb" | |
| with open(raw_fold_path, "w") as f: | |
| f.write(pdb_content) | |
| if not os.path.exists(raw_fold_path): | |
| return None, f"Error: Could not save folded PDB to {raw_fold_path}" | |
| # --- 3. Align & Polish (Existing Logic) --- | |
| print(f"Aligning folded structure to target {pdb_id}...") | |
| # Pass the generated file path to the existing polish_design function | |
| final_pdb_path, global_rmsd, core_rmsd, high_conf_rmsd = polish_design(pdb_id, raw_fold_path) | |
| # --- 4. Validate Alignment Output --- | |
| if not final_pdb_path or not os.path.exists(final_pdb_path): | |
| return None, f"Error: Alignment failed - output file not created: {final_pdb_path}" | |
| if high_conf_rmsd is None: | |
| return None, "Error: Alignment failed - RMSD calculation returned None" | |
| print(f"Success: Global RMSD={global_rmsd:.3f}A | Core RMSD={core_rmsd:.3f}A") | |
| # --- 5. Generate Report --- | |
| report = process_results(pdb_id, final_pdb_path, global_rmsd, high_conf_rmsd) | |
| # Clean up the raw unaligned fold to save space | |
| os.remove(raw_fold_path) | |
| return final_pdb_path, report | |
| except Exception as e: | |
| error_msg = f"Unexpected Error: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, error_msg | |
| # --- GRADIO INTERFACE --- | |
| # 1. Simple Dark Theme with Blue and Emerald Accents | |
| dark_biohub = gr.themes.Base( | |
| primary_hue="blue", | |
| secondary_hue="emerald", | |
| neutral_hue="slate", | |
| ).set( | |
| body_background_fill="#0f172a", | |
| block_background_fill="#1e293b", | |
| body_text_color="#f1f5f9", | |
| button_primary_background_fill="#10b981", | |
| button_primary_text_color="#ffffff", | |
| ) | |
| # 2. Targeted CSS for the 3D Viewer & Header | |
| biohub_css = """ | |
| /* Remove the footer for a clean portfolio look */ | |
| footer {display: none !important;} | |
| /* Fix the 3D viewer background to match the dark theme */ | |
| #molecule-viewer { | |
| background-color: #111827 !important; | |
| border: 1px solid #374151 !important; | |
| border-radius: 12px; | |
| } | |
| /* Header Styling */ | |
| #biohub-header { | |
| background: linear-gradient(135deg, #064e3b 0%, #1e40af 100%); | |
| padding: 1.5rem; | |
| border-radius: 12px; | |
| border: 1px solid #10b981; | |
| margin-bottom: 1rem; | |
| } | |
| """ | |
| with gr.Blocks(theme=dark_biohub, css=biohub_css) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div id='biohub-header'> | |
| <h1 style='color: white; margin: 0;'>BroteinShake</h1> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # TAB 1: GENERATIVE DESIGN | |
| with gr.Tab("1. Sequence Generation"): | |
| gr.Markdown("Enter a PDB ID to 'repaint' its binder interface using ProteinMPNN.") | |
| pdb_input = gr.Textbox(label="Target PDB ID", placeholder="e.g., 3kas", value="") | |
| load_pdb_btn = gr.Button("Load PDB", variant="secondary") | |
| pdb_status = gr.Markdown("Enter a PDB ID and click 'Load PDB' to begin") | |
| with gr.Column(): | |
| gr.Markdown("### Design Parameters") | |
| # Temperature (T) is the most critical knob for sequence recovery | |
| sampling_temp = gr.Slider( | |
| minimum=0.05, maximum=1.0, value=0.1, step=0.05, | |
| label="Sampling Temperature (T)", | |
| info="T=0.1 for high-fidelity; T=0.3 for natural diversity" | |
| ) | |
| # Dynamic Chain Handling | |
| chain_options = gr.CheckboxGroup( | |
| choices=[], | |
| label="Chains to Redesign", | |
| info="Identify which chains ProteinMPNN should modify (will populate after loading PDB)" | |
| ) | |
| chain_warning = gr.Markdown("Select at least one chain to enable generation", visible=True) | |
| # Hidden state to track if we've successfully parsed the PDB | |
| pdb_state = gr.State() | |
| # Legacy text inputs (hidden but kept for backward compatibility) | |
| with gr.Row(visible=False): | |
| f_chains = gr.Textbox(label="Fixed Chains (Lock)", value="A") | |
| v_chains = gr.Textbox(label="Variable Chains (Key)", value="B") | |
| # Generate button (initially disabled) | |
| gen_btn = gr.Button("Generate Optimized Sequences", variant="primary", interactive=False) | |
| # Load PDB and extract chains when button is clicked | |
| load_pdb_btn.click( | |
| fn=load_pdb_and_extract_chains, | |
| inputs=[pdb_input], | |
| outputs=[chain_options, pdb_status, gen_btn, pdb_state] | |
| ) | |
| # Validate chain selection and update button state | |
| chain_options.change( | |
| fn=validate_chain_selection, | |
| inputs=[chain_options, pdb_state], | |
| outputs=[gen_btn, chain_warning, pdb_state] | |
| ) | |
| # Stack components vertically | |
| fa_output = gr.Code( | |
| label="Designed Sequences", | |
| language="markdown", | |
| lines=10, | |
| max_lines=20 # Prevents infinite growth when showing all 20 candidates | |
| ) | |
| plot_output = gr.Plot( | |
| label="Design Evolution Dashboard" | |
| ) | |
| gr.Markdown("### System Status") | |
| status1 = gr.Markdown() | |
| gen_btn.click(run_part1, inputs=[pdb_input, f_chains, v_chains, sampling_temp, chain_options], | |
| outputs=[fa_output, plot_output, status1]) | |
| # TAB 2: STRUCTURAL VALIDATION | |
| with gr.Tab("2. Structural Validation"): | |
| gr.Markdown("### Final Structure Preview") | |
| # Updated REPS for local dev visibility | |
| REPS = [ | |
| { | |
| "model": 0, | |
| "style": "cartoon", | |
| "color": "spectrum", | |
| "opacity": 1.0 | |
| } | |
| ] | |
| # 3D Viewer component | |
| protein_view = Molecule3D(label="3D Structure Viewer (Refined Shuttle)", reps=REPS, elem_id="molecule-viewer") | |
| with gr.Row(): | |
| # REPLACEMENT: Text Area instead of File Upload | |
| sequence_input = gr.Textbox( | |
| label="Paste Protein Sequence", | |
| placeholder="Paste the sequence generated in Tab 1 here (e.g., MKTII...)", | |
| lines=4, | |
| max_lines=8 | |
| ) | |
| refined_download = gr.File(label="Download Aligned Lead (.pdb)") | |
| validate_btn = gr.Button("Run Structural Alignment", variant="primary") | |
| status2 = gr.Textbox(label="Validation Report", interactive=False, lines=5) | |
| # Wrapper function now accepts 'sequence' instead of 'file' | |
| def run_validation_with_view(pdb_id, sequence): | |
| try: | |
| # Call the updated run_part2 logic (API Fold -> Align) | |
| final_pdb, report = run_part2(pdb_id, sequence) | |
| if final_pdb is not None and os.path.exists(final_pdb): | |
| # Verify we're using the refined shuttle | |
| if "Refined_Shuttle.pdb" in final_pdb or os.path.basename(final_pdb) == "Refined_Shuttle.pdb": | |
| print(f"Visualizing refined shuttle: {final_pdb}") | |
| else: | |
| print(f"Warning: Expected Refined_Shuttle.pdb but got: {final_pdb}") | |
| # Return path for download, report text, and path for 3D viewer | |
| return final_pdb, report, final_pdb | |
| except Exception as e: | |
| error_msg = f"Error generating 3D view: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, error_msg, None | |
| # Updated inputs to include sequence_input | |
| validate_btn.click( | |
| run_validation_with_view, | |
| inputs=[pdb_input, sequence_input], | |
| outputs=[refined_download, status2, protein_view] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # Docker deployment for HuggingFace Spaces | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |