PepTron / app.py
cfisicaro's picture
feat: add structure viz
c456acd
#!/usr/bin/env python
"""
Gradio interface for PepTron protein ensemble generation (+ 3Dmol.js viewer).
This version treats multi-MODEL PDBs as trajectory frames via viewer.addModelsAsFrames(),
which is the recommended way to visualize an ensemble in 3Dmol.js. [web:42]
"""
import gradio as gr
from gradio import InputHTMLAttributes
import os
import subprocess
import tempfile
import logging
import shutil
import glob
import hashlib
os.environ["OMP_NUM_THREADS"] = "8"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
INFERENCE_STEPS = 10
MAX_BATCH_SIZE = 1
NUM_WORKERS = 8
CHECKPOINT_DIR = "/data/checkpoints/peptron-checkpoint"
REPO_DIR = "/app/PepTron"
RESULTS_DIR = "/data/results"
OUTPUTS_DIR = "outputs" # -> /app/outputs in HF Spaces container
def cleanup_outputs():
os.makedirs(OUTPUTS_DIR, exist_ok=True)
for p in glob.glob(os.path.join(OUTPUTS_DIR, "*")):
try:
if os.path.isdir(p):
shutil.rmtree(p)
else:
os.remove(p)
except Exception as e:
logger.warning(f"Could not delete {p}: {e}")
# Run once at process start (covers Space restart / rebuild cases)
cleanup_outputs()
# Bash script template
BASH_SCRIPT_CONTENT = """#!/bin/bash
set -e
CKPT_PATH="$1"
RESULTS_PATH="$2"
CSV_FILE="$3"
NUM_SAMPLES="$4"
STEPS="$5"
BATCH_SIZE="$6"
WORKERS="$7"
export NCCL_TIMEOUT=3600
export TORCH_NCCL_ENABLE_MONITORING=0
export TORCHDYNAMO_SUPPRESS_ERRORS=1
export CUDA_LAUNCH_BLOCKING=1
# PYTHONPATH=. tells Python to look for modules in the current directory (PepTron)
export PYTHONPATH=/app/PepTron
echo ">>> Inference..."
python -m peptron.infer \\
--config.inference.num_nodes 1 \\
--config.inference.checkpoint_path "$CKPT_PATH" \\
--config.inference.chains_path "$CSV_FILE" \\
--config.inference.results_path "$RESULTS_PATH" \\
--config.inference.num_gpus 1 \\
--config.inference.max_batch_size "$BATCH_SIZE" \\
--config.inference.num_workers "$WORKERS" \\
--config.inference.samples "$NUM_SAMPLES" \\
--config.inference.steps "$STEPS"
echo ">>> Convert..."
python -m peptron.pt_to_structure -i "$RESULTS_PATH" \\
-o "$RESULTS_PATH/ensembles" \\
-p $(($(nproc) / 2))
echo ">>> Filter..."
mkdir -p "$RESULTS_PATH/physical_ensembles"
for trajectory_file in "$RESULTS_PATH/ensembles/"*.pdb; do
[ -e "$trajectory_file" ] || continue
base_name=$(basename "$trajectory_file" .pdb)
output_file="$RESULTS_PATH/physical_ensembles/${base_name}_filtered.pdb"
python -m peptron.utils.filter_unphysical_traj --trajectory "$trajectory_file" --outfile "$output_file"
done
"""
def _read_text(path: str) -> str:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
def pdb_to_3dmol_iframe(pdb_path: str, height: int = 520, autoplay: bool = False) -> str:
"""
Returns an <iframe srcdoc="..."> containing a 3Dmol.js viewer that renders
the provided PDB. Multi-model PDBs are loaded as frames via addModelsAsFrames(). [web:42]
Note: Gradio doesn't have a native PDB viewer; HF recommends HTML+iframe+3Dmol.js. [web:5]
"""
pdb = _read_text(pdb_path)
html_doc = f"""
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
<style>
html, body {{ margin:0; padding:0; }}
.toolbar {{
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial;
font-size: 12px;
padding: 8px 10px;
border-bottom: 1px solid #eee;
display: flex;
gap: 10px;
align-items: center;
}}
.mol-container {{ width: 100%; height: {height}px; position: relative; }}
input[type="range"] {{ width: 240px; }}
</style>
</head>
<body>
<div class="toolbar">
<button id="btn_play">Play</button>
<button id="btn_pause">Pause</button>
<span>Frame:</span>
<input id="frame" type="range" min="0" max="0" value="0" step="1" />
<span id="frame_label">0</span>
</div>
<div id="container" class="mol-container"></div>
<script>
const pdb = `{pdb}`;
const element = document.getElementById("container");
const viewer = $3Dmol.createViewer(element, {{ backgroundColor: "white" }});
// Multi-model PDB -> frames in one model (ensemble/trajectory). [web:42]
const model = viewer.addModelsAsFrames(pdb, "pdb");
viewer.setStyle({{}}, {{ cartoon: {{ color: "steelblue" }} }});
viewer.zoomTo();
viewer.render();
// Frame controls
const slider = document.getElementById("frame");
const label = document.getElementById("frame_label");
const nframes = (model && model.getNumFrames) ? model.getNumFrames() : 1;
slider.max = Math.max(0, nframes - 1);
slider.value = 0;
label.textContent = "0";
slider.addEventListener("input", (e) => {{
const f = parseInt(e.target.value);
label.textContent = String(f);
if (model && model.setFrame) model.setFrame(f);
viewer.render();
}});
document.getElementById("btn_play").addEventListener("click", () => {{
viewer.animate({{loop: "forward"}}); // animate frames forward [web:42]
}});
document.getElementById("btn_pause").addEventListener("click", () => {{
viewer.stopAnimate();
}});
if ({str(autoplay).lower()}) {{
viewer.animate({{loop: "forward"}});
}}
</script>
</body>
</html>
"""
# Escape double quotes for srcdoc="...".
srcdoc = html_doc.replace('"', "&quot;")
return f'<iframe style="width:100%; height:{height+46}px; border:0;" srcdoc="{srcdoc}"></iframe>'
def peptron_inference(sequence, num_samples):
# 1. Validate Environment
if not os.path.exists(CHECKPOINT_DIR) or not os.listdir(CHECKPOINT_DIR):
return f"❌ Error: Checkpoint not found in {CHECKPOINT_DIR}.", None, None
if not os.path.isdir(REPO_DIR):
return (
f"❌ Error: '{REPO_DIR}' directory not found. Please upload the PepTron folder to the Space.",
None,
None,
)
# 2. Validate Input
clean_seq = sequence.strip().upper().replace(" ", "").replace("\n", "")
valid_chars = set("ACDEFGHIKLMNPQRSTVWY")
if not clean_seq or set(clean_seq) - valid_chars:
return "❌ Error: Invalid amino acid sequence.", None, None
try:
with tempfile.TemporaryDirectory() as tmpdir:
input_csv = os.path.join(tmpdir, "input.csv")
out_dir = os.path.join(tmpdir, "out")
script_path = os.path.join(tmpdir, "run.sh")
os.makedirs(out_dir, exist_ok=True)
with open(input_csv, "w") as f:
f.write("name,seqres\ntarget," + clean_seq + "\n")
with open(script_path, "w") as f:
f.write(BASH_SCRIPT_CONTENT)
os.chmod(script_path, 0o755)
# 3. Run Inference
cmd = [
"bash",
script_path,
CHECKPOINT_DIR,
out_dir,
input_csv,
str(num_samples),
str(INFERENCE_STEPS),
str(MAX_BATCH_SIZE),
str(NUM_WORKERS),
]
logger.info("Running inference...")
res = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=3600,
cwd=os.path.abspath(REPO_DIR),
env=os.environ.copy(),
)
if res.returncode != 0:
return f"❌ Failed:\n{res.stderr}", None, None
# 4. Process Outputs
phys_dir = os.path.join(out_dir, "physical_ensembles")
raw_dir = os.path.join(out_dir, "ensembles")
found = glob.glob(os.path.join(phys_dir, "*.pdb"))
if not found:
found = glob.glob(os.path.join(raw_dir, "*.pdb"))
if not found:
return "❌ No PDB generated", None, None
seq_hash = hashlib.sha256(clean_seq.encode("utf-8")).hexdigest()
output_filename = f"ensemble_{seq_hash[:16]}_L{len(clean_seq)}.pdb"
final_out = os.path.join(OUTPUTS_DIR, output_filename)
os.makedirs(OUTPUTS_DIR, exist_ok=True)
shutil.copy(found[0], final_out)
# Optional: keep a copy in persistent results dir (if mounted/writable)
try:
os.makedirs(RESULTS_DIR, exist_ok=True)
shutil.copy(found[0], f"{RESULTS_DIR}/{output_filename}")
except Exception as e:
logger.warning(f"Could not copy to RESULTS_DIR: {e}")
# 5. Viewer HTML (multi-model PDB handled as frames). [web:42]
viewer_html = pdb_to_3dmol_iframe(final_out, height=520, autoplay=False)
return f"✅ Done! ({num_samples} samples)", viewer_html, final_out
except Exception as e:
return f"❌ System Error: {str(e)}", None, None
# ~100 aa example sequence (synthetic; uses only valid AA letters)
EXAMPLE_SEQ_100ISH = (
"MSTNPKPQRKTKRNTNRRPQDVKPGGKKQTKK"
"GDSAENLQKLRDNLVQRLKNNGVSVEKVTKELG"
"ADKVEEMLAKLGADVVVVES"
)
def load_example():
return EXAMPLE_SEQ_100ISH
CSS = """
#sequence_box textarea {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
"Liberation Mono", "DejaVu Sans Mono", "Courier New", monospace !important;
letter-spacing: 0.04em;
font-variant-ligatures: none;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
gr.Markdown("# 🧬 PepTron - Multidomain protein ensemble generation")
with gr.Row():
with gr.Column():
seq = gr.Textbox(
label="Sequence",
lines=4,
placeholder="Sequence...",
elem_id="sequence_box",
html_attributes=InputHTMLAttributes(
spellcheck=False,
autocorrect="off",
autocapitalize="off",
autocomplete="off",
),
)
with gr.Row():
example_btn = gr.Button("Load example")
btn = gr.Button("Generate", variant="primary")
samps = gr.Slider(1, 50, value=5, label="Samples", step=1)
with gr.Column():
stat = gr.Textbox(label="Status")
mol = gr.HTML(label="3D view")
out = gr.File(label="PDB")
example_btn.click(load_example, inputs=None, outputs=seq)
btn.click(peptron_inference, [seq, samps], [stat, mol, out])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)