"""Gradio demo of the CardioSafe ion-channel safety model. The Space repo intentionally does not vendor the CardioSafe source. At startup we clone the upstream repos for CardioSafe and MolGpKa, then prepend them to sys.path. Cache survives across requests; only the first build pays the clone cost. """ from __future__ import annotations import os import subprocess import sys from pathlib import Path # --------------------------------------------------------------------------- # Boot: clone CardioSafe + MolGpKa, wire imports # --------------------------------------------------------------------------- CARDIOSAFE_REPO = "https://github.com/AppliedScientific/CardioSafe-benchmark.git" MOLGPKA_REPO = "https://github.com/Xundrug/MolGpKa.git" WORK = Path(os.environ.get("CARDIOSAFE_SPACE_WORK", Path.home() / "_cardiosafe")).resolve() CARDIOSAFE_DIR = WORK / "CardioSafe-benchmark" MOLGPKA_DIR = WORK / "MolGpKa" def _ensure_clone(url: str, dest: Path) -> None: """Clone the repo, or fast-forward an existing clone to origin/HEAD so the Space picks up upstream fixes without a full container rebuild.""" if dest.exists() and (dest / ".git").exists(): print(f" pulling {dest}") subprocess.run(["git", "-C", str(dest), "fetch", "--depth", "1", "origin"], check=False) subprocess.run(["git", "-C", str(dest), "reset", "--hard", "origin/HEAD"], check=False) return dest.parent.mkdir(parents=True, exist_ok=True) print(f" cloning {url} -> {dest}") subprocess.run(["git", "clone", "--depth", "1", url, str(dest)], check=True) print("Bootstrapping CardioSafe...") _ensure_clone(CARDIOSAFE_REPO, CARDIOSAFE_DIR) _ensure_clone(MOLGPKA_REPO, MOLGPKA_DIR) os.environ["MOLGPKA_SRC"] = str(MOLGPKA_DIR / "src") sys.path.insert(0, str(CARDIOSAFE_DIR)) def _patch_molgpka_smarts_path() -> None: """MolGpKa's utils/ionization_group.py computes `root = abspath(dirname(__file__))` at module load. On HF Spaces, __file__ resolves cwd-relative for sys.path imports, so `root` and `smarts_file` end up rooted at `/app/utils/` (which doesn't exist) and every pKa prediction silently falls back to sentinel values. We force-set them to the absolute MolGpKa path after the module loads. `get_ionization_aid` reads `smarts_file` from the module global on every call, so the patch sticks.""" import inference.featurize as _fz # noqa _fz._setup_molgpka() # triggers utils.* imports under MolGpKa/src import utils.ionization_group as _uig correct_root = str((MOLGPKA_DIR / "src" / "utils").resolve()) _uig.root = correct_root _uig.smarts_file = str(Path(correct_root) / "smarts_pattern.tsv") print(f" patched MolGpKa smarts_file -> {_uig.smarts_file}") _patch_molgpka_smarts_path() # --------------------------------------------------------------------------- # Heavy imports (after sys.path is set) # --------------------------------------------------------------------------- import gradio as gr import pandas as pd import torch from inference.ensemble import load_ensemble, load_l1000_encoder, predict # noqa: E402 from inference.featurize import featurize_batch # noqa: E402 from model.chemberta_encoder import ChemBERTaEncoder # noqa: E402 from model.cross_attn import ALL_HEADS # noqa: E402 DEVICE = torch.device("cpu") MAX_SMILES = 50 HEAD_LABELS = { "herg_pchembl": "hERG pIC50", "herg_blocker_10um": "hERG blocker CO (10 µM)", "herg_blocker_1um": "hERG blocker CO (1 µM)", "nav15_pchembl": "Nav1.5 pIC50", "nav15_blocker": "Nav1.5 blocker CO", "cav12_pchembl": "Cav1.2 pIC50", "cav12_blocker": "Cav1.2 blocker CO", "iks_blocker": "IKs blocker CO", } # --------------------------------------------------------------------------- # Eager-load v1.1 ensemble at boot; v1.0 lazy-loaded on first request # --------------------------------------------------------------------------- print("Loading ChemBERTa-77M-MTR...") CHEMBERTA = ChemBERTaEncoder() print("Loading L1000 encoder...") L1000 = load_l1000_encoder(device=DEVICE) print("Loading CardioSafe v1.1 ensemble (5 seeds)...") ENSEMBLES: dict[str, list] = {"v1.1": load_ensemble("v1.1", device=DEVICE)} print("Ready.") def _get_ensemble(version: str): if version not in ENSEMBLES: print(f"Loading CardioSafe {version} ensemble (5 seeds)...") ENSEMBLES[version] = load_ensemble(version, device=DEVICE) return ENSEMBLES[version] def run(smiles_text: str, version: str) -> pd.DataFrame: smiles = [s.strip() for s in smiles_text.splitlines() if s.strip()] if not smiles: raise gr.Error("Enter at least one SMILES.") if len(smiles) > MAX_SMILES: raise gr.Error(f"Max {MAX_SMILES} SMILES per request. Got {len(smiles)}.") ensemble = _get_ensemble(version) batch = featurize_batch(smiles, chemberta_encoder=CHEMBERTA, l1000_encoder=L1000) preds = predict(batch, ensemble=ensemble, device=DEVICE) rows = [] for i, smi in enumerate(smiles): row: dict = {"SMILES": smi} for h in ALL_HEADS: row[HEAD_LABELS[h]] = round(float(preds[h][i]), 3) rows.append(row) return pd.DataFrame(rows) EXAMPLE_INPUT = """CC(C)(C)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4 CC(C)(C(=O)O)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4 COc1ccc(CCN2CCC(CC2)Nc3nc4ccccc4n3Cc5ccc(F)cc5)cc1 COc1cc(N)c(Cl)cc1C(=O)N[C@@H]1CCN(CCCOc2ccc(F)cc2)C[C@@H]1OC CN(CCOc1ccc(NS(=O)(=O)C)cc1)CCc2ccc(NS(=O)(=O)C)cc2 C=C[C@H]1CN2CC[C@H]1C[C@@H]2[C@@H](O)c1ccnc2ccc(OC)cc12 COc1ccc(CCN(C)CCCC(C#N)(c2ccc(OC)c(OC)c2)C(C)C)cc1OC""" INTRO_MD = """# CardioSafe — cardiac ion-channel safety predictions Paste SMILES below (one per line, up to 50) and get predictions for the four CiPA channels: **hERG, Nav1.5, Cav1.2, IKs** — blocker classification output (CO; sigmoid in [0, 1], not a calibrated probability — the underlying classes are heavily imbalanced), plus pIC50 for hERG / Nav1.5 / Cav1.2 (IKs has no regression head — n = 115 labelled compounds). This is the paper-snapshot ensemble from [Jovanović et al. 2026 (bioRxiv)](https://www.biorxiv.org/content/10.64898/2026.05.06.723181v1). Weights on [HF](https://huggingface.co/appliedscientific/cardiosafe), source on [GitHub](https://github.com/AppliedScientific/CardioSafe-benchmark). The continually-updated production model is at [platform.appliedscientific.ai/cardiosafe](https://platform.appliedscientific.ai/cardiosafe). """ FOOTER_MD = """--- **v1.1** is the recommended retrain; it differs from v1.0 by 2 force-routed analogs in the cardiac-cliff cluster (see [Note S3](https://github.com/AppliedScientific/CardioSafe-benchmark/blob/main/data/supplementary/note_s3_v1_1_audit_correction.md)). Test fold and headline metrics are unchanged. **v1.0** is the preprint snapshot — use it when reproducing paper numbers. Per-checkpoint normalization, MolGpKa-based pKa descriptors, ChemBERTa-77M-MTR embeddings, and a learned L1000 expression encoder are all applied automatically. First request after a cold start may take ~30 s while ChemBERTa is downloaded. """ with gr.Blocks(title="CardioSafe", theme=gr.themes.Soft()) as demo: gr.Markdown(INTRO_MD) with gr.Row(): with gr.Column(scale=2): smiles_in = gr.Textbox( label="SMILES (one per line)", value=EXAMPLE_INPUT, lines=10, ) version_in = gr.Radio( ["v1.1", "v1.0"], value="v1.1", label="Ensemble (v1.1 recommended)", ) btn = gr.Button("Predict", variant="primary") with gr.Column(scale=3): out = gr.DataFrame(label="Predictions", interactive=False, wrap=True) gr.Markdown(FOOTER_MD) btn.click(run, inputs=[smiles_in, version_in], outputs=out) if __name__ == "__main__": demo.launch()