"""Gradio UI entrypoint for the FlowProt Hugging Face Docker Space."""
from __future__ import annotations
import base64
import csv
import html
import json
import logging
import os
import shutil
import traceback
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import gradio as gr
import pandas as pd
from inference import (
ArtifactResolutionError,
FlowProtInferenceService,
InferenceError,
InferenceResult,
ModelLoadError,
)
from self_consistency import (
FlowProtSelfConsistencyService,
SelfConsistencyError,
)
from viewer_alignment import AlignmentError, align_folded_to_backbone_overlay
logging.basicConfig(
level=os.getenv("FLOWPROT_LOG_LEVEL", "INFO").upper(),
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
LOGGER = logging.getLogger(__name__)
SERVICE = FlowProtInferenceService()
SC_SERVICE = FlowProtSelfConsistencyService()
VIEWER_HEIGHT = 620
SC_SUMMARY_HEADERS = ["Metric", "Value"]
SC_METRICS_HEADERS = ["Sample", "Seq len", "scTM", "scRMSD", "ESMFold mean pLDDT"]
SC_TOP_TM_HEADERS = ["Rank", "Sample", "Seq len", "scTM"]
SC_TOP_RMSD_HEADERS = ["Rank", "Sample", "Seq len", "scRMSD"]
EXAMPLE_CASE_DIR = (
Path(__file__).resolve().parent / "examples" / "flowprot_space_example"
).resolve()
EXAMPLE_CASE_SAMPLE_DIR = EXAMPLE_CASE_DIR / "sample"
EXAMPLE_CASE_MANIFEST = EXAMPLE_CASE_DIR / "manifest.json"
SPACE_OUTPUTS_DIR = (Path(__file__).resolve().parent / "space_outputs").resolve()
RUN_HISTORY_PATH = SPACE_OUTPUTS_DIR / "run_history.json"
RUN_HISTORY_LIMIT = 25
def _env_flag(name: str, default: bool = False) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return str(raw).strip().lower() in {"1", "true", "yes", "on"}
VIEW_EXAMPLE_ON_STARTUP_DEFAULT = _env_flag("FLOWPROT_VIEW_EXAMPLE_ON_STARTUP", default=False)
SC_DEFAULT_TOP_N = 3
SC_TM_MIN = 0.0
SC_TM_MAX = 1.0
SC_RMSD_MIN = 0.0
SC_RMSD_MAX = 10.0
SC_PLDDT_MIN = 0.0
SC_PLDDT_MAX = 1.0
UI_CSS = """
.gradio-container {
max-width: 1400px !important;
margin: 0 auto !important;
padding-left: 14px !important;
padding-right: 14px !important;
}
.fp-card {
border: 1px solid var(--block-border-color);
border-radius: 14px;
padding: 14px;
margin-bottom: 12px;
background: var(--block-background-fill);
box-shadow: 0 1px 2px rgba(15, 23, 42, 0.06);
}
.fp-tight {
gap: 12px !important;
}
.fp-status textarea {
min-height: 100px !important;
}
.fp-debug-note {
color: var(--body-text-color-subdued);
font-size: 0.92rem;
}
.fp-subnote {
color: var(--body-text-color-subdued);
font-size: 0.9rem;
}
.fp-main-row {
align-items: stretch;
}
.fp-plot-row {
align-items: stretch;
}
@media (max-width: 980px) {
.gradio-container {
padding-left: 10px !important;
padding-right: 10px !important;
}
.fp-main-row {
flex-direction: column !important;
}
.fp-status textarea {
min-height: 88px !important;
}
}
@media (max-width: 1280px) {
.fp-plot-row {
flex-direction: column !important;
}
}
"""
def _viewer_height_css() -> str:
return f"min(72vh, {VIEWER_HEIGHT}px)"
def _molstar_placeholder_html(message: Optional[str] = None) -> str:
placeholder_message = message or "Generate a sample and select it to load in Mol*."
return (
f"
"
f"{html.escape(placeholder_message)}"
"
"
)
def _mode_choices() -> List[str]:
choices = ["unconditional"]
if SERVICE.classifier_enabled:
choices.append("classifier")
if SERVICE.conditional_enabled:
choices.append("conditional")
return choices
def _empty_sc_plot_df() -> pd.DataFrame:
# Keep numeric columns explicitly typed so Gradio/Altair infer quantitative axes.
return pd.DataFrame(
{
"sample": pd.Series(dtype="string"),
"plot_color": pd.Series(dtype="string"),
"scTM": pd.Series(dtype="float64"),
"scRMSD": pd.Series(dtype="float64"),
"esmfold_mean_plddt": pd.Series(dtype="float64"),
}
)
def _empty_scatter_plot_html(title: str, message: str) -> str:
return (
""
f"
{html.escape(title)}
"
"
"
f"{html.escape(message)}"
"
"
"
"
)
def _empty_sc_outputs() -> Tuple[
str,
List[List[str]],
List[List[object]],
str,
str,
List[str],
List[List[object]],
List[List[object]],
]:
return (
"",
[],
[],
_empty_scatter_plot_html(
title="Self-consistency tradeoff (scTM vs scRMSD)",
message="Run self-consistency to render plot data.",
),
_empty_scatter_plot_html(
title="ESMFold confidence vs scTM",
message="Run self-consistency with folding enabled to render confidence plot.",
),
[],
[],
[],
)
def _empty_sc_state() -> Tuple[str, Dict[str, object], List[Dict[str, object]], List[str]]:
return "", {}, [], []
def _empty_example_view_outputs() -> Tuple[
str,
List[str],
List[str],
Dict[str, object],
str,
str,
List[List[str]],
List[List[object]],
pd.DataFrame,
pd.DataFrame,
List[str],
List[List[object]],
List[List[object]],
str,
Dict[str, object],
List[Dict[str, object]],
List[str],
str,
]:
empty_sc = _empty_sc_outputs()
empty_state = _empty_sc_state()
return (
"",
[],
[],
gr.update(choices=[], value=[], interactive=False),
_molstar_placeholder_html(),
*empty_sc,
*empty_state,
"",
)
def _resolve_saved_example_sample() -> Optional[str]:
if EXAMPLE_CASE_MANIFEST.exists():
try:
payload = json.loads(EXAMPLE_CASE_MANIFEST.read_text(encoding="utf-8"))
manifest_sample = payload.get("example_sample_path")
if manifest_sample:
path = Path(str(manifest_sample))
if path.exists():
return path.resolve().as_posix()
except Exception:
LOGGER.warning("Failed to parse example manifest at %s", EXAMPLE_CASE_MANIFEST)
fallback_sample = EXAMPLE_CASE_SAMPLE_DIR / "sample.pdb"
if fallback_sample.exists():
return fallback_sample.resolve().as_posix()
return None
def _find_latest_generated_sample() -> Optional[Path]:
output_root = (Path(__file__).resolve().parent / "space_outputs").resolve()
if not output_root.exists():
return None
candidates = sorted(
output_root.glob("space_*/length_*/sample_*/sample.pdb"),
key=lambda path: path.stat().st_mtime,
reverse=True,
)
return candidates[0] if candidates else None
def _normalize_selected_samples(selected_samples: Optional[object]) -> List[str]:
if selected_samples is None:
return []
if isinstance(selected_samples, str):
value = selected_samples.strip()
return [value] if value else []
if isinstance(selected_samples, (list, tuple, set)):
normalized: List[str] = []
for raw in selected_samples:
if raw is None:
continue
value = str(raw).strip()
if value:
normalized.append(value)
return normalized
value = str(selected_samples).strip()
return [value] if value else []
def _sample_label_from_path(sample_path: str, metric: Optional[Dict[str, object]] = None) -> str:
path = Path(sample_path)
sample_dir_name = path.parent.name
sample_id = sample_dir_name.replace("sample_", "", 1) if sample_dir_name.startswith("sample_") else sample_dir_name
esmfold_id: object = "x"
if metric is not None and metric.get("index") is not None:
esmfold_id = metric.get("index")
return f"sample_{sample_id}_{esmfold_id}"
def _enrich_metric_sample_fields(metric: Dict[str, object], sample_path: str) -> Dict[str, object]:
enriched = dict(metric)
enriched["sample_source_path"] = sample_path
enriched["sample_display"] = _sample_label_from_path(sample_path, metric=metric)
# Keep legacy key expected by existing views/tooltips.
enriched["sample"] = enriched["sample_display"]
return enriched
def _metric_sample_display(metric: Dict[str, object]) -> str:
value = metric.get("sample_display") or metric.get("sample")
return str(value) if value is not None else "unknown"
def _metric_sequence_length(metric: Dict[str, object]) -> Optional[int]:
sequence = metric.get("sequence")
if sequence is None:
return None
return len(str(sequence))
def _normalize_optional_text(value: Optional[object]) -> str:
if value is None:
return ""
return str(value).strip()
def _parse_optional_seed(seed_override: Optional[object]) -> Optional[int]:
text = _normalize_optional_text(seed_override)
if not text:
return None
try:
return int(text)
except ValueError as exc:
raise InferenceError("Seed override must be an integer value.") from exc
def _actionable_error_hint(exc: object) -> str:
text = str(exc).lower()
if "out of memory" in text or "cuda" in text and "memory" in text:
return (
" Hint: the GPU ran out of memory. Try a shorter length, fewer samples, "
"or fewer sampling timesteps."
)
if "no model artifact source" in text or "artifact" in text:
return (
" Hint: configure a checkpoint via FLOWPROT_CKPT_PATH, FLOWPROT_CKPT_DIR, "
"or FLOWPROT_HF_REPO_ID before running inference."
)
if "esmfold" in text or "transformers" in text:
return " Hint: ESMFold could not run. Disable folding or verify the ESMFold model is available."
return ""
def _short_error_text(message: object, max_chars: int = 180) -> str:
raw = str(message).strip()
if not raw:
return "unknown error"
first_line = raw.splitlines()[0].strip()
if len(first_line) <= max_chars:
return first_line
return first_line[: max_chars - 3].rstrip() + "..."
def _build_sc_filter_sample_choices(metrics: Sequence[Dict[str, object]]) -> List[str]:
ordered_unique: List[str] = []
seen = set()
for item in metrics:
label = _metric_sample_display(item)
if label in seen:
continue
seen.add(label)
ordered_unique.append(label)
return ordered_unique
def _build_folded_choices(metrics: Sequence[Dict[str, object]]) -> List[Tuple[str, str]]:
choices: List[Tuple[str, str]] = []
for item in metrics:
folded_path = _normalize_optional_text(item.get("folded_sample_path"))
if not folded_path:
continue
sc_tm = item.get("scTM")
sc_rmsd = item.get("scRMSD")
sequence_index = item.get("index")
sample_display = _metric_sample_display(item)
choice_label = (
f"{sample_display} | seq#{sequence_index} | "
f"scTM={float(sc_tm):.4f} | scRMSD={float(sc_rmsd):.4f}"
if sc_tm is not None and sc_rmsd is not None
else f"{sample_display} | seq#{sequence_index}"
)
choices.append((choice_label, folded_path))
return choices
def _build_sc_summary(num_selected_samples: int, num_processed_samples: int, metrics: List[Dict[str, object]]) -> Dict[str, object]:
summary: Dict[str, object] = {
"num_samples_selected": num_selected_samples,
"num_samples_processed": num_processed_samples,
"num_sequences_evaluated": len(metrics),
}
if metrics:
tm_values = [float(item["scTM"]) for item in metrics if item.get("scTM") is not None]
rmsd_values = [float(item["scRMSD"]) for item in metrics if item.get("scRMSD") is not None]
plddt_values = [
float(item["esmfold_mean_plddt"])
for item in metrics
if item.get("esmfold_mean_plddt") is not None
]
if tm_values:
summary["mean_scTM"] = float(sum(tm_values) / len(tm_values))
summary["max_scTM"] = max(tm_values)
summary["min_scTM"] = min(tm_values)
if rmsd_values:
summary["mean_scRMSD"] = float(sum(rmsd_values) / len(rmsd_values))
summary["min_scRMSD"] = min(rmsd_values)
summary["max_scRMSD"] = max(rmsd_values)
if plddt_values:
summary["mean_esmfold_plddt"] = float(sum(plddt_values) / len(plddt_values))
return summary
def _safe_float(value: Optional[str]) -> Optional[float]:
if value is None:
return None
text = str(value).strip()
if not text:
return None
try:
return float(text)
except ValueError:
return None
def _load_sc_metrics_from_csv(csv_path: Path) -> List[Dict[str, object]]:
metrics: List[Dict[str, object]] = []
with csv_path.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
for row in reader:
index_raw = row.get("index")
index_value: object = index_raw
if index_raw not in (None, ""):
try:
index_value = int(str(index_raw))
except ValueError:
index_value = index_raw
metrics.append(
{
"index": index_value,
"header": row.get("header"),
"sequence": row.get("sequence"),
"scTM": _safe_float(row.get("scTM")),
"scRMSD": _safe_float(row.get("scRMSD")),
"esmfold_mean_plddt": _safe_float(row.get("esmfold_mean_plddt")),
"folded_sample_path": row.get("folded_sample_path"),
}
)
return metrics
def _collect_loaded_sc_artifacts(run_dir: Path, metrics_csv: Path, metrics: List[Dict[str, object]]) -> List[str]:
artifacts: List[str] = []
parsed_jsonl = run_dir / "parsed_pdbs.jsonl"
mpnn_fasta = run_dir / "seqs" / "sample.fa"
if parsed_jsonl.exists():
artifacts.append(parsed_jsonl.resolve().as_posix())
if mpnn_fasta.exists():
artifacts.append(mpnn_fasta.resolve().as_posix())
esmf_dir = run_dir / "esmf"
if esmf_dir.exists():
folded = sorted(esmf_dir.glob("*.pdb"))
artifacts.extend(path.resolve().as_posix() for path in folded)
else:
for item in metrics:
folded_path = item.get("folded_sample_path")
if not folded_path:
continue
path = Path(str(folded_path))
if path.exists():
artifacts.append(path.resolve().as_posix())
if metrics_csv.exists():
artifacts.append(metrics_csv.resolve().as_posix())
return artifacts
def _hydrate_sc_outputs_for_samples(selected_samples: Optional[object]) -> Tuple[
str,
List[List[str]],
List[List[object]],
pd.DataFrame,
pd.DataFrame,
List[str],
List[List[object]],
List[List[object]],
str,
Dict[str, object],
List[Dict[str, object]],
List[str],
]:
sample_values = _normalize_selected_samples(selected_samples)
if not sample_values:
return (*_empty_sc_outputs(), *_empty_sc_state())
combined_metrics: List[Dict[str, object]] = []
combined_artifacts: List[str] = []
processed_samples = 0
for sample_value in sample_values:
sample_path = Path(sample_value)
sample_dir = sample_path.parent
sc_root = sample_dir / "self_consistency"
if not sc_root.exists():
continue
run_dirs = sorted(
[path for path in sc_root.iterdir() if path.is_dir()],
key=lambda path: path.stat().st_mtime,
reverse=True,
)
for run_dir in run_dirs:
metrics_csv = run_dir / "sc_results.csv"
if not metrics_csv.exists():
continue
try:
metrics = _load_sc_metrics_from_csv(metrics_csv)
if not metrics:
continue
for item in metrics:
combined_metrics.append(_enrich_metric_sample_fields(metric=item, sample_path=sample_value))
combined_artifacts.extend(
_collect_loaded_sc_artifacts(run_dir=run_dir, metrics_csv=metrics_csv, metrics=metrics)
)
processed_samples += 1
break
except Exception as exc: # pragma: no cover
LOGGER.warning("Failed to load saved self-consistency from %s: %s", metrics_csv, exc)
continue
if not combined_metrics:
return (*_empty_sc_outputs(), *_empty_sc_state())
# Keep artifact list stable and duplicate-free.
artifacts = list(dict.fromkeys(combined_artifacts))
summary = _build_sc_summary(
num_selected_samples=len(sample_values),
num_processed_samples=processed_samples,
metrics=combined_metrics,
)
base_status = (
f"Loaded saved self-consistency for {processed_samples}/{len(sample_values)} selected sample(s). "
f"Evaluated={len(combined_metrics)}."
)
rendered = _render_sc_outputs(
base_status=base_status,
summary=summary,
metrics=combined_metrics,
artifacts=artifacts,
top_n=SC_DEFAULT_TOP_N,
)
return (*rendered, base_status, summary, combined_metrics, artifacts)
def save_example_case(selected_samples: Optional[object]) -> str:
selected_values = _normalize_selected_samples(selected_samples)
selected_sample = selected_values[0] if selected_values else None
if not selected_sample:
latest = _find_latest_generated_sample()
if latest is None:
return "No generated sample found. Run inference or select a sample before saving an example case."
selected_sample = latest.resolve().as_posix()
sample_path = Path(selected_sample)
if not sample_path.exists():
return f"Selected sample does not exist: {sample_path}"
source_dir = sample_path.parent
try:
EXAMPLE_CASE_DIR.mkdir(parents=True, exist_ok=True)
if EXAMPLE_CASE_SAMPLE_DIR.exists():
shutil.rmtree(EXAMPLE_CASE_SAMPLE_DIR)
shutil.copytree(source_dir, EXAMPLE_CASE_SAMPLE_DIR)
saved_sample_path = EXAMPLE_CASE_SAMPLE_DIR / sample_path.name
if not saved_sample_path.exists():
return (
"Example case copy completed, but sample file was missing in copied directory. "
"Please verify the source sample."
)
manifest = {
"saved_at_utc": datetime.utcnow().isoformat() + "Z",
"source_sample_path": sample_path.resolve().as_posix(),
"example_sample_path": saved_sample_path.resolve().as_posix(),
}
EXAMPLE_CASE_MANIFEST.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
return (
f"Saved example case to {EXAMPLE_CASE_SAMPLE_DIR}. "
"Set FLOWPROT_VIEW_EXAMPLE_ON_STARTUP=true to auto-load it at app startup."
)
except Exception as exc: # pragma: no cover
LOGGER.exception("Failed to save example case.")
return f"Failed to save example case: {exc}"
def load_saved_example(
prefix_status: str = "Loaded saved example case.",
prefix_example_status: str = "Example case loaded.",
) -> Tuple[
str,
List[str],
List[str],
Dict[str, object],
str,
str,
List[List[str]],
List[List[object]],
pd.DataFrame,
pd.DataFrame,
List[str],
List[List[object]],
List[List[object]],
str,
Dict[str, object],
List[Dict[str, object]],
List[str],
str,
]:
sample_path = _resolve_saved_example_sample()
if not sample_path:
empty = _empty_example_view_outputs()
return (*empty[:-1], "No saved example found in examples/flowprot_space_example. Save one first.")
selector_update = gr.update(choices=[sample_path], value=[sample_path], interactive=True)
sc_payload = _hydrate_sc_outputs_for_samples([sample_path])
loaded_metrics = sc_payload[10]
sc_note = (
f" Loaded self-consistency ({len(loaded_metrics)} sequence(s))."
if loaded_metrics
else " No saved self-consistency metrics found for this example."
)
return (
f"{prefix_status} Source: {sample_path}",
[sample_path],
[],
selector_update,
_build_molstar_iframe_html(sample_path),
*sc_payload,
f"{prefix_example_status} Source: {sample_path}.{sc_note}",
)
def maybe_load_saved_example_on_startup(
view_example_on_startup: bool,
) -> Tuple[
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
object,
str,
]:
if not bool(view_example_on_startup):
return (
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
"",
)
return load_saved_example(
prefix_status="Loaded saved example case on startup.",
prefix_example_status="Startup example loaded.",
)
def _format_summary_rows(summary: Dict[str, object]) -> List[List[str]]:
preferred = [
("mean_scTM", "scTM score"),
("mean_scRMSD", "scRMSD score"),
("mean_esmfold_plddt", "Mean pLDDT"),
]
rows: List[List[str]] = []
for key, label in preferred:
if key not in summary:
continue
value = summary[key]
if isinstance(value, float):
precision = 4 if "plddt" not in key else 4
rows.append([label, f"{value:.{precision}f}"])
else:
rows.append([label, str(value)])
return rows
def _format_metrics_rows(metrics: List[Dict[str, object]]) -> List[List[object]]:
rows: List[List[object]] = []
for item in metrics:
rows.append(
[
_metric_sample_display(item),
_metric_sequence_length(item),
round(float(item.get("scTM")), 4) if item.get("scTM") is not None else None,
round(float(item.get("scRMSD")), 4) if item.get("scRMSD") is not None else None,
round(float(item.get("esmfold_mean_plddt")), 4)
if item.get("esmfold_mean_plddt") is not None
else None,
]
)
return rows
def _format_metrics_plot_df(metrics: List[Dict[str, object]]) -> pd.DataFrame:
plot_rows = [
{
"sample": _metric_sample_display(item),
"plot_color": "metric",
"scTM": item.get("scTM"),
"scRMSD": item.get("scRMSD"),
"esmfold_mean_plddt": item.get("esmfold_mean_plddt"),
}
for item in metrics
if item.get("scTM") is not None and item.get("scRMSD") is not None
]
if not plot_rows:
return _empty_sc_plot_df()
frame = pd.DataFrame(plot_rows)
for col in ("scTM", "scRMSD", "esmfold_mean_plddt"):
frame[col] = pd.to_numeric(frame[col], errors="coerce")
return frame
def _format_leaderboard_rows(
metrics: List[Dict[str, object]],
top_n: int,
) -> Tuple[List[List[object]], List[List[object]]]:
if not metrics:
return [], []
ranked_tm = sorted(
metrics,
key=lambda item: float(item["scTM"]),
reverse=True,
)[: int(top_n)]
top_tm_rows: List[List[object]] = []
for rank, item in enumerate(ranked_tm, start=1):
top_tm_rows.append(
[
rank,
_metric_sample_display(item),
_metric_sequence_length(item),
round(float(item["scTM"]), 4),
]
)
ranked_rmsd = sorted(
metrics,
key=lambda item: float(item["scRMSD"]),
)[: int(top_n)]
top_rmsd_rows: List[List[object]] = []
for rank, item in enumerate(ranked_rmsd, start=1):
top_rmsd_rows.append(
[
rank,
_metric_sample_display(item),
_metric_sequence_length(item),
round(float(item["scRMSD"]), 4),
]
)
return top_tm_rows, top_rmsd_rows
def _filter_metrics(
metrics: Sequence[Dict[str, object]],
sample_labels: Optional[object],
sample_query: str,
min_tm: float,
max_rmsd: float,
min_plddt: float,
) -> List[Dict[str, object]]:
selected_sample_labels = set(_normalize_selected_samples(sample_labels))
query = sample_query.strip().lower()
filtered: List[Dict[str, object]] = []
for item in metrics:
sample_display = _metric_sample_display(item)
if selected_sample_labels and sample_display not in selected_sample_labels:
continue
if query and query not in sample_display.lower():
continue
tm_value = item.get("scTM")
if tm_value is None or float(tm_value) < float(min_tm):
continue
rmsd_value = item.get("scRMSD")
if rmsd_value is None or float(rmsd_value) > float(max_rmsd):
continue
plddt_value = item.get("esmfold_mean_plddt")
if plddt_value is not None and float(plddt_value) < float(min_plddt):
continue
filtered.append(dict(item))
return filtered
def _filters_summary_line(
sample_labels: Optional[object],
sample_query: str,
min_tm: float,
max_rmsd: float,
min_plddt: float,
) -> Optional[str]:
labels = _normalize_selected_samples(sample_labels)
parts: List[str] = []
if labels:
parts.append(f"samples={len(labels)}")
query = sample_query.strip()
if query:
parts.append(f'text="{query}"')
if float(min_tm) > SC_TM_MIN:
parts.append(f"scTM>={float(min_tm):.2f}")
if float(max_rmsd) < SC_RMSD_MAX:
parts.append(f"scRMSD<={float(max_rmsd):.2f}")
if float(min_plddt) > SC_PLDDT_MIN:
parts.append(f"pLDDT>={float(min_plddt):.2f}")
if not parts:
return None
return "Filters active: " + ", ".join(parts)
def _build_scatter_plot_html(
plot_df: pd.DataFrame,
x_col: str,
y_col: str,
title: str,
x_title: str,
y_title: str,
fallback_x_range: Tuple[float, float],
fallback_y_range: Tuple[float, float],
) -> str:
if plot_df.empty:
return _empty_scatter_plot_html(title=title, message="No points available for this view.")
keep_cols = [x_col, y_col]
if "sample" in plot_df.columns:
keep_cols.append("sample")
rows = plot_df[keep_cols].copy()
rows[x_col] = pd.to_numeric(rows[x_col], errors="coerce")
rows[y_col] = pd.to_numeric(rows[y_col], errors="coerce")
rows = rows.dropna(subset=[x_col, y_col])
if rows.empty:
return _empty_scatter_plot_html(title=title, message="No numeric points available for this view.")
x_min, x_max = fallback_x_range
y_min, y_max = fallback_y_range
rows = rows[(rows[x_col] >= x_min) & (rows[x_col] <= x_max) & (rows[y_col] >= y_min) & (rows[y_col] <= y_max)]
if rows.empty:
return _empty_scatter_plot_html(title=title, message="No points within the fixed axis range for this view.")
x_vals = rows[x_col].astype(float).tolist()
y_vals = rows[y_col].astype(float).tolist()
labels = rows["sample"].astype(str).tolist() if "sample" in rows.columns else [""] * len(x_vals)
x_span = max(x_max - x_min, 1e-9)
y_span = max(y_max - y_min, 1e-9)
width = 760
height = 420
left = 74
right = 26
top = 36
bottom = 58
plot_w = width - left - right
plot_h = height - top - bottom
def sx(value: float) -> float:
return left + ((value - x_min) / x_span) * plot_w
def sy(value: float) -> float:
return top + (1.0 - ((value - y_min) / y_span)) * plot_h
points_svg = []
for xv, yv, label in zip(x_vals, y_vals, labels):
tooltip = html.escape(
f"{label} | {x_title.split(' (')[0]}={xv:.3f} | {y_title.split(' (')[0]}={yv:.3f}"
)
points_svg.append(
f""
f"{tooltip}"
)
x_ticks = []
y_ticks = []
tick_count = 5
for idx in range(tick_count + 1):
ratio = idx / tick_count
xv = x_min + ratio * x_span
yv = y_min + ratio * y_span
tx = left + ratio * plot_w
ty = top + (1.0 - ratio) * plot_h
x_ticks.append(
f""
f"{xv:.2f}"
)
y_ticks.append(
f""
f"{yv:.2f}"
)
return (
""
f"
{html.escape(title)}
"
f"
"
+ "
"
)
def _render_sc_outputs(
base_status: str,
summary: Dict[str, object],
metrics: List[Dict[str, object]],
artifacts: List[str],
top_n: int,
extra_status_line: Optional[str] = None,
) -> Tuple[
str,
List[List[str]],
List[List[object]],
str,
str,
List[str],
List[List[object]],
List[List[object]],
]:
status_lines = [line for line in [base_status, extra_status_line] if line]
if metrics:
status_lines.append(f"Showing {len(metrics)} evaluated sequence(s).")
status = "\n".join(status_lines)
summary_rows = _format_summary_rows(summary)
metrics_rows = _format_metrics_rows(metrics)
plot_df = _format_metrics_plot_df(metrics)
tm_vs_rmsd_html = _build_scatter_plot_html(
plot_df=plot_df,
x_col="scTM",
y_col="scRMSD",
title="Self-consistency tradeoff (scTM vs scRMSD)",
x_title="scTM (higher is better)",
y_title="scRMSD (lower is better)",
fallback_x_range=(SC_TM_MIN, SC_TM_MAX),
fallback_y_range=(SC_RMSD_MIN, SC_RMSD_MAX),
)
confidence_html = _build_scatter_plot_html(
plot_df=plot_df,
x_col="esmfold_mean_plddt",
y_col="scTM",
title="ESMFold confidence vs scTM",
x_title="ESMFold mean pLDDT",
y_title="scTM",
fallback_x_range=(SC_PLDDT_MIN, SC_PLDDT_MAX),
fallback_y_range=(SC_TM_MIN, SC_TM_MAX),
)
top_tm_rows, top_rmsd_rows = _format_leaderboard_rows(metrics, int(top_n))
return status, summary_rows, metrics_rows, tm_vs_rmsd_html, confidence_html, artifacts, top_tm_rows, top_rmsd_rows
def refresh_health() -> Dict[str, object]:
health = SERVICE.health_check()
health["self_consistency"] = SC_SERVICE.health_check()
return health
def _resolve_viewer_value(sample_path: Optional[str]) -> Optional[str]:
if not sample_path:
return None
normalized = Path(sample_path)
if normalized.exists():
# Use POSIX-style separators for robust frontend URL serialization.
return normalized.resolve().as_posix()
return None
def _build_molstar_iframe_html(
sample_path: Optional[str],
viewer_message: Optional[str] = None,
) -> str:
viewer_value = _resolve_viewer_value(sample_path)
if not viewer_value:
return _molstar_placeholder_html(message=viewer_message)
try:
pdb_text = Path(viewer_value).read_text(encoding="utf-8", errors="replace")
except Exception as exc: # pragma: no cover
LOGGER.warning("Failed to read PDB for Mol*: %s", exc)
return (
f""
"Failed to load structure for Mol*."
"
"
)
pdb_base64 = base64.b64encode(pdb_text.encode("utf-8")).decode("ascii")
srcdoc = f"""
"""
escaped_srcdoc = html.escape(srcdoc, quote=True)
return (
""
)
def update_selected_molstar(selected_samples: Optional[object]):
selected_values = _normalize_selected_samples(selected_samples)
viewer_sample = selected_values[0] if selected_values else None
return (
_build_molstar_iframe_html(viewer_sample),
*_hydrate_sc_outputs_for_samples(selected_values),
)
def _update_backbone_selector(selected_samples: Optional[object]) -> Dict[str, object]:
selected_values = _normalize_selected_samples(selected_samples)
return gr.update(
choices=selected_values,
value=selected_values[0] if selected_values else None,
interactive=bool(selected_values),
)
def _update_sc_filter_sample_selector(metrics: Optional[List[Dict[str, object]]]) -> Dict[str, object]:
choices = _build_sc_filter_sample_choices(metrics or [])
return gr.update(choices=choices, value=[], interactive=bool(choices))
def _update_folded_selector(metrics: Optional[List[Dict[str, object]]]) -> Dict[str, object]:
choices = _build_folded_choices(metrics or [])
return gr.update(
choices=choices,
value=choices[0][1] if choices else None,
interactive=bool(choices),
)
def load_aligned_compare_view(
backbone_sample_path: Optional[str],
folded_sample_path: Optional[str],
) -> Tuple[str, str]:
backbone_value = _normalize_optional_text(backbone_sample_path)
folded_value = _normalize_optional_text(folded_sample_path)
if not backbone_value:
return (
_molstar_placeholder_html(message="Select a backbone sample to compare."),
"Select a backbone sample before loading aligned compare view.",
)
if not folded_value:
return (
_build_molstar_iframe_html(backbone_value),
"Select a folded structure from self-consistency rows before comparison.",
)
try:
overlay = align_folded_to_backbone_overlay(
backbone_sample_path=backbone_value,
folded_sample_path=folded_value,
)
viewer = _build_molstar_iframe_html(
overlay.output_path,
viewer_message="Aligned backbone + ESMFold comparison.",
)
details = (
"Loaded aligned compare view. "
f"CA pairs={overlay.num_ca_pairs}, "
f"RMSD before={overlay.rmsd_before:.4f}, after={overlay.rmsd_after:.4f}. "
f"Overlay: {overlay.output_path}"
)
return viewer, details
except AlignmentError as exc:
LOGGER.error("Alignment compare failed: %s", exc)
return (
_build_molstar_iframe_html(backbone_value),
f"Failed to build aligned compare view: {exc}",
)
except Exception as exc: # pragma: no cover
LOGGER.exception("Unexpected compare-view failure.")
return (
_build_molstar_iframe_html(backbone_value),
f"Unexpected compare-view failure: {exc}",
)
def preload_model() -> Tuple[str, Dict[str, object]]:
try:
SERVICE.preload_model()
return "Model loaded successfully.", refresh_health()
except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc:
LOGGER.error("Model preload failed: %s", exc)
return f"Model preload failed: {exc}", refresh_health()
except Exception as exc: # pragma: no cover
LOGGER.exception("Unexpected preload failure.")
return f"Unexpected preload failure: {exc}", refresh_health()
def _collect_trajectory_artifacts(
include_trajectory_artifacts: bool,
result: InferenceResult,
) -> List[str]:
if not include_trajectory_artifacts:
return []
combined = [*result.trajectory_files, *result.x0_trajectory_files]
return list(dict.fromkeys(combined))
def run_generation(
mode: str,
length: int,
num_samples: int,
seed_override: Optional[object],
include_trajectory_artifacts: bool,
guidance_scale: float,
target_class: int,
num_timesteps: Optional[object],
reference_pdb: Optional[object],
fixed_residues_text: Optional[object],
use_classifier_guidance: bool,
progress=gr.Progress(track_tqdm=False),
):
try:
progress(0.05, desc="Loading model and preparing inputs...")
seed = _parse_optional_seed(seed_override)
timesteps = None
timesteps_text = _normalize_optional_text(num_timesteps)
if timesteps_text:
timesteps = int(float(timesteps_text))
reference_path = None
if reference_pdb is not None:
reference_path = getattr(reference_pdb, "name", None) or str(reference_pdb)
fixed_residues = _parse_fixed_residues(fixed_residues_text) if mode == "conditional" else None
progress(0.2, desc=f"Sampling {int(num_samples)} structure(s)...")
result = SERVICE.generate(
mode=mode,
length=int(length),
num_samples=int(num_samples),
seed=seed,
guidance_scale=float(guidance_scale) if mode in ("classifier", "conditional") else None,
target_class=int(target_class) if mode in ("classifier", "conditional") else None,
num_timesteps=timesteps,
reference_pdb_path=reference_path,
fixed_residues=fixed_residues,
use_classifier_guidance=bool(use_classifier_guidance),
)
progress(0.9, desc="Writing artifacts...")
_append_run_history(result)
trajectory_artifacts = _collect_trajectory_artifacts(
include_trajectory_artifacts=bool(include_trajectory_artifacts),
result=result,
)
trajectory_message = (
f"Trajectory artifacts ready: {len(trajectory_artifacts)} file(s)."
if include_trajectory_artifacts
else "Trajectory artifacts hidden; enable the checkbox to download traj/x0 files."
)
classifier_details = ""
if result.mode == "classifier":
classifier_details = (
f" Guidance scale={result.guidance_scale}, "
f"target_class={result.target_class}."
)
elif result.mode == "conditional":
classifier_details = f" Fixed residues={result.fixed_residue_count}."
steps_details = f" Timesteps={result.num_timesteps}." if result.num_timesteps else ""
summary = (
f"Generated {len(result.sample_files)} sample(s) with mode={result.mode}. "
f"Seed={result.seed}.{steps_details} Output dir: {result.run_dir}. "
f"Artifacts source: {result.artifacts_source}.{classifier_details} "
f"{trajectory_message}"
)
default_sample = result.sample_files[0] if result.sample_files else None
selector_update = gr.update(
choices=result.sample_files,
value=[default_sample] if default_sample else [],
interactive=bool(result.sample_files),
)
return (
summary,
result.sample_files,
trajectory_artifacts,
refresh_health(),
selector_update,
_build_molstar_iframe_html(default_sample),
*_empty_sc_outputs(),
*_empty_sc_state(),
)
except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc:
LOGGER.error("Inference request failed: %s", exc)
return (
f"Inference failed: {exc}{_actionable_error_hint(exc)}",
[],
[],
refresh_health(),
gr.update(choices=[], value=[], interactive=False),
_molstar_placeholder_html(),
*_empty_sc_outputs(),
*_empty_sc_state(),
)
except Exception as exc: # pragma: no cover
LOGGER.exception("Unhandled inference failure.")
trace = traceback.format_exc(limit=1)
return (
f"Unexpected error: {exc}{_actionable_error_hint(exc)}\n{trace}",
[],
[],
refresh_health(),
gr.update(choices=[], value=[], interactive=False),
_molstar_placeholder_html(),
*_empty_sc_outputs(),
*_empty_sc_state(),
)
def run_self_consistency(
selected_samples: Optional[object],
num_seq_per_target: int,
run_folding: bool,
top_n: int,
progress=gr.Progress(track_tqdm=False),
):
sample_values = _normalize_selected_samples(selected_samples)
if not sample_values:
empty_outputs = _empty_sc_outputs()
return (
"Select one or more samples before running self-consistency.",
*empty_outputs[1:],
*_empty_sc_state(),
)
try:
all_metrics: List[Dict[str, object]] = []
all_artifacts: List[str] = []
processed_samples = 0
failed_samples: List[str] = []
total_samples = len(sample_values)
for sample_idx, sample_value in enumerate(sample_values):
sample_label = _sample_label_from_path(sample_value)
progress(
sample_idx / max(total_samples, 1),
desc=f"ProteinMPNN + folding ({sample_idx + 1}/{total_samples}): {sample_label}",
)
def _fold_progress(done: int, total: int, _idx=sample_idx, _label=sample_label) -> None:
if total <= 0:
return
frac = (_idx + done / total) / max(total_samples, 1)
progress(frac, desc=f"Folding {_label}: {done}/{total} sequences")
try:
result = SC_SERVICE.run(
sample_path=sample_value,
num_seq_per_target=int(num_seq_per_target),
run_folding=bool(run_folding),
progress_callback=_fold_progress,
)
for item in result.per_sequence_metrics:
all_metrics.append(_enrich_metric_sample_fields(metric=item, sample_path=sample_value))
all_artifacts.extend(
[
result.parsed_pdbs_jsonl,
result.mpnn_fasta_path,
*result.folded_pdb_paths,
]
)
if result.metrics_csv_path:
all_artifacts.append(result.metrics_csv_path)
processed_samples += 1
except SelfConsistencyError as exc:
LOGGER.error("Self-consistency failed for %s: %s", sample_value, exc)
failed_samples.append(
f"{_sample_label_from_path(sample_value)} ({_short_error_text(exc)})"
)
if not all_metrics:
empty_outputs = _empty_sc_outputs()
fail_preview = "\n".join(f"- {sample}" for sample in failed_samples[:5])
if failed_samples and len(failed_samples) > 5:
fail_preview += f"\n... and {len(failed_samples) - 5} more failure(s)."
fail_context = fail_preview if failed_samples else "No metrics were generated."
return (
"Self-consistency failed across selected samples.\n"
f"Requested sequences/sample={int(num_seq_per_target)}, folding={'on' if run_folding else 'off'}.\n"
f"{fail_context}",
*empty_outputs[1:],
*_empty_sc_state(),
)
artifacts = list(dict.fromkeys(all_artifacts))
summary = _build_sc_summary(
num_selected_samples=len(sample_values),
num_processed_samples=processed_samples,
metrics=all_metrics,
)
base_status_lines = [
f"Self-consistency completed for {processed_samples}/{len(sample_values)} selected sample(s).",
f"Evaluated={len(all_metrics)} sequences. Requested sequences/sample={int(num_seq_per_target)}.",
f"Folding={'on' if run_folding else 'off'}. Artifacts={len(artifacts)}.",
]
if failed_samples:
preview = "; ".join(failed_samples[:3])
suffix = f"; +{len(failed_samples) - 3} more" if len(failed_samples) > 3 else ""
base_status_lines.append(f"Failed samples: {len(failed_samples)} ({preview}{suffix})")
base_status = "\n".join(base_status_lines)
rendered = _render_sc_outputs(
base_status=base_status,
summary=summary,
metrics=all_metrics,
artifacts=artifacts,
top_n=top_n,
)
return (*rendered, base_status, summary, all_metrics, artifacts)
except Exception as exc: # pragma: no cover
LOGGER.exception("Unexpected self-consistency failure.")
trace = traceback.format_exc(limit=1)
empty_outputs = _empty_sc_outputs()
return (
f"Unexpected self-consistency error: {exc}\n{trace}",
*empty_outputs[1:],
*_empty_sc_state(),
)
def apply_sc_view_settings(
base_status: str,
summary: Dict[str, object],
metrics: List[Dict[str, object]],
artifacts: List[str],
top_n: int,
filter_sample_labels: Optional[object],
filter_sample_query: str,
filter_tm_min: float,
filter_rmsd_max: float,
filter_plddt_min: float,
):
filtered_metrics = _filter_metrics(
metrics=metrics,
sample_labels=filter_sample_labels,
sample_query=filter_sample_query,
min_tm=filter_tm_min,
max_rmsd=filter_rmsd_max,
min_plddt=filter_plddt_min,
)
filters_line = _filters_summary_line(
sample_labels=filter_sample_labels,
sample_query=filter_sample_query,
min_tm=filter_tm_min,
max_rmsd=filter_rmsd_max,
min_plddt=filter_plddt_min,
)
return _render_sc_outputs(
base_status=base_status,
summary=summary,
metrics=filtered_metrics,
artifacts=artifacts,
top_n=top_n,
extra_status_line=filters_line,
)
def reset_sc_view_settings(
base_status: str,
summary: Dict[str, object],
metrics: List[Dict[str, object]],
artifacts: List[str],
):
rendered = _render_sc_outputs(
base_status=base_status,
summary=summary,
metrics=metrics,
artifacts=artifacts,
top_n=SC_DEFAULT_TOP_N,
)
return (
gr.update(value=SC_DEFAULT_TOP_N),
gr.update(value=[]),
gr.update(value=""),
gr.update(value=SC_TM_MIN),
gr.update(value=SC_RMSD_MAX),
gr.update(value=SC_PLDDT_MIN),
*rendered,
)
def clear_compare_status() -> str:
return ""
def _status_badge_html() -> str:
health = SERVICE.health_check()
model_loaded = bool(health.get("model_loaded"))
device = str(health.get("device", "auto"))
sc_health = SC_SERVICE.health_check()
pmpnn_ok = bool(sc_health.get("pmpnn_available")) and bool(sc_health.get("pmpnn_weights_available"))
def pill(label: str, ok: bool) -> str:
color = "#16a34a" if ok else "#b91c1c"
bg = "#dcfce7" if ok else "#fee2e2"
dot = "●"
return (
f"{dot} {html.escape(label)}"
)
device_pill = (
f"Device: {html.escape(device)}"
)
return (
""
+ pill("Model loaded" if model_loaded else "Model not loaded", model_loaded)
+ pill("Self-consistency ready" if pmpnn_ok else "ProteinMPNN missing", pmpnn_ok)
+ device_pill
+ "
"
)
def refresh_status_badge() -> str:
return _status_badge_html()
def _parse_fixed_residues(text: Optional[object]) -> Optional[List[int]]:
raw = _normalize_optional_text(text)
if not raw:
return None
residues: List[int] = []
for token in raw.replace(";", ",").split(","):
token = token.strip()
if not token:
continue
if ":" in token:
token = token.split(":", 1)[1].strip()
if "-" in token:
start_str, end_str = token.split("-", 1)
start, end = int(start_str.strip()), int(end_str.strip())
if end < start:
start, end = end, start
residues.extend(range(start, end + 1))
else:
residues.append(int(token))
deduped = sorted(dict.fromkeys(residues))
return deduped or None
def _resolve_trajectory_path(sample_path: Optional[str], kind: str) -> Optional[str]:
sample_values = _normalize_selected_samples(sample_path)
if not sample_values:
return None
parent = Path(sample_values[0]).parent
filename = "bb_traj.pdb" if kind == "Backbone trajectory" else "x0_traj.pdb"
candidate = parent / filename
return candidate.resolve().as_posix() if candidate.exists() else None
def _build_trajectory_iframe_html(traj_path: Optional[str], title: str) -> str:
if not traj_path:
return _molstar_placeholder_html(
message="No trajectory found for this sample. Generate a new sample to view its flow trajectory."
)
try:
pdb_text = Path(traj_path).read_text(encoding="utf-8", errors="replace")
except Exception as exc: # pragma: no cover
LOGGER.warning("Failed to read trajectory PDB: %s", exc)
return _molstar_placeholder_html(message="Failed to read trajectory file.")
pdb_base64 = base64.b64encode(pdb_text.encode("utf-8")).decode("ascii")
srcdoc = f"""
Frame 1 / 1
"""
escaped_srcdoc = html.escape(srcdoc, quote=True)
return (
f"{html.escape(title)}
"
""
)
def load_trajectory_view(selected_samples: Optional[object], kind: str) -> str:
traj_path = _resolve_trajectory_path(selected_samples, kind)
return _build_trajectory_iframe_html(traj_path, title=f"{kind} playback")
def update_selection_label(selected_samples: Optional[object]) -> str:
values = _normalize_selected_samples(selected_samples)
if not values:
return "_No sample selected. Generate or load a sample to begin._"
viewing = _sample_label_from_path(values[0])
return (
f"**Viewing in Mol*:** `{viewing}` | "
f"**Self-consistency targets:** {len(values)} sample(s)"
)
def _append_run_history(result: InferenceResult) -> None:
try:
SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
history: List[Dict[str, object]] = []
if RUN_HISTORY_PATH.exists():
try:
history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8"))
except Exception:
history = []
entry = {
"run_id": Path(result.run_dir).name,
"mode": result.mode,
"seed": result.seed,
"num_samples": len(result.sample_files),
"num_timesteps": result.num_timesteps,
"sample_files": result.sample_files,
"created_utc": datetime.utcnow().isoformat() + "Z",
}
history = [entry] + [h for h in history if h.get("run_id") != entry["run_id"]]
history = history[:RUN_HISTORY_LIMIT]
RUN_HISTORY_PATH.write_text(json.dumps(history, indent=2), encoding="utf-8")
except Exception: # pragma: no cover - history is best-effort.
LOGGER.warning("Failed to append run history.", exc_info=True)
def _run_history_choices() -> List[Tuple[str, str]]:
if not RUN_HISTORY_PATH.exists():
return []
try:
history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8"))
except Exception:
return []
choices: List[Tuple[str, str]] = []
for entry in history:
run_id = str(entry.get("run_id", "run"))
mode = str(entry.get("mode", "?"))
n = entry.get("num_samples", 0)
seed = entry.get("seed", "?")
label = f"{run_id} | {mode} | {n} sample(s) | seed={seed}"
choices.append((label, run_id))
return choices
def refresh_run_history() -> Dict[str, object]:
choices = _run_history_choices()
return gr.update(choices=choices, value=None, interactive=bool(choices))
def _history_samples_for_run(run_id: str) -> List[str]:
if not RUN_HISTORY_PATH.exists():
return []
try:
history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8"))
except Exception:
return []
for entry in history:
if str(entry.get("run_id")) == str(run_id):
return [p for p in entry.get("sample_files", []) if Path(p).exists()]
return []
def load_run_from_history(run_id: Optional[str]):
sample_files = _history_samples_for_run(run_id) if run_id else []
if not sample_files:
empty = _empty_sc_outputs()
return (
"Selected run is unavailable (artifacts may have been cleared).",
gr.update(),
_molstar_placeholder_html(message="Run artifacts not found."),
*empty,
*_empty_sc_state(),
)
default_sample = sample_files[0]
selector_update = gr.update(choices=sample_files, value=[default_sample], interactive=True)
return (
f"Loaded run {run_id} with {len(sample_files)} sample(s).",
selector_update,
_build_molstar_iframe_html(default_sample),
*_hydrate_sc_outputs_for_samples([default_sample]),
)
def export_sc_metrics_csv(
metrics: List[Dict[str, object]],
filter_sample_labels: Optional[object],
filter_sample_query: str,
filter_tm_min: float,
filter_rmsd_max: float,
filter_plddt_min: float,
) -> Optional[str]:
if not metrics:
return None
filtered = _filter_metrics(
metrics=metrics,
sample_labels=filter_sample_labels,
sample_query=filter_sample_query,
min_tm=filter_tm_min,
max_rmsd=filter_rmsd_max,
min_plddt=filter_plddt_min,
)
if not filtered:
return None
SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
out_path = SPACE_OUTPUTS_DIR / f"sc_metrics_export_{stamp}.csv"
with out_path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.writer(handle)
writer.writerow(["sample", "seq_len", "scTM", "scRMSD", "esmfold_mean_plddt"])
for item in filtered:
writer.writerow(
[
_metric_sample_display(item),
_metric_sequence_length(item),
item.get("scTM"),
item.get("scRMSD"),
item.get("esmfold_mean_plddt"),
]
)
return out_path.resolve().as_posix()
def build_run_zip(selected_samples: Optional[object]) -> Optional[str]:
values = _normalize_selected_samples(selected_samples)
if not values:
return None
# The run directory is space_outputs//length_*/sample_*/sample.pdb -> 3 levels up.
sample_dir = Path(values[0]).parent
run_dir = sample_dir.parent.parent
if not run_dir.exists():
return None
SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
archive_base = SPACE_OUTPUTS_DIR / f"{run_dir.name}_bundle_{stamp}"
archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(run_dir))
return Path(archive_path).resolve().as_posix()
def on_sc_metrics_select(
metrics: List[Dict[str, object]],
backbone_choices_value: Optional[str],
select_data: gr.SelectData,
):
if not metrics or select_data is None:
return gr.update(), gr.update(), gr.update(), gr.update()
row_index = select_data.index[0] if isinstance(select_data.index, (list, tuple)) else select_data.index
if row_index is None or row_index >= len(metrics):
return gr.update(), gr.update(), gr.update(), gr.update()
metric = metrics[row_index]
folded_path = _normalize_optional_text(metric.get("folded_sample_path"))
backbone_path = _normalize_optional_text(metric.get("sample_source_path"))
if not folded_path or not backbone_path:
return gr.update(), gr.update(), gr.update(), gr.update()
viewer_html, status = load_aligned_compare_view(backbone_path, folded_path)
return (
gr.update(value=backbone_path),
gr.update(value=folded_path),
viewer_html,
status,
)
with gr.Blocks(
title="FlowProt Protein Designer",
theme=gr.themes.Soft(),
css=UI_CSS,
) as demo:
gr.Markdown(
"""
# FlowProt Protein Designer
Generate novel protein backbones with flow matching, explore them in 3D, and validate
their designability with ProteinMPNN and ESMFold self-consistency.
"""
)
status_badge = gr.HTML(value=_status_badge_html())
with gr.Accordion("New here? Three steps to your first design", open=True, elem_classes=["fp-card"]):
gr.Markdown(
"1. **Generate** a backbone in the Generate tab (or click *Load demo example* below).\n"
"2. **View** it in interactive 3D and watch the generation trajectory in the View tab.\n"
"3. **Analyze** designability (scTM / scRMSD) in the Analyze tab."
)
demo_cta_button = gr.Button("Load demo example", variant="primary")
sc_base_status_state = gr.State("")
sc_summary_state = gr.State({})
sc_metrics_state = gr.State([])
sc_artifacts_state = gr.State([])
with gr.Tabs():
with gr.Tab("Generate"):
with gr.Group(elem_classes=["fp-card"]):
gr.Markdown("### Inference settings")
with gr.Row(elem_classes=["fp-tight"]):
mode = gr.Dropdown(
label="Inference mode",
choices=_mode_choices(),
value=SERVICE.mvp_mode if SERVICE.mvp_mode in _mode_choices() else "unconditional",
info="Unconditional designs from scratch; classifier-guided steers toward a target class.",
)
length = gr.Slider(label="Protein length (residues)", minimum=32, maximum=1024, value=128, step=1)
num_samples = gr.Slider(label="Samples per request", minimum=1, maximum=4, value=1, step=1)
with gr.Row(elem_classes=["fp-tight"]):
seed_override = gr.Textbox(
label="Seed override (optional)",
value="",
placeholder="Leave empty to use config inference.seed",
)
num_timesteps = gr.Slider(
label="Sampling timesteps",
minimum=10,
maximum=500,
value=100,
step=10,
info="More steps can improve quality at the cost of speed.",
)
include_trajectory_artifacts = gr.Checkbox(
label="Expose trajectory downloads",
value=False,
)
with gr.Group(visible=False) as classifier_group:
gr.Markdown("**Classifier guidance**", elem_classes=["fp-subnote"])
with gr.Row(elem_classes=["fp-tight"]):
guidance_scale = gr.Slider(
label="Classifier guidance scale",
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.05,
info="Higher values steer more strongly toward the target class.",
)
target_class = gr.Dropdown(
label="Classifier target class",
choices=[0, 1],
value=1,
info="Binary classifier label to steer toward (0 or 1).",
)
with gr.Group(visible=False) as conditional_group:
gr.Markdown(
"**Conditional design** keeps the selected residues of a reference structure fixed "
"while generating the rest. Protein length is taken from the uploaded PDB.",
elem_classes=["fp-subnote"],
)
with gr.Row(elem_classes=["fp-tight"]):
reference_pdb = gr.File(label="Reference PDB (chain A)", file_types=[".pdb"])
fixed_residues_text = gr.Textbox(
label="Fixed residues",
value="",
placeholder="e.g. 10-40,55,60-62 (empty = fix all)",
)
use_classifier_guidance = gr.Checkbox(
label="Apply classifier guidance",
value=False,
)
run_button = gr.Button("Run inference", variant="primary")
with gr.Group(elem_classes=["fp-card"]):
gr.Markdown("### Outputs")
output_files = gr.Files(label="Generated sample.pdb files")
trajectory_output_files = gr.Files(label="Optional trajectory artifacts (traj + x0)")
with gr.Row(elem_classes=["fp-tight"]):
zip_button = gr.Button("Bundle current run as zip", variant="secondary")
zip_file = gr.File(label="Run bundle (.zip)")
with gr.Tab("View"):
with gr.Group(elem_classes=["fp-card"]):
sample_selector = gr.Dropdown(
label="Select generated sample(s)",
choices=[],
value=[],
multiselect=True,
interactive=False,
info="The first selected sample is shown in 3D; all selected samples feed self-consistency.",
)
selection_label = gr.Markdown(update_selection_label(None))
molstar_viewer = gr.HTML(value=_molstar_placeholder_html())
with gr.Accordion("Generation trajectory playback", open=False, elem_classes=["fp-card"]):
gr.Markdown(
"Step through the flow-matching trajectory of the selected sample. "
"Backbone trajectory shows the integrated path; x0 trajectory shows the model's denoised prediction.",
elem_classes=["fp-subnote"],
)
with gr.Row(elem_classes=["fp-tight"]):
traj_source = gr.Dropdown(
label="Trajectory",
choices=["Backbone trajectory", "x0 trajectory"],
value="Backbone trajectory",
)
load_traj_button = gr.Button("Load trajectory", variant="secondary")
trajectory_viewer = gr.HTML(
value=_molstar_placeholder_html(message="Select a sample and load a trajectory to play it back.")
)
with gr.Accordion("Aligned compare (backbone vs ESMFold)", open=False, elem_classes=["fp-card"]):
gr.Markdown(
"Overlay a generated backbone with an ESMFold-predicted structure from self-consistency. "
"Tip: clicking a row in the Analyze metrics table loads the overlay automatically.",
elem_classes=["fp-subnote"],
)
with gr.Row(elem_classes=["fp-tight"]):
backbone_compare_selector = gr.Dropdown(
label="Backbone sample for aligned compare",
choices=[],
value=None,
interactive=False,
)
folded_compare_selector = gr.Dropdown(
label="Folded structure for compare",
choices=[],
value=None,
interactive=False,
)
load_compare_button = gr.Button("Load aligned compare view", variant="secondary")
molstar_compare_status = gr.Textbox(
label="3D compare status",
lines=2,
interactive=False,
)
with gr.Tab("Analyze"):
with gr.Group(elem_classes=["fp-card"]):
gr.Markdown(
"### Self-consistency (ProteinMPNN + ESMFold)\n"
"Runs across the samples selected in the View tab.",
)
with gr.Row(elem_classes=["fp-tight"]):
sc_num_sequences = gr.Slider(
label="ProteinMPNN sequences per sample",
minimum=1,
maximum=16,
value=4,
step=1,
)
sc_run_folding = gr.Checkbox(
label="Run ESMFold and compute scTM/scRMSD",
value=True,
)
sc_run_button = gr.Button("Run self-consistency", variant="primary")
with gr.Accordion("View settings", open=False):
with gr.Row(elem_classes=["fp-tight"]):
sc_top_n = gr.Slider(
label="Leaderboard top-N",
minimum=1,
maximum=10,
value=SC_DEFAULT_TOP_N,
step=1,
)
sc_filter_samples = gr.Dropdown(
label="Filter by sample label",
choices=[],
value=[],
multiselect=True,
interactive=False,
)
sc_filter_sample_query = gr.Textbox(
label="Sample text contains",
value="",
placeholder="Optional text search on sample labels",
)
with gr.Row(elem_classes=["fp-tight"]):
sc_filter_tm_min = gr.Slider(
label="Minimum scTM",
minimum=SC_TM_MIN,
maximum=SC_TM_MAX,
value=SC_TM_MIN,
step=0.01,
)
sc_filter_rmsd_max = gr.Slider(
label="Maximum scRMSD",
minimum=SC_RMSD_MIN,
maximum=SC_RMSD_MAX,
value=SC_RMSD_MAX,
step=0.1,
)
sc_filter_plddt_min = gr.Slider(
label="Minimum ESMFold mean pLDDT",
minimum=SC_PLDDT_MIN,
maximum=SC_PLDDT_MAX,
value=SC_PLDDT_MIN,
step=0.01,
)
with gr.Row(elem_classes=["fp-tight"]):
sc_apply_view_settings_button = gr.Button("Apply view settings", variant="secondary")
sc_reset_view_settings_button = gr.Button("Reset view settings")
sc_status = gr.Textbox(
label="Self-consistency status",
lines=4,
elem_classes=["fp-status"],
)
sc_summary = gr.Dataframe(
headers=SC_SUMMARY_HEADERS,
datatype=["str", "str"],
value=[],
label="Self-consistency summary",
interactive=False,
wrap=True,
)
sc_metrics = gr.Dataframe(
headers=SC_METRICS_HEADERS,
value=[],
label="Per-sequence metrics table (click a row to compare in 3D)",
interactive=False,
wrap=True,
)
with gr.Row(elem_classes=["fp-tight"]):
sc_export_button = gr.Button("Export current metrics as CSV", variant="secondary")
sc_export_file = gr.File(label="Filtered metrics (.csv)")
with gr.Row(elem_classes=["fp-tight", "fp-plot-row"]):
with gr.Column(scale=1, min_width=560):
sc_tm_rmsd_plot = gr.HTML(
value=_empty_scatter_plot_html(
title="Self-consistency tradeoff (scTM vs scRMSD)",
message="Run self-consistency to render plot data.",
)
)
with gr.Column(scale=1, min_width=560):
sc_confidence_plot = gr.HTML(
value=_empty_scatter_plot_html(
title="ESMFold confidence vs scTM",
message="Run self-consistency with folding enabled to render confidence plot.",
)
)
with gr.Row(elem_classes=["fp-tight"]):
sc_top_tm = gr.Dataframe(
headers=SC_TOP_TM_HEADERS,
value=[],
label="Top by highest scTM",
interactive=False,
wrap=True,
)
sc_top_rmsd = gr.Dataframe(
headers=SC_TOP_RMSD_HEADERS,
value=[],
label="Top by lowest scRMSD",
interactive=False,
wrap=True,
)
sc_artifacts = gr.Files(label="Self-consistency artifacts")
with gr.Tab("Advanced"):
with gr.Group(elem_classes=["fp-card"]):
gr.Markdown("### Recent runs")
with gr.Row(elem_classes=["fp-tight"]):
history_dropdown = gr.Dropdown(
label="Run history",
choices=_run_history_choices(),
value=None,
interactive=True,
)
history_refresh_button = gr.Button("Refresh history")
history_load_button = gr.Button("Load selected run", variant="secondary")
with gr.Group(elem_classes=["fp-card"]):
gr.Markdown("### Saved example case")
with gr.Row(elem_classes=["fp-tight"]):
save_example_button = gr.Button("Save selected sample as example", variant="secondary")
load_example_button = gr.Button("Load saved example", variant="secondary")
view_example_on_startup = gr.Checkbox(
label="View saved example on startup",
value=VIEW_EXAMPLE_ON_STARTUP_DEFAULT,
)
with gr.Group(elem_classes=["fp-card"]):
gr.Markdown("### Diagnostics")
with gr.Row(elem_classes=["fp-tight"]):
preload_button = gr.Button("Preload model")
health_button = gr.Button("Refresh health")
with gr.Accordion("Status and debug details", open=False):
status = gr.Textbox(
label="Run status",
lines=4,
elem_classes=["fp-status"],
)
example_status = gr.Textbox(
label="Example case status",
lines=3,
interactive=False,
)
health = gr.JSON(label="Health check")
generation_outputs = [
status,
output_files,
trajectory_output_files,
health,
sample_selector,
molstar_viewer,
sc_status,
sc_summary,
sc_metrics,
sc_tm_rmsd_plot,
sc_confidence_plot,
sc_artifacts,
sc_top_tm,
sc_top_rmsd,
sc_base_status_state,
sc_summary_state,
sc_metrics_state,
sc_artifacts_state,
]
example_load_outputs = [
status,
output_files,
trajectory_output_files,
sample_selector,
molstar_viewer,
sc_status,
sc_summary,
sc_metrics,
sc_tm_rmsd_plot,
sc_confidence_plot,
sc_artifacts,
sc_top_tm,
sc_top_rmsd,
sc_base_status_state,
sc_summary_state,
sc_metrics_state,
sc_artifacts_state,
example_status,
]
mode.change(
fn=lambda selected: (
gr.update(visible=selected == "classifier"),
gr.update(visible=selected == "conditional"),
),
inputs=[mode],
outputs=[classifier_group, conditional_group],
)
run_event = run_button.click(
fn=run_generation,
inputs=[
mode,
length,
num_samples,
seed_override,
include_trajectory_artifacts,
guidance_scale,
target_class,
num_timesteps,
reference_pdb,
fixed_residues_text,
use_classifier_guidance,
],
outputs=generation_outputs,
)
run_event.then(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
run_event.then(fn=refresh_status_badge, inputs=None, outputs=[status_badge])
run_event.then(fn=refresh_run_history, inputs=None, outputs=[history_dropdown])
sample_selector.change(
fn=update_selected_molstar,
inputs=[sample_selector],
outputs=[
molstar_viewer,
sc_status,
sc_summary,
sc_metrics,
sc_tm_rmsd_plot,
sc_confidence_plot,
sc_artifacts,
sc_top_tm,
sc_top_rmsd,
sc_base_status_state,
sc_summary_state,
sc_metrics_state,
sc_artifacts_state,
],
)
sample_selector.change(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
sample_selector.change(fn=_update_backbone_selector, inputs=[sample_selector], outputs=[backbone_compare_selector])
sample_selector.change(fn=clear_compare_status, inputs=None, outputs=[molstar_compare_status])
load_traj_button.click(
fn=load_trajectory_view,
inputs=[sample_selector, traj_source],
outputs=[trajectory_viewer],
)
zip_button.click(fn=build_run_zip, inputs=[sample_selector], outputs=[zip_file])
save_example_button.click(fn=save_example_case, inputs=[sample_selector], outputs=[example_status])
load_example_button.click(fn=load_saved_example, inputs=None, outputs=example_load_outputs)
demo_cta_button.click(fn=load_saved_example, inputs=None, outputs=example_load_outputs)
load_example_button.click(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
demo_cta_button.click(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
sc_run_button.click(
fn=run_self_consistency,
inputs=[
sample_selector,
sc_num_sequences,
sc_run_folding,
sc_top_n,
],
outputs=[
sc_status,
sc_summary,
sc_metrics,
sc_tm_rmsd_plot,
sc_confidence_plot,
sc_artifacts,
sc_top_tm,
sc_top_rmsd,
sc_base_status_state,
sc_summary_state,
sc_metrics_state,
sc_artifacts_state,
],
)
sc_apply_view_settings_button.click(
fn=apply_sc_view_settings,
inputs=[
sc_base_status_state,
sc_summary_state,
sc_metrics_state,
sc_artifacts_state,
sc_top_n,
sc_filter_samples,
sc_filter_sample_query,
sc_filter_tm_min,
sc_filter_rmsd_max,
sc_filter_plddt_min,
],
outputs=[
sc_status,
sc_summary,
sc_metrics,
sc_tm_rmsd_plot,
sc_confidence_plot,
sc_artifacts,
sc_top_tm,
sc_top_rmsd,
],
)
sc_reset_view_settings_button.click(
fn=reset_sc_view_settings,
inputs=[
sc_base_status_state,
sc_summary_state,
sc_metrics_state,
sc_artifacts_state,
],
outputs=[
sc_top_n,
sc_filter_samples,
sc_filter_sample_query,
sc_filter_tm_min,
sc_filter_rmsd_max,
sc_filter_plddt_min,
sc_status,
sc_summary,
sc_metrics,
sc_tm_rmsd_plot,
sc_confidence_plot,
sc_artifacts,
sc_top_tm,
sc_top_rmsd,
],
)
sc_export_button.click(
fn=export_sc_metrics_csv,
inputs=[
sc_metrics_state,
sc_filter_samples,
sc_filter_sample_query,
sc_filter_tm_min,
sc_filter_rmsd_max,
sc_filter_plddt_min,
],
outputs=[sc_export_file],
)
sc_metrics_state.change(
fn=_update_sc_filter_sample_selector,
inputs=[sc_metrics_state],
outputs=[sc_filter_samples],
)
sc_metrics_state.change(
fn=_update_folded_selector,
inputs=[sc_metrics_state],
outputs=[folded_compare_selector],
)
sc_metrics.select(
fn=on_sc_metrics_select,
inputs=[sc_metrics_state, backbone_compare_selector],
outputs=[backbone_compare_selector, folded_compare_selector, molstar_viewer, molstar_compare_status],
)
load_compare_button.click(
fn=load_aligned_compare_view,
inputs=[backbone_compare_selector, folded_compare_selector],
outputs=[molstar_viewer, molstar_compare_status],
)
history_refresh_button.click(fn=refresh_run_history, inputs=None, outputs=[history_dropdown])
history_event = history_load_button.click(
fn=load_run_from_history,
inputs=[history_dropdown],
outputs=[
status,
sample_selector,
molstar_viewer,
sc_status,
sc_summary,
sc_metrics,
sc_tm_rmsd_plot,
sc_confidence_plot,
sc_artifacts,
sc_top_tm,
sc_top_rmsd,
sc_base_status_state,
sc_summary_state,
sc_metrics_state,
sc_artifacts_state,
],
)
history_event.then(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label])
preload_button.click(fn=preload_model, inputs=None, outputs=[status, health]).then(
fn=refresh_status_badge, inputs=None, outputs=[status_badge]
)
health_button.click(fn=refresh_health, inputs=None, outputs=health).then(
fn=refresh_status_badge, inputs=None, outputs=[status_badge]
)
demo.load(fn=refresh_health, inputs=None, outputs=health)
demo.load(fn=refresh_status_badge, inputs=None, outputs=[status_badge])
demo.load(fn=refresh_run_history, inputs=None, outputs=[history_dropdown])
demo.load(
fn=maybe_load_saved_example_on_startup,
inputs=[view_example_on_startup],
outputs=example_load_outputs,
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1, max_size=16)
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
show_error=True,
)