TwinBooster / app.py
Maximilian Schuh
changed deps
74d2a42
import os
import tarfile
import tempfile
from pathlib import Path
from typing import List, Tuple
import gradio as gr
import pandas as pd
import plotly.express as px
from huggingface_hub import snapshot_download
import spaces
import twinbooster
tb = None
HF_TEXT_REPO = os.environ.get("HF_TEXT_REPO", "mschuh/PubChemDeBERTa-augmented")
# Store models where TwinBooster looks for them by default (~/.cache/twinbooster)
MODEL_DIR = Path.home() / ".cache" / "twinbooster"
WEIGHTS_SRC = Path(__file__).parent / "weights"
def ensure_models() -> str:
"""Download weights from Hugging Face LFS; fallback to twinbooster if already cached."""
MODEL_DIR.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("HF_HOME", str(MODEL_DIR))
os.environ.setdefault("HF_HUB_CACHE", str(MODEL_DIR))
downloaded = []
def grab(repo_id: str, subdir: str) -> bool:
local_dir = MODEL_DIR / subdir
if local_dir.exists():
return False
try:
snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
)
return True
except Exception:
return False
def extract_local(archive: Path, subdir: str) -> None:
target = MODEL_DIR / subdir
if target.exists() or not archive.exists():
return
target.mkdir(parents=True, exist_ok=True)
with tarfile.open(archive, "r:*") as tf:
tf.extractall(target)
if grab(HF_TEXT_REPO, "PubChemDeBERTa-augmented"):
downloaded.append(HF_TEXT_REPO)
extract_local(WEIGHTS_SRC / "lgbm_model.tar.xz", "lgbm_model")
downloaded.append("local lgbm_model.tar.xz")
extract_local(WEIGHTS_SRC / "bt_model.tar.xz", "bt_model")
downloaded.append("local bt_model.tar.xz")
# Ensure any missing pieces are resolved via package helper (will skip if already present)
twinbooster.download_models()
if downloaded:
return "Downloaded from Hugging Face: " + ", ".join(downloaded)
return "Models already present in cache."
def get_model():
"""Lazily load the TwinBooster model."""
global tb
if tb is None:
ensure_models()
tb = twinbooster.TwinBooster()
return tb
def parse_smiles(smiles_text: str) -> List[str]:
smiles = [line.strip() for line in smiles_text.splitlines() if line.strip()]
if not smiles:
raise gr.Error("Please provide at least one SMILES (one per line).")
return smiles
@spaces.GPU(duration=120)
def _gpu_predict(smiles: List[str], assay: str) -> Tuple[pd.DataFrame, object, str, str]:
"""GPU-only path: expects prevalidated inputs and ready models."""
model = get_model()
try:
preds, confs = model.predict(smiles, assay, get_confidence=True)
except TypeError:
preds = model.predict(smiles, assay)
confs = [None] * len(preds)
except Exception as exc: # pragma: no cover - shown to user
raise gr.Error(f"Inference failed: {exc}")
df = pd.DataFrame(
{
"SMILES": smiles,
"Assay": assay,
"Prediction": preds,
"Confidence": confs,
}
)
fig = px.bar(
df,
x="SMILES",
y="Prediction",
color="Prediction",
color_continuous_scale="Blues",
range_y=[0, 1],
title="TwinBooster predictions",
labels={"Prediction": "Predicted activity probability"},
)
fig.update_layout(xaxis_tickangle=-45, height=420)
csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
df.to_csv(csv_file.name, index=False)
csv_file.close()
xlsx_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
with pd.ExcelWriter(xlsx_file.name, engine="openpyxl") as writer:
df.to_excel(writer, index=False)
xlsx_file.close()
return df, fig, csv_file.name, xlsx_file.name
def run_prediction(smiles_text: str, assay_text: str) -> Tuple[pd.DataFrame, object, str, str]:
"""CPU wrapper: validates inputs and prepares models before GPU allocation."""
assay = assay_text.strip()
if not assay:
raise gr.Error("Please provide a bioassay description.")
smiles = parse_smiles(smiles_text)
# Download/refresh weights on CPU to keep ZeroGPU sessions short
ensure_models()
return _gpu_predict(smiles, assay)
def build_demo() -> gr.Blocks:
example_smiles = "CC1=CC=C(C=C1)C2=CC(=NC3=NC=NC(=C23)N)C4=CC=C(C=C4)F\nCC(=O)C1=CC=C(C=C1)NC(=O)C2=CC3=C(C=C2)N=C(C(=N3)C4=CC=CO4)C5=CC=CO5\nCC1=C(C=C(C=C1)Cl)NC2=C/C(=N\S(=O)(=O)C3=CC=CS3)/C4=CC=CC=C4C2=O\nCC(C)C1=NC2=CC=CC=C2C(=N1)SCC(=O)N3CCCC3 "
example_assay = "TR-FRET counterscreen for FAK inhibitors: dose-response biochemical high throughput screening assay to identify inhibitors of Proline-rich tyrosine kinase 2 (Pyk2)"
with gr.Blocks(title="TwinBooster") as demo:
gr.Markdown(
"# TwinBooster zero-shot predictor\n"
"Enter SMILES (one per line) and a bioassay description to obtain activity predictions."
)
gr.Markdown(
"TwinBooster fuses chemical structures and free-text assay descriptions to deliver SOTA zero-shot activity "
"predictions—useful for early triage and library prioritization when assay data are scarce. "
"Outputs include a table, bar chart, and CSV/XLSX downloads with predictions and confidence.\n\n"
"**Reference:** Schuh, M. G.; Boldini, D.; Sieber, S. A. "
"_Synergizing Chemical Structures and Bioassay Descriptions for Enhanced Molecular Property Prediction in Drug Discovery._ "
"J. Chem. Inf. Model. 2024, 64, 12, 4640–4650. "
"[JCIM paper](https://doi.org/10.1021/acs.jcim.4c00765)"
)
with gr.Row():
smiles_box = gr.Textbox(
label="SMILES list",
lines=10,
value=example_smiles,
placeholder="One SMILES per line",
)
assay_box = gr.Textbox(
label="Bioassay description",
lines=8,
value=example_assay,
placeholder="Describe the assay/task to predict.",
)
with gr.Row():
predict_btn = gr.Button("Run prediction", variant="primary")
download_btn = gr.Button("Download / refresh models")
status = gr.Markdown("")
table = gr.DataFrame(
label="Predictions",
headers=["SMILES", "Assay", "Prediction", "Confidence"],
datatype=["str", "str", "number", "number"],
interactive=False,
)
plot = gr.Plot(label="Prediction chart")
with gr.Row():
csv_out = gr.File(label="CSV download")
xlsx_out = gr.File(label="Excel download")
predict_btn.click(
run_prediction,
inputs=[smiles_box, assay_box],
outputs=[table, plot, csv_out, xlsx_out],
)
download_btn.click(ensure_models, outputs=status)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.queue(concurrency_count=1)
demo.launch()