File size: 7,238 Bytes
942392b
e174389
7a91b3c
942392b
 
 
 
 
 
 
bed36cc
942392b
 
 
 
 
 
 
e174389
942392b
 
 
 
 
 
 
 
 
 
 
 
 
 
3df1a88
 
 
 
 
 
 
 
 
942392b
5468e06
e174389
 
5468e06
e174389
 
 
 
942392b
 
e174389
5468e06
 
e174389
5468e06
 
942392b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74d2a42
 
 
942392b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74d2a42
 
 
 
 
 
 
 
 
 
 
 
 
942392b
5468e06
 
942392b
 
 
 
 
 
e19f3e1
5468e06
e19f3e1
 
 
 
5468e06
e19f3e1
 
942392b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bed36cc
942392b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import tarfile
import tempfile
from pathlib import Path
from typing import List, Tuple

import gradio as gr
import pandas as pd
import plotly.express as px
from huggingface_hub import snapshot_download
import spaces
import twinbooster

tb = None
HF_TEXT_REPO = os.environ.get("HF_TEXT_REPO", "mschuh/PubChemDeBERTa-augmented")

# Store models where TwinBooster looks for them by default (~/.cache/twinbooster)
MODEL_DIR = Path.home() / ".cache" / "twinbooster"
WEIGHTS_SRC = Path(__file__).parent / "weights"


def ensure_models() -> str:
    """Download weights from Hugging Face LFS; fallback to twinbooster if already cached."""
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    os.environ.setdefault("HF_HOME", str(MODEL_DIR))
    os.environ.setdefault("HF_HUB_CACHE", str(MODEL_DIR))

    downloaded = []

    def grab(repo_id: str, subdir: str) -> bool:
        local_dir = MODEL_DIR / subdir
        if local_dir.exists():
            return False
        try:
            snapshot_download(
                repo_id=repo_id,
                local_dir=str(local_dir),
                local_dir_use_symlinks=False,
            )
            return True
        except Exception:
            return False

    def extract_local(archive: Path, subdir: str) -> None:
        target = MODEL_DIR / subdir
        if target.exists() or not archive.exists():
            return
        target.mkdir(parents=True, exist_ok=True)
        with tarfile.open(archive, "r:*") as tf:
            tf.extractall(target)

    if grab(HF_TEXT_REPO, "PubChemDeBERTa-augmented"):
        downloaded.append(HF_TEXT_REPO)

    extract_local(WEIGHTS_SRC / "lgbm_model.tar.xz", "lgbm_model")
    downloaded.append("local lgbm_model.tar.xz")

    extract_local(WEIGHTS_SRC / "bt_model.tar.xz", "bt_model")
    downloaded.append("local bt_model.tar.xz")

    # Ensure any missing pieces are resolved via package helper (will skip if already present)
    twinbooster.download_models()

    if downloaded:
        return "Downloaded from Hugging Face: " + ", ".join(downloaded)
    return "Models already present in cache."


def get_model():
    """Lazily load the TwinBooster model."""
    global tb
    if tb is None:
        ensure_models()
        tb = twinbooster.TwinBooster()
    return tb


def parse_smiles(smiles_text: str) -> List[str]:
    smiles = [line.strip() for line in smiles_text.splitlines() if line.strip()]
    if not smiles:
        raise gr.Error("Please provide at least one SMILES (one per line).")
    return smiles


@spaces.GPU(duration=120)
def _gpu_predict(smiles: List[str], assay: str) -> Tuple[pd.DataFrame, object, str, str]:
    """GPU-only path: expects prevalidated inputs and ready models."""
    model = get_model()

    try:
        preds, confs = model.predict(smiles, assay, get_confidence=True)
    except TypeError:
        preds = model.predict(smiles, assay)
        confs = [None] * len(preds)
    except Exception as exc:  # pragma: no cover - shown to user
        raise gr.Error(f"Inference failed: {exc}")

    df = pd.DataFrame(
        {
            "SMILES": smiles,
            "Assay": assay,
            "Prediction": preds,
            "Confidence": confs,
        }
    )

    fig = px.bar(
        df,
        x="SMILES",
        y="Prediction",
        color="Prediction",
        color_continuous_scale="Blues",
        range_y=[0, 1],
        title="TwinBooster predictions",
        labels={"Prediction": "Predicted activity probability"},
    )
    fig.update_layout(xaxis_tickangle=-45, height=420)

    csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
    df.to_csv(csv_file.name, index=False)
    csv_file.close()

    xlsx_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
    with pd.ExcelWriter(xlsx_file.name, engine="openpyxl") as writer:
        df.to_excel(writer, index=False)
    xlsx_file.close()

    return df, fig, csv_file.name, xlsx_file.name


