Spaces:
Running on Zero
Running on Zero
pimenol
Update Google Colab link in app.py to use HTML anchor tag for better accessibility and user experience.
3f6c226 | """ | |
| ProteinTTT Demo Space | |
| Customize ESMFold with test-time training for improved protein structure prediction. | |
| """ | |
| import importlib | |
| import importlib.util | |
| import logging | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| import traceback | |
| def _pip_install(*args): | |
| """Run pip install as a subprocess and stream output for debugging.""" | |
| cmd = [sys.executable, "-m", "pip", "install", *args] | |
| print(f"[INSTALL] {' '.join(cmd)}") | |
| subprocess.check_call(cmd) | |
| importlib.invalidate_caches() | |
| # --- openfold --- | |
| try: | |
| import openfold # noqa: F401 | |
| print("[INFO] openfold already installed.") | |
| except ImportError: | |
| print("[INFO] Installing openfold (one-time, may take a few minutes)...") | |
| import re | |
| import shutil | |
| # 1) dllogger (pure-Python, no patching needed) | |
| _pip_install("--no-build-isolation", | |
| "dllogger @ git+https://github.com/NVIDIA/dllogger.git") | |
| # 2) Clone openfold so we can patch its setup.py before building. | |
| _of_dir = os.path.join(tempfile.gettempdir(), "_openfold_build") | |
| if os.path.exists(_of_dir): | |
| shutil.rmtree(_of_dir) | |
| subprocess.check_call([ | |
| "git", "clone", "--filter=blob:none", "--quiet", | |
| "https://github.com/aqlaboratory/openfold.git", _of_dir, | |
| ]) | |
| subprocess.check_call( | |
| ["git", "checkout", "-q", "4b41059694619831a7db195b7e0988fc4ff3a307"], | |
| cwd=_of_dir, | |
| ) | |
| # 3) Patch setup.py | |
| _setup_py = os.path.join(_of_dir, "setup.py") | |
| with open(_setup_py) as _f: | |
| _src = _f.read() | |
| _src = _src.replace( | |
| "from scripts.utils import get_nvidia_cc", | |
| "# from scripts.utils import get_nvidia_cc # patched out", | |
| ) | |
| _src = re.sub( | |
| r"compute_capabilities\s*=\s*set\(\[.*?set\(\[compute_capability\]\)", | |
| "compute_capabilities = set([(7, 0), (8, 0), (8, 6), (9, 0)])", | |
| _src, | |
| flags=re.DOTALL, | |
| ) | |
| _src = _src.replace("-std=c++14", "-std=c++17") | |
| with open(_setup_py, "w") as _f: | |
| _f.write(_src) | |
| # 4) CUDA architectures | |
| os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;8.0;8.6;9.0" | |
| # 5) Install the patched openfold | |
| _pip_install("--no-build-isolation", _of_dir) | |
| shutil.rmtree(_of_dir, ignore_errors=True) | |
| # Verify | |
| import openfold # noqa: F401 | |
| print("[INFO] openfold installed and verified.") | |
| # --- lora-diffusion --- | |
| if importlib.util.find_spec("lora_diffusion") is None: | |
| print("[INFO] Installing lora-diffusion (--no-deps --no-build-isolation)...") | |
| _pip_install("--no-deps", "--no-build-isolation", | |
| "lora-diffusion @ git+https://github.com/cloneofsimo/lora.git") | |
| _spec = importlib.util.find_spec("lora_diffusion") | |
| _init_path = os.path.join(os.path.dirname(_spec.origin), "__init__.py") | |
| with open(_init_path, "w") as _f: | |
| _f.write("from .lora import *\n") | |
| from lora_diffusion.lora import inject_trainable_lora as _check # noqa: F401 | |
| del _check | |
| print("[INFO] lora-diffusion installed, patched, and verified.") | |
| else: | |
| print("[INFO] lora-diffusion already installed.") | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| print("[INFO] Loading ESMFold model on CPU (this may take a minute)...") | |
| import esm as _esm_module # noqa: E402 | |
| _BASE_MODEL = _esm_module.pretrained.esmfold_v1() | |
| _BASE_MODEL = _BASE_MODEL.eval() | |
| _BASE_MODEL.set_chunk_size(128) | |
| print("[INFO] ESMFold model loaded on CPU.") | |
| def get_plddt_from_output(output) -> float: | |
| """Extract mean pLDDT from ESMFold output.""" | |
| return output["mean_plddt"].item() | |
| def create_3dmol_html(pdb_content: str, title: str = "") -> str: | |
| """Create 3Dmol.js visualization HTML with pLDDT coloring.""" | |
| import base64 | |
| import html as html_lib | |
| # Properly escape PDB content for JavaScript template literal | |
| pdb_escaped = pdb_content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$") | |
| # Escape title for HTML | |
| title_escaped = html_lib.escape(title) | |
| html_content = f"""<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <style> | |
| body {{ | |
| margin: 0; | |
| padding: 0; | |
| font-family: "IBM Plex Sans", "Inter", "Helvetica Neue", Arial, sans-serif; | |
| }} | |
| .mol-container {{ width: 100%; height: 450px; position: relative; }} | |
| .title {{ text-align: center; font-size: 14px; font-weight: bold; padding: 5px; background: #f0f0f0; }} | |
| .legend {{ | |
| position: absolute; | |
| right: 12px; | |
| bottom: 12px; | |
| z-index: 10; | |
| background: rgba(255, 255, 255, 0.92); | |
| border: 1px solid #d9d9d9; | |
| border-radius: 8px; | |
| padding: 8px 10px; | |
| font-size: 12px; | |
| color: #333; | |
| line-height: 1.25; | |
| box-shadow: 0 1px 4px rgba(0, 0, 0, 0.15); | |
| }} | |
| .legend-title {{ | |
| font-weight: 700; | |
| margin-bottom: 6px; | |
| }} | |
| .legend-row {{ | |
| display: flex; | |
| align-items: center; | |
| gap: 6px; | |
| margin: 2px 0; | |
| white-space: nowrap; | |
| }} | |
| .legend-color {{ | |
| width: 12px; | |
| height: 12px; | |
| border-radius: 2px; | |
| border: 1px solid rgba(0, 0, 0, 0.12); | |
| flex: none; | |
| }} | |
| </style> | |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
| </head> | |
| <body> | |
| <div class="title">{title_escaped}</div> | |
| <div id="container" class="mol-container"> | |
| <div class="legend"> | |
| <div class="legend-title">pLDDT</div> | |
| <div class="legend-row"><span class="legend-color" style="background:#0d57d3;"></span><span>≥ 90 (very high)</span></div> | |
| <div class="legend-row"><span class="legend-color" style="background:#6acbf1;"></span><span>70 - 90 (confident)</span></div> | |
| <div class="legend-row"><span class="legend-color" style="background:#fed936;"></span><span>50 - 70 (low)</span></div> | |
| <div class="legend-row"><span class="legend-color" style="background:#fd7d4d;"></span><span>< 50 (very low)</span></div> | |
| </div> | |
| </div> | |
| <script> | |
| let pdb = `{pdb_escaped}`; | |
| let viewer = $3Dmol.createViewer("container", {{backgroundColor: "white"}}); | |
| viewer.addModel(pdb, "pdb"); | |
| viewer.setStyle({{}}, {{ | |
| cartoon: {{ | |
| colorfunc: function(atom) {{ | |
| if (atom.b < 50) return "#fd7d4d"; // very low | |
| if (atom.b < 70) return "#fed936"; // low | |
| if (atom.b < 90) return "#6acbf1"; // confident | |
| return "#0d57d3"; // very high | |
| }} | |
| }} | |
| }}); | |
| viewer.zoomTo(); | |
| viewer.render(); | |
| </script> | |
| </body> | |
| </html>""" | |
| # Use base64 encoding for data URI (more reliable in sandboxed environments) | |
| html_bytes = html_content.encode('utf-8') | |
| html_base64 = base64.b64encode(html_bytes).decode('ascii') | |
| return f'<iframe style="width: 100%; height: 500px; border: 1px solid #ddd; border-radius: 8px;" src="data:text/html;base64,{html_base64}"></iframe>' | |
| # Main Prediction Logic | |
| def _progress_html(fraction: float, desc: str) -> str: | |
| """Return a small HTML progress bar.""" | |
| pct = int(fraction * 100) | |
| return f""" | |
| <div style="padding: 8px 0;"> | |
| <div style="font-size: 14px; margin-bottom: 4px; color: #ccc;">{desc} — {pct}%</div> | |
| <div style="background: #333; border-radius: 6px; height: 18px; width: 100%; overflow: hidden;"> | |
| <div style="background: linear-gradient(90deg, #22c55e, #16a34a); height: 100%; width: {pct}%; | |
| border-radius: 6px; transition: width 0.3s ease;"></div> | |
| </div> | |
| </div>""" | |
| def predict_structure( | |
| sequence: str, | |
| ttt_steps: int = 10, | |
| learning_rate: float = 1e-3, | |
| batch_size: int = 1, | |
| lora_rank: int = 4, | |
| lora_alpha: int = 8, | |
| seed: int = 0, | |
| ): | |
| """ | |
| Predict protein structure with and without ProteinTTT. | |
| Returns baseline pLDDT, improved pLDDT, improvement ratio, | |
| and 3D visualizations for both. | |
| """ | |
| baseline_plddt_text = "" | |
| ttt_plddt_text = "" | |
| improvement_text = "" | |
| baseline_html = "" | |
| ttt_html = "" | |
| baseline_pdb_file = None | |
| ttt_pdb_file = None | |
| progress_status = "" | |
| logs = [] | |
| ttt_model = None | |
| def current_result(): | |
| return ( | |
| baseline_plddt_text, | |
| ttt_plddt_text, | |
| improvement_text, | |
| baseline_html, | |
| ttt_html, | |
| baseline_pdb_file, | |
| ttt_pdb_file, | |
| progress_status, | |
| logs, | |
| ) | |
| def add_log(stage: str, message: str): | |
| timestamp = time.strftime("%H:%M:%S") | |
| logs.append([timestamp, stage, message]) | |
| print(f"[{timestamp}] [{stage}] {message}") | |
| try: | |
| if not sequence or len(sequence.strip()) < 10: | |
| raise gr.Error("Please enter a valid protein sequence (at least 10 amino acids).") | |
| # Clean sequence | |
| sequence = sequence.strip().upper() | |
| valid_aa = set("ACDEFGHIKLMNPQRSTVWY") | |
| if not all(aa in valid_aa for aa in sequence): | |
| raise gr.Error("Sequence contains invalid amino acids. Use only standard 20 amino acids.") | |
| if len(sequence) > 400: | |
| raise gr.Error("Sequence too long. Maximum 400 amino acids supported for this demo.") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| add_log("Init", f"Using device: {device}") | |
| progress_status = _progress_html(0.1, "Moving model to GPU…") | |
| base_model = _BASE_MODEL | |
| # Move model to GPU for inference | |
| base_model.to(device) | |
| add_log("Init", f"Model moved to {device}.") | |
| yield current_result() | |
| # ---- ESMFold Prediction ---- | |
| progress_status = _progress_html(0.2, "Running ESMFold prediction…") | |
| add_log("Baseline", "Running ESMFold prediction...") | |
| yield current_result() | |
| with torch.no_grad(): | |
| baseline_output = base_model.infer(sequence) | |
| baseline_pdb = base_model.output_to_pdb(baseline_output)[0] | |
| baseline_plddt = get_plddt_from_output(baseline_output) | |
| baseline_plddt_text = f"{baseline_plddt:.2f}" | |
| add_log("Baseline", f"Baseline pLDDT: {baseline_plddt:.2f}") | |
| # ---- ESMFold + ProteinTTT ---- | |
| progress_status = _progress_html(0.4, f"Loading ESMFold model…") | |
| add_log("TTT", f"Loading ESMFold model...") | |
| yield current_result() | |
| from proteinttt.models.esmfold import ESMFoldTTT, DEFAULT_ESMFOLD_TTT_CFG | |
| # Configure TTT | |
| ttt_cfg = DEFAULT_ESMFOLD_TTT_CFG | |
| ttt_cfg.steps = ttt_steps | |
| ttt_cfg.lr = learning_rate | |
| ttt_cfg.batch_size = batch_size | |
| ttt_cfg.lora_rank = lora_rank | |
| ttt_cfg.lora_alpha = lora_alpha | |
| ttt_cfg.seed = 0 | |
| # ProteinTTT upstream bug: eval_each_step=False can trigger | |
| # UnboundLocalError("eval_step_preds referenced before assignment"). | |
| ttt_cfg.eval_each_step = True | |
| add_log("TTT", f"Config: LR={learning_rate}, Batch={batch_size}, LoRA rank={lora_rank}, alpha={lora_alpha}") | |
| # Apply TTT | |
| ttt_model = ESMFoldTTT.ttt_from_pretrained( | |
| base_model, | |
| ttt_cfg=ttt_cfg, | |
| esmfold_config=base_model.cfg | |
| ) | |
| progress_status = _progress_html(0.5, "Running test-time training…") | |
| add_log("ProteinTTT", "Running test-time training...") | |
| yield current_result() | |
| # Attach a handler to capture step-by-step log messages from ttt() | |
| captured_logs = [] | |
| class _ListHandler(logging.Handler): | |
| """Handler that collects formatted log records into a list.""" | |
| def emit(self, record): | |
| try: | |
| captured_logs.append(self.format(record)) | |
| except Exception: | |
| pass | |
| list_handler = _ListHandler() | |
| list_handler.setLevel(logging.INFO) | |
| list_handler.setFormatter(logging.Formatter("%(message)s")) | |
| ttt_model.ttt_logger.addHandler(list_handler) | |
| # Run TTT synchronously (must stay on the main thread for ZeroGPU) | |
| ttt_result = ttt_model.ttt(sequence) | |
| ttt_model.ttt_logger.removeHandler(list_handler) | |
| # Now yield all captured step logs so the UI updates with each step | |
| for i, msg in enumerate(captured_logs, 1): | |
| frac = 0.5 + 0.3 * min(i / max(ttt_steps, 1), 1.0) | |
| progress_status = _progress_html(frac, f"ProteinTTT step {i}/{ttt_steps}") | |
| add_log("ProteinTTT", msg) | |
| yield current_result() | |
| add_log("ProteinTTT", "Test-time training finished.") | |
| yield current_result() | |
| progress_status = _progress_html(0.8, "Running ESMFold+ProteinTTT prediction…") | |
| add_log("ProteinTTT", "Running ESMFold+ProteinTTT prediction...") | |
| yield current_result() | |
| with torch.no_grad(): | |
| ttt_output = ttt_model.infer(sequence) | |
| ttt_pdb = ttt_model.output_to_pdb(ttt_output)[0] | |
| ttt_plddt = get_plddt_from_output(ttt_output) | |
| ttt_plddt_text = f"{ttt_plddt:.2f}" | |
| add_log("ProteinTTT", f"ProteinTTT pLDDT: {ttt_plddt:.2f}") | |
| # Prepare Results | |
| progress_status = _progress_html(0.9, "Preparing visualization…") | |
| add_log("Result", "Preparing visualizations and artifacts...") | |
| improvement_ratio = ttt_plddt / baseline_plddt if baseline_plddt > 0 else 0 | |
| improvement_text = f"{improvement_ratio:.2f}x" | |
| # Create visualization | |
| try: | |
| baseline_html = create_3dmol_html(baseline_pdb, f"ESMFold (pLDDT: {baseline_plddt:.1f})") | |
| print(f"[DEBUG] Created baseline HTML, length: {len(baseline_html)}") | |
| print(f"[DEBUG] Baseline HTML starts with: {baseline_html[:100]}") | |
| except Exception as e: | |
| print(f"[ERROR] Failed to create baseline HTML: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| baseline_html = f"<div>Error creating visualization: {str(e)}</div>" | |
| # Save PDB files for download | |
| baseline_pdb_path = tempfile.NamedTemporaryFile(mode='w', suffix='_esmfold.pdb', delete=False) | |
| baseline_pdb_path.write(baseline_pdb) | |
| baseline_pdb_path.close() | |
| baseline_pdb_file = baseline_pdb_path.name | |
| add_log("Baseline", "ESMFold PDB saved.") | |
| yield current_result() | |
| # Create 3D visualizations | |
| try: | |
| ttt_html = create_3dmol_html(ttt_pdb, f"ESMFold + ProteinTTT (pLDDT: {ttt_plddt:.1f})") | |
| print(f"[DEBUG] Created ProteinTTT HTML, length: {len(ttt_html)}") | |
| except Exception as e: | |
| print(f"[ERROR] Failed to create ProteinTTT HTML: {e}") | |
| ttt_html = f"<div>Error creating visualization: {str(e)}</div>" | |
| ttt_pdb_path = tempfile.NamedTemporaryFile(mode='w', suffix='_esmfold_proteinttt.pdb', delete=False) | |
| ttt_pdb_path.write(ttt_pdb) | |
| ttt_pdb_path.close() | |
| ttt_pdb_file = ttt_pdb_path.name | |
| progress_status = "" # Clear progress bar on completion | |
| add_log("Done", "Prediction completed successfully.") | |
| # Debug output | |
| print(f"[DEBUG] baseline_plddt_text: {baseline_plddt_text}") | |
| print(f"[DEBUG] ttt_plddt_text: {ttt_plddt_text}") | |
| print(f"[DEBUG] improvement_text: {improvement_text}") | |
| print(f"[DEBUG] baseline_html length: {len(baseline_html) if baseline_html else 0}") | |
| print(f"[DEBUG] ttt_html length: {len(ttt_html) if ttt_html else 0}") | |
| print(f"[DEBUG] baseline_pdb_file: {baseline_pdb_file}") | |
| print(f"[DEBUG] ttt_pdb_file: {ttt_pdb_file}") | |
| final_result = current_result() | |
| yield final_result | |
| return final_result | |
| except gr.Error: | |
| raise # Re-raise Gradio errors as-is | |
| except Exception as e: | |
| error_msg = f"{type(e).__name__}: {str(e)}" | |
| error_tb = traceback.format_exc() | |
| print(f"[ERROR] {error_msg}\n{error_tb}") | |
| add_log("Error", error_msg) | |
| yield current_result() | |
| raise gr.Error(f"Prediction failed: {error_msg}") | |
| finally: | |
| # Reset TTT state and release memory. | |
| if ttt_model is not None: | |
| try: | |
| ttt_model.ttt_reset() | |
| except Exception: | |
| pass | |
| del ttt_model | |
| # Move model back to CPU so it survives GPU release. | |
| try: | |
| _BASE_MODEL.to("cpu") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| # Gradio UI | |
| DEMO_SEQUENCE = "GIHLGELGLLPSTVLAIGYFENLVNIICESLNMLPKLEVSGKEYKKFKFTIVIPKDLDANIKKRAKIYFKQKSLIEIEIPTSSRNYPIHIQFDENSTDDILHLYDMPTTIGGIDKAIEMFMRKGHIGKTDQQKLLEERELRNFKTTLENLIATDAFAKEMVEVIIEE" | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="green", | |
| font=["Source Sans Pro", "ui-sans-serif", "system-ui", "-apple-system", "BlinkMacSystemFont", "Segoe UI", "Roboto", "Helvetica Neue", "Arial", "sans-serif"], | |
| font_mono=["IBM Plex Mono", "ui-monospace", "Consolas", "monospace"], | |
| ) | |
| ) as demo: | |
| gr.Markdown(""" | |
| # ProteinTTT | |
| ### Customize ESMFold with Test-Time Training for Enhanced Structure Prediction | |
| ProteinTTT enables customizing protein language models to one protein at a time for enhanced | |
| performance on challenging targets. | |
| 📄 [Paper](https://arxiv.org/abs/2411.02109) | 💻 [GitHub](https://github.com/anton-bushuiev/ProteinTTT) | <a href="https://colab.research.google.com/drive/1l_h7cw82SQpW9PvYzSQeYS4TXzS1QJ8o" target="_blank"> Google Colab</a> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| sequence_input = gr.Textbox( | |
| label="Protein Sequence", | |
| placeholder="Enter amino acid sequence (e.g., AFRQALQLAASGLAGGSAAVLFSAVAVGKPRAGGD...)", | |
| lines=4, | |
| info="Standard 20 amino acids only. Max 400 residues for this demo." | |
| ) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| gr.Markdown("Fine-tune ProteinTTT training parameters") | |
| with gr.Row(): | |
| ttt_steps = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| value=10, | |
| step=1, | |
| label="ProteinTTT Steps", | |
| info="More steps = better results but slower" | |
| ) | |
| learning_rate = gr.Number( | |
| label="Learning Rate", | |
| value=4e-4, | |
| minimum=1e-6, | |
| maximum=1e-1, | |
| info="Step size for parameter updates" | |
| ) | |
| with gr.Row(): | |
| batch_size = gr.Slider( | |
| label="Batch Size", | |
| minimum=1, | |
| maximum=8, | |
| value=4, | |
| step=1, | |
| info="Number of samples per training batch" | |
| ) | |
| lora_rank = gr.Slider( | |
| label="LoRA Rank", | |
| minimum=1, | |
| maximum=32, | |
| value=8, | |
| step=1, | |
| info="Rank of LoRA adapter matrices" | |
| ) | |
| with gr.Row(): | |
| lora_alpha = gr.Slider( | |
| label="LoRA Alpha", | |
| minimum=1, | |
| maximum=64, | |
| value=32, | |
| step=1, | |
| info="Scaling factor for LoRA" | |
| ) | |
| seed = gr.Number( | |
| label="Seed", | |
| value=0, | |
| minimum=0, | |
| maximum=2147483647, | |
| precision=0, | |
| info="Random seed for reproducibility" | |
| ) | |
| predict_btn = gr.Button("Predict Structure", variant="primary", size="lg") | |
| gr.Examples( | |
| examples=[ | |
| [DEMO_SEQUENCE, 10, 4e-4, 4, 8, 32.0, 0], | |
| ], | |
| inputs=[sequence_input, ttt_steps, learning_rate, batch_size, lora_rank, lora_alpha, seed], | |
| label="Example Sequences (click to use)" | |
| ) | |
| gr.Markdown("## Results") | |
| progress_bar = gr.HTML(elem_id="progress-bar") | |
| with gr.Row(): | |
| baseline_plddt = gr.Textbox(label="ESMFold pLDDT", interactive=False) | |
| ttt_plddt = gr.Textbox(label="ProteinTTT pLDDT", interactive=False) | |
| improvement = gr.Textbox(label="Improvement", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### ESMFold") | |
| baseline_viz = gr.HTML(label="ESMFold Structure") | |
| with gr.Column(): | |
| gr.Markdown("#### ESMFold + ProteinTTT") | |
| ttt_viz = gr.HTML(label="ESMFold + ProteinTTT Structure") | |
| with gr.Row(): | |
| baseline_file = gr.File(label="Download ESMFold PDB") | |
| ttt_file = gr.File(label="Download ESMFold + ProteinTTT PDB") | |
| with gr.Accordion("Log", open=False): | |
| log_table = gr.Dataframe( | |
| headers=["Time", "Stage", "Message"], | |
| datatype=["str", "str", "str"], | |
| row_count=(0, "dynamic"), | |
| col_count=(3, "fixed"), | |
| interactive=False, | |
| wrap=True, | |
| label="Run Log" | |
| ) | |
| predict_btn.click( | |
| fn=predict_structure, | |
| inputs=[sequence_input, ttt_steps, learning_rate, batch_size, lora_rank, lora_alpha, seed], | |
| outputs=[ | |
| baseline_plddt, | |
| ttt_plddt, | |
| improvement, | |
| baseline_viz, | |
| ttt_viz, | |
| baseline_file, | |
| ttt_file, | |
| progress_bar, | |
| log_table, | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |