"""Gradio UI entrypoint for the FlowProt Hugging Face Docker Space.""" from __future__ import annotations import base64 import csv import html import json import logging import os import shutil import traceback from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import gradio as gr import pandas as pd from inference import ( ArtifactResolutionError, FlowProtInferenceService, InferenceError, InferenceResult, ModelLoadError, ) from self_consistency import ( FlowProtSelfConsistencyService, SelfConsistencyError, ) from viewer_alignment import AlignmentError, align_folded_to_backbone_overlay logging.basicConfig( level=os.getenv("FLOWPROT_LOG_LEVEL", "INFO").upper(), format="%(asctime)s %(levelname)s %(name)s - %(message)s", ) LOGGER = logging.getLogger(__name__) SERVICE = FlowProtInferenceService() SC_SERVICE = FlowProtSelfConsistencyService() VIEWER_HEIGHT = 620 SC_SUMMARY_HEADERS = ["Metric", "Value"] SC_METRICS_HEADERS = ["Sample", "Seq len", "scTM", "scRMSD", "ESMFold mean pLDDT"] SC_TOP_TM_HEADERS = ["Rank", "Sample", "Seq len", "scTM"] SC_TOP_RMSD_HEADERS = ["Rank", "Sample", "Seq len", "scRMSD"] EXAMPLE_CASE_DIR = ( Path(__file__).resolve().parent / "examples" / "flowprot_space_example" ).resolve() EXAMPLE_CASE_SAMPLE_DIR = EXAMPLE_CASE_DIR / "sample" EXAMPLE_CASE_MANIFEST = EXAMPLE_CASE_DIR / "manifest.json" SPACE_OUTPUTS_DIR = (Path(__file__).resolve().parent / "space_outputs").resolve() RUN_HISTORY_PATH = SPACE_OUTPUTS_DIR / "run_history.json" RUN_HISTORY_LIMIT = 25 def _env_flag(name: str, default: bool = False) -> bool: raw = os.getenv(name) if raw is None: return default return str(raw).strip().lower() in {"1", "true", "yes", "on"} VIEW_EXAMPLE_ON_STARTUP_DEFAULT = _env_flag("FLOWPROT_VIEW_EXAMPLE_ON_STARTUP", default=False) SC_DEFAULT_TOP_N = 3 SC_TM_MIN = 0.0 SC_TM_MAX = 1.0 SC_RMSD_MIN = 0.0 SC_RMSD_MAX = 10.0 SC_PLDDT_MIN = 0.0 SC_PLDDT_MAX = 1.0 UI_CSS = """ .gradio-container { max-width: 1400px !important; margin: 0 auto !important; padding-left: 14px !important; padding-right: 14px !important; } .fp-card { border: 1px solid var(--block-border-color); border-radius: 14px; padding: 14px; margin-bottom: 12px; background: var(--block-background-fill); box-shadow: 0 1px 2px rgba(15, 23, 42, 0.06); } .fp-tight { gap: 12px !important; } .fp-status textarea { min-height: 100px !important; } .fp-debug-note { color: var(--body-text-color-subdued); font-size: 0.92rem; } .fp-subnote { color: var(--body-text-color-subdued); font-size: 0.9rem; } .fp-main-row { align-items: stretch; } .fp-plot-row { align-items: stretch; } @media (max-width: 980px) { .gradio-container { padding-left: 10px !important; padding-right: 10px !important; } .fp-main-row { flex-direction: column !important; } .fp-status textarea { min-height: 88px !important; } } @media (max-width: 1280px) { .fp-plot-row { flex-direction: column !important; } } """ def _viewer_height_css() -> str: return f"min(72vh, {VIEWER_HEIGHT}px)" def _molstar_placeholder_html(message: Optional[str] = None) -> str: placeholder_message = message or "Generate a sample and select it to load in Mol*." return ( f"
" f"{html.escape(placeholder_message)}" "
" ) def _mode_choices() -> List[str]: choices = ["unconditional"] if SERVICE.classifier_enabled: choices.append("classifier") if SERVICE.conditional_enabled: choices.append("conditional") return choices def _empty_sc_plot_df() -> pd.DataFrame: # Keep numeric columns explicitly typed so Gradio/Altair infer quantitative axes. return pd.DataFrame( { "sample": pd.Series(dtype="string"), "plot_color": pd.Series(dtype="string"), "scTM": pd.Series(dtype="float64"), "scRMSD": pd.Series(dtype="float64"), "esmfold_mean_plddt": pd.Series(dtype="float64"), } ) def _empty_scatter_plot_html(title: str, message: str) -> str: return ( "
" f"
{html.escape(title)}
" "
" f"{html.escape(message)}" "
" "
" ) def _empty_sc_outputs() -> Tuple[ str, List[List[str]], List[List[object]], str, str, List[str], List[List[object]], List[List[object]], ]: return ( "", [], [], _empty_scatter_plot_html( title="Self-consistency tradeoff (scTM vs scRMSD)", message="Run self-consistency to render plot data.", ), _empty_scatter_plot_html( title="ESMFold confidence vs scTM", message="Run self-consistency with folding enabled to render confidence plot.", ), [], [], [], ) def _empty_sc_state() -> Tuple[str, Dict[str, object], List[Dict[str, object]], List[str]]: return "", {}, [], [] def _empty_example_view_outputs() -> Tuple[ str, List[str], List[str], Dict[str, object], str, str, List[List[str]], List[List[object]], pd.DataFrame, pd.DataFrame, List[str], List[List[object]], List[List[object]], str, Dict[str, object], List[Dict[str, object]], List[str], str, ]: empty_sc = _empty_sc_outputs() empty_state = _empty_sc_state() return ( "", [], [], gr.update(choices=[], value=[], interactive=False), _molstar_placeholder_html(), *empty_sc, *empty_state, "", ) def _resolve_saved_example_sample() -> Optional[str]: if EXAMPLE_CASE_MANIFEST.exists(): try: payload = json.loads(EXAMPLE_CASE_MANIFEST.read_text(encoding="utf-8")) manifest_sample = payload.get("example_sample_path") if manifest_sample: path = Path(str(manifest_sample)) if path.exists(): return path.resolve().as_posix() except Exception: LOGGER.warning("Failed to parse example manifest at %s", EXAMPLE_CASE_MANIFEST) fallback_sample = EXAMPLE_CASE_SAMPLE_DIR / "sample.pdb" if fallback_sample.exists(): return fallback_sample.resolve().as_posix() return None def _find_latest_generated_sample() -> Optional[Path]: output_root = (Path(__file__).resolve().parent / "space_outputs").resolve() if not output_root.exists(): return None candidates = sorted( output_root.glob("space_*/length_*/sample_*/sample.pdb"), key=lambda path: path.stat().st_mtime, reverse=True, ) return candidates[0] if candidates else None def _normalize_selected_samples(selected_samples: Optional[object]) -> List[str]: if selected_samples is None: return [] if isinstance(selected_samples, str): value = selected_samples.strip() return [value] if value else [] if isinstance(selected_samples, (list, tuple, set)): normalized: List[str] = [] for raw in selected_samples: if raw is None: continue value = str(raw).strip() if value: normalized.append(value) return normalized value = str(selected_samples).strip() return [value] if value else [] def _sample_label_from_path(sample_path: str, metric: Optional[Dict[str, object]] = None) -> str: path = Path(sample_path) sample_dir_name = path.parent.name sample_id = sample_dir_name.replace("sample_", "", 1) if sample_dir_name.startswith("sample_") else sample_dir_name esmfold_id: object = "x" if metric is not None and metric.get("index") is not None: esmfold_id = metric.get("index") return f"sample_{sample_id}_{esmfold_id}" def _enrich_metric_sample_fields(metric: Dict[str, object], sample_path: str) -> Dict[str, object]: enriched = dict(metric) enriched["sample_source_path"] = sample_path enriched["sample_display"] = _sample_label_from_path(sample_path, metric=metric) # Keep legacy key expected by existing views/tooltips. enriched["sample"] = enriched["sample_display"] return enriched def _metric_sample_display(metric: Dict[str, object]) -> str: value = metric.get("sample_display") or metric.get("sample") return str(value) if value is not None else "unknown" def _metric_sequence_length(metric: Dict[str, object]) -> Optional[int]: sequence = metric.get("sequence") if sequence is None: return None return len(str(sequence)) def _normalize_optional_text(value: Optional[object]) -> str: if value is None: return "" return str(value).strip() def _parse_optional_seed(seed_override: Optional[object]) -> Optional[int]: text = _normalize_optional_text(seed_override) if not text: return None try: return int(text) except ValueError as exc: raise InferenceError("Seed override must be an integer value.") from exc def _actionable_error_hint(exc: object) -> str: text = str(exc).lower() if "out of memory" in text or "cuda" in text and "memory" in text: return ( " Hint: the GPU ran out of memory. Try a shorter length, fewer samples, " "or fewer sampling timesteps." ) if "no model artifact source" in text or "artifact" in text: return ( " Hint: configure a checkpoint via FLOWPROT_CKPT_PATH, FLOWPROT_CKPT_DIR, " "or FLOWPROT_HF_REPO_ID before running inference." ) if "esmfold" in text or "transformers" in text: return " Hint: ESMFold could not run. Disable folding or verify the ESMFold model is available." return "" def _short_error_text(message: object, max_chars: int = 180) -> str: raw = str(message).strip() if not raw: return "unknown error" first_line = raw.splitlines()[0].strip() if len(first_line) <= max_chars: return first_line return first_line[: max_chars - 3].rstrip() + "..." def _build_sc_filter_sample_choices(metrics: Sequence[Dict[str, object]]) -> List[str]: ordered_unique: List[str] = [] seen = set() for item in metrics: label = _metric_sample_display(item) if label in seen: continue seen.add(label) ordered_unique.append(label) return ordered_unique def _build_folded_choices(metrics: Sequence[Dict[str, object]]) -> List[Tuple[str, str]]: choices: List[Tuple[str, str]] = [] for item in metrics: folded_path = _normalize_optional_text(item.get("folded_sample_path")) if not folded_path: continue sc_tm = item.get("scTM") sc_rmsd = item.get("scRMSD") sequence_index = item.get("index") sample_display = _metric_sample_display(item) choice_label = ( f"{sample_display} | seq#{sequence_index} | " f"scTM={float(sc_tm):.4f} | scRMSD={float(sc_rmsd):.4f}" if sc_tm is not None and sc_rmsd is not None else f"{sample_display} | seq#{sequence_index}" ) choices.append((choice_label, folded_path)) return choices def _build_sc_summary(num_selected_samples: int, num_processed_samples: int, metrics: List[Dict[str, object]]) -> Dict[str, object]: summary: Dict[str, object] = { "num_samples_selected": num_selected_samples, "num_samples_processed": num_processed_samples, "num_sequences_evaluated": len(metrics), } if metrics: tm_values = [float(item["scTM"]) for item in metrics if item.get("scTM") is not None] rmsd_values = [float(item["scRMSD"]) for item in metrics if item.get("scRMSD") is not None] plddt_values = [ float(item["esmfold_mean_plddt"]) for item in metrics if item.get("esmfold_mean_plddt") is not None ] if tm_values: summary["mean_scTM"] = float(sum(tm_values) / len(tm_values)) summary["max_scTM"] = max(tm_values) summary["min_scTM"] = min(tm_values) if rmsd_values: summary["mean_scRMSD"] = float(sum(rmsd_values) / len(rmsd_values)) summary["min_scRMSD"] = min(rmsd_values) summary["max_scRMSD"] = max(rmsd_values) if plddt_values: summary["mean_esmfold_plddt"] = float(sum(plddt_values) / len(plddt_values)) return summary def _safe_float(value: Optional[str]) -> Optional[float]: if value is None: return None text = str(value).strip() if not text: return None try: return float(text) except ValueError: return None def _load_sc_metrics_from_csv(csv_path: Path) -> List[Dict[str, object]]: metrics: List[Dict[str, object]] = [] with csv_path.open("r", encoding="utf-8", newline="") as handle: reader = csv.DictReader(handle) for row in reader: index_raw = row.get("index") index_value: object = index_raw if index_raw not in (None, ""): try: index_value = int(str(index_raw)) except ValueError: index_value = index_raw metrics.append( { "index": index_value, "header": row.get("header"), "sequence": row.get("sequence"), "scTM": _safe_float(row.get("scTM")), "scRMSD": _safe_float(row.get("scRMSD")), "esmfold_mean_plddt": _safe_float(row.get("esmfold_mean_plddt")), "folded_sample_path": row.get("folded_sample_path"), } ) return metrics def _collect_loaded_sc_artifacts(run_dir: Path, metrics_csv: Path, metrics: List[Dict[str, object]]) -> List[str]: artifacts: List[str] = [] parsed_jsonl = run_dir / "parsed_pdbs.jsonl" mpnn_fasta = run_dir / "seqs" / "sample.fa" if parsed_jsonl.exists(): artifacts.append(parsed_jsonl.resolve().as_posix()) if mpnn_fasta.exists(): artifacts.append(mpnn_fasta.resolve().as_posix()) esmf_dir = run_dir / "esmf" if esmf_dir.exists(): folded = sorted(esmf_dir.glob("*.pdb")) artifacts.extend(path.resolve().as_posix() for path in folded) else: for item in metrics: folded_path = item.get("folded_sample_path") if not folded_path: continue path = Path(str(folded_path)) if path.exists(): artifacts.append(path.resolve().as_posix()) if metrics_csv.exists(): artifacts.append(metrics_csv.resolve().as_posix()) return artifacts def _hydrate_sc_outputs_for_samples(selected_samples: Optional[object]) -> Tuple[ str, List[List[str]], List[List[object]], pd.DataFrame, pd.DataFrame, List[str], List[List[object]], List[List[object]], str, Dict[str, object], List[Dict[str, object]], List[str], ]: sample_values = _normalize_selected_samples(selected_samples) if not sample_values: return (*_empty_sc_outputs(), *_empty_sc_state()) combined_metrics: List[Dict[str, object]] = [] combined_artifacts: List[str] = [] processed_samples = 0 for sample_value in sample_values: sample_path = Path(sample_value) sample_dir = sample_path.parent sc_root = sample_dir / "self_consistency" if not sc_root.exists(): continue run_dirs = sorted( [path for path in sc_root.iterdir() if path.is_dir()], key=lambda path: path.stat().st_mtime, reverse=True, ) for run_dir in run_dirs: metrics_csv = run_dir / "sc_results.csv" if not metrics_csv.exists(): continue try: metrics = _load_sc_metrics_from_csv(metrics_csv) if not metrics: continue for item in metrics: combined_metrics.append(_enrich_metric_sample_fields(metric=item, sample_path=sample_value)) combined_artifacts.extend( _collect_loaded_sc_artifacts(run_dir=run_dir, metrics_csv=metrics_csv, metrics=metrics) ) processed_samples += 1 break except Exception as exc: # pragma: no cover LOGGER.warning("Failed to load saved self-consistency from %s: %s", metrics_csv, exc) continue if not combined_metrics: return (*_empty_sc_outputs(), *_empty_sc_state()) # Keep artifact list stable and duplicate-free. artifacts = list(dict.fromkeys(combined_artifacts)) summary = _build_sc_summary( num_selected_samples=len(sample_values), num_processed_samples=processed_samples, metrics=combined_metrics, ) base_status = ( f"Loaded saved self-consistency for {processed_samples}/{len(sample_values)} selected sample(s). " f"Evaluated={len(combined_metrics)}." ) rendered = _render_sc_outputs( base_status=base_status, summary=summary, metrics=combined_metrics, artifacts=artifacts, top_n=SC_DEFAULT_TOP_N, ) return (*rendered, base_status, summary, combined_metrics, artifacts) def save_example_case(selected_samples: Optional[object]) -> str: selected_values = _normalize_selected_samples(selected_samples) selected_sample = selected_values[0] if selected_values else None if not selected_sample: latest = _find_latest_generated_sample() if latest is None: return "No generated sample found. Run inference or select a sample before saving an example case." selected_sample = latest.resolve().as_posix() sample_path = Path(selected_sample) if not sample_path.exists(): return f"Selected sample does not exist: {sample_path}" source_dir = sample_path.parent try: EXAMPLE_CASE_DIR.mkdir(parents=True, exist_ok=True) if EXAMPLE_CASE_SAMPLE_DIR.exists(): shutil.rmtree(EXAMPLE_CASE_SAMPLE_DIR) shutil.copytree(source_dir, EXAMPLE_CASE_SAMPLE_DIR) saved_sample_path = EXAMPLE_CASE_SAMPLE_DIR / sample_path.name if not saved_sample_path.exists(): return ( "Example case copy completed, but sample file was missing in copied directory. " "Please verify the source sample." ) manifest = { "saved_at_utc": datetime.utcnow().isoformat() + "Z", "source_sample_path": sample_path.resolve().as_posix(), "example_sample_path": saved_sample_path.resolve().as_posix(), } EXAMPLE_CASE_MANIFEST.write_text(json.dumps(manifest, indent=2), encoding="utf-8") return ( f"Saved example case to {EXAMPLE_CASE_SAMPLE_DIR}. " "Set FLOWPROT_VIEW_EXAMPLE_ON_STARTUP=true to auto-load it at app startup." ) except Exception as exc: # pragma: no cover LOGGER.exception("Failed to save example case.") return f"Failed to save example case: {exc}" def load_saved_example( prefix_status: str = "Loaded saved example case.", prefix_example_status: str = "Example case loaded.", ) -> Tuple[ str, List[str], List[str], Dict[str, object], str, str, List[List[str]], List[List[object]], pd.DataFrame, pd.DataFrame, List[str], List[List[object]], List[List[object]], str, Dict[str, object], List[Dict[str, object]], List[str], str, ]: sample_path = _resolve_saved_example_sample() if not sample_path: empty = _empty_example_view_outputs() return (*empty[:-1], "No saved example found in examples/flowprot_space_example. Save one first.") selector_update = gr.update(choices=[sample_path], value=[sample_path], interactive=True) sc_payload = _hydrate_sc_outputs_for_samples([sample_path]) loaded_metrics = sc_payload[10] sc_note = ( f" Loaded self-consistency ({len(loaded_metrics)} sequence(s))." if loaded_metrics else " No saved self-consistency metrics found for this example." ) return ( f"{prefix_status} Source: {sample_path}", [sample_path], [], selector_update, _build_molstar_iframe_html(sample_path), *sc_payload, f"{prefix_example_status} Source: {sample_path}.{sc_note}", ) def maybe_load_saved_example_on_startup( view_example_on_startup: bool, ) -> Tuple[ object, object, object, object, object, object, object, object, object, object, object, object, object, object, object, object, object, str, ]: if not bool(view_example_on_startup): return ( gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), "", ) return load_saved_example( prefix_status="Loaded saved example case on startup.", prefix_example_status="Startup example loaded.", ) def _format_summary_rows(summary: Dict[str, object]) -> List[List[str]]: preferred = [ ("mean_scTM", "scTM score"), ("mean_scRMSD", "scRMSD score"), ("mean_esmfold_plddt", "Mean pLDDT"), ] rows: List[List[str]] = [] for key, label in preferred: if key not in summary: continue value = summary[key] if isinstance(value, float): precision = 4 if "plddt" not in key else 4 rows.append([label, f"{value:.{precision}f}"]) else: rows.append([label, str(value)]) return rows def _format_metrics_rows(metrics: List[Dict[str, object]]) -> List[List[object]]: rows: List[List[object]] = [] for item in metrics: rows.append( [ _metric_sample_display(item), _metric_sequence_length(item), round(float(item.get("scTM")), 4) if item.get("scTM") is not None else None, round(float(item.get("scRMSD")), 4) if item.get("scRMSD") is not None else None, round(float(item.get("esmfold_mean_plddt")), 4) if item.get("esmfold_mean_plddt") is not None else None, ] ) return rows def _format_metrics_plot_df(metrics: List[Dict[str, object]]) -> pd.DataFrame: plot_rows = [ { "sample": _metric_sample_display(item), "plot_color": "metric", "scTM": item.get("scTM"), "scRMSD": item.get("scRMSD"), "esmfold_mean_plddt": item.get("esmfold_mean_plddt"), } for item in metrics if item.get("scTM") is not None and item.get("scRMSD") is not None ] if not plot_rows: return _empty_sc_plot_df() frame = pd.DataFrame(plot_rows) for col in ("scTM", "scRMSD", "esmfold_mean_plddt"): frame[col] = pd.to_numeric(frame[col], errors="coerce") return frame def _format_leaderboard_rows( metrics: List[Dict[str, object]], top_n: int, ) -> Tuple[List[List[object]], List[List[object]]]: if not metrics: return [], [] ranked_tm = sorted( metrics, key=lambda item: float(item["scTM"]), reverse=True, )[: int(top_n)] top_tm_rows: List[List[object]] = [] for rank, item in enumerate(ranked_tm, start=1): top_tm_rows.append( [ rank, _metric_sample_display(item), _metric_sequence_length(item), round(float(item["scTM"]), 4), ] ) ranked_rmsd = sorted( metrics, key=lambda item: float(item["scRMSD"]), )[: int(top_n)] top_rmsd_rows: List[List[object]] = [] for rank, item in enumerate(ranked_rmsd, start=1): top_rmsd_rows.append( [ rank, _metric_sample_display(item), _metric_sequence_length(item), round(float(item["scRMSD"]), 4), ] ) return top_tm_rows, top_rmsd_rows def _filter_metrics( metrics: Sequence[Dict[str, object]], sample_labels: Optional[object], sample_query: str, min_tm: float, max_rmsd: float, min_plddt: float, ) -> List[Dict[str, object]]: selected_sample_labels = set(_normalize_selected_samples(sample_labels)) query = sample_query.strip().lower() filtered: List[Dict[str, object]] = [] for item in metrics: sample_display = _metric_sample_display(item) if selected_sample_labels and sample_display not in selected_sample_labels: continue if query and query not in sample_display.lower(): continue tm_value = item.get("scTM") if tm_value is None or float(tm_value) < float(min_tm): continue rmsd_value = item.get("scRMSD") if rmsd_value is None or float(rmsd_value) > float(max_rmsd): continue plddt_value = item.get("esmfold_mean_plddt") if plddt_value is not None and float(plddt_value) < float(min_plddt): continue filtered.append(dict(item)) return filtered def _filters_summary_line( sample_labels: Optional[object], sample_query: str, min_tm: float, max_rmsd: float, min_plddt: float, ) -> Optional[str]: labels = _normalize_selected_samples(sample_labels) parts: List[str] = [] if labels: parts.append(f"samples={len(labels)}") query = sample_query.strip() if query: parts.append(f'text="{query}"') if float(min_tm) > SC_TM_MIN: parts.append(f"scTM>={float(min_tm):.2f}") if float(max_rmsd) < SC_RMSD_MAX: parts.append(f"scRMSD<={float(max_rmsd):.2f}") if float(min_plddt) > SC_PLDDT_MIN: parts.append(f"pLDDT>={float(min_plddt):.2f}") if not parts: return None return "Filters active: " + ", ".join(parts) def _build_scatter_plot_html( plot_df: pd.DataFrame, x_col: str, y_col: str, title: str, x_title: str, y_title: str, fallback_x_range: Tuple[float, float], fallback_y_range: Tuple[float, float], ) -> str: if plot_df.empty: return _empty_scatter_plot_html(title=title, message="No points available for this view.") keep_cols = [x_col, y_col] if "sample" in plot_df.columns: keep_cols.append("sample") rows = plot_df[keep_cols].copy() rows[x_col] = pd.to_numeric(rows[x_col], errors="coerce") rows[y_col] = pd.to_numeric(rows[y_col], errors="coerce") rows = rows.dropna(subset=[x_col, y_col]) if rows.empty: return _empty_scatter_plot_html(title=title, message="No numeric points available for this view.") x_min, x_max = fallback_x_range y_min, y_max = fallback_y_range rows = rows[(rows[x_col] >= x_min) & (rows[x_col] <= x_max) & (rows[y_col] >= y_min) & (rows[y_col] <= y_max)] if rows.empty: return _empty_scatter_plot_html(title=title, message="No points within the fixed axis range for this view.") x_vals = rows[x_col].astype(float).tolist() y_vals = rows[y_col].astype(float).tolist() labels = rows["sample"].astype(str).tolist() if "sample" in rows.columns else [""] * len(x_vals) x_span = max(x_max - x_min, 1e-9) y_span = max(y_max - y_min, 1e-9) width = 760 height = 420 left = 74 right = 26 top = 36 bottom = 58 plot_w = width - left - right plot_h = height - top - bottom def sx(value: float) -> float: return left + ((value - x_min) / x_span) * plot_w def sy(value: float) -> float: return top + (1.0 - ((value - y_min) / y_span)) * plot_h points_svg = [] for xv, yv, label in zip(x_vals, y_vals, labels): tooltip = html.escape( f"{label} | {x_title.split(' (')[0]}={xv:.3f} | {y_title.split(' (')[0]}={yv:.3f}" ) points_svg.append( f"" f"{tooltip}" ) x_ticks = [] y_ticks = [] tick_count = 5 for idx in range(tick_count + 1): ratio = idx / tick_count xv = x_min + ratio * x_span yv = y_min + ratio * y_span tx = left + ratio * plot_w ty = top + (1.0 - ratio) * plot_h x_ticks.append( f"" f"{xv:.2f}" ) y_ticks.append( f"" f"{yv:.2f}" ) return ( "
" f"
{html.escape(title)}
" f"" f"" + "".join( f"" for i in range(tick_count + 1) ) + "".join(x_ticks) + "".join(y_ticks) + "".join(points_svg) + f"{html.escape(x_title)}" + f"{html.escape(y_title)}" + "" + "
" ) def _render_sc_outputs( base_status: str, summary: Dict[str, object], metrics: List[Dict[str, object]], artifacts: List[str], top_n: int, extra_status_line: Optional[str] = None, ) -> Tuple[ str, List[List[str]], List[List[object]], str, str, List[str], List[List[object]], List[List[object]], ]: status_lines = [line for line in [base_status, extra_status_line] if line] if metrics: status_lines.append(f"Showing {len(metrics)} evaluated sequence(s).") status = "\n".join(status_lines) summary_rows = _format_summary_rows(summary) metrics_rows = _format_metrics_rows(metrics) plot_df = _format_metrics_plot_df(metrics) tm_vs_rmsd_html = _build_scatter_plot_html( plot_df=plot_df, x_col="scTM", y_col="scRMSD", title="Self-consistency tradeoff (scTM vs scRMSD)", x_title="scTM (higher is better)", y_title="scRMSD (lower is better)", fallback_x_range=(SC_TM_MIN, SC_TM_MAX), fallback_y_range=(SC_RMSD_MIN, SC_RMSD_MAX), ) confidence_html = _build_scatter_plot_html( plot_df=plot_df, x_col="esmfold_mean_plddt", y_col="scTM", title="ESMFold confidence vs scTM", x_title="ESMFold mean pLDDT", y_title="scTM", fallback_x_range=(SC_PLDDT_MIN, SC_PLDDT_MAX), fallback_y_range=(SC_TM_MIN, SC_TM_MAX), ) top_tm_rows, top_rmsd_rows = _format_leaderboard_rows(metrics, int(top_n)) return status, summary_rows, metrics_rows, tm_vs_rmsd_html, confidence_html, artifacts, top_tm_rows, top_rmsd_rows def refresh_health() -> Dict[str, object]: health = SERVICE.health_check() health["self_consistency"] = SC_SERVICE.health_check() return health def _resolve_viewer_value(sample_path: Optional[str]) -> Optional[str]: if not sample_path: return None normalized = Path(sample_path) if normalized.exists(): # Use POSIX-style separators for robust frontend URL serialization. return normalized.resolve().as_posix() return None def _build_molstar_iframe_html( sample_path: Optional[str], viewer_message: Optional[str] = None, ) -> str: viewer_value = _resolve_viewer_value(sample_path) if not viewer_value: return _molstar_placeholder_html(message=viewer_message) try: pdb_text = Path(viewer_value).read_text(encoding="utf-8", errors="replace") except Exception as exc: # pragma: no cover LOGGER.warning("Failed to read PDB for Mol*: %s", exc) return ( f"
" "Failed to load structure for Mol*." "
" ) pdb_base64 = base64.b64encode(pdb_text.encode("utf-8")).decode("ascii") srcdoc = f"""
""" escaped_srcdoc = html.escape(srcdoc, quote=True) return ( "" ) def update_selected_molstar(selected_samples: Optional[object]): selected_values = _normalize_selected_samples(selected_samples) viewer_sample = selected_values[0] if selected_values else None return ( _build_molstar_iframe_html(viewer_sample), *_hydrate_sc_outputs_for_samples(selected_values), ) def _update_backbone_selector(selected_samples: Optional[object]) -> Dict[str, object]: selected_values = _normalize_selected_samples(selected_samples) return gr.update( choices=selected_values, value=selected_values[0] if selected_values else None, interactive=bool(selected_values), ) def _update_sc_filter_sample_selector(metrics: Optional[List[Dict[str, object]]]) -> Dict[str, object]: choices = _build_sc_filter_sample_choices(metrics or []) return gr.update(choices=choices, value=[], interactive=bool(choices)) def _update_folded_selector(metrics: Optional[List[Dict[str, object]]]) -> Dict[str, object]: choices = _build_folded_choices(metrics or []) return gr.update( choices=choices, value=choices[0][1] if choices else None, interactive=bool(choices), ) def load_aligned_compare_view( backbone_sample_path: Optional[str], folded_sample_path: Optional[str], ) -> Tuple[str, str]: backbone_value = _normalize_optional_text(backbone_sample_path) folded_value = _normalize_optional_text(folded_sample_path) if not backbone_value: return ( _molstar_placeholder_html(message="Select a backbone sample to compare."), "Select a backbone sample before loading aligned compare view.", ) if not folded_value: return ( _build_molstar_iframe_html(backbone_value), "Select a folded structure from self-consistency rows before comparison.", ) try: overlay = align_folded_to_backbone_overlay( backbone_sample_path=backbone_value, folded_sample_path=folded_value, ) viewer = _build_molstar_iframe_html( overlay.output_path, viewer_message="Aligned backbone + ESMFold comparison.", ) details = ( "Loaded aligned compare view. " f"CA pairs={overlay.num_ca_pairs}, " f"RMSD before={overlay.rmsd_before:.4f}, after={overlay.rmsd_after:.4f}. " f"Overlay: {overlay.output_path}" ) return viewer, details except AlignmentError as exc: LOGGER.error("Alignment compare failed: %s", exc) return ( _build_molstar_iframe_html(backbone_value), f"Failed to build aligned compare view: {exc}", ) except Exception as exc: # pragma: no cover LOGGER.exception("Unexpected compare-view failure.") return ( _build_molstar_iframe_html(backbone_value), f"Unexpected compare-view failure: {exc}", ) def preload_model() -> Tuple[str, Dict[str, object]]: try: SERVICE.preload_model() return "Model loaded successfully.", refresh_health() except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc: LOGGER.error("Model preload failed: %s", exc) return f"Model preload failed: {exc}", refresh_health() except Exception as exc: # pragma: no cover LOGGER.exception("Unexpected preload failure.") return f"Unexpected preload failure: {exc}", refresh_health() def _collect_trajectory_artifacts( include_trajectory_artifacts: bool, result: InferenceResult, ) -> List[str]: if not include_trajectory_artifacts: return [] combined = [*result.trajectory_files, *result.x0_trajectory_files] return list(dict.fromkeys(combined)) def run_generation( mode: str, length: int, num_samples: int, seed_override: Optional[object], include_trajectory_artifacts: bool, guidance_scale: float, target_class: int, num_timesteps: Optional[object], reference_pdb: Optional[object], fixed_residues_text: Optional[object], use_classifier_guidance: bool, progress=gr.Progress(track_tqdm=False), ): try: progress(0.05, desc="Loading model and preparing inputs...") seed = _parse_optional_seed(seed_override) timesteps = None timesteps_text = _normalize_optional_text(num_timesteps) if timesteps_text: timesteps = int(float(timesteps_text)) reference_path = None if reference_pdb is not None: reference_path = getattr(reference_pdb, "name", None) or str(reference_pdb) fixed_residues = _parse_fixed_residues(fixed_residues_text) if mode == "conditional" else None progress(0.2, desc=f"Sampling {int(num_samples)} structure(s)...") result = SERVICE.generate( mode=mode, length=int(length), num_samples=int(num_samples), seed=seed, guidance_scale=float(guidance_scale) if mode in ("classifier", "conditional") else None, target_class=int(target_class) if mode in ("classifier", "conditional") else None, num_timesteps=timesteps, reference_pdb_path=reference_path, fixed_residues=fixed_residues, use_classifier_guidance=bool(use_classifier_guidance), ) progress(0.9, desc="Writing artifacts...") _append_run_history(result) trajectory_artifacts = _collect_trajectory_artifacts( include_trajectory_artifacts=bool(include_trajectory_artifacts), result=result, ) trajectory_message = ( f"Trajectory artifacts ready: {len(trajectory_artifacts)} file(s)." if include_trajectory_artifacts else "Trajectory artifacts hidden; enable the checkbox to download traj/x0 files." ) classifier_details = "" if result.mode == "classifier": classifier_details = ( f" Guidance scale={result.guidance_scale}, " f"target_class={result.target_class}." ) elif result.mode == "conditional": classifier_details = f" Fixed residues={result.fixed_residue_count}." steps_details = f" Timesteps={result.num_timesteps}." if result.num_timesteps else "" summary = ( f"Generated {len(result.sample_files)} sample(s) with mode={result.mode}. " f"Seed={result.seed}.{steps_details} Output dir: {result.run_dir}. " f"Artifacts source: {result.artifacts_source}.{classifier_details} " f"{trajectory_message}" ) default_sample = result.sample_files[0] if result.sample_files else None selector_update = gr.update( choices=result.sample_files, value=[default_sample] if default_sample else [], interactive=bool(result.sample_files), ) return ( summary, result.sample_files, trajectory_artifacts, refresh_health(), selector_update, _build_molstar_iframe_html(default_sample), *_empty_sc_outputs(), *_empty_sc_state(), ) except (ArtifactResolutionError, ModelLoadError, InferenceError) as exc: LOGGER.error("Inference request failed: %s", exc) return ( f"Inference failed: {exc}{_actionable_error_hint(exc)}", [], [], refresh_health(), gr.update(choices=[], value=[], interactive=False), _molstar_placeholder_html(), *_empty_sc_outputs(), *_empty_sc_state(), ) except Exception as exc: # pragma: no cover LOGGER.exception("Unhandled inference failure.") trace = traceback.format_exc(limit=1) return ( f"Unexpected error: {exc}{_actionable_error_hint(exc)}\n{trace}", [], [], refresh_health(), gr.update(choices=[], value=[], interactive=False), _molstar_placeholder_html(), *_empty_sc_outputs(), *_empty_sc_state(), ) def run_self_consistency( selected_samples: Optional[object], num_seq_per_target: int, run_folding: bool, top_n: int, progress=gr.Progress(track_tqdm=False), ): sample_values = _normalize_selected_samples(selected_samples) if not sample_values: empty_outputs = _empty_sc_outputs() return ( "Select one or more samples before running self-consistency.", *empty_outputs[1:], *_empty_sc_state(), ) try: all_metrics: List[Dict[str, object]] = [] all_artifacts: List[str] = [] processed_samples = 0 failed_samples: List[str] = [] total_samples = len(sample_values) for sample_idx, sample_value in enumerate(sample_values): sample_label = _sample_label_from_path(sample_value) progress( sample_idx / max(total_samples, 1), desc=f"ProteinMPNN + folding ({sample_idx + 1}/{total_samples}): {sample_label}", ) def _fold_progress(done: int, total: int, _idx=sample_idx, _label=sample_label) -> None: if total <= 0: return frac = (_idx + done / total) / max(total_samples, 1) progress(frac, desc=f"Folding {_label}: {done}/{total} sequences") try: result = SC_SERVICE.run( sample_path=sample_value, num_seq_per_target=int(num_seq_per_target), run_folding=bool(run_folding), progress_callback=_fold_progress, ) for item in result.per_sequence_metrics: all_metrics.append(_enrich_metric_sample_fields(metric=item, sample_path=sample_value)) all_artifacts.extend( [ result.parsed_pdbs_jsonl, result.mpnn_fasta_path, *result.folded_pdb_paths, ] ) if result.metrics_csv_path: all_artifacts.append(result.metrics_csv_path) processed_samples += 1 except SelfConsistencyError as exc: LOGGER.error("Self-consistency failed for %s: %s", sample_value, exc) failed_samples.append( f"{_sample_label_from_path(sample_value)} ({_short_error_text(exc)})" ) if not all_metrics: empty_outputs = _empty_sc_outputs() fail_preview = "\n".join(f"- {sample}" for sample in failed_samples[:5]) if failed_samples and len(failed_samples) > 5: fail_preview += f"\n... and {len(failed_samples) - 5} more failure(s)." fail_context = fail_preview if failed_samples else "No metrics were generated." return ( "Self-consistency failed across selected samples.\n" f"Requested sequences/sample={int(num_seq_per_target)}, folding={'on' if run_folding else 'off'}.\n" f"{fail_context}", *empty_outputs[1:], *_empty_sc_state(), ) artifacts = list(dict.fromkeys(all_artifacts)) summary = _build_sc_summary( num_selected_samples=len(sample_values), num_processed_samples=processed_samples, metrics=all_metrics, ) base_status_lines = [ f"Self-consistency completed for {processed_samples}/{len(sample_values)} selected sample(s).", f"Evaluated={len(all_metrics)} sequences. Requested sequences/sample={int(num_seq_per_target)}.", f"Folding={'on' if run_folding else 'off'}. Artifacts={len(artifacts)}.", ] if failed_samples: preview = "; ".join(failed_samples[:3]) suffix = f"; +{len(failed_samples) - 3} more" if len(failed_samples) > 3 else "" base_status_lines.append(f"Failed samples: {len(failed_samples)} ({preview}{suffix})") base_status = "\n".join(base_status_lines) rendered = _render_sc_outputs( base_status=base_status, summary=summary, metrics=all_metrics, artifacts=artifacts, top_n=top_n, ) return (*rendered, base_status, summary, all_metrics, artifacts) except Exception as exc: # pragma: no cover LOGGER.exception("Unexpected self-consistency failure.") trace = traceback.format_exc(limit=1) empty_outputs = _empty_sc_outputs() return ( f"Unexpected self-consistency error: {exc}\n{trace}", *empty_outputs[1:], *_empty_sc_state(), ) def apply_sc_view_settings( base_status: str, summary: Dict[str, object], metrics: List[Dict[str, object]], artifacts: List[str], top_n: int, filter_sample_labels: Optional[object], filter_sample_query: str, filter_tm_min: float, filter_rmsd_max: float, filter_plddt_min: float, ): filtered_metrics = _filter_metrics( metrics=metrics, sample_labels=filter_sample_labels, sample_query=filter_sample_query, min_tm=filter_tm_min, max_rmsd=filter_rmsd_max, min_plddt=filter_plddt_min, ) filters_line = _filters_summary_line( sample_labels=filter_sample_labels, sample_query=filter_sample_query, min_tm=filter_tm_min, max_rmsd=filter_rmsd_max, min_plddt=filter_plddt_min, ) return _render_sc_outputs( base_status=base_status, summary=summary, metrics=filtered_metrics, artifacts=artifacts, top_n=top_n, extra_status_line=filters_line, ) def reset_sc_view_settings( base_status: str, summary: Dict[str, object], metrics: List[Dict[str, object]], artifacts: List[str], ): rendered = _render_sc_outputs( base_status=base_status, summary=summary, metrics=metrics, artifacts=artifacts, top_n=SC_DEFAULT_TOP_N, ) return ( gr.update(value=SC_DEFAULT_TOP_N), gr.update(value=[]), gr.update(value=""), gr.update(value=SC_TM_MIN), gr.update(value=SC_RMSD_MAX), gr.update(value=SC_PLDDT_MIN), *rendered, ) def clear_compare_status() -> str: return "" def _status_badge_html() -> str: health = SERVICE.health_check() model_loaded = bool(health.get("model_loaded")) device = str(health.get("device", "auto")) sc_health = SC_SERVICE.health_check() pmpnn_ok = bool(sc_health.get("pmpnn_available")) and bool(sc_health.get("pmpnn_weights_available")) def pill(label: str, ok: bool) -> str: color = "#16a34a" if ok else "#b91c1c" bg = "#dcfce7" if ok else "#fee2e2" dot = "●" return ( f"{dot} {html.escape(label)}" ) device_pill = ( f"Device: {html.escape(device)}" ) return ( "
" + pill("Model loaded" if model_loaded else "Model not loaded", model_loaded) + pill("Self-consistency ready" if pmpnn_ok else "ProteinMPNN missing", pmpnn_ok) + device_pill + "
" ) def refresh_status_badge() -> str: return _status_badge_html() def _parse_fixed_residues(text: Optional[object]) -> Optional[List[int]]: raw = _normalize_optional_text(text) if not raw: return None residues: List[int] = [] for token in raw.replace(";", ",").split(","): token = token.strip() if not token: continue if ":" in token: token = token.split(":", 1)[1].strip() if "-" in token: start_str, end_str = token.split("-", 1) start, end = int(start_str.strip()), int(end_str.strip()) if end < start: start, end = end, start residues.extend(range(start, end + 1)) else: residues.append(int(token)) deduped = sorted(dict.fromkeys(residues)) return deduped or None def _resolve_trajectory_path(sample_path: Optional[str], kind: str) -> Optional[str]: sample_values = _normalize_selected_samples(sample_path) if not sample_values: return None parent = Path(sample_values[0]).parent filename = "bb_traj.pdb" if kind == "Backbone trajectory" else "x0_traj.pdb" candidate = parent / filename return candidate.resolve().as_posix() if candidate.exists() else None def _build_trajectory_iframe_html(traj_path: Optional[str], title: str) -> str: if not traj_path: return _molstar_placeholder_html( message="No trajectory found for this sample. Generate a new sample to view its flow trajectory." ) try: pdb_text = Path(traj_path).read_text(encoding="utf-8", errors="replace") except Exception as exc: # pragma: no cover LOGGER.warning("Failed to read trajectory PDB: %s", exc) return _molstar_placeholder_html(message="Failed to read trajectory file.") pdb_base64 = base64.b64encode(pdb_text.encode("utf-8")).decode("ascii") srcdoc = f"""
Frame 1 / 1
""" escaped_srcdoc = html.escape(srcdoc, quote=True) return ( f"
{html.escape(title)}
" "" ) def load_trajectory_view(selected_samples: Optional[object], kind: str) -> str: traj_path = _resolve_trajectory_path(selected_samples, kind) return _build_trajectory_iframe_html(traj_path, title=f"{kind} playback") def update_selection_label(selected_samples: Optional[object]) -> str: values = _normalize_selected_samples(selected_samples) if not values: return "_No sample selected. Generate or load a sample to begin._" viewing = _sample_label_from_path(values[0]) return ( f"**Viewing in Mol*:** `{viewing}`  |  " f"**Self-consistency targets:** {len(values)} sample(s)" ) def _append_run_history(result: InferenceResult) -> None: try: SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True) history: List[Dict[str, object]] = [] if RUN_HISTORY_PATH.exists(): try: history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8")) except Exception: history = [] entry = { "run_id": Path(result.run_dir).name, "mode": result.mode, "seed": result.seed, "num_samples": len(result.sample_files), "num_timesteps": result.num_timesteps, "sample_files": result.sample_files, "created_utc": datetime.utcnow().isoformat() + "Z", } history = [entry] + [h for h in history if h.get("run_id") != entry["run_id"]] history = history[:RUN_HISTORY_LIMIT] RUN_HISTORY_PATH.write_text(json.dumps(history, indent=2), encoding="utf-8") except Exception: # pragma: no cover - history is best-effort. LOGGER.warning("Failed to append run history.", exc_info=True) def _run_history_choices() -> List[Tuple[str, str]]: if not RUN_HISTORY_PATH.exists(): return [] try: history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8")) except Exception: return [] choices: List[Tuple[str, str]] = [] for entry in history: run_id = str(entry.get("run_id", "run")) mode = str(entry.get("mode", "?")) n = entry.get("num_samples", 0) seed = entry.get("seed", "?") label = f"{run_id} | {mode} | {n} sample(s) | seed={seed}" choices.append((label, run_id)) return choices def refresh_run_history() -> Dict[str, object]: choices = _run_history_choices() return gr.update(choices=choices, value=None, interactive=bool(choices)) def _history_samples_for_run(run_id: str) -> List[str]: if not RUN_HISTORY_PATH.exists(): return [] try: history = json.loads(RUN_HISTORY_PATH.read_text(encoding="utf-8")) except Exception: return [] for entry in history: if str(entry.get("run_id")) == str(run_id): return [p for p in entry.get("sample_files", []) if Path(p).exists()] return [] def load_run_from_history(run_id: Optional[str]): sample_files = _history_samples_for_run(run_id) if run_id else [] if not sample_files: empty = _empty_sc_outputs() return ( "Selected run is unavailable (artifacts may have been cleared).", gr.update(), _molstar_placeholder_html(message="Run artifacts not found."), *empty, *_empty_sc_state(), ) default_sample = sample_files[0] selector_update = gr.update(choices=sample_files, value=[default_sample], interactive=True) return ( f"Loaded run {run_id} with {len(sample_files)} sample(s).", selector_update, _build_molstar_iframe_html(default_sample), *_hydrate_sc_outputs_for_samples([default_sample]), ) def export_sc_metrics_csv( metrics: List[Dict[str, object]], filter_sample_labels: Optional[object], filter_sample_query: str, filter_tm_min: float, filter_rmsd_max: float, filter_plddt_min: float, ) -> Optional[str]: if not metrics: return None filtered = _filter_metrics( metrics=metrics, sample_labels=filter_sample_labels, sample_query=filter_sample_query, min_tm=filter_tm_min, max_rmsd=filter_rmsd_max, min_plddt=filter_plddt_min, ) if not filtered: return None SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True) stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") out_path = SPACE_OUTPUTS_DIR / f"sc_metrics_export_{stamp}.csv" with out_path.open("w", encoding="utf-8", newline="") as handle: writer = csv.writer(handle) writer.writerow(["sample", "seq_len", "scTM", "scRMSD", "esmfold_mean_plddt"]) for item in filtered: writer.writerow( [ _metric_sample_display(item), _metric_sequence_length(item), item.get("scTM"), item.get("scRMSD"), item.get("esmfold_mean_plddt"), ] ) return out_path.resolve().as_posix() def build_run_zip(selected_samples: Optional[object]) -> Optional[str]: values = _normalize_selected_samples(selected_samples) if not values: return None # The run directory is space_outputs//length_*/sample_*/sample.pdb -> 3 levels up. sample_dir = Path(values[0]).parent run_dir = sample_dir.parent.parent if not run_dir.exists(): return None SPACE_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True) stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") archive_base = SPACE_OUTPUTS_DIR / f"{run_dir.name}_bundle_{stamp}" archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(run_dir)) return Path(archive_path).resolve().as_posix() def on_sc_metrics_select( metrics: List[Dict[str, object]], backbone_choices_value: Optional[str], select_data: gr.SelectData, ): if not metrics or select_data is None: return gr.update(), gr.update(), gr.update(), gr.update() row_index = select_data.index[0] if isinstance(select_data.index, (list, tuple)) else select_data.index if row_index is None or row_index >= len(metrics): return gr.update(), gr.update(), gr.update(), gr.update() metric = metrics[row_index] folded_path = _normalize_optional_text(metric.get("folded_sample_path")) backbone_path = _normalize_optional_text(metric.get("sample_source_path")) if not folded_path or not backbone_path: return gr.update(), gr.update(), gr.update(), gr.update() viewer_html, status = load_aligned_compare_view(backbone_path, folded_path) return ( gr.update(value=backbone_path), gr.update(value=folded_path), viewer_html, status, ) with gr.Blocks( title="FlowProt Protein Designer", theme=gr.themes.Soft(), css=UI_CSS, ) as demo: gr.Markdown( """ # FlowProt Protein Designer Generate novel protein backbones with flow matching, explore them in 3D, and validate their designability with ProteinMPNN and ESMFold self-consistency. """ ) status_badge = gr.HTML(value=_status_badge_html()) with gr.Accordion("New here? Three steps to your first design", open=True, elem_classes=["fp-card"]): gr.Markdown( "1. **Generate** a backbone in the Generate tab (or click *Load demo example* below).\n" "2. **View** it in interactive 3D and watch the generation trajectory in the View tab.\n" "3. **Analyze** designability (scTM / scRMSD) in the Analyze tab." ) demo_cta_button = gr.Button("Load demo example", variant="primary") sc_base_status_state = gr.State("") sc_summary_state = gr.State({}) sc_metrics_state = gr.State([]) sc_artifacts_state = gr.State([]) with gr.Tabs(): with gr.Tab("Generate"): with gr.Group(elem_classes=["fp-card"]): gr.Markdown("### Inference settings") with gr.Row(elem_classes=["fp-tight"]): mode = gr.Dropdown( label="Inference mode", choices=_mode_choices(), value=SERVICE.mvp_mode if SERVICE.mvp_mode in _mode_choices() else "unconditional", info="Unconditional designs from scratch; classifier-guided steers toward a target class.", ) length = gr.Slider(label="Protein length (residues)", minimum=32, maximum=1024, value=128, step=1) num_samples = gr.Slider(label="Samples per request", minimum=1, maximum=4, value=1, step=1) with gr.Row(elem_classes=["fp-tight"]): seed_override = gr.Textbox( label="Seed override (optional)", value="", placeholder="Leave empty to use config inference.seed", ) num_timesteps = gr.Slider( label="Sampling timesteps", minimum=10, maximum=500, value=100, step=10, info="More steps can improve quality at the cost of speed.", ) include_trajectory_artifacts = gr.Checkbox( label="Expose trajectory downloads", value=False, ) with gr.Group(visible=False) as classifier_group: gr.Markdown("**Classifier guidance**", elem_classes=["fp-subnote"]) with gr.Row(elem_classes=["fp-tight"]): guidance_scale = gr.Slider( label="Classifier guidance scale", minimum=0.0, maximum=1.0, value=0.2, step=0.05, info="Higher values steer more strongly toward the target class.", ) target_class = gr.Dropdown( label="Classifier target class", choices=[0, 1], value=1, info="Binary classifier label to steer toward (0 or 1).", ) with gr.Group(visible=False) as conditional_group: gr.Markdown( "**Conditional design** keeps the selected residues of a reference structure fixed " "while generating the rest. Protein length is taken from the uploaded PDB.", elem_classes=["fp-subnote"], ) with gr.Row(elem_classes=["fp-tight"]): reference_pdb = gr.File(label="Reference PDB (chain A)", file_types=[".pdb"]) fixed_residues_text = gr.Textbox( label="Fixed residues", value="", placeholder="e.g. 10-40,55,60-62 (empty = fix all)", ) use_classifier_guidance = gr.Checkbox( label="Apply classifier guidance", value=False, ) run_button = gr.Button("Run inference", variant="primary") with gr.Group(elem_classes=["fp-card"]): gr.Markdown("### Outputs") output_files = gr.Files(label="Generated sample.pdb files") trajectory_output_files = gr.Files(label="Optional trajectory artifacts (traj + x0)") with gr.Row(elem_classes=["fp-tight"]): zip_button = gr.Button("Bundle current run as zip", variant="secondary") zip_file = gr.File(label="Run bundle (.zip)") with gr.Tab("View"): with gr.Group(elem_classes=["fp-card"]): sample_selector = gr.Dropdown( label="Select generated sample(s)", choices=[], value=[], multiselect=True, interactive=False, info="The first selected sample is shown in 3D; all selected samples feed self-consistency.", ) selection_label = gr.Markdown(update_selection_label(None)) molstar_viewer = gr.HTML(value=_molstar_placeholder_html()) with gr.Accordion("Generation trajectory playback", open=False, elem_classes=["fp-card"]): gr.Markdown( "Step through the flow-matching trajectory of the selected sample. " "Backbone trajectory shows the integrated path; x0 trajectory shows the model's denoised prediction.", elem_classes=["fp-subnote"], ) with gr.Row(elem_classes=["fp-tight"]): traj_source = gr.Dropdown( label="Trajectory", choices=["Backbone trajectory", "x0 trajectory"], value="Backbone trajectory", ) load_traj_button = gr.Button("Load trajectory", variant="secondary") trajectory_viewer = gr.HTML( value=_molstar_placeholder_html(message="Select a sample and load a trajectory to play it back.") ) with gr.Accordion("Aligned compare (backbone vs ESMFold)", open=False, elem_classes=["fp-card"]): gr.Markdown( "Overlay a generated backbone with an ESMFold-predicted structure from self-consistency. " "Tip: clicking a row in the Analyze metrics table loads the overlay automatically.", elem_classes=["fp-subnote"], ) with gr.Row(elem_classes=["fp-tight"]): backbone_compare_selector = gr.Dropdown( label="Backbone sample for aligned compare", choices=[], value=None, interactive=False, ) folded_compare_selector = gr.Dropdown( label="Folded structure for compare", choices=[], value=None, interactive=False, ) load_compare_button = gr.Button("Load aligned compare view", variant="secondary") molstar_compare_status = gr.Textbox( label="3D compare status", lines=2, interactive=False, ) with gr.Tab("Analyze"): with gr.Group(elem_classes=["fp-card"]): gr.Markdown( "### Self-consistency (ProteinMPNN + ESMFold)\n" "Runs across the samples selected in the View tab.", ) with gr.Row(elem_classes=["fp-tight"]): sc_num_sequences = gr.Slider( label="ProteinMPNN sequences per sample", minimum=1, maximum=16, value=4, step=1, ) sc_run_folding = gr.Checkbox( label="Run ESMFold and compute scTM/scRMSD", value=True, ) sc_run_button = gr.Button("Run self-consistency", variant="primary") with gr.Accordion("View settings", open=False): with gr.Row(elem_classes=["fp-tight"]): sc_top_n = gr.Slider( label="Leaderboard top-N", minimum=1, maximum=10, value=SC_DEFAULT_TOP_N, step=1, ) sc_filter_samples = gr.Dropdown( label="Filter by sample label", choices=[], value=[], multiselect=True, interactive=False, ) sc_filter_sample_query = gr.Textbox( label="Sample text contains", value="", placeholder="Optional text search on sample labels", ) with gr.Row(elem_classes=["fp-tight"]): sc_filter_tm_min = gr.Slider( label="Minimum scTM", minimum=SC_TM_MIN, maximum=SC_TM_MAX, value=SC_TM_MIN, step=0.01, ) sc_filter_rmsd_max = gr.Slider( label="Maximum scRMSD", minimum=SC_RMSD_MIN, maximum=SC_RMSD_MAX, value=SC_RMSD_MAX, step=0.1, ) sc_filter_plddt_min = gr.Slider( label="Minimum ESMFold mean pLDDT", minimum=SC_PLDDT_MIN, maximum=SC_PLDDT_MAX, value=SC_PLDDT_MIN, step=0.01, ) with gr.Row(elem_classes=["fp-tight"]): sc_apply_view_settings_button = gr.Button("Apply view settings", variant="secondary") sc_reset_view_settings_button = gr.Button("Reset view settings") sc_status = gr.Textbox( label="Self-consistency status", lines=4, elem_classes=["fp-status"], ) sc_summary = gr.Dataframe( headers=SC_SUMMARY_HEADERS, datatype=["str", "str"], value=[], label="Self-consistency summary", interactive=False, wrap=True, ) sc_metrics = gr.Dataframe( headers=SC_METRICS_HEADERS, value=[], label="Per-sequence metrics table (click a row to compare in 3D)", interactive=False, wrap=True, ) with gr.Row(elem_classes=["fp-tight"]): sc_export_button = gr.Button("Export current metrics as CSV", variant="secondary") sc_export_file = gr.File(label="Filtered metrics (.csv)") with gr.Row(elem_classes=["fp-tight", "fp-plot-row"]): with gr.Column(scale=1, min_width=560): sc_tm_rmsd_plot = gr.HTML( value=_empty_scatter_plot_html( title="Self-consistency tradeoff (scTM vs scRMSD)", message="Run self-consistency to render plot data.", ) ) with gr.Column(scale=1, min_width=560): sc_confidence_plot = gr.HTML( value=_empty_scatter_plot_html( title="ESMFold confidence vs scTM", message="Run self-consistency with folding enabled to render confidence plot.", ) ) with gr.Row(elem_classes=["fp-tight"]): sc_top_tm = gr.Dataframe( headers=SC_TOP_TM_HEADERS, value=[], label="Top by highest scTM", interactive=False, wrap=True, ) sc_top_rmsd = gr.Dataframe( headers=SC_TOP_RMSD_HEADERS, value=[], label="Top by lowest scRMSD", interactive=False, wrap=True, ) sc_artifacts = gr.Files(label="Self-consistency artifacts") with gr.Tab("Advanced"): with gr.Group(elem_classes=["fp-card"]): gr.Markdown("### Recent runs") with gr.Row(elem_classes=["fp-tight"]): history_dropdown = gr.Dropdown( label="Run history", choices=_run_history_choices(), value=None, interactive=True, ) history_refresh_button = gr.Button("Refresh history") history_load_button = gr.Button("Load selected run", variant="secondary") with gr.Group(elem_classes=["fp-card"]): gr.Markdown("### Saved example case") with gr.Row(elem_classes=["fp-tight"]): save_example_button = gr.Button("Save selected sample as example", variant="secondary") load_example_button = gr.Button("Load saved example", variant="secondary") view_example_on_startup = gr.Checkbox( label="View saved example on startup", value=VIEW_EXAMPLE_ON_STARTUP_DEFAULT, ) with gr.Group(elem_classes=["fp-card"]): gr.Markdown("### Diagnostics") with gr.Row(elem_classes=["fp-tight"]): preload_button = gr.Button("Preload model") health_button = gr.Button("Refresh health") with gr.Accordion("Status and debug details", open=False): status = gr.Textbox( label="Run status", lines=4, elem_classes=["fp-status"], ) example_status = gr.Textbox( label="Example case status", lines=3, interactive=False, ) health = gr.JSON(label="Health check") generation_outputs = [ status, output_files, trajectory_output_files, health, sample_selector, molstar_viewer, sc_status, sc_summary, sc_metrics, sc_tm_rmsd_plot, sc_confidence_plot, sc_artifacts, sc_top_tm, sc_top_rmsd, sc_base_status_state, sc_summary_state, sc_metrics_state, sc_artifacts_state, ] example_load_outputs = [ status, output_files, trajectory_output_files, sample_selector, molstar_viewer, sc_status, sc_summary, sc_metrics, sc_tm_rmsd_plot, sc_confidence_plot, sc_artifacts, sc_top_tm, sc_top_rmsd, sc_base_status_state, sc_summary_state, sc_metrics_state, sc_artifacts_state, example_status, ] mode.change( fn=lambda selected: ( gr.update(visible=selected == "classifier"), gr.update(visible=selected == "conditional"), ), inputs=[mode], outputs=[classifier_group, conditional_group], ) run_event = run_button.click( fn=run_generation, inputs=[ mode, length, num_samples, seed_override, include_trajectory_artifacts, guidance_scale, target_class, num_timesteps, reference_pdb, fixed_residues_text, use_classifier_guidance, ], outputs=generation_outputs, ) run_event.then(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label]) run_event.then(fn=refresh_status_badge, inputs=None, outputs=[status_badge]) run_event.then(fn=refresh_run_history, inputs=None, outputs=[history_dropdown]) sample_selector.change( fn=update_selected_molstar, inputs=[sample_selector], outputs=[ molstar_viewer, sc_status, sc_summary, sc_metrics, sc_tm_rmsd_plot, sc_confidence_plot, sc_artifacts, sc_top_tm, sc_top_rmsd, sc_base_status_state, sc_summary_state, sc_metrics_state, sc_artifacts_state, ], ) sample_selector.change(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label]) sample_selector.change(fn=_update_backbone_selector, inputs=[sample_selector], outputs=[backbone_compare_selector]) sample_selector.change(fn=clear_compare_status, inputs=None, outputs=[molstar_compare_status]) load_traj_button.click( fn=load_trajectory_view, inputs=[sample_selector, traj_source], outputs=[trajectory_viewer], ) zip_button.click(fn=build_run_zip, inputs=[sample_selector], outputs=[zip_file]) save_example_button.click(fn=save_example_case, inputs=[sample_selector], outputs=[example_status]) load_example_button.click(fn=load_saved_example, inputs=None, outputs=example_load_outputs) demo_cta_button.click(fn=load_saved_example, inputs=None, outputs=example_load_outputs) load_example_button.click(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label]) demo_cta_button.click(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label]) sc_run_button.click( fn=run_self_consistency, inputs=[ sample_selector, sc_num_sequences, sc_run_folding, sc_top_n, ], outputs=[ sc_status, sc_summary, sc_metrics, sc_tm_rmsd_plot, sc_confidence_plot, sc_artifacts, sc_top_tm, sc_top_rmsd, sc_base_status_state, sc_summary_state, sc_metrics_state, sc_artifacts_state, ], ) sc_apply_view_settings_button.click( fn=apply_sc_view_settings, inputs=[ sc_base_status_state, sc_summary_state, sc_metrics_state, sc_artifacts_state, sc_top_n, sc_filter_samples, sc_filter_sample_query, sc_filter_tm_min, sc_filter_rmsd_max, sc_filter_plddt_min, ], outputs=[ sc_status, sc_summary, sc_metrics, sc_tm_rmsd_plot, sc_confidence_plot, sc_artifacts, sc_top_tm, sc_top_rmsd, ], ) sc_reset_view_settings_button.click( fn=reset_sc_view_settings, inputs=[ sc_base_status_state, sc_summary_state, sc_metrics_state, sc_artifacts_state, ], outputs=[ sc_top_n, sc_filter_samples, sc_filter_sample_query, sc_filter_tm_min, sc_filter_rmsd_max, sc_filter_plddt_min, sc_status, sc_summary, sc_metrics, sc_tm_rmsd_plot, sc_confidence_plot, sc_artifacts, sc_top_tm, sc_top_rmsd, ], ) sc_export_button.click( fn=export_sc_metrics_csv, inputs=[ sc_metrics_state, sc_filter_samples, sc_filter_sample_query, sc_filter_tm_min, sc_filter_rmsd_max, sc_filter_plddt_min, ], outputs=[sc_export_file], ) sc_metrics_state.change( fn=_update_sc_filter_sample_selector, inputs=[sc_metrics_state], outputs=[sc_filter_samples], ) sc_metrics_state.change( fn=_update_folded_selector, inputs=[sc_metrics_state], outputs=[folded_compare_selector], ) sc_metrics.select( fn=on_sc_metrics_select, inputs=[sc_metrics_state, backbone_compare_selector], outputs=[backbone_compare_selector, folded_compare_selector, molstar_viewer, molstar_compare_status], ) load_compare_button.click( fn=load_aligned_compare_view, inputs=[backbone_compare_selector, folded_compare_selector], outputs=[molstar_viewer, molstar_compare_status], ) history_refresh_button.click(fn=refresh_run_history, inputs=None, outputs=[history_dropdown]) history_event = history_load_button.click( fn=load_run_from_history, inputs=[history_dropdown], outputs=[ status, sample_selector, molstar_viewer, sc_status, sc_summary, sc_metrics, sc_tm_rmsd_plot, sc_confidence_plot, sc_artifacts, sc_top_tm, sc_top_rmsd, sc_base_status_state, sc_summary_state, sc_metrics_state, sc_artifacts_state, ], ) history_event.then(fn=update_selection_label, inputs=[sample_selector], outputs=[selection_label]) preload_button.click(fn=preload_model, inputs=None, outputs=[status, health]).then( fn=refresh_status_badge, inputs=None, outputs=[status_badge] ) health_button.click(fn=refresh_health, inputs=None, outputs=health).then( fn=refresh_status_badge, inputs=None, outputs=[status_badge] ) demo.load(fn=refresh_health, inputs=None, outputs=health) demo.load(fn=refresh_status_badge, inputs=None, outputs=[status_badge]) demo.load(fn=refresh_run_history, inputs=None, outputs=[history_dropdown]) demo.load( fn=maybe_load_saved_example_on_startup, inputs=[view_example_on_startup], outputs=example_load_outputs, ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1, max_size=16) demo.launch( server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True, )