def run_prediction(smiles_text: str, assay_text: str) -> Tuple[pd.DataFrame, object, str, str]:
    """CPU wrapper: validates inputs and prepares models before GPU allocation."""
    assay = assay_text.strip()
    if not assay:
        raise gr.Error("Please provide a bioassay description.")

    smiles = parse_smiles(smiles_text)
    # Download/refresh weights on CPU to keep ZeroGPU sessions short
    ensure_models()

    return _gpu_predict(smiles, assay)


def build_demo() -> gr.Blocks:
    example_smiles = "CC1=CC=C(C=C1)C2=CC(=NC3=NC=NC(=C23)N)C4=CC=C(C=C4)F\nCC(=O)C1=CC=C(C=C1)NC(=O)C2=CC3=C(C=C2)N=C(C(=N3)C4=CC=CO4)C5=CC=CO5\nCC1=C(C=C(C=C1)Cl)NC2=C/C(=N\S(=O)(=O)C3=CC=CS3)/C4=CC=CC=C4C2=O\nCC(C)C1=NC2=CC=CC=C2C(=N1)SCC(=O)N3CCCC3 "
    example_assay = "TR-FRET counterscreen for FAK inhibitors: dose-response biochemical high throughput screening assay to identify inhibitors of Proline-rich tyrosine kinase 2 (Pyk2)"

    with gr.Blocks(title="TwinBooster") as demo:
        gr.Markdown(
            "# TwinBooster zero-shot predictor\n"
            "Enter SMILES (one per line) and a bioassay description to obtain activity predictions."
        )
        gr.Markdown(
            "TwinBooster fuses chemical structures and free-text assay descriptions to deliver SOTA zero-shot activity "
            "predictions—useful for early triage and library prioritization when assay data are scarce. "
            "Outputs include a table, bar chart, and CSV/XLSX downloads with predictions and confidence.\n\n"
            "**Reference:** Schuh, M. G.; Boldini, D.; Sieber, S. A. "
            "_Synergizing Chemical Structures and Bioassay Descriptions for Enhanced Molecular Property Prediction in Drug Discovery._ "
            "J. Chem. Inf. Model. 2024, 64, 12, 4640–4650. "
            "[JCIM paper](https://doi.org/10.1021/acs.jcim.4c00765)"
        )

        with gr.Row():
            smiles_box = gr.Textbox(
                label="SMILES list",
                lines=10,
                value=example_smiles,
                placeholder="One SMILES per line",
            )
            assay_box = gr.Textbox(
                label="Bioassay description",
                lines=8,
                value=example_assay,
                placeholder="Describe the assay/task to predict.",
            )

        with gr.Row():
            predict_btn = gr.Button("Run prediction", variant="primary")
            download_btn = gr.Button("Download / refresh models")

        status = gr.Markdown("")

        table = gr.DataFrame(
            label="Predictions",
            headers=["SMILES", "Assay", "Prediction", "Confidence"],
            datatype=["str", "str", "number", "number"],
            interactive=False,
        )
        plot = gr.Plot(label="Prediction chart")

        with gr.Row():
            csv_out = gr.File(label="CSV download")
            xlsx_out = gr.File(label="Excel download")

        predict_btn.click(
            run_prediction,
            inputs=[smiles_box, assay_box],
            outputs=[table, plot, csv_out, xlsx_out],
        )
        download_btn.click(ensure_models, outputs=status)

    return demo


if __name__ == "__main__":
    demo = build_demo()
    demo.queue(concurrency_count=1)
    demo.launch()