ProteinTTT / app.py
pimenol
Update Google Colab link in app.py to use HTML anchor tag for better accessibility and user experience.
3f6c226
"""
ProteinTTT Demo Space
Customize ESMFold with test-time training for improved protein structure prediction.
"""
import importlib
import importlib.util
import logging
import os
import subprocess
import sys
import tempfile
import time
import traceback
def _pip_install(*args):
"""Run pip install as a subprocess and stream output for debugging."""
cmd = [sys.executable, "-m", "pip", "install", *args]
print(f"[INSTALL] {' '.join(cmd)}")
subprocess.check_call(cmd)
importlib.invalidate_caches()
# --- openfold ---
try:
import openfold # noqa: F401
print("[INFO] openfold already installed.")
except ImportError:
print("[INFO] Installing openfold (one-time, may take a few minutes)...")
import re
import shutil
# 1) dllogger (pure-Python, no patching needed)
_pip_install("--no-build-isolation",
"dllogger @ git+https://github.com/NVIDIA/dllogger.git")
# 2) Clone openfold so we can patch its setup.py before building.
_of_dir = os.path.join(tempfile.gettempdir(), "_openfold_build")
if os.path.exists(_of_dir):
shutil.rmtree(_of_dir)
subprocess.check_call([
"git", "clone", "--filter=blob:none", "--quiet",
"https://github.com/aqlaboratory/openfold.git", _of_dir,
])
subprocess.check_call(
["git", "checkout", "-q", "4b41059694619831a7db195b7e0988fc4ff3a307"],
cwd=_of_dir,
)
# 3) Patch setup.py
_setup_py = os.path.join(_of_dir, "setup.py")
with open(_setup_py) as _f:
_src = _f.read()
_src = _src.replace(
"from scripts.utils import get_nvidia_cc",
"# from scripts.utils import get_nvidia_cc # patched out",
)
_src = re.sub(
r"compute_capabilities\s*=\s*set\(\[.*?set\(\[compute_capability\]\)",
"compute_capabilities = set([(7, 0), (8, 0), (8, 6), (9, 0)])",
_src,
flags=re.DOTALL,
)
_src = _src.replace("-std=c++14", "-std=c++17")
with open(_setup_py, "w") as _f:
_f.write(_src)
# 4) CUDA architectures
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;8.0;8.6;9.0"
# 5) Install the patched openfold
_pip_install("--no-build-isolation", _of_dir)
shutil.rmtree(_of_dir, ignore_errors=True)
# Verify
import openfold # noqa: F401
print("[INFO] openfold installed and verified.")
# --- lora-diffusion ---
if importlib.util.find_spec("lora_diffusion") is None:
print("[INFO] Installing lora-diffusion (--no-deps --no-build-isolation)...")
_pip_install("--no-deps", "--no-build-isolation",
"lora-diffusion @ git+https://github.com/cloneofsimo/lora.git")
_spec = importlib.util.find_spec("lora_diffusion")
_init_path = os.path.join(os.path.dirname(_spec.origin), "__init__.py")
with open(_init_path, "w") as _f:
_f.write("from .lora import *\n")
from lora_diffusion.lora import inject_trainable_lora as _check # noqa: F401
del _check
print("[INFO] lora-diffusion installed, patched, and verified.")
else:
print("[INFO] lora-diffusion already installed.")
import spaces
import gradio as gr
import torch
print("[INFO] Loading ESMFold model on CPU (this may take a minute)...")
import esm as _esm_module # noqa: E402
_BASE_MODEL = _esm_module.pretrained.esmfold_v1()
_BASE_MODEL = _BASE_MODEL.eval()
_BASE_MODEL.set_chunk_size(128)
print("[INFO] ESMFold model loaded on CPU.")
def get_plddt_from_output(output) -> float:
"""Extract mean pLDDT from ESMFold output."""
return output["mean_plddt"].item()
def create_3dmol_html(pdb_content: str, title: str = "") -> str:
"""Create 3Dmol.js visualization HTML with pLDDT coloring."""
import base64
import html as html_lib
# Properly escape PDB content for JavaScript template literal
pdb_escaped = pdb_content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$")
# Escape title for HTML
title_escaped = html_lib.escape(title)
html_content = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body {{
margin: 0;
padding: 0;
font-family: "IBM Plex Sans", "Inter", "Helvetica Neue", Arial, sans-serif;
}}
.mol-container {{ width: 100%; height: 450px; position: relative; }}
.title {{ text-align: center; font-size: 14px; font-weight: bold; padding: 5px; background: #f0f0f0; }}
.legend {{
position: absolute;
right: 12px;
bottom: 12px;
z-index: 10;
background: rgba(255, 255, 255, 0.92);
border: 1px solid #d9d9d9;
border-radius: 8px;
padding: 8px 10px;
font-size: 12px;
color: #333;
line-height: 1.25;
box-shadow: 0 1px 4px rgba(0, 0, 0, 0.15);
}}
.legend-title {{
font-weight: 700;
margin-bottom: 6px;
}}
.legend-row {{
display: flex;
align-items: center;
gap: 6px;
margin: 2px 0;
white-space: nowrap;
}}
.legend-color {{
width: 12px;
height: 12px;
border-radius: 2px;
border: 1px solid rgba(0, 0, 0, 0.12);
flex: none;
}}
</style>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div class="title">{title_escaped}</div>
<div id="container" class="mol-container">
<div class="legend">
<div class="legend-title">pLDDT</div>
<div class="legend-row"><span class="legend-color" style="background:#0d57d3;"></span><span>&ge; 90 (very high)</span></div>
<div class="legend-row"><span class="legend-color" style="background:#6acbf1;"></span><span>70 - 90 (confident)</span></div>
<div class="legend-row"><span class="legend-color" style="background:#fed936;"></span><span>50 - 70 (low)</span></div>
<div class="legend-row"><span class="legend-color" style="background:#fd7d4d;"></span><span>&lt; 50 (very low)</span></div>
</div>
</div>
<script>
let pdb = `{pdb_escaped}`;
let viewer = $3Dmol.createViewer("container", {{backgroundColor: "white"}});
viewer.addModel(pdb, "pdb");
viewer.setStyle({{}}, {{
cartoon: {{
colorfunc: function(atom) {{
if (atom.b < 50) return "#fd7d4d"; // very low
if (atom.b < 70) return "#fed936"; // low
if (atom.b < 90) return "#6acbf1"; // confident
return "#0d57d3"; // very high
}}
}}
}});
viewer.zoomTo();
viewer.render();
</script>
</body>
</html>"""
# Use base64 encoding for data URI (more reliable in sandboxed environments)
html_bytes = html_content.encode('utf-8')
html_base64 = base64.b64encode(html_bytes).decode('ascii')
return f'<iframe style="width: 100%; height: 500px; border: 1px solid #ddd; border-radius: 8px;" src="data:text/html;base64,{html_base64}"></iframe>'
# Main Prediction Logic
def _progress_html(fraction: float, desc: str) -> str:
"""Return a small HTML progress bar."""
pct = int(fraction * 100)
return f"""
<div style="padding: 8px 0;">
<div style="font-size: 14px; margin-bottom: 4px; color: #ccc;">{desc}{pct}%</div>
<div style="background: #333; border-radius: 6px; height: 18px; width: 100%; overflow: hidden;">
<div style="background: linear-gradient(90deg, #22c55e, #16a34a); height: 100%; width: {pct}%;
border-radius: 6px; transition: width 0.3s ease;"></div>
</div>
</div>"""
@spaces.GPU(duration=120)
def predict_structure(
sequence: str,
ttt_steps: int = 10,
learning_rate: float = 1e-3,
batch_size: int = 1,
lora_rank: int = 4,
lora_alpha: int = 8,
seed: int = 0,
):
"""
Predict protein structure with and without ProteinTTT.
Returns baseline pLDDT, improved pLDDT, improvement ratio,
and 3D visualizations for both.
"""
baseline_plddt_text = ""
ttt_plddt_text = ""
improvement_text = ""
baseline_html = ""
ttt_html = ""
baseline_pdb_file = None
ttt_pdb_file = None
progress_status = ""
logs = []
ttt_model = None
def current_result():
return (
baseline_plddt_text,
ttt_plddt_text,
improvement_text,
baseline_html,
ttt_html,
baseline_pdb_file,
ttt_pdb_file,
progress_status,
logs,
)
def add_log(stage: str, message: str):
timestamp = time.strftime("%H:%M:%S")
logs.append([timestamp, stage, message])
print(f"[{timestamp}] [{stage}] {message}")
try:
if not sequence or len(sequence.strip()) < 10:
raise gr.Error("Please enter a valid protein sequence (at least 10 amino acids).")
# Clean sequence
sequence = sequence.strip().upper()
valid_aa = set("ACDEFGHIKLMNPQRSTVWY")
if not all(aa in valid_aa for aa in sequence):
raise gr.Error("Sequence contains invalid amino acids. Use only standard 20 amino acids.")
if len(sequence) > 400:
raise gr.Error("Sequence too long. Maximum 400 amino acids supported for this demo.")
device = "cuda" if torch.cuda.is_available() else "cpu"
add_log("Init", f"Using device: {device}")
progress_status = _progress_html(0.1, "Moving model to GPU…")
base_model = _BASE_MODEL
# Move model to GPU for inference
base_model.to(device)
add_log("Init", f"Model moved to {device}.")
yield current_result()
# ---- ESMFold Prediction ----
progress_status = _progress_html(0.2, "Running ESMFold prediction…")
add_log("Baseline", "Running ESMFold prediction...")
yield current_result()
with torch.no_grad():
baseline_output = base_model.infer(sequence)
baseline_pdb = base_model.output_to_pdb(baseline_output)[0]
baseline_plddt = get_plddt_from_output(baseline_output)
baseline_plddt_text = f"{baseline_plddt:.2f}"
add_log("Baseline", f"Baseline pLDDT: {baseline_plddt:.2f}")
# ---- ESMFold + ProteinTTT ----
progress_status = _progress_html(0.4, f"Loading ESMFold model…")
add_log("TTT", f"Loading ESMFold model...")
yield current_result()
from proteinttt.models.esmfold import ESMFoldTTT, DEFAULT_ESMFOLD_TTT_CFG
# Configure TTT
ttt_cfg = DEFAULT_ESMFOLD_TTT_CFG
ttt_cfg.steps = ttt_steps
ttt_cfg.lr = learning_rate
ttt_cfg.batch_size = batch_size
ttt_cfg.lora_rank = lora_rank
ttt_cfg.lora_alpha = lora_alpha
ttt_cfg.seed = 0
# ProteinTTT upstream bug: eval_each_step=False can trigger
# UnboundLocalError("eval_step_preds referenced before assignment").
ttt_cfg.eval_each_step = True
add_log("TTT", f"Config: LR={learning_rate}, Batch={batch_size}, LoRA rank={lora_rank}, alpha={lora_alpha}")
# Apply TTT
ttt_model = ESMFoldTTT.ttt_from_pretrained(
base_model,
ttt_cfg=ttt_cfg,
esmfold_config=base_model.cfg
)
progress_status = _progress_html(0.5, "Running test-time training…")
add_log("ProteinTTT", "Running test-time training...")
yield current_result()
# Attach a handler to capture step-by-step log messages from ttt()
captured_logs = []
class _ListHandler(logging.Handler):
"""Handler that collects formatted log records into a list."""
def emit(self, record):
try:
captured_logs.append(self.format(record))
except Exception:
pass
list_handler = _ListHandler()
list_handler.setLevel(logging.INFO)
list_handler.setFormatter(logging.Formatter("%(message)s"))
ttt_model.ttt_logger.addHandler(list_handler)
# Run TTT synchronously (must stay on the main thread for ZeroGPU)
ttt_result = ttt_model.ttt(sequence)
ttt_model.ttt_logger.removeHandler(list_handler)
# Now yield all captured step logs so the UI updates with each step
for i, msg in enumerate(captured_logs, 1):
frac = 0.5 + 0.3 * min(i / max(ttt_steps, 1), 1.0)
progress_status = _progress_html(frac, f"ProteinTTT step {i}/{ttt_steps}")
add_log("ProteinTTT", msg)
yield current_result()
add_log("ProteinTTT", "Test-time training finished.")
yield current_result()
progress_status = _progress_html(0.8, "Running ESMFold+ProteinTTT prediction…")
add_log("ProteinTTT", "Running ESMFold+ProteinTTT prediction...")
yield current_result()
with torch.no_grad():
ttt_output = ttt_model.infer(sequence)
ttt_pdb = ttt_model.output_to_pdb(ttt_output)[0]
ttt_plddt = get_plddt_from_output(ttt_output)
ttt_plddt_text = f"{ttt_plddt:.2f}"
add_log("ProteinTTT", f"ProteinTTT pLDDT: {ttt_plddt:.2f}")
# Prepare Results
progress_status = _progress_html(0.9, "Preparing visualization…")
add_log("Result", "Preparing visualizations and artifacts...")
improvement_ratio = ttt_plddt / baseline_plddt if baseline_plddt > 0 else 0
improvement_text = f"{improvement_ratio:.2f}x"
# Create visualization
try:
baseline_html = create_3dmol_html(baseline_pdb, f"ESMFold (pLDDT: {baseline_plddt:.1f})")
print(f"[DEBUG] Created baseline HTML, length: {len(baseline_html)}")
print(f"[DEBUG] Baseline HTML starts with: {baseline_html[:100]}")
except Exception as e:
print(f"[ERROR] Failed to create baseline HTML: {e}")
import traceback
traceback.print_exc()
baseline_html = f"<div>Error creating visualization: {str(e)}</div>"
# Save PDB files for download
baseline_pdb_path = tempfile.NamedTemporaryFile(mode='w', suffix='_esmfold.pdb', delete=False)
baseline_pdb_path.write(baseline_pdb)
baseline_pdb_path.close()
baseline_pdb_file = baseline_pdb_path.name
add_log("Baseline", "ESMFold PDB saved.")
yield current_result()
# Create 3D visualizations
try:
ttt_html = create_3dmol_html(ttt_pdb, f"ESMFold + ProteinTTT (pLDDT: {ttt_plddt:.1f})")
print(f"[DEBUG] Created ProteinTTT HTML, length: {len(ttt_html)}")
except Exception as e:
print(f"[ERROR] Failed to create ProteinTTT HTML: {e}")
ttt_html = f"<div>Error creating visualization: {str(e)}</div>"
ttt_pdb_path = tempfile.NamedTemporaryFile(mode='w', suffix='_esmfold_proteinttt.pdb', delete=False)
ttt_pdb_path.write(ttt_pdb)
ttt_pdb_path.close()
ttt_pdb_file = ttt_pdb_path.name
progress_status = "" # Clear progress bar on completion
add_log("Done", "Prediction completed successfully.")
# Debug output
print(f"[DEBUG] baseline_plddt_text: {baseline_plddt_text}")
print(f"[DEBUG] ttt_plddt_text: {ttt_plddt_text}")
print(f"[DEBUG] improvement_text: {improvement_text}")
print(f"[DEBUG] baseline_html length: {len(baseline_html) if baseline_html else 0}")
print(f"[DEBUG] ttt_html length: {len(ttt_html) if ttt_html else 0}")
print(f"[DEBUG] baseline_pdb_file: {baseline_pdb_file}")
print(f"[DEBUG] ttt_pdb_file: {ttt_pdb_file}")
final_result = current_result()
yield final_result
return final_result
except gr.Error:
raise # Re-raise Gradio errors as-is
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
error_tb = traceback.format_exc()
print(f"[ERROR] {error_msg}\n{error_tb}")
add_log("Error", error_msg)
yield current_result()
raise gr.Error(f"Prediction failed: {error_msg}")
finally:
# Reset TTT state and release memory.
if ttt_model is not None:
try:
ttt_model.ttt_reset()
except Exception:
pass
del ttt_model
# Move model back to CPU so it survives GPU release.
try:
_BASE_MODEL.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
# Gradio UI
DEMO_SEQUENCE = "GIHLGELGLLPSTVLAIGYFENLVNIICESLNMLPKLEVSGKEYKKFKFTIVIPKDLDANIKKRAKIYFKQKSLIEIEIPTSSRNYPIHIQFDENSTDDILHLYDMPTTIGGIDKAIEMFMRKGHIGKTDQQKLLEERELRNFKTTLENLIATDAFAKEMVEVIIEE"
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="green",
font=["Source Sans Pro", "ui-sans-serif", "system-ui", "-apple-system", "BlinkMacSystemFont", "Segoe UI", "Roboto", "Helvetica Neue", "Arial", "sans-serif"],
font_mono=["IBM Plex Mono", "ui-monospace", "Consolas", "monospace"],
)
) as demo:
gr.Markdown("""
# ProteinTTT
### Customize ESMFold with Test-Time Training for Enhanced Structure Prediction
ProteinTTT enables customizing protein language models to one protein at a time for enhanced
performance on challenging targets.
📄 [Paper](https://arxiv.org/abs/2411.02109) | 💻 [GitHub](https://github.com/anton-bushuiev/ProteinTTT) | <a href="https://colab.research.google.com/drive/1l_h7cw82SQpW9PvYzSQeYS4TXzS1QJ8o" target="_blank"> Google Colab</a>
""")
with gr.Row():
with gr.Column(scale=2):
sequence_input = gr.Textbox(
label="Protein Sequence",
placeholder="Enter amino acid sequence (e.g., AFRQALQLAASGLAGGSAAVLFSAVAVGKPRAGGD...)",
lines=4,
info="Standard 20 amino acids only. Max 400 residues for this demo."
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
gr.Markdown("Fine-tune ProteinTTT training parameters")
with gr.Row():
ttt_steps = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=1,
label="ProteinTTT Steps",
info="More steps = better results but slower"
)
learning_rate = gr.Number(
label="Learning Rate",
value=4e-4,
minimum=1e-6,
maximum=1e-1,
info="Step size for parameter updates"
)
with gr.Row():
batch_size = gr.Slider(
label="Batch Size",
minimum=1,
maximum=8,
value=4,
step=1,
info="Number of samples per training batch"
)
lora_rank = gr.Slider(
label="LoRA Rank",
minimum=1,
maximum=32,
value=8,
step=1,
info="Rank of LoRA adapter matrices"
)
with gr.Row():
lora_alpha = gr.Slider(
label="LoRA Alpha",
minimum=1,
maximum=64,
value=32,
step=1,
info="Scaling factor for LoRA"
)
seed = gr.Number(
label="Seed",
value=0,
minimum=0,
maximum=2147483647,
precision=0,
info="Random seed for reproducibility"
)
predict_btn = gr.Button("Predict Structure", variant="primary", size="lg")
gr.Examples(
examples=[
[DEMO_SEQUENCE, 10, 4e-4, 4, 8, 32.0, 0],
],
inputs=[sequence_input, ttt_steps, learning_rate, batch_size, lora_rank, lora_alpha, seed],
label="Example Sequences (click to use)"
)
gr.Markdown("## Results")
progress_bar = gr.HTML(elem_id="progress-bar")
with gr.Row():
baseline_plddt = gr.Textbox(label="ESMFold pLDDT", interactive=False)
ttt_plddt = gr.Textbox(label="ProteinTTT pLDDT", interactive=False)
improvement = gr.Textbox(label="Improvement", interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("#### ESMFold")
baseline_viz = gr.HTML(label="ESMFold Structure")
with gr.Column():
gr.Markdown("#### ESMFold + ProteinTTT")
ttt_viz = gr.HTML(label="ESMFold + ProteinTTT Structure")
with gr.Row():
baseline_file = gr.File(label="Download ESMFold PDB")
ttt_file = gr.File(label="Download ESMFold + ProteinTTT PDB")
with gr.Accordion("Log", open=False):
log_table = gr.Dataframe(
headers=["Time", "Stage", "Message"],
datatype=["str", "str", "str"],
row_count=(0, "dynamic"),
col_count=(3, "fixed"),
interactive=False,
wrap=True,
label="Run Log"
)
predict_btn.click(
fn=predict_structure,
inputs=[sequence_input, ttt_steps, learning_rate, batch_size, lora_rank, lora_alpha, seed],
outputs=[
baseline_plddt,
ttt_plddt,
improvement,
baseline_viz,
ttt_viz,
baseline_file,
ttt_file,
progress_bar,
log_table,
]
)
if __name__ == "__main__":
demo.launch()