Spaces:
Runtime error
Runtime error
File size: 7,238 Bytes
942392b e174389 7a91b3c 942392b bed36cc 942392b e174389 942392b 3df1a88 942392b 5468e06 e174389 5468e06 e174389 942392b e174389 5468e06 e174389 5468e06 942392b 74d2a42 942392b 74d2a42 942392b 5468e06 942392b e19f3e1 5468e06 e19f3e1 5468e06 e19f3e1 942392b bed36cc 942392b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import os
import tarfile
import tempfile
from pathlib import Path
from typing import List, Tuple
import gradio as gr
import pandas as pd
import plotly.express as px
from huggingface_hub import snapshot_download
import spaces
import twinbooster
tb = None
HF_TEXT_REPO = os.environ.get("HF_TEXT_REPO", "mschuh/PubChemDeBERTa-augmented")
# Store models where TwinBooster looks for them by default (~/.cache/twinbooster)
MODEL_DIR = Path.home() / ".cache" / "twinbooster"
WEIGHTS_SRC = Path(__file__).parent / "weights"
def ensure_models() -> str:
"""Download weights from Hugging Face LFS; fallback to twinbooster if already cached."""
MODEL_DIR.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("HF_HOME", str(MODEL_DIR))
os.environ.setdefault("HF_HUB_CACHE", str(MODEL_DIR))
downloaded = []
def grab(repo_id: str, subdir: str) -> bool:
local_dir = MODEL_DIR / subdir
if local_dir.exists():
return False
try:
snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
)
return True
except Exception:
return False
def extract_local(archive: Path, subdir: str) -> None:
target = MODEL_DIR / subdir
if target.exists() or not archive.exists():
return
target.mkdir(parents=True, exist_ok=True)
with tarfile.open(archive, "r:*") as tf:
tf.extractall(target)
if grab(HF_TEXT_REPO, "PubChemDeBERTa-augmented"):
downloaded.append(HF_TEXT_REPO)
extract_local(WEIGHTS_SRC / "lgbm_model.tar.xz", "lgbm_model")
downloaded.append("local lgbm_model.tar.xz")
extract_local(WEIGHTS_SRC / "bt_model.tar.xz", "bt_model")
downloaded.append("local bt_model.tar.xz")
# Ensure any missing pieces are resolved via package helper (will skip if already present)
twinbooster.download_models()
if downloaded:
return "Downloaded from Hugging Face: " + ", ".join(downloaded)
return "Models already present in cache."
def get_model():
"""Lazily load the TwinBooster model."""
global tb
if tb is None:
ensure_models()
tb = twinbooster.TwinBooster()
return tb
def parse_smiles(smiles_text: str) -> List[str]:
smiles = [line.strip() for line in smiles_text.splitlines() if line.strip()]
if not smiles:
raise gr.Error("Please provide at least one SMILES (one per line).")
return smiles
@spaces.GPU(duration=120)
def _gpu_predict(smiles: List[str], assay: str) -> Tuple[pd.DataFrame, object, str, str]:
"""GPU-only path: expects prevalidated inputs and ready models."""
model = get_model()
try:
preds, confs = model.predict(smiles, assay, get_confidence=True)
except TypeError:
preds = model.predict(smiles, assay)
confs = [None] * len(preds)
except Exception as exc: # pragma: no cover - shown to user
raise gr.Error(f"Inference failed: {exc}")
df = pd.DataFrame(
{
"SMILES": smiles,
"Assay": assay,
"Prediction": preds,
"Confidence": confs,
}
)
fig = px.bar(
df,
x="SMILES",
y="Prediction",
color="Prediction",
color_continuous_scale="Blues",
range_y=[0, 1],
title="TwinBooster predictions",
labels={"Prediction": "Predicted activity probability"},
)
fig.update_layout(xaxis_tickangle=-45, height=420)
csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
df.to_csv(csv_file.name, index=False)
csv_file.close()
xlsx_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
with pd.ExcelWriter(xlsx_file.name, engine="openpyxl") as writer:
df.to_excel(writer, index=False)
xlsx_file.close()
return df, fig, csv_file.name, xlsx_file.name
def run_prediction(smiles_text: str, assay_text: str) -> Tuple[pd.DataFrame, object, str, str]:
"""CPU wrapper: validates inputs and prepares models before GPU allocation."""
assay = assay_text.strip()
if not assay:
raise gr.Error("Please provide a bioassay description.")
smiles = parse_smiles(smiles_text)
# Download/refresh weights on CPU to keep ZeroGPU sessions short
ensure_models()
return _gpu_predict(smiles, assay)
def build_demo() -> gr.Blocks:
example_smiles = "CC1=CC=C(C=C1)C2=CC(=NC3=NC=NC(=C23)N)C4=CC=C(C=C4)F\nCC(=O)C1=CC=C(C=C1)NC(=O)C2=CC3=C(C=C2)N=C(C(=N3)C4=CC=CO4)C5=CC=CO5\nCC1=C(C=C(C=C1)Cl)NC2=C/C(=N\S(=O)(=O)C3=CC=CS3)/C4=CC=CC=C4C2=O\nCC(C)C1=NC2=CC=CC=C2C(=N1)SCC(=O)N3CCCC3 "
example_assay = "TR-FRET counterscreen for FAK inhibitors: dose-response biochemical high throughput screening assay to identify inhibitors of Proline-rich tyrosine kinase 2 (Pyk2)"
with gr.Blocks(title="TwinBooster") as demo:
gr.Markdown(
"# TwinBooster zero-shot predictor\n"
"Enter SMILES (one per line) and a bioassay description to obtain activity predictions."
)
gr.Markdown(
"TwinBooster fuses chemical structures and free-text assay descriptions to deliver SOTA zero-shot activity "
"predictions—useful for early triage and library prioritization when assay data are scarce. "
"Outputs include a table, bar chart, and CSV/XLSX downloads with predictions and confidence.\n\n"
"**Reference:** Schuh, M. G.; Boldini, D.; Sieber, S. A. "
"_Synergizing Chemical Structures and Bioassay Descriptions for Enhanced Molecular Property Prediction in Drug Discovery._ "
"J. Chem. Inf. Model. 2024, 64, 12, 4640–4650. "
"[JCIM paper](https://doi.org/10.1021/acs.jcim.4c00765)"
)
with gr.Row():
smiles_box = gr.Textbox(
label="SMILES list",
lines=10,
value=example_smiles,
placeholder="One SMILES per line",
)
assay_box = gr.Textbox(
label="Bioassay description",
lines=8,
value=example_assay,
placeholder="Describe the assay/task to predict.",
)
with gr.Row():
predict_btn = gr.Button("Run prediction", variant="primary")
download_btn = gr.Button("Download / refresh models")
status = gr.Markdown("")
table = gr.DataFrame(
label="Predictions",
headers=["SMILES", "Assay", "Prediction", "Confidence"],
datatype=["str", "str", "number", "number"],
interactive=False,
)
plot = gr.Plot(label="Prediction chart")
with gr.Row():
csv_out = gr.File(label="CSV download")
xlsx_out = gr.File(label="Excel download")
predict_btn.click(
run_prediction,
inputs=[smiles_box, assay_box],
outputs=[table, plot, csv_out, xlsx_out],
)
download_btn.click(ensure_models, outputs=status)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.queue(concurrency_count=1)
demo.launch()
|