""" ProteinTTT Demo Space Customize ESMFold with test-time training for improved protein structure prediction. """ import importlib import importlib.util import logging import os import subprocess import sys import tempfile import time import traceback def _pip_install(*args): """Run pip install as a subprocess and stream output for debugging.""" cmd = [sys.executable, "-m", "pip", "install", *args] print(f"[INSTALL] {' '.join(cmd)}") subprocess.check_call(cmd) importlib.invalidate_caches() # --- openfold --- try: import openfold # noqa: F401 print("[INFO] openfold already installed.") except ImportError: print("[INFO] Installing openfold (one-time, may take a few minutes)...") import re import shutil # 1) dllogger (pure-Python, no patching needed) _pip_install("--no-build-isolation", "dllogger @ git+https://github.com/NVIDIA/dllogger.git") # 2) Clone openfold so we can patch its setup.py before building. _of_dir = os.path.join(tempfile.gettempdir(), "_openfold_build") if os.path.exists(_of_dir): shutil.rmtree(_of_dir) subprocess.check_call([ "git", "clone", "--filter=blob:none", "--quiet", "https://github.com/aqlaboratory/openfold.git", _of_dir, ]) subprocess.check_call( ["git", "checkout", "-q", "4b41059694619831a7db195b7e0988fc4ff3a307"], cwd=_of_dir, ) # 3) Patch setup.py _setup_py = os.path.join(_of_dir, "setup.py") with open(_setup_py) as _f: _src = _f.read() _src = _src.replace( "from scripts.utils import get_nvidia_cc", "# from scripts.utils import get_nvidia_cc # patched out", ) _src = re.sub( r"compute_capabilities\s*=\s*set\(\[.*?set\(\[compute_capability\]\)", "compute_capabilities = set([(7, 0), (8, 0), (8, 6), (9, 0)])", _src, flags=re.DOTALL, ) _src = _src.replace("-std=c++14", "-std=c++17") with open(_setup_py, "w") as _f: _f.write(_src) # 4) CUDA architectures os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;8.0;8.6;9.0" # 5) Install the patched openfold _pip_install("--no-build-isolation", _of_dir) shutil.rmtree(_of_dir, ignore_errors=True) # Verify import openfold # noqa: F401 print("[INFO] openfold installed and verified.") # --- lora-diffusion --- if importlib.util.find_spec("lora_diffusion") is None: print("[INFO] Installing lora-diffusion (--no-deps --no-build-isolation)...") _pip_install("--no-deps", "--no-build-isolation", "lora-diffusion @ git+https://github.com/cloneofsimo/lora.git") _spec = importlib.util.find_spec("lora_diffusion") _init_path = os.path.join(os.path.dirname(_spec.origin), "__init__.py") with open(_init_path, "w") as _f: _f.write("from .lora import *\n") from lora_diffusion.lora import inject_trainable_lora as _check # noqa: F401 del _check print("[INFO] lora-diffusion installed, patched, and verified.") else: print("[INFO] lora-diffusion already installed.") import spaces import gradio as gr import torch print("[INFO] Loading ESMFold model on CPU (this may take a minute)...") import esm as _esm_module # noqa: E402 _BASE_MODEL = _esm_module.pretrained.esmfold_v1() _BASE_MODEL = _BASE_MODEL.eval() _BASE_MODEL.set_chunk_size(128) print("[INFO] ESMFold model loaded on CPU.") def get_plddt_from_output(output) -> float: """Extract mean pLDDT from ESMFold output.""" return output["mean_plddt"].item() def create_3dmol_html(pdb_content: str, title: str = "") -> str: """Create 3Dmol.js visualization HTML with pLDDT coloring.""" import base64 import html as html_lib # Properly escape PDB content for JavaScript template literal pdb_escaped = pdb_content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$") # Escape title for HTML title_escaped = html_lib.escape(title) html_content = f"""
{title_escaped}
pLDDT
≥ 90 (very high)
70 - 90 (confident)
50 - 70 (low)
< 50 (very low)
""" # Use base64 encoding for data URI (more reliable in sandboxed environments) html_bytes = html_content.encode('utf-8') html_base64 = base64.b64encode(html_bytes).decode('ascii') return f'' # Main Prediction Logic def _progress_html(fraction: float, desc: str) -> str: """Return a small HTML progress bar.""" pct = int(fraction * 100) return f"""
{desc} — {pct}%
""" @spaces.GPU(duration=120) def predict_structure( sequence: str, ttt_steps: int = 10, learning_rate: float = 1e-3, batch_size: int = 1, lora_rank: int = 4, lora_alpha: int = 8, seed: int = 0, ): """ Predict protein structure with and without ProteinTTT. Returns baseline pLDDT, improved pLDDT, improvement ratio, and 3D visualizations for both. """ baseline_plddt_text = "" ttt_plddt_text = "" improvement_text = "" baseline_html = "" ttt_html = "" baseline_pdb_file = None ttt_pdb_file = None progress_status = "" logs = [] ttt_model = None def current_result(): return ( baseline_plddt_text, ttt_plddt_text, improvement_text, baseline_html, ttt_html, baseline_pdb_file, ttt_pdb_file, progress_status, logs, ) def add_log(stage: str, message: str): timestamp = time.strftime("%H:%M:%S") logs.append([timestamp, stage, message]) print(f"[{timestamp}] [{stage}] {message}") try: if not sequence or len(sequence.strip()) < 10: raise gr.Error("Please enter a valid protein sequence (at least 10 amino acids).") # Clean sequence sequence = sequence.strip().upper() valid_aa = set("ACDEFGHIKLMNPQRSTVWY") if not all(aa in valid_aa for aa in sequence): raise gr.Error("Sequence contains invalid amino acids. Use only standard 20 amino acids.") if len(sequence) > 400: raise gr.Error("Sequence too long. Maximum 400 amino acids supported for this demo.") device = "cuda" if torch.cuda.is_available() else "cpu" add_log("Init", f"Using device: {device}") progress_status = _progress_html(0.1, "Moving model to GPU…") base_model = _BASE_MODEL # Move model to GPU for inference base_model.to(device) add_log("Init", f"Model moved to {device}.") yield current_result() # ---- ESMFold Prediction ---- progress_status = _progress_html(0.2, "Running ESMFold prediction…") add_log("Baseline", "Running ESMFold prediction...") yield current_result() with torch.no_grad(): baseline_output = base_model.infer(sequence) baseline_pdb = base_model.output_to_pdb(baseline_output)[0] baseline_plddt = get_plddt_from_output(baseline_output) baseline_plddt_text = f"{baseline_plddt:.2f}" add_log("Baseline", f"Baseline pLDDT: {baseline_plddt:.2f}") # ---- ESMFold + ProteinTTT ---- progress_status = _progress_html(0.4, f"Loading ESMFold model…") add_log("TTT", f"Loading ESMFold model...") yield current_result() from proteinttt.models.esmfold import ESMFoldTTT, DEFAULT_ESMFOLD_TTT_CFG # Configure TTT ttt_cfg = DEFAULT_ESMFOLD_TTT_CFG ttt_cfg.steps = ttt_steps ttt_cfg.lr = learning_rate ttt_cfg.batch_size = batch_size ttt_cfg.lora_rank = lora_rank ttt_cfg.lora_alpha = lora_alpha ttt_cfg.seed = 0 # ProteinTTT upstream bug: eval_each_step=False can trigger # UnboundLocalError("eval_step_preds referenced before assignment"). ttt_cfg.eval_each_step = True add_log("TTT", f"Config: LR={learning_rate}, Batch={batch_size}, LoRA rank={lora_rank}, alpha={lora_alpha}") # Apply TTT ttt_model = ESMFoldTTT.ttt_from_pretrained( base_model, ttt_cfg=ttt_cfg, esmfold_config=base_model.cfg ) progress_status = _progress_html(0.5, "Running test-time training…") add_log("ProteinTTT", "Running test-time training...") yield current_result() # Attach a handler to capture step-by-step log messages from ttt() captured_logs = [] class _ListHandler(logging.Handler): """Handler that collects formatted log records into a list.""" def emit(self, record): try: captured_logs.append(self.format(record)) except Exception: pass list_handler = _ListHandler() list_handler.setLevel(logging.INFO) list_handler.setFormatter(logging.Formatter("%(message)s")) ttt_model.ttt_logger.addHandler(list_handler) # Run TTT synchronously (must stay on the main thread for ZeroGPU) ttt_result = ttt_model.ttt(sequence) ttt_model.ttt_logger.removeHandler(list_handler) # Now yield all captured step logs so the UI updates with each step for i, msg in enumerate(captured_logs, 1): frac = 0.5 + 0.3 * min(i / max(ttt_steps, 1), 1.0) progress_status = _progress_html(frac, f"ProteinTTT step {i}/{ttt_steps}") add_log("ProteinTTT", msg) yield current_result() add_log("ProteinTTT", "Test-time training finished.") yield current_result() progress_status = _progress_html(0.8, "Running ESMFold+ProteinTTT prediction…") add_log("ProteinTTT", "Running ESMFold+ProteinTTT prediction...") yield current_result() with torch.no_grad(): ttt_output = ttt_model.infer(sequence) ttt_pdb = ttt_model.output_to_pdb(ttt_output)[0] ttt_plddt = get_plddt_from_output(ttt_output) ttt_plddt_text = f"{ttt_plddt:.2f}" add_log("ProteinTTT", f"ProteinTTT pLDDT: {ttt_plddt:.2f}") # Prepare Results progress_status = _progress_html(0.9, "Preparing visualization…") add_log("Result", "Preparing visualizations and artifacts...") improvement_ratio = ttt_plddt / baseline_plddt if baseline_plddt > 0 else 0 improvement_text = f"{improvement_ratio:.2f}x" # Create visualization try: baseline_html = create_3dmol_html(baseline_pdb, f"ESMFold (pLDDT: {baseline_plddt:.1f})") print(f"[DEBUG] Created baseline HTML, length: {len(baseline_html)}") print(f"[DEBUG] Baseline HTML starts with: {baseline_html[:100]}") except Exception as e: print(f"[ERROR] Failed to create baseline HTML: {e}") import traceback traceback.print_exc() baseline_html = f"
Error creating visualization: {str(e)}
" # Save PDB files for download baseline_pdb_path = tempfile.NamedTemporaryFile(mode='w', suffix='_esmfold.pdb', delete=False) baseline_pdb_path.write(baseline_pdb) baseline_pdb_path.close() baseline_pdb_file = baseline_pdb_path.name add_log("Baseline", "ESMFold PDB saved.") yield current_result() # Create 3D visualizations try: ttt_html = create_3dmol_html(ttt_pdb, f"ESMFold + ProteinTTT (pLDDT: {ttt_plddt:.1f})") print(f"[DEBUG] Created ProteinTTT HTML, length: {len(ttt_html)}") except Exception as e: print(f"[ERROR] Failed to create ProteinTTT HTML: {e}") ttt_html = f"
Error creating visualization: {str(e)}
" ttt_pdb_path = tempfile.NamedTemporaryFile(mode='w', suffix='_esmfold_proteinttt.pdb', delete=False) ttt_pdb_path.write(ttt_pdb) ttt_pdb_path.close() ttt_pdb_file = ttt_pdb_path.name progress_status = "" # Clear progress bar on completion add_log("Done", "Prediction completed successfully.") # Debug output print(f"[DEBUG] baseline_plddt_text: {baseline_plddt_text}") print(f"[DEBUG] ttt_plddt_text: {ttt_plddt_text}") print(f"[DEBUG] improvement_text: {improvement_text}") print(f"[DEBUG] baseline_html length: {len(baseline_html) if baseline_html else 0}") print(f"[DEBUG] ttt_html length: {len(ttt_html) if ttt_html else 0}") print(f"[DEBUG] baseline_pdb_file: {baseline_pdb_file}") print(f"[DEBUG] ttt_pdb_file: {ttt_pdb_file}") final_result = current_result() yield final_result return final_result except gr.Error: raise # Re-raise Gradio errors as-is except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" error_tb = traceback.format_exc() print(f"[ERROR] {error_msg}\n{error_tb}") add_log("Error", error_msg) yield current_result() raise gr.Error(f"Prediction failed: {error_msg}") finally: # Reset TTT state and release memory. if ttt_model is not None: try: ttt_model.ttt_reset() except Exception: pass del ttt_model # Move model back to CPU so it survives GPU release. try: _BASE_MODEL.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass # Gradio UI DEMO_SEQUENCE = "GIHLGELGLLPSTVLAIGYFENLVNIICESLNMLPKLEVSGKEYKKFKFTIVIPKDLDANIKKRAKIYFKQKSLIEIEIPTSSRNYPIHIQFDENSTDDILHLYDMPTTIGGIDKAIEMFMRKGHIGKTDQQKLLEERELRNFKTTLENLIATDAFAKEMVEVIIEE" with gr.Blocks( theme=gr.themes.Soft( primary_hue="green", font=["Source Sans Pro", "ui-sans-serif", "system-ui", "-apple-system", "BlinkMacSystemFont", "Segoe UI", "Roboto", "Helvetica Neue", "Arial", "sans-serif"], font_mono=["IBM Plex Mono", "ui-monospace", "Consolas", "monospace"], ) ) as demo: gr.Markdown(""" # ProteinTTT ### Customize ESMFold with Test-Time Training for Enhanced Structure Prediction ProteinTTT enables customizing protein language models to one protein at a time for enhanced performance on challenging targets. 📄 [Paper](https://arxiv.org/abs/2411.02109) | 💻 [GitHub](https://github.com/anton-bushuiev/ProteinTTT) | Google Colab """) with gr.Row(): with gr.Column(scale=2): sequence_input = gr.Textbox( label="Protein Sequence", placeholder="Enter amino acid sequence (e.g., AFRQALQLAASGLAGGSAAVLFSAVAVGKPRAGGD...)", lines=4, info="Standard 20 amino acids only. Max 400 residues for this demo." ) with gr.Accordion("⚙️ Advanced Settings", open=False): gr.Markdown("Fine-tune ProteinTTT training parameters") with gr.Row(): ttt_steps = gr.Slider( minimum=5, maximum=30, value=10, step=1, label="ProteinTTT Steps", info="More steps = better results but slower" ) learning_rate = gr.Number( label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-1, info="Step size for parameter updates" ) with gr.Row(): batch_size = gr.Slider( label="Batch Size", minimum=1, maximum=8, value=4, step=1, info="Number of samples per training batch" ) lora_rank = gr.Slider( label="LoRA Rank", minimum=1, maximum=32, value=8, step=1, info="Rank of LoRA adapter matrices" ) with gr.Row(): lora_alpha = gr.Slider( label="LoRA Alpha", minimum=1, maximum=64, value=32, step=1, info="Scaling factor for LoRA" ) seed = gr.Number( label="Seed", value=0, minimum=0, maximum=2147483647, precision=0, info="Random seed for reproducibility" ) predict_btn = gr.Button("Predict Structure", variant="primary", size="lg") gr.Examples( examples=[ [DEMO_SEQUENCE, 10, 4e-4, 4, 8, 32.0, 0], ], inputs=[sequence_input, ttt_steps, learning_rate, batch_size, lora_rank, lora_alpha, seed], label="Example Sequences (click to use)" ) gr.Markdown("## Results") progress_bar = gr.HTML(elem_id="progress-bar") with gr.Row(): baseline_plddt = gr.Textbox(label="ESMFold pLDDT", interactive=False) ttt_plddt = gr.Textbox(label="ProteinTTT pLDDT", interactive=False) improvement = gr.Textbox(label="Improvement", interactive=False) with gr.Row(): with gr.Column(): gr.Markdown("#### ESMFold") baseline_viz = gr.HTML(label="ESMFold Structure") with gr.Column(): gr.Markdown("#### ESMFold + ProteinTTT") ttt_viz = gr.HTML(label="ESMFold + ProteinTTT Structure") with gr.Row(): baseline_file = gr.File(label="Download ESMFold PDB") ttt_file = gr.File(label="Download ESMFold + ProteinTTT PDB") with gr.Accordion("Log", open=False): log_table = gr.Dataframe( headers=["Time", "Stage", "Message"], datatype=["str", "str", "str"], row_count=(0, "dynamic"), col_count=(3, "fixed"), interactive=False, wrap=True, label="Run Log" ) predict_btn.click( fn=predict_structure, inputs=[sequence_input, ttt_steps, learning_rate, batch_size, lora_rank, lora_alpha, seed], outputs=[ baseline_plddt, ttt_plddt, improvement, baseline_viz, ttt_viz, baseline_file, ttt_file, progress_bar, log_table, ] ) if __name__ == "__main__": demo.launch()