Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Gradio interface for PepTron protein ensemble generation (+ 3Dmol.js viewer). | |
| This version treats multi-MODEL PDBs as trajectory frames via viewer.addModelsAsFrames(), | |
| which is the recommended way to visualize an ensemble in 3Dmol.js. [web:42] | |
| """ | |
| import gradio as gr | |
| from gradio import InputHTMLAttributes | |
| import os | |
| import subprocess | |
| import tempfile | |
| import logging | |
| import shutil | |
| import glob | |
| import hashlib | |
| os.environ["OMP_NUM_THREADS"] = "8" | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Configuration --- | |
| INFERENCE_STEPS = 10 | |
| MAX_BATCH_SIZE = 1 | |
| NUM_WORKERS = 8 | |
| CHECKPOINT_DIR = "/data/checkpoints/peptron-checkpoint" | |
| REPO_DIR = "/app/PepTron" | |
| RESULTS_DIR = "/data/results" | |
| OUTPUTS_DIR = "outputs" # -> /app/outputs in HF Spaces container | |
| def cleanup_outputs(): | |
| os.makedirs(OUTPUTS_DIR, exist_ok=True) | |
| for p in glob.glob(os.path.join(OUTPUTS_DIR, "*")): | |
| try: | |
| if os.path.isdir(p): | |
| shutil.rmtree(p) | |
| else: | |
| os.remove(p) | |
| except Exception as e: | |
| logger.warning(f"Could not delete {p}: {e}") | |
| # Run once at process start (covers Space restart / rebuild cases) | |
| cleanup_outputs() | |
| # Bash script template | |
| BASH_SCRIPT_CONTENT = """#!/bin/bash | |
| set -e | |
| CKPT_PATH="$1" | |
| RESULTS_PATH="$2" | |
| CSV_FILE="$3" | |
| NUM_SAMPLES="$4" | |
| STEPS="$5" | |
| BATCH_SIZE="$6" | |
| WORKERS="$7" | |
| export NCCL_TIMEOUT=3600 | |
| export TORCH_NCCL_ENABLE_MONITORING=0 | |
| export TORCHDYNAMO_SUPPRESS_ERRORS=1 | |
| export CUDA_LAUNCH_BLOCKING=1 | |
| # PYTHONPATH=. tells Python to look for modules in the current directory (PepTron) | |
| export PYTHONPATH=/app/PepTron | |
| echo ">>> Inference..." | |
| python -m peptron.infer \\ | |
| --config.inference.num_nodes 1 \\ | |
| --config.inference.checkpoint_path "$CKPT_PATH" \\ | |
| --config.inference.chains_path "$CSV_FILE" \\ | |
| --config.inference.results_path "$RESULTS_PATH" \\ | |
| --config.inference.num_gpus 1 \\ | |
| --config.inference.max_batch_size "$BATCH_SIZE" \\ | |
| --config.inference.num_workers "$WORKERS" \\ | |
| --config.inference.samples "$NUM_SAMPLES" \\ | |
| --config.inference.steps "$STEPS" | |
| echo ">>> Convert..." | |
| python -m peptron.pt_to_structure -i "$RESULTS_PATH" \\ | |
| -o "$RESULTS_PATH/ensembles" \\ | |
| -p $(($(nproc) / 2)) | |
| echo ">>> Filter..." | |
| mkdir -p "$RESULTS_PATH/physical_ensembles" | |
| for trajectory_file in "$RESULTS_PATH/ensembles/"*.pdb; do | |
| [ -e "$trajectory_file" ] || continue | |
| base_name=$(basename "$trajectory_file" .pdb) | |
| output_file="$RESULTS_PATH/physical_ensembles/${base_name}_filtered.pdb" | |
| python -m peptron.utils.filter_unphysical_traj --trajectory "$trajectory_file" --outfile "$output_file" | |
| done | |
| """ | |
| def _read_text(path: str) -> str: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| def pdb_to_3dmol_iframe(pdb_path: str, height: int = 520, autoplay: bool = False) -> str: | |
| """ | |
| Returns an <iframe srcdoc="..."> containing a 3Dmol.js viewer that renders | |
| the provided PDB. Multi-model PDBs are loaded as frames via addModelsAsFrames(). [web:42] | |
| Note: Gradio doesn't have a native PDB viewer; HF recommends HTML+iframe+3Dmol.js. [web:5] | |
| """ | |
| pdb = _read_text(pdb_path) | |
| html_doc = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
| <style> | |
| html, body {{ margin:0; padding:0; }} | |
| .toolbar {{ | |
| font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; | |
| font-size: 12px; | |
| padding: 8px 10px; | |
| border-bottom: 1px solid #eee; | |
| display: flex; | |
| gap: 10px; | |
| align-items: center; | |
| }} | |
| .mol-container {{ width: 100%; height: {height}px; position: relative; }} | |
| input[type="range"] {{ width: 240px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="toolbar"> | |
| <button id="btn_play">Play</button> | |
| <button id="btn_pause">Pause</button> | |
| <span>Frame:</span> | |
| <input id="frame" type="range" min="0" max="0" value="0" step="1" /> | |
| <span id="frame_label">0</span> | |
| </div> | |
| <div id="container" class="mol-container"></div> | |
| <script> | |
| const pdb = `{pdb}`; | |
| const element = document.getElementById("container"); | |
| const viewer = $3Dmol.createViewer(element, {{ backgroundColor: "white" }}); | |
| // Multi-model PDB -> frames in one model (ensemble/trajectory). [web:42] | |
| const model = viewer.addModelsAsFrames(pdb, "pdb"); | |
| viewer.setStyle({{}}, {{ cartoon: {{ color: "steelblue" }} }}); | |
| viewer.zoomTo(); | |
| viewer.render(); | |
| // Frame controls | |
| const slider = document.getElementById("frame"); | |
| const label = document.getElementById("frame_label"); | |
| const nframes = (model && model.getNumFrames) ? model.getNumFrames() : 1; | |
| slider.max = Math.max(0, nframes - 1); | |
| slider.value = 0; | |
| label.textContent = "0"; | |
| slider.addEventListener("input", (e) => {{ | |
| const f = parseInt(e.target.value); | |
| label.textContent = String(f); | |
| if (model && model.setFrame) model.setFrame(f); | |
| viewer.render(); | |
| }}); | |
| document.getElementById("btn_play").addEventListener("click", () => {{ | |
| viewer.animate({{loop: "forward"}}); // animate frames forward [web:42] | |
| }}); | |
| document.getElementById("btn_pause").addEventListener("click", () => {{ | |
| viewer.stopAnimate(); | |
| }}); | |
| if ({str(autoplay).lower()}) {{ | |
| viewer.animate({{loop: "forward"}}); | |
| }} | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # Escape double quotes for srcdoc="...". | |
| srcdoc = html_doc.replace('"', """) | |
| return f'<iframe style="width:100%; height:{height+46}px; border:0;" srcdoc="{srcdoc}"></iframe>' | |
| def peptron_inference(sequence, num_samples): | |
| # 1. Validate Environment | |
| if not os.path.exists(CHECKPOINT_DIR) or not os.listdir(CHECKPOINT_DIR): | |
| return f"❌ Error: Checkpoint not found in {CHECKPOINT_DIR}.", None, None | |
| if not os.path.isdir(REPO_DIR): | |
| return ( | |
| f"❌ Error: '{REPO_DIR}' directory not found. Please upload the PepTron folder to the Space.", | |
| None, | |
| None, | |
| ) | |
| # 2. Validate Input | |
| clean_seq = sequence.strip().upper().replace(" ", "").replace("\n", "") | |
| valid_chars = set("ACDEFGHIKLMNPQRSTVWY") | |
| if not clean_seq or set(clean_seq) - valid_chars: | |
| return "❌ Error: Invalid amino acid sequence.", None, None | |
| try: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| input_csv = os.path.join(tmpdir, "input.csv") | |
| out_dir = os.path.join(tmpdir, "out") | |
| script_path = os.path.join(tmpdir, "run.sh") | |
| os.makedirs(out_dir, exist_ok=True) | |
| with open(input_csv, "w") as f: | |
| f.write("name,seqres\ntarget," + clean_seq + "\n") | |
| with open(script_path, "w") as f: | |
| f.write(BASH_SCRIPT_CONTENT) | |
| os.chmod(script_path, 0o755) | |
| # 3. Run Inference | |
| cmd = [ | |
| "bash", | |
| script_path, | |
| CHECKPOINT_DIR, | |
| out_dir, | |
| input_csv, | |
| str(num_samples), | |
| str(INFERENCE_STEPS), | |
| str(MAX_BATCH_SIZE), | |
| str(NUM_WORKERS), | |
| ] | |
| logger.info("Running inference...") | |
| res = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| timeout=3600, | |
| cwd=os.path.abspath(REPO_DIR), | |
| env=os.environ.copy(), | |
| ) | |
| if res.returncode != 0: | |
| return f"❌ Failed:\n{res.stderr}", None, None | |
| # 4. Process Outputs | |
| phys_dir = os.path.join(out_dir, "physical_ensembles") | |
| raw_dir = os.path.join(out_dir, "ensembles") | |
| found = glob.glob(os.path.join(phys_dir, "*.pdb")) | |
| if not found: | |
| found = glob.glob(os.path.join(raw_dir, "*.pdb")) | |
| if not found: | |
| return "❌ No PDB generated", None, None | |
| seq_hash = hashlib.sha256(clean_seq.encode("utf-8")).hexdigest() | |
| output_filename = f"ensemble_{seq_hash[:16]}_L{len(clean_seq)}.pdb" | |
| final_out = os.path.join(OUTPUTS_DIR, output_filename) | |
| os.makedirs(OUTPUTS_DIR, exist_ok=True) | |
| shutil.copy(found[0], final_out) | |
| # Optional: keep a copy in persistent results dir (if mounted/writable) | |
| try: | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| shutil.copy(found[0], f"{RESULTS_DIR}/{output_filename}") | |
| except Exception as e: | |
| logger.warning(f"Could not copy to RESULTS_DIR: {e}") | |
| # 5. Viewer HTML (multi-model PDB handled as frames). [web:42] | |
| viewer_html = pdb_to_3dmol_iframe(final_out, height=520, autoplay=False) | |
| return f"✅ Done! ({num_samples} samples)", viewer_html, final_out | |
| except Exception as e: | |
| return f"❌ System Error: {str(e)}", None, None | |
| # ~100 aa example sequence (synthetic; uses only valid AA letters) | |
| EXAMPLE_SEQ_100ISH = ( | |
| "MSTNPKPQRKTKRNTNRRPQDVKPGGKKQTKK" | |
| "GDSAENLQKLRDNLVQRLKNNGVSVEKVTKELG" | |
| "ADKVEEMLAKLGADVVVVES" | |
| ) | |
| def load_example(): | |
| return EXAMPLE_SEQ_100ISH | |
| CSS = """ | |
| #sequence_box textarea { | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, | |
| "Liberation Mono", "DejaVu Sans Mono", "Courier New", monospace !important; | |
| letter-spacing: 0.04em; | |
| font-variant-ligatures: none; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo: | |
| gr.Markdown("# 🧬 PepTron - Multidomain protein ensemble generation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| seq = gr.Textbox( | |
| label="Sequence", | |
| lines=4, | |
| placeholder="Sequence...", | |
| elem_id="sequence_box", | |
| html_attributes=InputHTMLAttributes( | |
| spellcheck=False, | |
| autocorrect="off", | |
| autocapitalize="off", | |
| autocomplete="off", | |
| ), | |
| ) | |
| with gr.Row(): | |
| example_btn = gr.Button("Load example") | |
| btn = gr.Button("Generate", variant="primary") | |
| samps = gr.Slider(1, 50, value=5, label="Samples", step=1) | |
| with gr.Column(): | |
| stat = gr.Textbox(label="Status") | |
| mol = gr.HTML(label="3D view") | |
| out = gr.File(label="PDB") | |
| example_btn.click(load_example, inputs=None, outputs=seq) | |
| btn.click(peptron_inference, [seq, samps], [stat, mol, out]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |