Spaces:
Runtime error
Runtime error
| import os | |
| import tarfile | |
| import tempfile | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| import twinbooster | |
| tb = None | |
| HF_TEXT_REPO = os.environ.get("HF_TEXT_REPO", "mschuh/PubChemDeBERTa-augmented") | |
| # Store models where TwinBooster looks for them by default (~/.cache/twinbooster) | |
| MODEL_DIR = Path.home() / ".cache" / "twinbooster" | |
| WEIGHTS_SRC = Path(__file__).parent / "weights" | |
| def ensure_models() -> str: | |
| """Download weights from Hugging Face LFS; fallback to twinbooster if already cached.""" | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| os.environ.setdefault("HF_HOME", str(MODEL_DIR)) | |
| os.environ.setdefault("HF_HUB_CACHE", str(MODEL_DIR)) | |
| downloaded = [] | |
| def grab(repo_id: str, subdir: str) -> bool: | |
| local_dir = MODEL_DIR / subdir | |
| if local_dir.exists(): | |
| return False | |
| try: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=str(local_dir), | |
| local_dir_use_symlinks=False, | |
| ) | |
| return True | |
| except Exception: | |
| return False | |
| def extract_local(archive: Path, subdir: str) -> None: | |
| target = MODEL_DIR / subdir | |
| if target.exists() or not archive.exists(): | |
| return | |
| target.mkdir(parents=True, exist_ok=True) | |
| with tarfile.open(archive, "r:*") as tf: | |
| tf.extractall(target) | |
| if grab(HF_TEXT_REPO, "PubChemDeBERTa-augmented"): | |
| downloaded.append(HF_TEXT_REPO) | |
| extract_local(WEIGHTS_SRC / "lgbm_model.tar.xz", "lgbm_model") | |
| downloaded.append("local lgbm_model.tar.xz") | |
| extract_local(WEIGHTS_SRC / "bt_model.tar.xz", "bt_model") | |
| downloaded.append("local bt_model.tar.xz") | |
| # Ensure any missing pieces are resolved via package helper (will skip if already present) | |
| twinbooster.download_models() | |
| if downloaded: | |
| return "Downloaded from Hugging Face: " + ", ".join(downloaded) | |
| return "Models already present in cache." | |
| def get_model(): | |
| """Lazily load the TwinBooster model.""" | |
| global tb | |
| if tb is None: | |
| ensure_models() | |
| tb = twinbooster.TwinBooster() | |
| return tb | |
| def parse_smiles(smiles_text: str) -> List[str]: | |
| smiles = [line.strip() for line in smiles_text.splitlines() if line.strip()] | |
| if not smiles: | |
| raise gr.Error("Please provide at least one SMILES (one per line).") | |
| return smiles | |
| def _gpu_predict(smiles: List[str], assay: str) -> Tuple[pd.DataFrame, object, str, str]: | |
| """GPU-only path: expects prevalidated inputs and ready models.""" | |
| model = get_model() | |
| try: | |
| preds, confs = model.predict(smiles, assay, get_confidence=True) | |
| except TypeError: | |
| preds = model.predict(smiles, assay) | |
| confs = [None] * len(preds) | |
| except Exception as exc: # pragma: no cover - shown to user | |
| raise gr.Error(f"Inference failed: {exc}") | |
| df = pd.DataFrame( | |
| { | |
| "SMILES": smiles, | |
| "Assay": assay, | |
| "Prediction": preds, | |
| "Confidence": confs, | |
| } | |
| ) | |
| fig = px.bar( | |
| df, | |
| x="SMILES", | |
| y="Prediction", | |
| color="Prediction", | |
| color_continuous_scale="Blues", | |
| range_y=[0, 1], | |
| title="TwinBooster predictions", | |
| labels={"Prediction": "Predicted activity probability"}, | |
| ) | |
| fig.update_layout(xaxis_tickangle=-45, height=420) | |
| csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| df.to_csv(csv_file.name, index=False) | |
| csv_file.close() | |
| xlsx_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") | |
| with pd.ExcelWriter(xlsx_file.name, engine="openpyxl") as writer: | |
| df.to_excel(writer, index=False) | |
| xlsx_file.close() | |
| return df, fig, csv_file.name, xlsx_file.name | |
| def run_prediction(smiles_text: str, assay_text: str) -> Tuple[pd.DataFrame, object, str, str]: | |
| """CPU wrapper: validates inputs and prepares models before GPU allocation.""" | |
| assay = assay_text.strip() | |
| if not assay: | |
| raise gr.Error("Please provide a bioassay description.") | |
| smiles = parse_smiles(smiles_text) | |
| # Download/refresh weights on CPU to keep ZeroGPU sessions short | |
| ensure_models() | |
| return _gpu_predict(smiles, assay) | |
| def build_demo() -> gr.Blocks: | |
| example_smiles = "CC1=CC=C(C=C1)C2=CC(=NC3=NC=NC(=C23)N)C4=CC=C(C=C4)F\nCC(=O)C1=CC=C(C=C1)NC(=O)C2=CC3=C(C=C2)N=C(C(=N3)C4=CC=CO4)C5=CC=CO5\nCC1=C(C=C(C=C1)Cl)NC2=C/C(=N\S(=O)(=O)C3=CC=CS3)/C4=CC=CC=C4C2=O\nCC(C)C1=NC2=CC=CC=C2C(=N1)SCC(=O)N3CCCC3 " | |
| example_assay = "TR-FRET counterscreen for FAK inhibitors: dose-response biochemical high throughput screening assay to identify inhibitors of Proline-rich tyrosine kinase 2 (Pyk2)" | |
| with gr.Blocks(title="TwinBooster") as demo: | |
| gr.Markdown( | |
| "# TwinBooster zero-shot predictor\n" | |
| "Enter SMILES (one per line) and a bioassay description to obtain activity predictions." | |
| ) | |
| gr.Markdown( | |
| "TwinBooster fuses chemical structures and free-text assay descriptions to deliver SOTA zero-shot activity " | |
| "predictions—useful for early triage and library prioritization when assay data are scarce. " | |
| "Outputs include a table, bar chart, and CSV/XLSX downloads with predictions and confidence.\n\n" | |
| "**Reference:** Schuh, M. G.; Boldini, D.; Sieber, S. A. " | |
| "_Synergizing Chemical Structures and Bioassay Descriptions for Enhanced Molecular Property Prediction in Drug Discovery._ " | |
| "J. Chem. Inf. Model. 2024, 64, 12, 4640–4650. " | |
| "[JCIM paper](https://doi.org/10.1021/acs.jcim.4c00765)" | |
| ) | |
| with gr.Row(): | |
| smiles_box = gr.Textbox( | |
| label="SMILES list", | |
| lines=10, | |
| value=example_smiles, | |
| placeholder="One SMILES per line", | |
| ) | |
| assay_box = gr.Textbox( | |
| label="Bioassay description", | |
| lines=8, | |
| value=example_assay, | |
| placeholder="Describe the assay/task to predict.", | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button("Run prediction", variant="primary") | |
| download_btn = gr.Button("Download / refresh models") | |
| status = gr.Markdown("") | |
| table = gr.DataFrame( | |
| label="Predictions", | |
| headers=["SMILES", "Assay", "Prediction", "Confidence"], | |
| datatype=["str", "str", "number", "number"], | |
| interactive=False, | |
| ) | |
| plot = gr.Plot(label="Prediction chart") | |
| with gr.Row(): | |
| csv_out = gr.File(label="CSV download") | |
| xlsx_out = gr.File(label="Excel download") | |
| predict_btn.click( | |
| run_prediction, | |
| inputs=[smiles_box, assay_box], | |
| outputs=[table, plot, csv_out, xlsx_out], | |
| ) | |
| download_btn.click(ensure_models, outputs=status) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.queue(concurrency_count=1) | |
| demo.launch() | |