import os import tarfile import tempfile from pathlib import Path from typing import List, Tuple import gradio as gr import pandas as pd import plotly.express as px from huggingface_hub import snapshot_download import spaces import twinbooster tb = None HF_TEXT_REPO = os.environ.get("HF_TEXT_REPO", "mschuh/PubChemDeBERTa-augmented") # Store models where TwinBooster looks for them by default (~/.cache/twinbooster) MODEL_DIR = Path.home() / ".cache" / "twinbooster" WEIGHTS_SRC = Path(__file__).parent / "weights" def ensure_models() -> str: """Download weights from Hugging Face LFS; fallback to twinbooster if already cached.""" MODEL_DIR.mkdir(parents=True, exist_ok=True) os.environ.setdefault("HF_HOME", str(MODEL_DIR)) os.environ.setdefault("HF_HUB_CACHE", str(MODEL_DIR)) downloaded = [] def grab(repo_id: str, subdir: str) -> bool: local_dir = MODEL_DIR / subdir if local_dir.exists(): return False try: snapshot_download( repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, ) return True except Exception: return False def extract_local(archive: Path, subdir: str) -> None: target = MODEL_DIR / subdir if target.exists() or not archive.exists(): return target.mkdir(parents=True, exist_ok=True) with tarfile.open(archive, "r:*") as tf: tf.extractall(target) if grab(HF_TEXT_REPO, "PubChemDeBERTa-augmented"): downloaded.append(HF_TEXT_REPO) extract_local(WEIGHTS_SRC / "lgbm_model.tar.xz", "lgbm_model") downloaded.append("local lgbm_model.tar.xz") extract_local(WEIGHTS_SRC / "bt_model.tar.xz", "bt_model") downloaded.append("local bt_model.tar.xz") # Ensure any missing pieces are resolved via package helper (will skip if already present) twinbooster.download_models() if downloaded: return "Downloaded from Hugging Face: " + ", ".join(downloaded) return "Models already present in cache." def get_model(): """Lazily load the TwinBooster model.""" global tb if tb is None: ensure_models() tb = twinbooster.TwinBooster() return tb def parse_smiles(smiles_text: str) -> List[str]: smiles = [line.strip() for line in smiles_text.splitlines() if line.strip()] if not smiles: raise gr.Error("Please provide at least one SMILES (one per line).") return smiles @spaces.GPU(duration=120) def _gpu_predict(smiles: List[str], assay: str) -> Tuple[pd.DataFrame, object, str, str]: """GPU-only path: expects prevalidated inputs and ready models.""" model = get_model() try: preds, confs = model.predict(smiles, assay, get_confidence=True) except TypeError: preds = model.predict(smiles, assay) confs = [None] * len(preds) except Exception as exc: # pragma: no cover - shown to user raise gr.Error(f"Inference failed: {exc}") df = pd.DataFrame( { "SMILES": smiles, "Assay": assay, "Prediction": preds, "Confidence": confs, } ) fig = px.bar( df, x="SMILES", y="Prediction", color="Prediction", color_continuous_scale="Blues", range_y=[0, 1], title="TwinBooster predictions", labels={"Prediction": "Predicted activity probability"}, ) fig.update_layout(xaxis_tickangle=-45, height=420) csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") df.to_csv(csv_file.name, index=False) csv_file.close() xlsx_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") with pd.ExcelWriter(xlsx_file.name, engine="openpyxl") as writer: df.to_excel(writer, index=False) xlsx_file.close() return df, fig, csv_file.name, xlsx_file.name def run_prediction(smiles_text: str, assay_text: str) -> Tuple[pd.DataFrame, object, str, str]: """CPU wrapper: validates inputs and prepares models before GPU allocation.""" assay = assay_text.strip() if not assay: raise gr.Error("Please provide a bioassay description.") smiles = parse_smiles(smiles_text) # Download/refresh weights on CPU to keep ZeroGPU sessions short ensure_models() return _gpu_predict(smiles, assay) def build_demo() -> gr.Blocks: example_smiles = "CC1=CC=C(C=C1)C2=CC(=NC3=NC=NC(=C23)N)C4=CC=C(C=C4)F\nCC(=O)C1=CC=C(C=C1)NC(=O)C2=CC3=C(C=C2)N=C(C(=N3)C4=CC=CO4)C5=CC=CO5\nCC1=C(C=C(C=C1)Cl)NC2=C/C(=N\S(=O)(=O)C3=CC=CS3)/C4=CC=CC=C4C2=O\nCC(C)C1=NC2=CC=CC=C2C(=N1)SCC(=O)N3CCCC3 " example_assay = "TR-FRET counterscreen for FAK inhibitors: dose-response biochemical high throughput screening assay to identify inhibitors of Proline-rich tyrosine kinase 2 (Pyk2)" with gr.Blocks(title="TwinBooster") as demo: gr.Markdown( "# TwinBooster zero-shot predictor\n" "Enter SMILES (one per line) and a bioassay description to obtain activity predictions." ) gr.Markdown( "TwinBooster fuses chemical structures and free-text assay descriptions to deliver SOTA zero-shot activity " "predictions—useful for early triage and library prioritization when assay data are scarce. " "Outputs include a table, bar chart, and CSV/XLSX downloads with predictions and confidence.\n\n" "**Reference:** Schuh, M. G.; Boldini, D.; Sieber, S. A. " "_Synergizing Chemical Structures and Bioassay Descriptions for Enhanced Molecular Property Prediction in Drug Discovery._ " "J. Chem. Inf. Model. 2024, 64, 12, 4640–4650. " "[JCIM paper](https://doi.org/10.1021/acs.jcim.4c00765)" ) with gr.Row(): smiles_box = gr.Textbox( label="SMILES list", lines=10, value=example_smiles, placeholder="One SMILES per line", ) assay_box = gr.Textbox( label="Bioassay description", lines=8, value=example_assay, placeholder="Describe the assay/task to predict.", ) with gr.Row(): predict_btn = gr.Button("Run prediction", variant="primary") download_btn = gr.Button("Download / refresh models") status = gr.Markdown("") table = gr.DataFrame( label="Predictions", headers=["SMILES", "Assay", "Prediction", "Confidence"], datatype=["str", "str", "number", "number"], interactive=False, ) plot = gr.Plot(label="Prediction chart") with gr.Row(): csv_out = gr.File(label="CSV download") xlsx_out = gr.File(label="Excel download") predict_btn.click( run_prediction, inputs=[smiles_box, assay_box], outputs=[table, plot, csv_out, xlsx_out], ) download_btn.click(ensure_models, outputs=status) return demo if __name__ == "__main__": demo = build_demo() demo.queue(concurrency_count=1) demo.launch()