| """Gradio UI entrypoint for the FlowProt Hugging Face Docker Space."""
|
|
|
| from __future__ import annotations
|
|
|
| import base64
|
| import csv
|
| import html
|
| import json
|
| import logging
|
| import os
|
| import shutil
|
| import traceback
|
| from datetime import datetime
|
| from pathlib import Path
|
| from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
| import gradio as gr
|
| import pandas as pd
|
|
|
| from inference import (
|
| ArtifactResolutionError,
|
| FlowProtInferenceService,
|
| InferenceError,
|
| InferenceResult,
|
| ModelLoadError,
|
| )
|
| from self_consistency import (
|
| FlowProtSelfConsistencyService,
|
| SelfConsistencyError,
|
| )
|
| from viewer_alignment import AlignmentError, align_folded_to_backbone_overlay
|
|
|
| logging.basicConfig(
|
| level=os.getenv("FLOWPROT_LOG_LEVEL", "INFO").upper(),
|
| format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
| )
|
| LOGGER = logging.getLogger(__name__)
|
|
|
| SERVICE = FlowProtInferenceService()
|
| SC_SERVICE = FlowProtSelfConsistencyService()
|
| VIEWER_HEIGHT = 620
|
| SC_SUMMARY_HEADERS = ["Metric", "Value"]
|
| SC_METRICS_HEADERS = ["Sample", "Seq len", "scTM", "scRMSD", "ESMFold mean pLDDT"]
|
| SC_TOP_TM_HEADERS = ["Rank", "Sample", "Seq len", "scTM"]
|
| SC_TOP_RMSD_HEADERS = ["Rank", "Sample", "Seq len", "scRMSD"]
|
| EXAMPLE_CASE_DIR = (
|
| Path(__file__).resolve().parent / "examples" / "flowprot_space_example"
|
| ).resolve()
|
| EXAMPLE_CASE_SAMPLE_DIR = EXAMPLE_CASE_DIR / "sample"
|
| EXAMPLE_CASE_MANIFEST = EXAMPLE_CASE_DIR / "manifest.json"
|
| SPACE_OUTPUTS_DIR = (Path(__file__).resolve().parent / "space_outputs").resolve()
|
| RUN_HISTORY_PATH = SPACE_OUTPUTS_DIR / "run_history.json"
|
| RUN_HISTORY_LIMIT = 25
|
|
|
|
|
| def _env_flag(name: str, default: bool = False) -> bool:
|
| raw = os.getenv(name)
|
| if raw is None:
|
| return default
|
| return str(raw).strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
| VIEW_EXAMPLE_ON_STARTUP_DEFAULT = _env_flag("FLOWPROT_VIEW_EXAMPLE_ON_STARTUP", default=False)
|
| SC_DEFAULT_TOP_N = 3
|
| SC_TM_MIN = 0.0
|
| SC_TM_MAX = 1.0
|
| SC_RMSD_MIN = 0.0
|
| SC_RMSD_MAX = 10.0
|
| SC_PLDDT_MIN = 0.0
|
| SC_PLDDT_MAX = 1.0
|
| UI_CSS = """
|
| .gradio-container {
|
| max-width: 1400px !important;
|
| margin: 0 auto !important;
|
| padding-left: 14px !important;
|
| padding-right: 14px !important;
|
| }
|
| .fp-card {
|
| border: 1px solid var(--block-border-color);
|
| border-radius: 14px;
|
| padding: 14px;
|
| margin-bottom: 12px;
|
| background: var(--block-background-fill);
|
| box-shadow: 0 1px 2px rgba(15, 23, 42, 0.06);
|
| }
|
| .fp-tight {
|
| gap: 12px !important;
|
| }
|
| .fp-status textarea {
|
| min-height: 100px !important;
|
| }
|
| .fp-debug-note {
|
| color: var(--body-text-color-subdued);
|
| font-size: 0.92rem;
|
| }
|
| .fp-subnote {
|
| color: var(--body-text-color-subdued);
|
| font-size: 0.9rem;
|
| }
|
| .fp-main-row {
|
| align-items: stretch;
|
| }
|
| .fp-plot-row {
|
| align-items: stretch;
|
| }
|
| @media (max-width: 980px) {
|
| .gradio-container {
|
| padding-left: 10px !important;
|
| padding-right: 10px !important;
|
| }
|
| .fp-main-row {
|
| flex-direction: column !important;
|
| }
|
| .fp-status textarea {
|
| min-height: 88px !important;
|
| }
|
| }
|
| @media (max-width: 1280px) {
|
| .fp-plot-row {
|
| flex-direction: column !important;
|
| }
|
| }
|
| """
|
|
|
|
|
| def _viewer_height_css() -> str:
|
| return f"min(72vh, {VIEWER_HEIGHT}px)"
|
|
|
|
|
| def _molstar_placeholder_html(message: Optional[str] = None) -> str:
|
| placeholder_message = message or "Generate a sample and select it to load in Mol*."
|
| return (
|
| f"<div style=\"height: {_viewer_height_css()}; border: 1px solid var(--block-border-color, #d9d9d9); border-radius: 8px; "
|
| "display: flex; align-items: center; justify-content: center; color: var(--body-text-color-subdued, #666666); background: var(--block-background-fill, #ffffff);\">"
|
| f"{html.escape(placeholder_message)}"
|
| "</div>"
|
| )
|
|
|
|
|
| def _mode_choices() -> List[str]:
|
| choices = ["unconditional"]
|
| if SERVICE.classifier_enabled:
|
| choices.append("classifier")
|
| if SERVICE.conditional_enabled:
|
| choices.append("conditional")
|
| return choices
|
|
|
|
|
| def _empty_sc_plot_df() -> pd.DataFrame:
|
|
|
| return pd.DataFrame(
|
| {
|
| "sample": pd.Series(dtype="string"),
|
| "plot_color": pd.Series(dtype="string"),
|
| "scTM": pd.Series(dtype="float64"),
|
| "scRMSD": pd.Series(dtype="float64"),
|
| "esmfold_mean_plddt": pd.Series(dtype="float64"),
|
| }
|
| )
|
|
|
|
|
| def _empty_scatter_plot_html(title: str, message: str) -> str:
|
| return (
|
| "<div style=\"border: 1px solid var(--block-border-color, #d9d9d9); border-radius: 10px; "
|
| "background: var(--block-background-fill, #ffffff); min-height: 420px; padding: 10px 12px;\">"
|
| f"<div style=\"font-weight: 600; margin-bottom: 6px;\">{html.escape(title)}</div>"
|
| "<div style=\"height: 360px; display:flex; align-items:center; justify-content:center; "
|
| "color: var(--body-text-color-subdued, #666666);\">"
|
| f"{html.escape(message)}"
|
| "</div>"
|
| "</div>"
|
| )
|
|
|
|
|
| def _empty_sc_outputs() -> Tuple[
|
| str,
|
| List[List[str]],
|
| List[List[object]],
|
| str,
|
| str,
|
| List[str],
|
| List[List[object]],
|
| List[List[object]],
|
| ]:
|
| return (
|
| "",
|
| [],
|
| [],
|
| _empty_scatter_plot_html(
|
| title="Self-consistency tradeoff (scTM vs scRMSD)",
|
| message="Run self-consistency to render plot data.",
|
| ),
|
| _empty_scatter_plot_html(
|
| title="ESMFold confidence vs scTM",
|
| message="Run self-consistency with folding enabled to render confidence plot.",
|
| ),
|
| [],
|
| [],
|
| [],
|
| )
|
|
|
|
|
| def _empty_sc_state() -> Tuple[str, Dict[str, object], List[Dict[str, object]], List[str]]:
|
| return "", {}, [], []
|
|
|
|
|
| def _empty_example_view_outputs() -> Tuple[
|
| str,
|
| List[str],
|
| List[str],
|
| Dict[str, object],
|
| str,
|
| str,
|
| List[List[str]],
|
| List[List[object]],
|
| pd.DataFrame,
|
| pd.DataFrame,
|
| List[str],
|
| List[List[object]],
|
| List[List[object]],
|
| str,
|
| Dict[str, object],
|
| List[Dict[str, object]],
|
| List[str],
|
| str,
|
| ]:
|
| empty_sc = _empty_sc_outputs()
|
| empty_state = _empty_sc_state()
|
| return (
|
| "",
|
| [],
|
| [],
|
| gr.update(choices=[], value=[], interactive=False),
|
| _molstar_placeholder_html(),
|
| *empty_sc,
|
| *empty_state,
|
| "",
|
| )
|
|
|
|
|
| def _resolve_saved_example_sample() -> Optional[str]:
|
| if EXAMPLE_CASE_MANIFEST.exists():
|
| try:
|
| payload = json.loads(EXAMPLE_CASE_MANIFEST.read_text(encoding="utf-8"))
|
| manifest_sample = payload.get("example_sample_path")
|
| if manifest_sample:
|
| path = Path(str(manifest_sample))
|
| if path.exists():
|
| return path.resolve().as_posix()
|
| except Exception:
|
| LOGGER.warning("Failed to parse example manifest at %s", EXAMPLE_CASE_MANIFEST)
|
|
|
| fallback_sample = EXAMPLE_CASE_SAMPLE_DIR / "sample.pdb"
|
| if fallback_sample.exists():
|
| return fallback_sample.resolve().as_posix()
|
| return None
|
|
|
|
|
| def _find_latest_generated_sample() -> Optional[Path]:
|
| output_root = (Path(__file__).resolve().parent / "space_outputs").resolve()
|
| if not output_root.exists():
|
| return None
|
| candidates = sorted(
|
| output_root.glob("space_*/length_*/sample_*/sample.pdb"),
|
| key=lambda path: path.stat().st_mtime,
|
| reverse=True,
|
| )
|
| return candidates[0] if candidates else None
|
|
|
|
|
| def _normalize_selected_samples(selected_samples: Optional[object]) -> List[str]:
|
| if selected_samples is None:
|
| return []
|
| if isinstance(selected_samples, str):
|
| value = selected_samples.strip()
|
| return [value] if value else []
|
| if isinstance(selected_samples, (list, tuple, set)):
|
| normalized: List[str] = []
|
| for raw in selected_samples:
|
| if raw is None:
|
| continue
|
| value = str(raw).strip()
|
| if value:
|
| normalized.append(value)
|
| return normalized
|
| value = str(selected_samples).strip()
|
| return [value] if value else []
|
|
|
|
|
| def _sample_label_from_path(sample_path: str, metric: Optional[Dict[str, object]] = None) -> str:
|
| path = Path(sample_path)
|
| sample_dir_name = path.parent.name
|
| sample_id = sample_dir_name.replace("sample_", "", 1) if sample_dir_name.startswith("sample_") else sample_dir_name
|
| esmfold_id: object = "x"
|
| if metric is not None and metric.get("index") is not None:
|
| esmfold_id = metric.get("index")
|
| return f"sample_{sample_id}_{esmfold_id}"
|
|
|
|
|
| def _enrich_metric_sample_fields(metric: Dict[str, object], sample_path: str) -> Dict[str, object]:
|
| enriched = dict(metric)
|
| enriched["sample_source_path"] = sample_path
|
| enriched["sample_display"] = _sample_label_from_path(sample_path, metric=metric)
|
|
|
| enriched["sample"] = enriched["sample_display"]
|
| return enriched
|
|
|
|
|
| def _metric_sample_display(metric: Dict[str, object]) -> str:
|
| value = metric.get("sample_display") or metric.get("sample")
|
| return str(value) if value is not None else "unknown"
|
|
|
|
|
| def _metric_sequence_length(metric: Dict[str, object]) -> Optional[int]:
|
| sequence = metric.get("sequence")
|
| if sequence is None:
|
| return None
|
| return len(str(sequence))
|
|
|
|
|
| def _normalize_optional_text(value: Optional[object]) -> str:
|
| if value is None:
|
| return ""
|
| return str(value).strip()
|
|
|
|
|
| def _parse_optional_seed(seed_override: Optional[object]) -> Optional[int]:
|
| text = _normalize_optional_text(seed_override)
|
| if not text:
|
| return None
|
| try:
|
| return int(text)
|
| except ValueError as exc:
|
| raise InferenceError("Seed override must be an integer value.") from exc
|
|
|
|
|
| def _actionable_error_hint(exc: object) -> str:
|
| text = str(exc).lower()
|
| if "out of memory" in text or "cuda" in text and "memory" in text:
|
| return (
|
| " Hint: the GPU ran out of memory. Try a shorter length, fewer samples, "
|
| "or fewer sampling timesteps."
|
| )
|
| if "no model artifact source" in text or "artifact" in text:
|
| return (
|
| " Hint: configure a checkpoint via FLOWPROT_CKPT_PATH, FLOWPROT_CKPT_DIR, "
|
| "or FLOWPROT_HF_REPO_ID before running inference."
|
| )
|
| if "esmfold" in text or "transformers" in text:
|
| return " Hint: ESMFold could not run. Disable folding or verify the ESMFold model is available."
|
| return ""
|
|
|
|
|
| def _short_error_text(message: object, max_chars: int = 180) -> str:
|
| raw = str(message).strip()
|
| if not raw:
|
| return "unknown error"
|
| first_line = raw.splitlines()[0].strip()
|
| if len(first_line) <= max_chars:
|
| return first_line
|
| return first_line[: max_chars - 3].rstrip() + "..."
|
|
|
|
|
| def _build_sc_filter_sample_choices(metrics: Sequence[Dict[str, object]]) -> List[str]:
|
| ordered_unique: List[str] = []
|
| seen = set()
|
| for item in metrics:
|
| label = _metric_sample_display(item)
|
| if label in seen:
|
| continue
|
| seen.add(label)
|
| ordered_unique.append(label)
|
| return ordered_unique
|
|
|
|
|
| def _build_folded_choices(metrics: Sequence[Dict[str, object]]) -> List[Tuple[str, str]]:
|
| choices: List[Tuple[str, str]] = []
|
| for item in metrics:
|
| folded_path = _normalize_optional_text(item.get("folded_sample_path"))
|
| if not folded_path:
|
| continue
|
| sc_tm = item.get("scTM")
|
| sc_rmsd = item.get("scRMSD")
|
| sequence_index = item.get("index")
|
| sample_display = _metric_sample_display(item)
|
| choice_label = (
|
| f"{sample_display} | seq#{sequence_index} | "
|
| f"scTM={float(sc_tm):.4f} | scRMSD={float(sc_rmsd):.4f}"
|
| if sc_tm is not None and sc_rmsd is not None
|
| else f"{sample_display} | seq#{sequence_index}"
|
| )
|
| choices.append((choice_label, folded_path))
|
| return choices
|
|
|
|
|
| def _build_sc_summary(num_selected_samples: int, num_processed_samples: int, metrics: List[Dict[str, object]]) -> Dict[str, object]:
|
| summary: Dict[str, object] = {
|
| "num_samples_selected": num_selected_samples,
|
| "num_samples_processed": num_processed_samples,
|
| "num_sequences_evaluated": len(metrics),
|
| }
|
| if metrics:
|
| tm_values = [float(item["scTM"]) for item in metrics if item.get("scTM") is not None]
|
| rmsd_values = [float(item["scRMSD"]) for item in metrics if item.get("scRMSD") is not None]
|
| plddt_values = [
|
| float(item["esmfold_mean_plddt"])
|
| for item in metrics
|
| if item.get("esmfold_mean_plddt") is not None
|
| ]
|
| if tm_values:
|
| summary["mean_scTM"] = float(sum(tm_values) / len(tm_values))
|
| summary["max_scTM"] = max(tm_values)
|
| summary["min_scTM"] = min(tm_values)
|
| if rmsd_values:
|
| summary["mean_scRMSD"] = float(sum(rmsd_values) / len(rmsd_values))
|
| summary["min_scRMSD"] = min(rmsd_values)
|
| summary["max_scRMSD"] = max(rmsd_values)
|
| if plddt_values:
|
| summary["mean_esmfold_plddt"] = float(sum(plddt_values) / len(plddt_values))
|
| return summary
|
|
|
|
|
| def _safe_float(value: Optional[str]) -> Optional[float]:
|
| if value is None:
|
| return None
|
| text = str(value).strip()
|
| if not text:
|
| return None
|
| try:
|
| return float(text)
|
| except ValueError:
|
| return None
|
|
|
|
|
| def _load_sc_metrics_from_csv(csv_path: Path) -> List[Dict[str, object]]:
|
| metrics: List[Dict[str, object]] = []
|
| with csv_path.open("r", encoding="utf-8", newline="") as handle:
|
| reader = csv.DictReader(handle)
|
| for row in reader:
|
| index_raw = row.get("index")
|
| index_value: object = index_raw
|
| if index_raw not in (None, ""):
|
| try:
|
| index_value = int(str(index_raw))
|
| except ValueError:
|
| index_value = index_raw
|
| metrics.append(
|
| {
|
| "index": index_value,
|
| "header": row.get("header"),
|
| "sequence": row.get("sequence"),
|
| "scTM": _safe_float(row.get("scTM")),
|
| "scRMSD": _safe_float(row.get("scRMSD")),
|
| "esmfold_mean_plddt": _safe_float(row.get("esmfold_mean_plddt")),
|
| "folded_sample_path": row.get("folded_sample_path"),
|
| }
|
| )
|
| return metrics
|
|
|
|
|
| def _collect_loaded_sc_artifacts(run_dir: Path, metrics_csv: Path, metrics: List[Dict[str, object]]) -> List[str]:
|
| artifacts: List[str] = []
|
| parsed_jsonl = run_dir / "parsed_pdbs.jsonl"
|
| mpnn_fasta = run_dir / "seqs" / "sample.fa"
|
| if parsed_jsonl.exists():
|
| artifacts.append(parsed_jsonl.resolve().as_posix())
|
| if mpnn_fasta.exists():
|
| artifacts.append(mpnn_fasta.resolve().as_posix())
|
|
|
| esmf_dir = run_dir / "esmf"
|
| if esmf_dir.exists():
|
| folded = sorted(esmf_dir.glob("*.pdb"))
|
| artifacts.extend(path.resolve().as_posix() for path in folded)
|
| else:
|
| for item in metrics:
|
| folded_path = item.get("folded_sample_path")
|
| if not folded_path:
|
| continue
|
| path = Path(str(folded_path))
|
| if path.exists():
|
| artifacts.append(path.resolve().as_posix())
|
|
|
| if metrics_csv.exists():
|
| artifacts.append(metrics_csv.resolve().as_posix())
|
| return artifacts
|
|
|
|
|
| def _hydrate_sc_outputs_for_samples(selected_samples: Optional[object]) -> Tuple[
|
| str,
|
| List[List[str]],
|
| List[List[object]],
|
| pd.DataFrame,
|
| pd.DataFrame,
|
| List[str],
|
| List[List[object]],
|
| List[List[object]],
|
| str,
|
| Dict[str, object],
|
| List[Dict[str, object]],
|
| List[str],
|
| ]:
|
| sample_values = _normalize_selected_samples(selected_samples)
|
| if not sample_values:
|
| return (*_empty_sc_outputs(), *_empty_sc_state())
|
|
|
| combined_metrics: List[Dict[str, object]] = []
|
| combined_artifacts: List[str] = []
|
| processed_samples = 0
|
|
|
| for sample_value in sample_values:
|
| sample_path = Path(sample_value)
|
| sample_dir = sample_path.parent
|
| sc_root = sample_dir / "self_consistency"
|
| if not sc_root.exists():
|
| continue
|
|
|
| run_dirs = sorted(
|
| [path for path in sc_root.iterdir() if path.is_dir()],
|
| key=lambda path: path.stat().st_mtime,
|
| reverse=True,
|
| )
|
| for run_dir in run_dirs:
|
| metrics_csv = run_dir / "sc_results.csv"
|
| if not metrics_csv.exists():
|
| continue
|
| try:
|
| metrics = _load_sc_metrics_from_csv(metrics_csv)
|
| if not metrics:
|
| continue
|
| for item in metrics:
|
| combined_metrics.append(_enrich_metric_sample_fields(metric=item, sample_path=sample_value))
|
| combined_artifacts.extend(
|
| _collect_loaded_sc_artifacts(run_dir=run_dir, metrics_csv=metrics_csv, metrics=metrics)
|
| )
|
| processed_samples += 1
|
| break
|
| except Exception as exc:
|
| LOGGER.warning("Failed to load saved self-consistency from %s: %s", metrics_csv, exc)
|
| continue
|
|
|
| if not combined_metrics:
|
| return (*_empty_sc_outputs(), *_empty_sc_state())
|
|
|
|
|
| artifacts = list(dict.fromkeys(combined_artifacts))
|
| summary = _build_sc_summary(
|
| num_selected_samples=len(sample_values),
|
| num_processed_samples=processed_samples,
|
| metrics=combined_metrics,
|
| )
|
| base_status = (
|
| f"Loaded saved self-consistency for {processed_samples}/{len(sample_values)} selected sample(s). "
|
| f"Evaluated={len(combined_metrics)}."
|
| )
|
| rendered = _render_sc_outputs(
|
| base_status=base_status,
|
| summary=summary,
|
| metrics=combined_metrics,
|
| artifacts=artifacts,
|
| top_n=SC_DEFAULT_TOP_N,
|
| )
|
| return (*rendered, base_status, summary, combined_metrics, artifacts)
|
|
|
|
|
| def save_example_case(selected_samples: Optional[object]) -> str:
|
| selected_values = _normalize_selected_samples(selected_samples)
|
| selected_sample = selected_values[0] if selected_values else None
|
| if not selected_sample:
|
| latest = _find_latest_generated_sample()
|
| if latest is None:
|
| return "No generated sample found. Run inference or select a sample before saving an example case."
|
| selected_sample = latest.resolve().as_posix()
|
|
|
| sample_path = Path(selected_sample)
|
| if not sample_path.exists():
|
| return f"Selected sample does not exist: {sample_path}"
|
|
|
| source_dir = sample_path.parent
|
| try:
|
| EXAMPLE_CASE_DIR.mkdir(parents=True, exist_ok=True)
|
| if EXAMPLE_CASE_SAMPLE_DIR.exists():
|
| shutil.rmtree(EXAMPLE_CASE_SAMPLE_DIR)
|
| shutil.copytree(source_dir, EXAMPLE_CASE_SAMPLE_DIR)
|
|
|
| saved_sample_path = EXAMPLE_CASE_SAMPLE_DIR / sample_path.name
|
| if not saved_sample_path.exists():
|
| return (
|
| "Example case copy completed, but sample file was missing in copied directory. "
|
| "Please verify the source sample."
|
| )
|
|
|
| manifest = {
|
| "saved_at_utc": datetime.utcnow().isoformat() + "Z",
|
| "source_sample_path": sample_path.resolve().as_posix(),
|
| "example_sample_path": saved_sample_path.resolve().as_posix(),
|
| }
|
| EXAMPLE_CASE_MANIFEST.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
|
| return (
|
| f"Saved example case to {EXAMPLE_CASE_SAMPLE_DIR}. "
|
| "Set FLOWPROT_VIEW_EXAMPLE_ON_STARTUP=true to auto-load it at app startup."
|
| )
|
| except Exception as exc:
|
| LOGGER.exception("Failed to save example case.")
|
| return f"Failed to save example case: {exc}"
|
|
|
|
|
| def load_saved_example(
|
| prefix_status: str = "Loaded saved example case.",
|
| prefix_example_status: str = "Example case loaded.",
|
| ) -> Tuple[
|
| str,
|
| List[str],
|
| List[str],
|
| Dict[str, object],
|
| str,
|
| str,
|
| List[List[str]],
|
| List[List[object]],
|
| pd.DataFrame,
|
| pd.DataFrame,
|
| List[str],
|
| List[List[object]],
|
| List[List[object]],
|
| str,
|
| Dict[str, object],
|
| List[Dict[str, object]],
|
| List[str],
|
| str,
|
| ]:
|
| sample_path = _resolve_saved_example_sample()
|
| if not sample_path:
|
| empty = _empty_example_view_outputs()
|
| return (*empty[:-1], "No saved example found in examples/flowprot_space_example. Save one first.")
|
|
|
| selector_update = gr.update(choices=[sample_path], value=[sample_path], interactive=True)
|
| sc_payload = _hydrate_sc_outputs_for_samples([sample_path])
|
| loaded_metrics = sc_payload[10]
|
| sc_note = (
|
| f" Loaded self-consistency ({len(loaded_metrics)} sequence(s))."
|
| if loaded_metrics
|
| else " No saved self-consistency metrics found for this example."
|
| )
|
| return (
|
| f"{prefix_status} Source: {sample_path}",
|
| [sample_path],
|
| [],
|
| selector_update,
|
| _build_molstar_iframe_html(sample_path),
|
| *sc_payload,
|
| f"{prefix_example_status} Source: {sample_path}.{sc_note}",
|
| )
|
|
|
|
|
| def maybe_load_saved_example_on_startup(
|
| view_example_on_startup: bool,
|
| ) -> Tuple[
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| object,
|
| str,
|
| ]:
|
| if not bool(view_example_on_startup):
|
| return (
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| gr.skip(),
|
| "",
|
| )
|
| return load_saved_example(
|
| prefix_status="Loaded saved example case on startup.",
|
| prefix_example_status="Startup example loaded.",
|
| )
|
|
|
|
|
| def _format_summary_rows(summary: Dict[str, object]) -> List[List[str]]:
|
| preferred = [
|
| ("mean_scTM", "scTM score"),
|
| ("mean_scRMSD", "scRMSD score"),
|
| ("mean_esmfold_plddt", "Mean pLDDT"),
|
| ]
|
| rows: List[List[str]] = []
|
| for key, label in preferred:
|
| if key not in summary:
|
| continue
|
| value = summary[key]
|
| if isinstance(value, float):
|
| precision = 4 if "plddt" not in key else 4
|
| rows.append([label, f"{value:.{precision}f}"])
|
| else:
|
| rows.append([label, str(value)])
|
| return rows
|
|
|
|
|
| def _format_metrics_rows(metrics: List[Dict[str, object]]) -> List[List[object]]:
|
| rows: List[List[object]] = []
|
| for item in metrics:
|
| rows.append(
|
| [
|
| _metric_sample_display(item),
|
| _metric_sequence_length(item),
|
| round(float(item.get("scTM")), 4) if item.get("scTM") is not None else None,
|
| round(float(item.get("scRMSD")), 4) if item.get("scRMSD") is not None else None,
|
| round(float(item.get("esmfold_mean_plddt")), 4)
|
| if item.get("esmfold_mean_plddt") is not None
|
| else None,
|
| ]
|
| )
|
| return rows
|
|
|
|
|
| def _format_metrics_plot_df(metrics: List[Dict[str, object]]) -> pd.DataFrame:
|
| plot_rows = [
|
| {
|
| "sample": _metric_sample_display(item),
|
| "plot_color": "metric",
|
| "scTM": item.get("scTM"),
|
| "scRMSD": item.get("scRMSD"),
|
| "esmfold_mean_plddt": item.get("esmfold_mean_plddt"),
|
| }
|
| for item in metrics
|
| if item.get("scTM") is not None and item.get("scRMSD") is not None
|
| ]
|
| if not plot_rows:
|
| return _empty_sc_plot_df()
|
| frame = pd.DataFrame(plot_rows)
|
| for col in ("scTM", "scRMSD", "esmfold_mean_plddt"):
|
| frame[col] = pd.to_numeric(frame[col], errors="coerce")
|
| return frame
|
|
|
|
|
| def _format_leaderboard_rows(
|
| metrics: List[Dict[str, object]],
|
| top_n: int,
|
| ) -> Tuple[List[List[object]], List[List[object]]]:
|
| if not metrics:
|
| return [], []
|
|
|
| ranked_tm = sorted(
|
| metrics,
|
| key=lambda item: float(item["scTM"]),
|
| reverse=True,
|
| )[: int(top_n)]
|
| top_tm_rows: List[List[object]] = []
|
| for rank, item in enumerate(ranked_tm, start=1):
|
| top_tm_rows.append(
|
| [
|
| rank,
|
| _metric_sample_display(item),
|
| _metric_sequence_length(item),
|
| round(float(item["scTM"]), 4),
|
| ]
|
| )
|
|
|
| ranked_rmsd = sorted(
|
| metrics,
|
| key=lambda item: float(item["scRMSD"]),
|
| )[: int(top_n)]
|
| top_rmsd_rows: List[List[object]] = []
|
| for rank, item in enumerate(ranked_rmsd, start=1):
|
| top_rmsd_rows.append(
|
| [
|
| rank,
|
| _metric_sample_display(item),
|
| _metric_sequence_length(item),
|
| round(float(item["scRMSD"]), 4),
|
| ]
|
| )
|
| return top_tm_rows, top_rmsd_rows
|
|
|
|
|
| def _filter_metrics(
|
| metrics: Sequence[Dict[str, object]],
|
| sample_labels: Optional[object],
|
| sample_query: str,
|
| min_tm: float,
|
| max_rmsd: float,
|
| min_plddt: float,
|
| ) -> List[Dict[str, object]]:
|
| selected_sample_labels = set(_normalize_selected_samples(sample_labels))
|
| query = sample_query.strip().lower()
|
| filtered: List[Dict[str, object]] = []
|
|
|
| for item in metrics:
|
| sample_display = _metric_sample_display(item)
|
| if selected_sample_labels and sample_display not in selected_sample_labels:
|
| continue
|
| if query and query not in sample_display.lower():
|
| continue
|
|
|
| tm_value = item.get("scTM")
|
| if tm_value is None or float(tm_value) < float(min_tm):
|
| continue
|
| rmsd_value = item.get("scRMSD")
|
| if rmsd_value is None or float(rmsd_value) > float(max_rmsd):
|
| continue
|
| plddt_value = item.get("esmfold_mean_plddt")
|
| if plddt_value is not None and float(plddt_value) < float(min_plddt):
|
| continue
|
| filtered.append(dict(item))
|
| return filtered
|
|
|
|
|
| def _filters_summary_line(
|
| sample_labels: Optional[object],
|
| sample_query: str,
|
| min_tm: float,
|
| max_rmsd: float,
|
| min_plddt: float,
|
| ) -> Optional[str]:
|
| labels = _normalize_selected_samples(sample_labels)
|
| parts: List[str] = []
|
| if labels:
|
| parts.append(f"samples={len(labels)}")
|
| query = sample_query.strip()
|
| if query:
|
| parts.append(f'text="{query}"')
|
| if float(min_tm) > SC_TM_MIN:
|
| parts.append(f"scTM>={float(min_tm):.2f}")
|
| if float(max_rmsd) < SC_RMSD_MAX:
|
| parts.append(f"scRMSD<={float(max_rmsd):.2f}")
|
| if float(min_plddt) > SC_PLDDT_MIN:
|
| parts.append(f"pLDDT>={float(min_plddt):.2f}")
|
| if not parts:
|
| return None
|
| return "Filters active: " + ", ".join(parts)
|
|
|
|
|
| def _build_scatter_plot_html(
|
| plot_df: pd.DataFrame,
|
| x_col: str,
|
| y_col: str,
|
| title: str,
|
| x_title: str,
|
| y_title: str,
|
| fallback_x_range: Tuple[float, float],
|
| fallback_y_range: Tuple[float, float],
|
| ) -> str:
|
| if plot_df.empty:
|
| return _empty_scatter_plot_html(title=title, message="No points available for this view.")
|
|
|
| keep_cols = [x_col, y_col]
|
| if "sample" in plot_df.columns:
|
| keep_cols.append("sample")
|
| rows = plot_df[keep_cols].copy()
|
| rows[x_col] = pd.to_numeric(rows[x_col], errors="coerce")
|
| rows[y_col] = pd.to_numeric(rows[y_col], errors="coerce")
|
| rows = rows.dropna(subset=[x_col, y_col])
|
| if rows.empty:
|
| return _empty_scatter_plot_html(title=title, message="No numeric points available for this view.")
|
|
|
| x_min, x_max = fallback_x_range
|
| y_min, y_max = fallback_y_range
|
| rows = rows[(rows[x_col] >= x_min) & (rows[x_col] <= x_max) & (rows[y_col] >= y_min) & (rows[y_col] <= y_max)]
|
| if rows.empty:
|
| return _empty_scatter_plot_html(title=title, message="No points within the fixed axis range for this view.")
|
|
|
| x_vals = rows[x_col].astype(float).tolist()
|
| y_vals = rows[y_col].astype(float).tolist()
|
| labels = rows["sample"].astype(str).tolist() if "sample" in rows.columns else [""] * len(x_vals)
|
| x_span = max(x_max - x_min, 1e-9)
|
| y_span = max(y_max - y_min, 1e-9)
|
|
|
| width = 760
|
| height = 420
|
| left = 74
|
| right = 26
|
| top = 36
|
| bottom = 58
|
| plot_w = width - left - right
|
| plot_h = height - top - bottom
|
|
|
| def sx(value: float) -> float:
|
| return left + ((value - x_min) / x_span) * plot_w
|
|
|
| def sy(value: float) -> float:
|
| return top + (1.0 - ((value - y_min) / y_span)) * plot_h
|
|
|
| points_svg = []
|
| for xv, yv, label in zip(x_vals, y_vals, labels):
|
| tooltip = html.escape(
|
| f"{label} | {x_title.split(' (')[0]}={xv:.3f} | {y_title.split(' (')[0]}={yv:.3f}"
|
| )
|
| points_svg.append(
|
| f"<circle cx=\"{sx(xv):.2f}\" cy=\"{sy(yv):.2f}\" r=\"4.5\" "
|
| "fill=\"#2563eb\" fill-opacity=\"0.9\" stroke=\"#ffffff\" stroke-width=\"1.1\">"
|
| f"<title>{tooltip}</title></circle>"
|
| )
|
|
|
| x_ticks = []
|
| y_ticks = []
|
| tick_count = 5
|
| for idx in range(tick_count + 1):
|
| ratio = idx / tick_count
|
| xv = x_min + ratio * x_span
|
| yv = y_min + ratio * y_span
|
| tx = left + ratio * plot_w
|
| ty = top + (1.0 - ratio) * plot_h
|
| x_ticks.append(
|
| f"<line x1=\"{tx:.2f}\" y1=\"{top + plot_h:.2f}\" x2=\"{tx:.2f}\" y2=\"{top + plot_h + 6:.2f}\" stroke=\"#94a3b8\" />"
|
| f"<text x=\"{tx:.2f}\" y=\"{top + plot_h + 20:.2f}\" text-anchor=\"middle\" font-size=\"11\" fill=\"#64748b\">{xv:.2f}</text>"
|
| )
|
| y_ticks.append(
|
| f"<line x1=\"{left - 6:.2f}\" y1=\"{ty:.2f}\" x2=\"{left:.2f}\" y2=\"{ty:.2f}\" stroke=\"#94a3b8\" />"
|
| f"<text x=\"{left - 10:.2f}\" y=\"{ty + 4:.2f}\" text-anchor=\"end\" font-size=\"11\" fill=\"#64748b\">{yv:.2f}</text>"
|
| )
|
|
|
| return (
|
| "<div style=\"border: 1px solid var(--block-border-color, #d9d9d9); border-radius: 10px; "
|
| "background: var(--block-background-fill, #ffffff); padding: 10px 12px;\">"
|
| f"<div style=\"font-weight: 600; margin-bottom: 6px;\">{html.escape(title)}</div>"
|
| f"<svg viewBox=\"0 0 {width} {height}\" width=\"100%\" role=\"img\" aria-label=\"{html.escape(title)}\">"
|
| f"<rect x=\"{left}\" y=\"{top}\" width=\"{plot_w}\" height=\"{plot_h}\" fill=\"none\" stroke=\"#cbd5e1\" />"
|
| + "".join(
|
| f"<line x1=\"{left}\" y1=\"{top + (i / tick_count) * plot_h:.2f}\" x2=\"{left + plot_w}\" y2=\"{top + (i / tick_count) * plot_h:.2f}\" "
|
| "stroke=\"#eef2f7\" />"
|
| for i in range(tick_count + 1)
|
| )
|
| + "".join(x_ticks)
|
| + "".join(y_ticks)
|
| + "".join(points_svg)
|
| + f"<text x=\"{left + plot_w / 2:.2f}\" y=\"{height - 14}\" text-anchor=\"middle\" font-size=\"12\" fill=\"#475569\">{html.escape(x_title)}</text>"
|
| + f"<text x=\"18\" y=\"{top + plot_h / 2:.2f}\" text-anchor=\"middle\" font-size=\"12\" fill=\"#475569\" transform=\"rotate(-90 18 {top + plot_h / 2:.2f})\">{html.escape(y_title)}</text>"
|
| + "</svg>"
|
| + "</div>"
|
| )
|
|
|
|
|
| def _render_sc_outputs(
|
| base_status: str,
|
| summary: Dict[str, object],
|
| metrics: List[Dict[str, object]],
|
| artifacts: List[str],
|
| top_n: int,
|
| extra_status_line: Optional[str] = None,
|
| ) -> Tuple[
|
| str,
|
| List[List[str]],
|
| List[List[object]],
|
| str,
|
| str,
|
| List[str],
|
| List[List[object]],
|
| List[List[object]],
|
| ]:
|
| status_lines = [line for line in [base_status, extra_status_line] if line]
|
| if metrics:
|
| status_lines.append(f"Showing {len(metrics)} evaluated sequence(s).")
|
| status = "\n".join(status_lines)
|
|
|
| summary_rows = _format_summary_rows(summary)
|
| metrics_rows = _format_metrics_rows(metrics)
|
| plot_df = _format_metrics_plot_df(metrics)
|
| tm_vs_rmsd_html = _build_scatter_plot_html(
|
| plot_df=plot_df,
|
| x_col="scTM",
|
| y_col="scRMSD",
|
| title="Self-consistency tradeoff (scTM vs scRMSD)",
|
| x_title="scTM (higher is better)",
|
| y_title="scRMSD (lower is better)",
|
| fallback_x_range=(SC_TM_MIN, SC_TM_MAX),
|
| fallback_y_range=(SC_RMSD_MIN, SC_RMSD_MAX),
|
| )
|
| confidence_html = _build_scatter_plot_html(
|
| plot_df=plot_df,
|
| x_col="esmfold_mean_plddt",
|
| y_col="scTM",
|
| title="ESMFold confidence vs scTM",
|
| x_title="ESMFold mean pLDDT",
|
| y_title="scTM",
|
| fallback_x_range=(SC_PLDDT_MIN, SC_PLDDT_MAX),
|
| fallback_y_range=(SC_TM_MIN, SC_TM_MAX),
|
| )
|
| top_tm_rows, top_rmsd_rows = _format_leaderboard_rows(metrics, int(top_n))
|
| return status, summary_rows, metrics_rows, tm_vs_rmsd_html, confidence_html, artifacts, top_tm_rows, top_rmsd_rows
|
|
|
|
|
| def refresh_health() -> Dict[str, object]:
|
| health = SERVICE.health_check()
|
| health["self_consistency"] = SC_SERVICE.health_check()
|
| return health
|
|
|
|
|
| def _resolve_viewer_value(sample_path: Optional[str]) -> Optional[str]:
|
| if not sample_path:
|
| return None
|
| normalized = Path(sample_path)
|
| if normalized.exists():
|
|
|
| return normalized.resolve().as_posix()
|
| return None
|
|
|
|
|
| def _build_molstar_iframe_html(
|
| sample_path: Optional[str],
|
| viewer_message: Optional[str] = None,
|
| ) -> str:
|
| viewer_value = _resolve_viewer_value(sample_path)
|
| if not viewer_value:
|
| return _molstar_placeholder_html(message=viewer_message)
|
|
|
| try:
|
| pdb_text = Path(viewer_value).read_text(encoding="utf-8", errors="replace")
|
| except Exception as exc:
|
| LOGGER.warning("Failed to read PDB for Mol*: %s", exc)
|
| return (
|
| f"<div style=\"height: {_viewer_height_css()}; border: 1px solid #d9d9d9; border-radius: 8px; "
|
| "display: flex; align-items: center; justify-content: center; color: #b42318; background: #ffffff;\">"
|
| "Failed to load structure for Mol*."
|
| "</div>"
|
| )
|
|
|
| pdb_base64 = base64.b64encode(pdb_text.encode("utf-8")).decode("ascii")
|
| srcdoc = f"""
|
| <!doctype html>
|
| <html>
|
| <head>
|
| <meta charset="utf-8" />
|
| <script src="https://cdn.jsdelivr.net/npm/@rcsb/rcsb-molstar/build/dist/viewer/rcsb-molstar.js"></script>
|
| <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@rcsb/rcsb-molstar/build/dist/viewer/rcsb-molstar.css" />
|
| <style>
|
| html, body {{
|
| margin: 0;
|
| width: 100%;
|
| height: 100%;
|
| overflow: hidden;
|
| background: #ffffff;
|
| }}
|
| #protein-viewer {{
|
| width: 100%;
|
| height: 100%;
|
| }}
|
| #viewer-status {{
|
| position: absolute;
|
| top: 10px;
|
| left: 10px;
|
| right: 10px;
|
| z-index: 10;
|
| display: none;
|
| font-family: sans-serif;
|
| border-radius: 6px;
|
| border: 1px solid #fecaca;
|
| background: #fff1f2;
|
| color: #b42318;
|
| padding: 8px 10px;
|
| font-size: 13px;
|
| }}
|
| #viewer-shell {{
|
| position: relative;
|
| width: 100%;
|
| height: 100%;
|
| }}
|
| </style>
|
| </head>
|
| <body>
|
| <div id="viewer-shell">
|
| <div id="viewer-status"></div>
|
| <div id="protein-viewer"></div>
|
| </div>
|
| <script>
|
| (async function() {{
|
| const statusEl = document.getElementById("viewer-status");
|
| const viewer = new rcsbMolstar.Viewer("protein-viewer", {{
|
| layoutShowControls: true,
|
| layoutShowSequence: true,
|
| viewportShowExpand: false,
|
| collapseLeftPanel: false
|
| }});
|
| const pdbBase64 = {json.dumps(pdb_base64)};
|
| const blob = new Blob([atob(pdbBase64)], {{ type: "text/plain" }});
|
| const url = URL.createObjectURL(blob);
|
| try {{
|
| await viewer.loadStructureFromUrl(url, "pdb");
|
| statusEl.style.display = "none";
|
| statusEl.textContent = "";
|
| }} catch (error) {{
|
| const reason = error && error.message ? error.message : String(error);
|
| statusEl.style.display = "block";
|
| statusEl.textContent = "Mol* failed to render this structure: " + reason;
|
| console.error("Error loading structure:", error);
|
| }} finally {{
|
| URL.revokeObjectURL(url);
|
| }}
|
| }})();
|
| </script>
|
| </body>
|
| </html>
|
| """
|
| escaped_srcdoc = html.escape(srcdoc, quote=True)
|
| return (
|
| "<iframe "
|
| "title=\"Mol* Viewer\" "
|
| f"style=\"width: 100%; height: {_viewer_height_css()}; border: 1px solid #d9d9d9; border-radius: 8px; overflow: hidden; background: #ffffff;\" "
|
| "srcdoc=\""
|
| + escaped_srcdoc
|
| + "\"></iframe>"
|
| )
|
|
|
|
|
| def update_selected_molstar(selected_samples: Optional[object]):
|
| selected_values = _normalize_selected_samples(selected_samples)
|
| viewer_sample = selected_values[0] if selected_values else None
|
| return (
|
| _build_molstar_iframe_html(viewer_sample),
|
| *_hydrate_sc_outputs_for_samples(selected_values),
|
| )
|
|
|
|
|
| def _update_backbone_selector(selected_samples: Optional[object]) -> Dict[str, object]:
|
| selected_values = _normalize_selected_samples(selected_samples)
|
| return gr.update(
|
| choices=selected_values,
|
| value=selected_values[0] if selected_values else None,
|
| interactive=bool(selected_values),
|
| )
|
|
|
|
|
| def _update_sc_filter_sample_selector(metrics: Optional[List[Dict[str, object]]]) -> Dict[str, object]:
|
| choices = _build_sc_filter_sample_choices(metrics or [])
|
| return gr.update(choices=choices, value=[], interactive=bool(choices))
|
|
|
|
|
| def _update_folded_selector(metrics: Optional[List[Dict[str, object]]]) -> Dict[str, object]:
|
| choices = _build_folded_choices(metrics or [])
|
| return gr.update(
|
| choices=choices,
|
| value=choices[0][1] if choices else None,
|
| interactive=bool(choices),
|
| )
|
|
|
|
|
| def load_aligned_compare_view(
|
| backbone_sample_path: Optional[str],
|
| folded_sample_path: Optional[str],
|
| ) -> Tuple[str, str]:
|
| backbone_value = _normalize_optional_text(backbone_sample_path)
|
| folded_value = _normalize_optional_text(folded_sample_path)
|
| if not backbone_value:
|
| return (
|
| _molstar_placeholder_html(message="Select a backbone sample to compare."),
|
| "Select a backbone sample before loading aligned compare view.",
|
| )
|
| if not folded_value:
|
| return (
|
| _build_molstar_iframe_html(backbone_value),
|
| "Select a folded structure from self-consistency rows before comparison.",
|
| )
|
|
|
| try:
|
| overlay = align_folded_to_backbone_overlay(
|
| backbone_sample_path=backbone_value,
|
| folded_sample_path=folded_value,
|
| )
|
| viewer = _build_molstar_iframe_html(
|
| overlay.output_path,
|
| viewer_message="Aligned backbone + ESMFold comparison.",
|
| )
|
| details = (
|
| "Loaded aligned compare view. "
|
| f"CA pairs={overlay.num_ca_pairs}, "
|
| f"RMSD before={overlay.rmsd_before:.4f}, after={overlay.rmsd_after:.4f}. "
|
| f"Overlay: {overlay.output_path}"
|
| )
|
| return viewer, details
|
| except AlignmentError as exc:
|
| LOGGER.error("Alignment compare failed: %s", exc)
|
| return (
|
| _build_molstar_iframe_html(backbone_value),
|
| f"Failed to build aligned compare view: {exc}",
|
| )
|
| except Exception as exc:
|
| LOGGER.exception("Unexpected compare-view failure.")
|
| return (
|
| _build_molstar_iframe_html(backbone_value),
|
| f"Unexpected compare-view failure: {exc}",
|
| )
|
|
|
|
|
| def preload_model() -> Tuple[str, Dict[str, object]]:
|
| try:
|
| SERVICE.preload_model()
|
| return "Model loaded successfully.", refresh_health()
|
| except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc:
|
| LOGGER.error("Model preload failed: %s", exc)
|
| return f"Model preload failed: {exc}", refresh_health()
|
| except Exception as exc:
|
| LOGGER.exception("Unexpected preload failure.")
|
| return f"Unexpected preload failure: {exc}", refresh_health()
|
|
|
|
|
| def _collect_trajectory_artifacts(
|
| include_trajectory_artifacts: bool,
|
| result: InferenceResult,
|
| ) -> List[str]:
|
| if not include_trajectory_artifacts:
|
| return []
|
| combined = [*result.trajectory_files, *result.x0_trajectory_files]
|
| return list(dict.fromkeys(combined))
|
|
|
|
|
| def run_generation(
|
| mode: str,
|
| length: int,
|
| num_samples: int,
|
| seed_override: Optional[object],
|
| include_trajectory_artifacts: bool,
|
| guidance_scale: float,
|
| target_class: int,
|
| num_timesteps: Optional[object],
|
| reference_pdb: Optional[object],
|
| fixed_residues_text: Optional[object],
|
| use_classifier_guidance: bool,
|
| progress=gr.Progress(track_tqdm=False),
|
| ):
|
| try:
|
| progress(0.05, desc="Loading model and preparing inputs...")
|
| seed = _parse_optional_seed(seed_override)
|
| timesteps = None
|
| timesteps_text = _normalize_optional_text(num_timesteps)
|
| if timesteps_text:
|
| timesteps = int(float(timesteps_text))
|
| reference_path = None
|
| if reference_pdb is not None:
|
| reference_path = getattr(reference_pdb, "name", None) or str(reference_pdb)
|
| fixed_residues = _parse_fixed_residues(fixed_residues_text) if mode == "conditional" else None
|
|
|
| progress(0.2, desc=f"Sampling {int(num_samples)} structure(s)...")
|
| result = SERVICE.generate(
|
| mode=mode,
|
| length=int(length),
|
| num_samples=int(num_samples),
|
| seed=seed,
|
| guidance_scale=float(guidance_scale) if mode in ("classifier", "conditional") else None,
|
| target_class=int(target_class) if mode in ("classifier", "conditional") else None,
|
| num_timesteps=timesteps,
|
| reference_pdb_path=reference_path,
|
| fixed_residues=fixed_residues,
|
| use_classifier_guidance=bool(use_classifier_guidance),
|
| )
|
| progress(0.9, desc="Writing artifacts...")
|
| _append_run_history(result)
|
| trajectory_artifacts = _collect_trajectory_artifacts(
|
| include_trajectory_artifacts=bool(include_trajectory_artifacts),
|
| result=result,
|
| )
|
| trajectory_message = (
|
| f"Trajectory artifacts ready: {len(trajectory_artifacts)} file(s)."
|
| if include_trajectory_artifacts
|
| else "Trajectory artifacts hidden; enable the checkbox to download traj/x0 files."
|
| )
|
| classifier_details = ""
|
| if result.mode == "classifier":
|
| classifier_details = (
|
| f" Guidance scale={result.guidance_scale}, "
|
| f"target_class={result.target_class}."
|
| )
|
| elif result.mode == "conditional":
|
| classifier_details = f" Fixed residues={result.fixed_residue_count}."
|
| steps_details = f" Timesteps={result.num_timesteps}." if result.num_timesteps else ""
|
| summary = (
|
| f"Generated {len(result.sample_files)} sample(s) with mode={result.mode}. "
|
| f"Seed={result.seed}.{steps_details} Output dir: {result.run_dir}. "
|
| f"Artifacts source: {result.artifacts_source}.{classifier_details} "
|
| f"{trajectory_message}"
|
| )
|
| default_sample = result.sample_files[0] if result.sample_files else None
|
| selector_update = gr.update(
|
| choices=result.sample_files,
|
| value=[default_sample] if default_sample else [],
|
| interactive=bool(result.sample_files),
|
| )
|
| return (
|
| summary,
|
| result.sample_files,
|
| trajectory_artifacts,
|
| refresh_health(),
|
| selector_update,
|
| _build_molstar_iframe_html(default_sample),
|
| *_empty_sc_outputs(),
|
| *_empty_sc_state(),
|
| )
|
| except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc:
|
| LOGGER.error("Inference request failed: %s", exc)
|
| return (
|
| f"Inference failed: {exc}{_actionable_error_hint(exc)}",
|
| [],
|
| [],
|
| refresh_health(),
|
| gr.update(choices=[], value=[], interactive=False),
|
| _molstar_placeholder_html(),
|
| *_empty_sc_outputs(),
|
| *_empty_sc_state(),
|
| )
|
| except Exception as exc:
|
| LOGGER.exception("Unhandled inference failure.")
|
| trace = traceback.format_exc(limit=1)
|
| return (
|
| f"Unexpected error: {exc}{_actionable_error_hint(exc)}\n{trace}",
|
| [],
|
| [],
|
| refresh_health(),
|
| gr.update(choices=[], value=[], interactive=False),
|
| _molstar_placeholder_html(),
|
| *_empty_sc_outputs(),
|
| *_empty_sc_state(),
|
| )
|
|
|
|
|
| def run_self_consistency(
|
| selected_samples: Optional[object],
|
| num_seq_per_target: int,
|
| run_folding: bool,
|
| top_n: int,
|
| progress=gr.Progress(track_tqdm=False),
|
| ):
|
| sample_values = _normalize_selected_samples(selected_samples)
|
| if not sample_values:
|
| empty_outputs = _empty_sc_outputs()
|
| return (
|
| "Select one or more samples before running self-consistency.",
|
| *empty_outputs[1:],
|
| *_empty_sc_state(),
|
| )
|
|
|
| try:
|
| all_metrics: List[Dict[str, object]] = []
|
| all_artifacts: List[str] = []
|
| processed_samples = 0
|
| failed_samples: List[str] = []
|
| total_samples = len(sample_values)
|
|
|
| for sample_idx, sample_value in enumerate(sample_values):
|
| sample_label = _sample_label_from_path(sample_value)
|
| progress(
|
| sample_idx / max(total_samples, 1),
|
| desc=f"ProteinMPNN + folding ({sample_idx + 1}/{total_samples}): {sample_label}",
|
| )
|
|
|
| def _fold_progress(done: int, total: int, _idx=sample_idx, _label=sample_label) -> None:
|
| if total <= 0:
|
| return
|
| frac = (_idx + done / total) / max(total_samples, 1)
|
| progress(frac, desc=f"Folding {_label}: {done}/{total} sequences")
|
|
|
| try:
|
| result = SC_SERVICE.run(
|
| sample_path=sample_value,
|
| num_seq_per_target=int(num_seq_per_target),
|
| run_folding=bool(run_folding),
|
| progress_callback=_fold_progress,
|
| )
|
| for item in result.per_sequence_metrics:
|
| all_metrics.append(_enrich_metric_sample_fields(metric=item, sample_path=sample_value))
|
|
|
| all_artifacts.extend(
|
| [
|
| result.parsed_pdbs_jsonl,
|
| result.mpnn_fasta_path,
|
| *result.folded_pdb_paths,
|
| ]
|
| )
|
| if result.metrics_csv_path:
|
| all_artifacts.append(result.metrics_csv_path)
|
| processed_samples += 1
|
| except SelfConsistencyError as exc:
|
| LOGGER.error("Self-consistency failed for %s: %s", sample_value, exc)
|
| failed_samples.append(
|
| f"{_sample_label_from_path(sample_value)} ({_short_error_text(exc)})"
|
| )
|
|
|
| if not all_metrics:
|
| empty_outputs = _empty_sc_outputs()
|
| fail_preview = "\n".join(f"- {sample}" for sample in failed_samples[:5])
|
| if failed_samples and len(failed_samples) > 5:
|
| fail_preview += f"\n... and {len(failed_samples) - 5} more failure(s)."
|
| fail_context = fail_preview if failed_samples else "No metrics were generated."
|
| return (
|
| "Self-consistency failed across selected samples.\n"
|
| f"Requested sequences/sample={int(num_seq_per_target)}, folding={'on' if run_folding else 'off'}.\n"
|
| f"{fail_context}",
|
| *empty_outputs[1:],
|
| *_empty_sc_state(),
|
| )
|
|
|
| artifacts = list(dict.fromkeys(all_artifacts))
|
| summary = _build_sc_summary(
|
| num_selected_samples=len(sample_values),
|
| num_processed_samples=processed_samples,
|
| metrics=all_metrics,
|
| )
|
| base_status_lines = [
|
| f"Self-consistency completed for {processed_samples}/{len(sample_values)} selected sample(s).",
|
| f"Evaluated={len(all_metrics)} sequences. Requested sequences/sample={int(num_seq_per_target)}.",
|
| f"Folding={'on' if run_folding else 'off'}. Artifacts={len(artifacts)}.",
|
| ]
|
| if failed_samples:
|
| preview = "; ".join(failed_samples[:3])
|
| suffix = f"; +{len(failed_samples) - 3} more" if len(failed_samples) > 3 else ""
|
| base_status_lines.append(f"Failed samples: {len(failed_samples)} ({preview}{suffix})")
|
| base_status = "\n".join(base_status_lines)
|
|
|
| rendered = _render_sc_outputs(
|
| base_status=base_status,
|
| summary=summary,
|
| metrics=all_metrics,
|
| artifacts=artifacts,
|
| top_n=top_n,
|
| )
|
| return (*rendered, base_status, summary, all_metrics, artifacts)
|
| except Exception as exc:
|
| LOGGER.exception("Unexpected self-consistency failure.")
|
| trace = traceback.format_exc(limit=1)
|
| empty_outputs = _empty_sc_outputs()
|
| return (
|
| f"Unexpected self-consistency error: {exc}\n{trace}",
|
| *empty_outputs[1:],
|
| *_empty_sc_state(),
|
| )
|
|
|
|
|
| def apply_sc_view_settings(
|
| base_status: str,
|
| summary: Dict[str, object],
|
| metrics: List[Dict[str, object]],
|
| artifacts: List[str],
|
| top_n: int,
|
| filter_sample_labels: Optional[object],
|
| filter_sample_query: str,
|
| filter_tm_min: float,
|
| filter_rmsd_max: float,
|
| filter_plddt_min: float,
|
| ):
|
| filtered_metrics = _filter_metrics(
|
| metrics=metrics,
|
| sample_labels=filter_sample_labels,
|
| sample_query=filter_sample_query,
|
| min_tm=filter_tm_min,
|
| max_rmsd=filter_rmsd_max,
|
| min_plddt=filter_plddt_min,
|
| )
|
| filters_line = _filters_summary_line(
|
| sample_labels=filter_sample_labels,
|
| sample_query=filter_sample_query,
|
| min_tm=filter_tm_min,
|
| max_rmsd=filter_rmsd_max,
|
| min_plddt=filter_plddt_min,
|
| )
|
| return _render_sc_outputs(
|
| base_status=base_status,
|
| summary=summary,
|
| metrics=filtered_metrics,
|
| artifacts=artifacts,
|
| top_n=top_n,
|
| extra_status_line=filters_line,
|
| )
|
|
|
|
|
| def reset_sc_view_settings(
|
| base_status: str,
|
| summary: Dict[str, object],
|
| metrics: List[Dict[str, object]],
|
| artifacts: List[str],
|
| ):
|
| rendered = _render_sc_outputs(
|
| base_status=base_status,
|
| summary=summary,
|
| metrics=metrics,
|
| artifacts=artifacts,
|
| top_n=SC_DEFAULT_TOP_N,
|
| )
|
| return (
|
| gr.update(value=SC_DEFAULT_TOP_N),
|
| gr.update(value=[]),
|
| gr.update(value=""),
|
| gr.update(value=SC_TM_MIN),
|
| gr.update(value=SC_RMSD_MAX),
|
| gr.update(value=SC_PLDDT_MIN),
|
| *rendered,
|
| )
|
|
|
|
|
| def clear_compare_status() -> str:
|
| return ""
|
|
|
|
|
| def _status_badge_html() -> str:
|
| health = SERVICE.health_check()
|
| model_loaded = bool(health.get("model_loaded"))
|
| device = str(health.get("device", "auto"))
|
| sc_health = SC_SERVICE.health_check()
|
| pmpnn_ok = bool(sc_health.get("pmpnn_available")) and bool(sc_health.get("pmpnn_weights_available"))
|
|
|
| def pill(label: str, ok: bool) -> str:
|
| color = "#16a34a" if ok else "#b91c1c"
|
| bg = "#dcfce7" if ok else "#fee2e2"
|
| dot = "●"
|
| return (
|
| f"<span style=\"display:inline-flex; align-items:center; gap:6px; "
|
| f"padding:3px 10px; border-radius:999px; background:{bg}; color:{color}; "
|
| f"font-size:0.82rem; font-weight:600; margin-right:8px;\">{dot} {html.escape(label)}</span>"
|
| )
|
|
|
| device_pill = (
|
| f"<span style=\"display:inline-flex; align-items:center; padding:3px 10px; "
|
| f"border-radius:999px; background:var(--block-background-fill,#f1f5f9); "
|
| f"border:1px solid var(--block-border-color,#e2e8f0); font-size:0.82rem; "
|
| f"margin-right:8px;\">Device: {html.escape(device)}</span>"
|
| )
|
| return (
|
| "<div style=\"display:flex; flex-wrap:wrap; align-items:center; gap:4px; padding:4px 0;\">"
|
| + pill("Model loaded" if model_loaded else "Model not loaded", model_loaded)
|
| + pill("Self-consistency ready" if pmpnn_ok else "ProteinMPNN missing", pmpnn_ok)
|
| + device_pill
|
| + "</div>"
|
| )
|
|
|
|
|
| def refresh_status_badge() -> str:
|
| return _status_badge_html()
|
|
|
|
|
| def _parse_fixed_residues(text: Optional[object]) -> Optional[List[int]]:
|
| raw = _normalize_optional_text(text)
|
| if not raw:
|
| return None
|
| residues: List[int] = []
|
| for token in raw.replace(";", ",").split(","):
|
| token = token.strip()
|
| if not token:
|
| continue
|
| if ":" in token:
|
| token = token.split(":", 1)[1].strip()
|
| if "-" in token:
|
| start_str, end_str = token.split("-", 1)
|
| start, end = int(start_str.strip()), int(end_str.strip())
|
| if end < start:
|
| start, end = end, start
|
| residues.extend(range(start, end + 1))
|
| else:
|
| residues.append(int(token))
|
| deduped = sorted(dict.fromkeys(residues))
|
| return deduped or None
|
|
|
|
|
| def _resolve_trajectory_path(sample_path: Optional[str], kind: str) -> Optional[str]:
|
| sample_values = _normalize_selected_samples(sample_path)
|
| if not sample_values:
|
| return None
|
| parent = Path(sample_values[0]).parent
|
| filename = "bb_traj.pdb" if kind == "Backbone trajectory" else "x0_traj.pdb"
|
| candidate = parent / filename
|
| return candidate.resolve().as_posix() if candidate.exists() else None
|
|
|
|
|
| def _build_trajectory_iframe_html(traj_path: Optional[str], title: str) -> str:
|
| if not traj_path:
|
| return _molstar_placeholder_html(
|
| message="No trajectory found for this sample. Generate a new sample to view its flow trajectory."
|
| )
|
| try:
|
| pdb_text = Path(traj_path).read_text(encoding="utf-8", errors="replace")
|
| except Exception as exc:
|
| LOGGER.warning("Failed to read trajectory PDB: %s", exc)
|
| return _molstar_placeholder_html(message="Failed to read trajectory file.")
|
|
|
| pdb_base64 = base64.b64encode(pdb_text.encode("utf-8")).decode("ascii")
|
| srcdoc = f"""
|
| <!doctype html>
|
| <html>
|
| <head>
|
| <meta charset="utf-8" />
|
| <script src="https://cdn.jsdelivr.net/npm/@rcsb/rcsb-molstar/build/dist/viewer/rcsb-molstar.js"></script>
|
| <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@rcsb/rcsb-molstar/build/dist/viewer/rcsb-molstar.css" />
|
| <style>
|
| html, body {{ margin: 0; width: 100%; height: 100%; overflow: hidden; background: #ffffff; font-family: sans-serif; }}
|
| #protein-viewer {{ width: 100%; height: calc(100% - 52px); }}
|
| #controls {{ height: 52px; display: flex; align-items: center; gap: 10px; padding: 6px 10px; box-sizing: border-box; }}
|
| #controls button {{ cursor: pointer; border: 1px solid #cbd5e1; background: #f8fafc; border-radius: 6px; padding: 4px 12px; font-size: 13px; }}
|
| #frame-slider {{ flex: 1; }}
|
| #frame-label {{ font-size: 12px; color: #475569; min-width: 96px; text-align: right; }}
|
| </style>
|
| </head>
|
| <body>
|
| <div id="protein-viewer"></div>
|
| <div id="controls">
|
| <button id="play-btn">Play</button>
|
| <input id="frame-slider" type="range" min="0" max="0" value="0" />
|
| <span id="frame-label">Frame 1 / 1</span>
|
| </div>
|
| <script>
|
| (async function() {{
|
| const viewer = new rcsbMolstar.Viewer("protein-viewer", {{
|
| layoutShowControls: false,
|
| layoutShowSequence: false,
|
| viewportShowExpand: false
|
| }});
|
| const pdbText = atob({json.dumps(pdb_base64)});
|
| const lines = pdbText.split(/\\r?\\n/);
|
| let frames = [];
|
| let current = [];
|
| let hasModels = false;
|
| for (const line of lines) {{
|
| if (line.startsWith("MODEL")) {{ hasModels = true; current = []; continue; }}
|
| if (line.startsWith("ENDMDL")) {{ frames.push(current.join("\\n")); current = []; continue; }}
|
| current.push(line);
|
| }}
|
| if (!hasModels) {{ frames = [pdbText]; }}
|
| if (frames.length === 0 && current.length) {{ frames = [current.join("\\n")]; }}
|
|
|
| const slider = document.getElementById("frame-slider");
|
| const label = document.getElementById("frame-label");
|
| const playBtn = document.getElementById("play-btn");
|
| slider.max = String(Math.max(frames.length - 1, 0));
|
|
|
| async function loadFrame(idx) {{
|
| const body = "MODEL 1\\n" + frames[idx] + "\\nENDMDL\\nEND\\n";
|
| const blob = new Blob([body], {{ type: "text/plain" }});
|
| const url = URL.createObjectURL(blob);
|
| try {{
|
| await viewer.clear();
|
| await viewer.loadStructureFromUrl(url, "pdb");
|
| }} catch (e) {{ console.error(e); }} finally {{ URL.revokeObjectURL(url); }}
|
| label.textContent = "Frame " + (idx + 1) + " / " + frames.length;
|
| }}
|
|
|
| let timer = null;
|
| slider.addEventListener("input", () => loadFrame(parseInt(slider.value, 10)));
|
| playBtn.addEventListener("click", () => {{
|
| if (timer) {{ clearInterval(timer); timer = null; playBtn.textContent = "Play"; return; }}
|
| playBtn.textContent = "Pause";
|
| timer = setInterval(() => {{
|
| let next = (parseInt(slider.value, 10) + 1) % frames.length;
|
| slider.value = String(next);
|
| loadFrame(next);
|
| if (next === frames.length - 1) {{ clearInterval(timer); timer = null; playBtn.textContent = "Play"; }}
|
| }}, 350);
|
| }});
|
| await loadFrame(0);
|
| }})();
|
| </script>
|
| </body>
|
| </html>
|
| """
|
| escaped_srcdoc = html.escape(srcdoc, quote=True)
|
| return (
|
| f"<div style=\"font-weight:600; margin-bottom:6px;\">{html.escape(title)}</div>"
|
| "<iframe title=\"Mol* Trajectory\" "
|
| f"style=\"width: 100%; height: {_viewer_height_css()}; border: 1px solid #d9d9d9; border-radius: 8px; overflow: hidden; background: #ffffff;\" "
|
| "srcdoc=\""
|
| + escaped_srcdoc
|
| + "\"></iframe>"
|
| )
|
|
|
|
|
| def load_trajectory_view(selected_samples: Optional[object], kind: str) -> str:
|
| traj_path = _resolve_trajectory_path(selected_samples, kind)
|
| return _build_trajectory_iframe_html(traj_path, title=f"{kind} playback")
|
|
|
|
|
| def update_selection_label(selected_samples: Optional[object]) -> str:
|
| values = _normalize_selected_samples(selected_samples)
|
| if not values:
|
| return "_No sample selected. Generate or load a sample to begin._"
|
| viewing = _sample_label_from_path(values[0])
|
| return (
|
| f"**Viewing in Mol*:** `{viewing}` | "
|
| f"**Self-consistency targets:** {len(values)} sample(s)"
|
| )
|
|
|
|
|
| def _append_run_history(result: InferenceResult) -> None:
|
| try:
|
| SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
|
| history: List[Dict[str, object]] = []
|
| if RUN_HISTORY_PATH.exists():
|
| try:
|
| history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8"))
|
| except Exception:
|
| history = []
|
| entry = {
|
| "run_id": Path(result.run_dir).name,
|
| "mode": result.mode,
|
| "seed": result.seed,
|
| "num_samples": len(result.sample_files),
|
| "num_timesteps": result.num_timesteps,
|
| "sample_files": result.sample_files,
|
| "created_utc": datetime.utcnow().isoformat() + "Z",
|
| }
|
| history = [entry] + [h for h in history if h.get("run_id") != entry["run_id"]]
|
| history = history[:RUN_HISTORY_LIMIT]
|
| RUN_HISTORY_PATH.write_text(json.dumps(history, indent=2), encoding="utf-8")
|
| except Exception:
|
| LOGGER.warning("Failed to append run history.", exc_info=True)
|
|
|
|
|
| def _run_history_choices() -> List[Tuple[str, str]]:
|
| if not RUN_HISTORY_PATH.exists():
|
| return []
|
| try:
|
| history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8"))
|
| except Exception:
|
| return []
|
| choices: List[Tuple[str, str]] = []
|
| for entry in history:
|
| run_id = str(entry.get("run_id", "run"))
|
| mode = str(entry.get("mode", "?"))
|
| n = entry.get("num_samples", 0)
|
| seed = entry.get("seed", "?")
|
| label = f"{run_id} | {mode} | {n} sample(s) | seed={seed}"
|
| choices.append((label, run_id))
|
| return choices
|
|
|
|
|
| def refresh_run_history() -> Dict[str, object]:
|
| choices = _run_history_choices()
|
| return gr.update(choices=choices, value=None, interactive=bool(choices))
|
|
|
|
|
| def _history_samples_for_run(run_id: str) -> List[str]:
|
| if not RUN_HISTORY_PATH.exists():
|
| return []
|
| try:
|
| history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8"))
|
| except Exception:
|
| return []
|
| for entry in history:
|
| if str(entry.get("run_id")) == str(run_id):
|
| return [p for p in entry.get("sample_files", []) if Path(p).exists()]
|
| return []
|
|
|
|
|
| def load_run_from_history(run_id: Optional[str]):
|
| sample_files = _history_samples_for_run(run_id) if run_id else []
|
| if not sample_files:
|
| empty = _empty_sc_outputs()
|
| return (
|
| "Selected run is unavailable (artifacts may have been cleared).",
|
| gr.update(),
|
| _molstar_placeholder_html(message="Run artifacts not found."),
|
| *empty,
|
| *_empty_sc_state(),
|
| )
|
| default_sample = sample_files[0]
|
| selector_update = gr.update(choices=sample_files, value=[default_sample], interactive=True)
|
| return (
|
| f"Loaded run {run_id} with {len(sample_files)} sample(s).",
|
| selector_update,
|
| _build_molstar_iframe_html(default_sample),
|
| *_hydrate_sc_outputs_for_samples([default_sample]),
|
| )
|
|
|
|
|
| def export_sc_metrics_csv(
|
| metrics: List[Dict[str, object]],
|
| filter_sample_labels: Optional[object],
|
| filter_sample_query: str,
|
| filter_tm_min: float,
|
| filter_rmsd_max: float,
|
| filter_plddt_min: float,
|
| ) -> Optional[str]:
|
| if not metrics:
|
| return None
|
| filtered = _filter_metrics(
|
| metrics=metrics,
|
| sample_labels=filter_sample_labels,
|
| sample_query=filter_sample_query,
|
| min_tm=filter_tm_min,
|
| max_rmsd=filter_rmsd_max,
|
| min_plddt=filter_plddt_min,
|
| )
|
| if not filtered:
|
| return None
|
| SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
|
| stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| out_path = SPACE_OUTPUTS_DIR / f"sc_metrics_export_{stamp}.csv"
|
| with out_path.open("w", encoding="utf-8", newline="") as handle:
|
| writer = csv.writer(handle)
|
| writer.writerow(["sample", "seq_len", "scTM", "scRMSD", "esmfold_mean_plddt"])
|
| for item in filtered:
|
| writer.writerow(
|
| [
|
| _metric_sample_display(item),
|
| _metric_sequence_length(item),
|
| item.get("scTM"),
|
| item.get("scRMSD"),
|
| item.get("esmfold_mean_plddt"),
|
| ]
|
| )
|
| return out_path.resolve().as_posix()
|
|
|
|
|
| def build_run_zip(selected_samples: Optional[object]) -> Optional[str]:
|
| values = _normalize_selected_samples(selected_samples)
|
| if not values:
|
| return None
|
|
|
| sample_dir = Path(values[0]).parent
|
| run_dir = sample_dir.parent.parent
|
| if not run_dir.exists():
|
| return None
|
| SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
|
| stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| archive_base = SPACE_OUTPUTS_DIR / f"{run_dir.name}_bundle_{stamp}"
|
| archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(run_dir))
|
| return Path(archive_path).resolve().as_posix()
|
|
|
|
|
| def on_sc_metrics_select(
|
| metrics: List[Dict[str, object]],
|
| backbone_choices_value: Optional[str],
|
| select_data: gr.SelectData,
|
| ):
|
| if not metrics or select_data is None:
|
| return gr.update(), gr.update(), gr.update(), gr.update()
|
| row_index = select_data.index[0] if isinstance(select_data.index, (list, tuple)) else select_data.index
|
| if row_index is None or row_index >= len(metrics):
|
| return gr.update(), gr.update(), gr.update(), gr.update()
|
| metric = metrics[row_index]
|
| folded_path = _normalize_optional_text(metric.get("folded_sample_path"))
|
| backbone_path = _normalize_optional_text(metric.get("sample_source_path"))
|
| if not folded_path or not backbone_path:
|
| return gr.update(), gr.update(), gr.update(), gr.update()
|
| viewer_html, status = load_aligned_compare_view(backbone_path, folded_path)
|
| return (
|
| gr.update(value=backbone_path),
|
| gr.update(value=folded_path),
|
| viewer_html,
|
| status,
|
| )
|
|
|
|
|
| with gr.Blocks(
|
| title="FlowProt Protein Designer",
|
| theme=gr.themes.Soft(),
|
| css=UI_CSS,
|
| ) as demo:
|
| gr.Markdown(
|
| """
|
| # FlowProt Protein Designer
|
| Generate novel protein backbones with flow matching, explore them in 3D, and validate
|
| their designability with ProteinMPNN and ESMFold self-consistency.
|
| """
|
| )
|
| status_badge = gr.HTML(value=_status_badge_html())
|
| with gr.Accordion("New here? Three steps to your first design", open=True, elem_classes=["fp-card"]):
|
| gr.Markdown(
|
| "1. **Generate** a backbone in the Generate tab (or click *Load demo example* below).\n"
|
| "2. **View** it in interactive 3D and watch the generation trajectory in the View tab.\n"
|
| "3. **Analyze** designability (scTM / scRMSD) in the Analyze tab."
|
| )
|
| demo_cta_button = gr.Button("Load demo example", variant="primary")
|
|
|
| sc_base_status_state = gr.State("")
|
| sc_summary_state = gr.State({})
|
| sc_metrics_state = gr.State([])
|
| sc_artifacts_state = gr.State([])
|
|
|
| with gr.Tabs():
|
| with gr.Tab("Generate"):
|
| with gr.Group(elem_classes=["fp-card"]):
|
| gr.Markdown("### Inference settings")
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| mode = gr.Dropdown(
|
| label="Inference mode",
|
| choices=_mode_choices(),
|
| value=SERVICE.mvp_mode if SERVICE.mvp_mode in _mode_choices() else "unconditional",
|
| info="Unconditional designs from scratch; classifier-guided steers toward a target class.",
|
| )
|
| length = gr.Slider(label="Protein length (residues)", minimum=32, maximum=1024, value=128, step=1)
|
| num_samples = gr.Slider(label="Samples per request", minimum=1, maximum=4, value=1, step=1)
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| seed_override = gr.Textbox(
|
| label="Seed override (optional)",
|
| value="",
|
| placeholder="Leave empty to use config inference.seed",
|
| )
|
| num_timesteps = gr.Slider(
|
| label="Sampling timesteps",
|
| minimum=10,
|
| maximum=500,
|
| value=100,
|
| step=10,
|
| info="More steps can improve quality at the cost of speed.",
|
| )
|
| include_trajectory_artifacts = gr.Checkbox(
|
| label="Expose trajectory downloads",
|
| value=False,
|
| )
|
|
|
| with gr.Group(visible=False) as classifier_group:
|
| gr.Markdown("**Classifier guidance**", elem_classes=["fp-subnote"])
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| guidance_scale = gr.Slider(
|
| label="Classifier guidance scale",
|
| minimum=0.0,
|
| maximum=1.0,
|
| value=0.2,
|
| step=0.05,
|
| info="Higher values steer more strongly toward the target class.",
|
| )
|
| target_class = gr.Dropdown(
|
| label="Classifier target class",
|
| choices=[0, 1],
|
| value=1,
|
| info="Binary classifier label to steer toward (0 or 1).",
|
| )
|
|
|
| with gr.Group(visible=False) as conditional_group:
|
| gr.Markdown(
|
| "**Conditional design** keeps the selected residues of a reference structure fixed "
|
| "while generating the rest. Protein length is taken from the uploaded PDB.",
|
| elem_classes=["fp-subnote"],
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| reference_pdb = gr.File(label="Reference PDB (chain A)", file_types=[".pdb"])
|
| fixed_residues_text = gr.Textbox(
|
| label="Fixed residues",
|
| value="",
|
| placeholder="e.g. 10-40,55,60-62 (empty = fix all)",
|
| )
|
| use_classifier_guidance = gr.Checkbox(
|
| label="Apply classifier guidance",
|
| value=False,
|
| )
|
|
|
| run_button = gr.Button("Run inference", variant="primary")
|
|
|
| with gr.Group(elem_classes=["fp-card"]):
|
| gr.Markdown("### Outputs")
|
| output_files = gr.Files(label="Generated sample.pdb files")
|
| trajectory_output_files = gr.Files(label="Optional trajectory artifacts (traj + x0)")
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| zip_button = gr.Button("Bundle current run as zip", variant="secondary")
|
| zip_file = gr.File(label="Run bundle (.zip)")
|
|
|
| with gr.Tab("View"):
|
| with gr.Group(elem_classes=["fp-card"]):
|
| sample_selector = gr.Dropdown(
|
| label="Select generated sample(s)",
|
| choices=[],
|
| value=[],
|
| multiselect=True,
|
| interactive=False,
|
| info="The first selected sample is shown in 3D; all selected samples feed self-consistency.",
|
| )
|
| selection_label = gr.Markdown(update_selection_label(None))
|
| molstar_viewer = gr.HTML(value=_molstar_placeholder_html())
|
|
|
| with gr.Accordion("Generation trajectory playback", open=False, elem_classes=["fp-card"]):
|
| gr.Markdown(
|
| "Step through the flow-matching trajectory of the selected sample. "
|
| "Backbone trajectory shows the integrated path; x0 trajectory shows the model's denoised prediction.",
|
| elem_classes=["fp-subnote"],
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| traj_source = gr.Dropdown(
|
| label="Trajectory",
|
| choices=["Backbone trajectory", "x0 trajectory"],
|
| value="Backbone trajectory",
|
| )
|
| load_traj_button = gr.Button("Load trajectory", variant="secondary")
|
| trajectory_viewer = gr.HTML(
|
| value=_molstar_placeholder_html(message="Select a sample and load a trajectory to play it back.")
|
| )
|
|
|
| with gr.Accordion("Aligned compare (backbone vs ESMFold)", open=False, elem_classes=["fp-card"]):
|
| gr.Markdown(
|
| "Overlay a generated backbone with an ESMFold-predicted structure from self-consistency. "
|
| "Tip: clicking a row in the Analyze metrics table loads the overlay automatically.",
|
| elem_classes=["fp-subnote"],
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| backbone_compare_selector = gr.Dropdown(
|
| label="Backbone sample for aligned compare",
|
| choices=[],
|
| value=None,
|
| interactive=False,
|
| )
|
| folded_compare_selector = gr.Dropdown(
|
| label="Folded structure for compare",
|
| choices=[],
|
| value=None,
|
| interactive=False,
|
| )
|
| load_compare_button = gr.Button("Load aligned compare view", variant="secondary")
|
| molstar_compare_status = gr.Textbox(
|
| label="3D compare status",
|
| lines=2,
|
| interactive=False,
|
| )
|
|
|
| with gr.Tab("Analyze"):
|
| with gr.Group(elem_classes=["fp-card"]):
|
| gr.Markdown(
|
| "### Self-consistency (ProteinMPNN + ESMFold)\n"
|
| "Runs across the samples selected in the View tab.",
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| sc_num_sequences = gr.Slider(
|
| label="ProteinMPNN sequences per sample",
|
| minimum=1,
|
| maximum=16,
|
| value=4,
|
| step=1,
|
| )
|
| sc_run_folding = gr.Checkbox(
|
| label="Run ESMFold and compute scTM/scRMSD",
|
| value=True,
|
| )
|
| sc_run_button = gr.Button("Run self-consistency", variant="primary")
|
|
|
| with gr.Accordion("View settings", open=False):
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| sc_top_n = gr.Slider(
|
| label="Leaderboard top-N",
|
| minimum=1,
|
| maximum=10,
|
| value=SC_DEFAULT_TOP_N,
|
| step=1,
|
| )
|
| sc_filter_samples = gr.Dropdown(
|
| label="Filter by sample label",
|
| choices=[],
|
| value=[],
|
| multiselect=True,
|
| interactive=False,
|
| )
|
| sc_filter_sample_query = gr.Textbox(
|
| label="Sample text contains",
|
| value="",
|
| placeholder="Optional text search on sample labels",
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| sc_filter_tm_min = gr.Slider(
|
| label="Minimum scTM",
|
| minimum=SC_TM_MIN,
|
| maximum=SC_TM_MAX,
|
| value=SC_TM_MIN,
|
| step=0.01,
|
| )
|
| sc_filter_rmsd_max = gr.Slider(
|
| label="Maximum scRMSD",
|
| minimum=SC_RMSD_MIN,
|
| maximum=SC_RMSD_MAX,
|
| value=SC_RMSD_MAX,
|
| step=0.1,
|
| )
|
| sc_filter_plddt_min = gr.Slider(
|
| label="Minimum ESMFold mean pLDDT",
|
| minimum=SC_PLDDT_MIN,
|
| maximum=SC_PLDDT_MAX,
|
| value=SC_PLDDT_MIN,
|
| step=0.01,
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| sc_apply_view_settings_button = gr.Button("Apply view settings", variant="secondary")
|
| sc_reset_view_settings_button = gr.Button("Reset view settings")
|
|
|
| sc_status = gr.Textbox(
|
| label="Self-consistency status",
|
| lines=4,
|
| elem_classes=["fp-status"],
|
| )
|
| sc_summary = gr.Dataframe(
|
| headers=SC_SUMMARY_HEADERS,
|
| datatype=["str", "str"],
|
| value=[],
|
| label="Self-consistency summary",
|
| interactive=False,
|
| wrap=True,
|
| )
|
| sc_metrics = gr.Dataframe(
|
| headers=SC_METRICS_HEADERS,
|
| value=[],
|
| label="Per-sequence metrics table (click a row to compare in 3D)",
|
| interactive=False,
|
| wrap=True,
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| sc_export_button = gr.Button("Export current metrics as CSV", variant="secondary")
|
| sc_export_file = gr.File(label="Filtered metrics (.csv)")
|
| with gr.Row(elem_classes=["fp-tight", "fp-plot-row"]):
|
| with gr.Column(scale=1, min_width=560):
|
| sc_tm_rmsd_plot = gr.HTML(
|
| value=_empty_scatter_plot_html(
|
| title="Self-consistency tradeoff (scTM vs scRMSD)",
|
| message="Run self-consistency to render plot data.",
|
| )
|
| )
|
| with gr.Column(scale=1, min_width=560):
|
| sc_confidence_plot = gr.HTML(
|
| value=_empty_scatter_plot_html(
|
| title="ESMFold confidence vs scTM",
|
| message="Run self-consistency with folding enabled to render confidence plot.",
|
| )
|
| )
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| sc_top_tm = gr.Dataframe(
|
| headers=SC_TOP_TM_HEADERS,
|
| value=[],
|
| label="Top by highest scTM",
|
| interactive=False,
|
| wrap=True,
|
| )
|
| sc_top_rmsd = gr.Dataframe(
|
| headers=SC_TOP_RMSD_HEADERS,
|
| value=[],
|
| label="Top by lowest scRMSD",
|
| interactive=False,
|
| wrap=True,
|
| )
|
| sc_artifacts = gr.Files(label="Self-consistency artifacts")
|
|
|
| with gr.Tab("Advanced"):
|
| with gr.Group(elem_classes=["fp-card"]):
|
| gr.Markdown("### Recent runs")
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| history_dropdown = gr.Dropdown(
|
| label="Run history",
|
| choices=_run_history_choices(),
|
| value=None,
|
| interactive=True,
|
| )
|
| history_refresh_button = gr.Button("Refresh history")
|
| history_load_button = gr.Button("Load selected run", variant="secondary")
|
|
|
| with gr.Group(elem_classes=["fp-card"]):
|
| gr.Markdown("### Saved example case")
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| save_example_button = gr.Button("Save selected sample as example", variant="secondary")
|
| load_example_button = gr.Button("Load saved example", variant="secondary")
|
| view_example_on_startup = gr.Checkbox(
|
| label="View saved example on startup",
|
| value=VIEW_EXAMPLE_ON_STARTUP_DEFAULT,
|
| )
|
|
|
| with gr.Group(elem_classes=["fp-card"]):
|
| gr.Markdown("### Diagnostics")
|
| with gr.Row(elem_classes=["fp-tight"]):
|
| preload_button = gr.Button("Preload model")
|
| health_button = gr.Button("Refresh health")
|
| with gr.Accordion("Status and debug details", open=False):
|
| status = gr.Textbox(
|
| label="Run status",
|
| lines=4,
|
| elem_classes=["fp-status"],
|
| )
|
| example_status = gr.Textbox(
|
| label="Example case status",
|
| lines=3,
|
| interactive=False,
|
| )
|
| health = gr.JSON(label="Health check")
|
|
|
| generation_outputs = [
|
| status,
|
| output_files,
|
| trajectory_output_files,
|
| health,
|
| sample_selector,
|
| molstar_viewer,
|
| sc_status,
|
| sc_summary,
|
| sc_metrics,
|
| sc_tm_rmsd_plot,
|
| sc_confidence_plot,
|
| sc_artifacts,
|
| sc_top_tm,
|
| sc_top_rmsd,
|
| sc_base_status_state,
|
| sc_summary_state,
|
| sc_metrics_state,
|
| sc_artifacts_state,
|
| ]
|
| example_load_outputs = [
|
| status,
|
| output_files,
|
| trajectory_output_files,
|
| sample_selector,
|
| molstar_viewer,
|
| sc_status,
|
| sc_summary,
|
| sc_metrics,
|
| sc_tm_rmsd_plot,
|
| sc_confidence_plot,
|
| sc_artifacts,
|
| sc_top_tm,
|
| sc_top_rmsd,
|
| sc_base_status_state,
|
| sc_summary_state,
|
| sc_metrics_state,
|
| sc_artifacts_state,
|
| example_status,
|
| ]
|
|
|
| mode.change(
|
| fn=lambda selected: (
|
| gr.update(visible=selected == "classifier"),
|
| gr.update(visible=selected == "conditional"),
|
| ),
|
| inputs=[mode],
|
| outputs=[classifier_group, conditional_group],
|
| )
|
|
|
| run_event = run_button.click(
|
| fn=run_generation,
|
| inputs=[
|
| mode,
|
| length,
|
| num_samples,
|
| seed_override,
|
| include_trajectory_artifacts,
|
| guidance_scale,
|
| target_class,
|
| num_timesteps,
|
| reference_pdb,
|
| fixed_residues_text,
|
| use_classifier_guidance,
|
| ],
|
| outputs=generation_outputs,
|
| )
|
| run_event.then(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
|
| run_event.then(fn=refresh_status_badge, inputs=None, outputs=[status_badge])
|
| run_event.then(fn=refresh_run_history, inputs=None, outputs=[history_dropdown])
|
|
|
| sample_selector.change(
|
| fn=update_selected_molstar,
|
| inputs=[sample_selector],
|
| outputs=[
|
| molstar_viewer,
|
| sc_status,
|
| sc_summary,
|
| sc_metrics,
|
| sc_tm_rmsd_plot,
|
| sc_confidence_plot,
|
| sc_artifacts,
|
| sc_top_tm,
|
| sc_top_rmsd,
|
| sc_base_status_state,
|
| sc_summary_state,
|
| sc_metrics_state,
|
| sc_artifacts_state,
|
| ],
|
| )
|
| sample_selector.change(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
|
| sample_selector.change(fn=_update_backbone_selector, inputs=[sample_selector], outputs=[backbone_compare_selector])
|
| sample_selector.change(fn=clear_compare_status, inputs=None, outputs=[molstar_compare_status])
|
|
|
| load_traj_button.click(
|
| fn=load_trajectory_view,
|
| inputs=[sample_selector, traj_source],
|
| outputs=[trajectory_viewer],
|
| )
|
| zip_button.click(fn=build_run_zip, inputs=[sample_selector], outputs=[zip_file])
|
|
|
| save_example_button.click(fn=save_example_case, inputs=[sample_selector], outputs=[example_status])
|
| load_example_button.click(fn=load_saved_example, inputs=None, outputs=example_load_outputs)
|
| demo_cta_button.click(fn=load_saved_example, inputs=None, outputs=example_load_outputs)
|
| load_example_button.click(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
|
| demo_cta_button.click(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
|
|
|
| sc_run_button.click(
|
| fn=run_self_consistency,
|
| inputs=[
|
| sample_selector,
|
| sc_num_sequences,
|
| sc_run_folding,
|
| sc_top_n,
|
| ],
|
| outputs=[
|
| sc_status,
|
| sc_summary,
|
| sc_metrics,
|
| sc_tm_rmsd_plot,
|
| sc_confidence_plot,
|
| sc_artifacts,
|
| sc_top_tm,
|
| sc_top_rmsd,
|
| sc_base_status_state,
|
| sc_summary_state,
|
| sc_metrics_state,
|
| sc_artifacts_state,
|
| ],
|
| )
|
| sc_apply_view_settings_button.click(
|
| fn=apply_sc_view_settings,
|
| inputs=[
|
| sc_base_status_state,
|
| sc_summary_state,
|
| sc_metrics_state,
|
| sc_artifacts_state,
|
| sc_top_n,
|
| sc_filter_samples,
|
| sc_filter_sample_query,
|
| sc_filter_tm_min,
|
| sc_filter_rmsd_max,
|
| sc_filter_plddt_min,
|
| ],
|
| outputs=[
|
| sc_status,
|
| sc_summary,
|
| sc_metrics,
|
| sc_tm_rmsd_plot,
|
| sc_confidence_plot,
|
| sc_artifacts,
|
| sc_top_tm,
|
| sc_top_rmsd,
|
| ],
|
| )
|
| sc_reset_view_settings_button.click(
|
| fn=reset_sc_view_settings,
|
| inputs=[
|
| sc_base_status_state,
|
| sc_summary_state,
|
| sc_metrics_state,
|
| sc_artifacts_state,
|
| ],
|
| outputs=[
|
| sc_top_n,
|
| sc_filter_samples,
|
| sc_filter_sample_query,
|
| sc_filter_tm_min,
|
| sc_filter_rmsd_max,
|
| sc_filter_plddt_min,
|
| sc_status,
|
| sc_summary,
|
| sc_metrics,
|
| sc_tm_rmsd_plot,
|
| sc_confidence_plot,
|
| sc_artifacts,
|
| sc_top_tm,
|
| sc_top_rmsd,
|
| ],
|
| )
|
| sc_export_button.click(
|
| fn=export_sc_metrics_csv,
|
| inputs=[
|
| sc_metrics_state,
|
| sc_filter_samples,
|
| sc_filter_sample_query,
|
| sc_filter_tm_min,
|
| sc_filter_rmsd_max,
|
| sc_filter_plddt_min,
|
| ],
|
| outputs=[sc_export_file],
|
| )
|
| sc_metrics_state.change(
|
| fn=_update_sc_filter_sample_selector,
|
| inputs=[sc_metrics_state],
|
| outputs=[sc_filter_samples],
|
| )
|
| sc_metrics_state.change(
|
| fn=_update_folded_selector,
|
| inputs=[sc_metrics_state],
|
| outputs=[folded_compare_selector],
|
| )
|
| sc_metrics.select(
|
| fn=on_sc_metrics_select,
|
| inputs=[sc_metrics_state, backbone_compare_selector],
|
| outputs=[backbone_compare_selector, folded_compare_selector, molstar_viewer, molstar_compare_status],
|
| )
|
| load_compare_button.click(
|
| fn=load_aligned_compare_view,
|
| inputs=[backbone_compare_selector, folded_compare_selector],
|
| outputs=[molstar_viewer, molstar_compare_status],
|
| )
|
|
|
| history_refresh_button.click(fn=refresh_run_history, inputs=None, outputs=[history_dropdown])
|
| history_event = history_load_button.click(
|
| fn=load_run_from_history,
|
| inputs=[history_dropdown],
|
| outputs=[
|
| status,
|
| sample_selector,
|
| molstar_viewer,
|
| sc_status,
|
| sc_summary,
|
| sc_metrics,
|
| sc_tm_rmsd_plot,
|
| sc_confidence_plot,
|
| sc_artifacts,
|
| sc_top_tm,
|
| sc_top_rmsd,
|
| sc_base_status_state,
|
| sc_summary_state,
|
| sc_metrics_state,
|
| sc_artifacts_state,
|
| ],
|
| )
|
| history_event.then(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
|
|
|
| preload_button.click(fn=preload_model, inputs=None, outputs=[status, health]).then(
|
| fn=refresh_status_badge, inputs=None, outputs=[status_badge]
|
| )
|
| health_button.click(fn=refresh_health, inputs=None, outputs=health).then(
|
| fn=refresh_status_badge, inputs=None, outputs=[status_badge]
|
| )
|
|
|
| demo.load(fn=refresh_health, inputs=None, outputs=health)
|
| demo.load(fn=refresh_status_badge, inputs=None, outputs=[status_badge])
|
| demo.load(fn=refresh_run_history, inputs=None, outputs=[history_dropdown])
|
| demo.load(
|
| fn=maybe_load_saved_example_on_startup,
|
| inputs=[view_example_on_startup],
|
| outputs=example_load_outputs,
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.queue(default_concurrency_limit=1, max_size=16)
|
| demo.launch(
|
| server_name="0.0.0.0",
|
| server_port=int(os.getenv("PORT", "7860")),
|
| show_error=True,
|
| )
|
|
|