Update app.py
Browse files
app.py
CHANGED
|
@@ -1,19 +1,5 @@
|
|
| 1 |
# app.py β VeloBind HF Spaces inference app
|
| 2 |
-
|
| 3 |
-
# Uses the exact 45 fold models that produced the reported R=0.8469 on CASF-2016.
|
| 4 |
-
# No retraining required. Upload output/models/ to HF model repo and set
|
| 5 |
-
# HF_MODEL_REPO below.
|
| 6 |
-
#
|
| 7 |
-
# HF model repo should contain:
|
| 8 |
-
# fold_model_s{seed}_{type}_f{fold}.pkl β 45 files (3 seeds Γ 3 types Γ 5 folds)
|
| 9 |
-
# meta_type_casf16.pkl β Ridge meta-learner (from 06_eval_both.py)
|
| 10 |
-
# target_scaler.pkl β TargetScaler (from 03_train.py)
|
| 11 |
-
# ligand_scaler.pkl β from output/preprocessors/
|
| 12 |
-
#
|
| 13 |
-
# Free tier: 16GB RAM, 2 vCPU, 50GB disk β all 45 models fit easily (~2-3GB total).
|
| 14 |
-
# Cold start: ~30-40s to download + load models on first visit.
|
| 15 |
-
|
| 16 |
-
import os, json, warnings, time
|
| 17 |
import numpy as np
|
| 18 |
import pandas as pd
|
| 19 |
import streamlit as st
|
|
@@ -21,7 +7,6 @@ import joblib
|
|
| 21 |
import torch
|
| 22 |
import matplotlib.pyplot as plt
|
| 23 |
from pathlib import Path
|
| 24 |
-
from scipy.stats import pearsonr
|
| 25 |
|
| 26 |
warnings.filterwarnings("ignore")
|
| 27 |
from rdkit import RDLogger
|
|
@@ -33,9 +18,8 @@ MODEL_CACHE = Path("/tmp/velobind_models")
|
|
| 33 |
SEEDS = [42, 123, 456]
|
| 34 |
MODEL_TYPES = ["lgbm", "cb", "xgb"]
|
| 35 |
N_FOLDS = 5
|
|
|
|
| 36 |
|
| 37 |
-
# Best feature config β Step 9 winner from 03_train.py ablation
|
| 38 |
-
# MUST match what the fold models were trained on
|
| 39 |
import sys
|
| 40 |
sys.path.append(str(Path(__file__).parent))
|
| 41 |
from src.features.protein import load_esm, embed_batch, sequence_features
|
|
@@ -45,34 +29,49 @@ from src.config import config
|
|
| 45 |
|
| 46 |
|
| 47 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
-
#
|
| 49 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
def load_all_models():
|
| 53 |
from huggingface_hub import hf_hub_download
|
| 54 |
MODEL_CACHE.mkdir(parents=True, exist_ok=True)
|
| 55 |
|
| 56 |
-
# Build list of all files to fetch
|
| 57 |
model_files = (
|
| 58 |
[f"fold_model_s{s}_{t}_f{f}.pkl"
|
| 59 |
for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
|
| 60 |
+ ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
|
| 61 |
)
|
| 62 |
|
| 63 |
-
|
| 64 |
for i, fname in enumerate(model_files):
|
| 65 |
local = MODEL_CACHE / fname
|
| 66 |
if not local.exists():
|
| 67 |
-
hf_hub_download(
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
progress.progress((i + 1) / len(model_files),
|
| 72 |
-
text=f"Loading {fname}β¦")
|
| 73 |
-
progress.empty()
|
| 74 |
|
| 75 |
-
# Load into nested dict: fold_models[seed][type][fold] = model
|
| 76 |
fold_models = {}
|
| 77 |
for s in SEEDS:
|
| 78 |
fold_models[s] = {}
|
|
@@ -85,27 +84,22 @@ def load_all_models():
|
|
| 85 |
meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
|
| 86 |
scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
|
| 87 |
lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
|
| 88 |
-
|
| 89 |
return fold_models, meta, scaler, lig_sc
|
| 90 |
|
| 91 |
-
|
|
|
|
| 92 |
def load_esm_model():
|
| 93 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 94 |
tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
|
| 95 |
return tokenizer, esm_model, device
|
| 96 |
|
|
|
|
| 97 |
@st.cache_resource(show_spinner=False)
|
| 98 |
def load_ad_centroid():
|
| 99 |
-
|
| 100 |
-
local_paths = [
|
| 101 |
-
Path("output/models/deployment"),
|
| 102 |
-
Path("output/models"),
|
| 103 |
-
]
|
| 104 |
-
for p in local_paths:
|
| 105 |
if (p / "ad_centroid.npy").exists():
|
| 106 |
return (np.load(p / "ad_centroid.npy"),
|
| 107 |
float(np.load(p / "ad_threshold.npy")))
|
| 108 |
-
# HF fallback
|
| 109 |
for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
|
| 110 |
local = MODEL_CACHE / fname
|
| 111 |
if not local.exists():
|
|
@@ -118,6 +112,7 @@ def load_ad_centroid():
|
|
| 118 |
return (np.load(MODEL_CACHE / "ad_centroid.npy"),
|
| 119 |
float(np.load(MODEL_CACHE / "ad_threshold.npy")))
|
| 120 |
|
|
|
|
| 121 |
def ad_check(esm_mean_vec, centroid, threshold):
|
| 122 |
if centroid is None:
|
| 123 |
return "UNKNOWN", float("nan")
|
|
@@ -126,29 +121,25 @@ def ad_check(esm_mean_vec, centroid, threshold):
|
|
| 126 |
|
| 127 |
|
| 128 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
-
# Feature
|
| 130 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
-
def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats
|
| 132 |
-
"""Matches assemble() in 06_casf_eval.py exactly β 10,054d."""
|
| 133 |
return np.concatenate([
|
| 134 |
-
esm_mean[:, -480:],
|
| 135 |
-
seq_feat,
|
| 136 |
-
lig_feats["ecfp"],
|
| 137 |
-
lig_feats["ecfp2"],
|
| 138 |
-
lig_feats["ecfp6"],
|
| 139 |
-
lig_feats["fcfp"],
|
| 140 |
-
lig_feats["estate"],
|
| 141 |
-
lig_feats["maccs"],
|
| 142 |
-
lig_feats["atom_pair"],
|
| 143 |
-
lig_feats["torsion"],
|
| 144 |
-
lig_feats["phys"],
|
| 145 |
], axis=1)
|
| 146 |
|
| 147 |
|
| 148 |
-
def extract_features(sequence
|
| 149 |
-
tokenizer, esm_model, device, lig_scaler):
|
| 150 |
-
"""Returns X [N_valid, D], valid_mask [N_smiles]."""
|
| 151 |
-
# Protein (embed once, tile)
|
| 152 |
esm_mean, esm_var, esm_attn, _ = embed_batch(
|
| 153 |
[sequence], tokenizer, esm_model,
|
| 154 |
config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
|
|
@@ -156,7 +147,6 @@ def extract_features(sequence: str, smiles_list: list,
|
|
| 156 |
)
|
| 157 |
seq_feat = np.array([sequence_features(sequence)])
|
| 158 |
|
| 159 |
-
# Ligands
|
| 160 |
lig_feats, valid_mask, _ = extract_ligand_features(
|
| 161 |
smiles_list, scaler=lig_scaler, fit_scaler=False
|
| 162 |
)
|
|
@@ -166,44 +156,34 @@ def extract_features(sequence: str, smiles_list: list,
|
|
| 166 |
bool_mask[valid_mask] = True
|
| 167 |
valid_mask = bool_mask
|
| 168 |
|
| 169 |
-
# Tile protein over valid ligands only
|
| 170 |
n_valid = int(valid_mask.sum())
|
| 171 |
-
esm_mean_t = np.tile(esm_mean,
|
| 172 |
-
esm_var_t = np.tile(esm_var,
|
| 173 |
-
esm_attn_t = np.tile(esm_attn,
|
| 174 |
-
seq_feat_t = np.tile(seq_feat,
|
| 175 |
|
| 176 |
X = assemble_from_parts(esm_mean_t, esm_var_t, esm_attn_t, seq_feat_t, lig_feats)
|
| 177 |
return X, valid_mask, esm_mean[0]
|
| 178 |
|
| 179 |
|
| 180 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 181 |
-
# Prediction
|
| 182 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 183 |
def predict(X, fold_models, meta, scaler):
|
| 184 |
-
"""
|
| 185 |
-
Returns:
|
| 186 |
-
preds [N] final ensemble pKd
|
| 187 |
-
preds_all [N, 9] per-(seed,type) predictions for uncertainty
|
| 188 |
-
"""
|
| 189 |
-
# Each entry: average over 5 folds for one (seed, type) combo
|
| 190 |
type_avgs = []
|
| 191 |
for s in SEEDS:
|
| 192 |
for t in MODEL_TYPES:
|
| 193 |
fold_preds = np.stack([
|
| 194 |
scaler.inverse(fold_models[s][t][f].predict(X))
|
| 195 |
for f in range(N_FOLDS)
|
| 196 |
-
], axis=1)
|
| 197 |
-
type_avgs.append(fold_preds.mean(axis=1))
|
| 198 |
-
|
| 199 |
-
preds_all = np.stack(type_avgs, axis=1)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
xgb_avg = preds_all[:, [2, 5, 8]].mean(axis=1)
|
| 205 |
-
preds = meta.predict(np.column_stack([lgbm_avg, cb_avg, xgb_avg]))
|
| 206 |
-
|
| 207 |
return preds, preds_all
|
| 208 |
|
| 209 |
|
|
@@ -212,34 +192,74 @@ def uncertainty_interval(preds_all, z=1.96):
|
|
| 212 |
return preds_all.mean(axis=1) - z * std, preds_all.mean(axis=1) + z * std
|
| 213 |
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 216 |
# Plots
|
| 217 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 218 |
-
def bar_chart(names, preds, lo, hi, title):
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
bars = ax.bar(x, preds, color="#4C72B0", alpha=0.85, width=0.6,
|
| 223 |
-
yerr=err, capsize=5, error_kw=dict(ecolor=
|
| 224 |
ax.set_xticks(x)
|
| 225 |
-
ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10)
|
| 226 |
-
ax.set_ylabel("Predicted pKd", fontsize=11)
|
| 227 |
-
ax.set_title(title, fontsize=12, fontweight='bold')
|
| 228 |
-
ax.
|
|
|
|
|
|
|
| 229 |
for bar, val in zip(bars, preds):
|
| 230 |
ax.text(bar.get_x() + bar.get_width() / 2,
|
| 231 |
bar.get_height() + 0.05, f"{val:.2f}",
|
| 232 |
-
ha='center', va='bottom', fontsize=9,
|
|
|
|
| 233 |
plt.tight_layout()
|
| 234 |
return fig
|
| 235 |
|
| 236 |
|
| 237 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 238 |
-
#
|
| 239 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
-
st.set_page_config(page_title="VeloBind",
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
def load_svg_b64(path):
|
| 245 |
with open(path, "rb") as f:
|
|
@@ -249,6 +269,7 @@ logo_b64 = load_svg_b64("logo.svg")
|
|
| 249 |
|
| 250 |
st.markdown(f"""
|
| 251 |
<style>
|
|
|
|
| 252 |
.header-wrap {{
|
| 253 |
display: flex; align-items: center; gap: 1.5rem;
|
| 254 |
margin-bottom: 1.5rem;
|
|
@@ -259,17 +280,17 @@ st.markdown(f"""
|
|
| 259 |
}}
|
| 260 |
.logo-box img {{ height: 130px; width: auto; display: block; }}
|
| 261 |
.header-text {{
|
| 262 |
-
background:
|
| 263 |
padding: 1.5rem 2rem; border-radius: 12px; flex: 1;
|
| 264 |
}}
|
| 265 |
.header-text h1 {{ color: #fff; font-size: 2.2rem; margin: 0; }}
|
| 266 |
-
.header-text p {{ color: #
|
| 267 |
.metric-card {{
|
| 268 |
-
background:
|
| 269 |
border-radius: 10px; padding: 1rem; text-align: center;
|
| 270 |
}}
|
| 271 |
-
.metric-val {{ font-size: 2rem; font-weight: 700; color:
|
| 272 |
-
.metric-lab {{ font-size: 0.8rem; color:
|
| 273 |
.ad-in {{ background:#1b4332; border:1px solid #2d6a4f; color:#40916c;
|
| 274 |
border-radius:8px; padding:0.4rem 1rem; font-weight:700; display:inline-block; }}
|
| 275 |
.ad-out {{ background:#4a1c24; border:1px solid #9b2335; color:#e74c3c;
|
|
@@ -283,24 +304,25 @@ st.markdown(f"""
|
|
| 283 |
</div>
|
| 284 |
<div class="header-text">
|
| 285 |
<h1>VeloBind</h1>
|
| 286 |
-
<p>Structure-free protein
|
| 287 |
-
|
|
|
|
|
|
|
| 288 |
</div>
|
| 289 |
</div>
|
| 290 |
""", unsafe_allow_html=True)
|
| 291 |
|
| 292 |
-
# ββ Load
|
| 293 |
fold_models, meta, target_scaler, lig_scaler = load_all_models()
|
| 294 |
tokenizer, esm_model, device = load_esm_model()
|
|
|
|
| 295 |
n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
|
| 296 |
st.success(f"β {n_loaded} fold models loaded | Device: {device.upper()}")
|
| 297 |
|
| 298 |
# ββ Mode selector βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 299 |
mode = st.radio(
|
| 300 |
"Select mode",
|
| 301 |
-
["
|
| 302 |
-
"π Batch screening (CSV)",
|
| 303 |
-
"π― One compound vs. multiple targets"],
|
| 304 |
horizontal=True,
|
| 305 |
)
|
| 306 |
st.markdown("---")
|
|
@@ -309,13 +331,17 @@ st.markdown("---")
|
|
| 309 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
# MODE 1 β Single query
|
| 311 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 312 |
-
if mode == "
|
| 313 |
|
| 314 |
col_p, col_l = st.columns(2)
|
| 315 |
with col_p:
|
| 316 |
st.subheader("Protein")
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
with col_l:
|
| 320 |
st.subheader("Ligand")
|
| 321 |
smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O")
|
|
@@ -329,15 +355,18 @@ if mode == "π¬ Single query":
|
|
| 329 |
if chosen != "β":
|
| 330 |
smi = examples[chosen]
|
| 331 |
|
| 332 |
-
if st.button("Predict
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
| 335 |
else:
|
| 336 |
-
with st.spinner("Running inference
|
| 337 |
t0 = time.time()
|
| 338 |
try:
|
| 339 |
X, valid, esm_vec = extract_features(
|
| 340 |
-
seq
|
| 341 |
tokenizer, esm_model, device, lig_scaler
|
| 342 |
)
|
| 343 |
if not valid.any():
|
|
@@ -361,17 +390,14 @@ if mode == "π¬ Single query":
|
|
| 361 |
<div class="metric-lab">95% model interval (Β±1.96Ο, 45 models)</div>
|
| 362 |
</div>""", unsafe_allow_html=True)
|
| 363 |
with c3:
|
| 364 |
-
Ki = 10 ** (9 - pkd)
|
| 365 |
st.markdown(f"""<div class="metric-card">
|
| 366 |
-
<div class="metric-val">{
|
| 367 |
-
<div class="metric-lab">Estimated
|
| 368 |
</div>""", unsafe_allow_html=True)
|
| 369 |
-
ad_centroid, ad_threshold = load_ad_centroid()
|
| 370 |
-
ad_label, ad_dist = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 371 |
-
|
| 372 |
with c4:
|
| 373 |
-
|
| 374 |
-
|
|
|
|
| 375 |
st.markdown(f"""<div class="metric-card">
|
| 376 |
<div class="{ad_cls}">{ad_label}</div>
|
| 377 |
<div class="metric-lab">Applicability domain</div>
|
|
@@ -379,22 +405,22 @@ if mode == "π¬ Single query":
|
|
| 379 |
|
| 380 |
if ad_label == "OUT OF DOMAIN":
|
| 381 |
st.warning("Protein is outside the training distribution. "
|
| 382 |
-
|
| 383 |
|
| 384 |
st.caption(
|
| 385 |
f"Inference time: {elapsed:.2f}s | "
|
| 386 |
-
f"45-model ensemble (3 seeds
|
| 387 |
f"Device: {device.upper()}"
|
| 388 |
)
|
| 389 |
|
| 390 |
with st.expander("Per-model breakdown"):
|
| 391 |
labels = [f"s{s}_{t}" for s in SEEDS for t in MODEL_TYPES]
|
| 392 |
fig = bar_chart(
|
| 393 |
-
labels,
|
| 394 |
-
preds_all[0],
|
| 395 |
preds_all[0] - preds_all[0].std(),
|
| 396 |
preds_all[0] + preds_all[0].std(),
|
| 397 |
-
"Seed
|
|
|
|
| 398 |
)
|
| 399 |
st.pyplot(fig, use_container_width=True)
|
| 400 |
plt.close(fig)
|
|
@@ -407,7 +433,7 @@ if mode == "π¬ Single query":
|
|
| 407 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 408 |
# MODE 2 β Batch screening
|
| 409 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 410 |
-
elif mode == "
|
| 411 |
|
| 412 |
st.subheader("Batch Screening")
|
| 413 |
st.markdown("One protein, many compounds. Upload a CSV with a `smiles` column "
|
|
@@ -415,8 +441,8 @@ elif mode == "π Batch screening (CSV)":
|
|
| 415 |
|
| 416 |
col_seq, col_csv = st.columns(2)
|
| 417 |
with col_seq:
|
| 418 |
-
|
| 419 |
-
|
| 420 |
with col_csv:
|
| 421 |
uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"])
|
| 422 |
st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
|
|
@@ -424,9 +450,10 @@ elif mode == "π Batch screening (CSV)":
|
|
| 424 |
max_cpds = st.slider("Max compounds", 10, 500, 100,
|
| 425 |
help="~1s per compound on CPU free tier.")
|
| 426 |
|
| 427 |
-
if st.button("Run batch screening
|
| 428 |
-
|
| 429 |
-
|
|
|
|
| 430 |
elif uploaded is None:
|
| 431 |
st.error("Please upload a CSV file.")
|
| 432 |
else:
|
|
@@ -440,22 +467,16 @@ elif mode == "π Batch screening (CSV)":
|
|
| 440 |
names_list = (df_in['name'].tolist() if 'name' in df_in.columns
|
| 441 |
else [f"cpd_{i}" for i in range(len(df_in))])
|
| 442 |
|
| 443 |
-
|
| 444 |
-
with st.spinner(f"Screening {len(smiles_list)} compoundsβ¦"):
|
| 445 |
t0 = time.time()
|
| 446 |
X, valid, esm_vec = extract_features(
|
| 447 |
-
batch_seq
|
| 448 |
tokenizer, esm_model, device, lig_scaler
|
| 449 |
)
|
| 450 |
-
|
| 451 |
-
for i, smiles in enumerate(smiles_list):
|
| 452 |
-
if valid[i]:
|
| 453 |
-
label, _ = ad_check(esm_vec, ad_centroid, ad_threshold)
|
| 454 |
-
ad_labels.append(label)
|
| 455 |
-
|
| 456 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 457 |
-
lo, hi
|
| 458 |
-
elapsed
|
| 459 |
|
| 460 |
valid_names = [names_list[i] for i in range(len(names_list)) if valid[i]]
|
| 461 |
valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
|
|
@@ -467,12 +488,16 @@ elif mode == "π Batch screening (CSV)":
|
|
| 467 |
'pKd_pred': np.round(preds, 3),
|
| 468 |
'CI_lo': np.round(lo, 3),
|
| 469 |
'CI_hi': np.round(hi, 3),
|
| 470 |
-
'
|
| 471 |
'model_std': np.round(preds_all.std(axis=1), 3),
|
| 472 |
-
'AD'
|
| 473 |
}).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
|
| 474 |
results_df.insert(0, 'rank', range(1, len(results_df) + 1))
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
st.success(
|
| 477 |
f"β {len(results_df)} compounds in {elapsed:.1f}s "
|
| 478 |
f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
|
|
@@ -486,7 +511,8 @@ elif mode == "π Batch screening (CSV)":
|
|
| 486 |
top_df['pKd_pred'].values,
|
| 487 |
top_df['CI_lo'].values,
|
| 488 |
top_df['CI_hi'].values,
|
| 489 |
-
f"Top {top_n} hits"
|
|
|
|
| 490 |
)
|
| 491 |
st.pyplot(fig, use_container_width=True)
|
| 492 |
plt.close(fig)
|
|
@@ -496,7 +522,7 @@ elif mode == "π Batch screening (CSV)":
|
|
| 496 |
use_container_width=True, height=400,
|
| 497 |
)
|
| 498 |
st.download_button(
|
| 499 |
-
"
|
| 500 |
results_df.to_csv(index=False).encode(),
|
| 501 |
file_name="velobind_screening.csv",
|
| 502 |
mime="text/csv",
|
|
@@ -506,7 +532,7 @@ elif mode == "π Batch screening (CSV)":
|
|
| 506 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 507 |
# MODE 3 β One compound vs. multiple targets
|
| 508 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 509 |
-
elif mode == "
|
| 510 |
|
| 511 |
st.subheader("Selectivity Profiling")
|
| 512 |
st.markdown("One SMILES, multiple proteins β ranked by predicted pKd. "
|
|
@@ -517,33 +543,37 @@ elif mode == "π― One compound vs. multiple targets":
|
|
| 517 |
multi_seqs = st.text_area(
|
| 518 |
"Target proteins (one per line)",
|
| 519 |
height=250,
|
| 520 |
-
placeholder=
|
| 521 |
-
"ABL1: MGPSENDPNLFVALY...\n"
|
| 522 |
-
"EGFR: MRPSGTAGAALLALL...\n"
|
| 523 |
-
"CDK2: MENFQKVEKIGEGTY..."
|
| 524 |
-
),
|
| 525 |
)
|
| 526 |
|
| 527 |
-
if st.button("Run selectivity profiling
|
| 528 |
if not multi_smi.strip() or not multi_seqs.strip():
|
| 529 |
st.error("Please enter a SMILES and at least one protein sequence.")
|
| 530 |
else:
|
| 531 |
targets = {}
|
|
|
|
| 532 |
for i, line in enumerate(multi_seqs.strip().splitlines()):
|
| 533 |
line = line.strip()
|
| 534 |
if not line:
|
| 535 |
continue
|
| 536 |
if ":" in line:
|
| 537 |
-
name,
|
| 538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
else:
|
| 540 |
-
targets[
|
| 541 |
|
|
|
|
|
|
|
|
|
|
| 542 |
if not targets:
|
| 543 |
-
st.error("
|
| 544 |
st.stop()
|
| 545 |
|
| 546 |
-
ad_centroid, ad_threshold = load_ad_centroid()
|
| 547 |
results, progress = [], st.progress(0)
|
| 548 |
for idx, (name, seq) in enumerate(targets.items()):
|
| 549 |
try:
|
|
@@ -554,15 +584,15 @@ elif mode == "π― One compound vs. multiple targets":
|
|
| 554 |
if valid.any():
|
| 555 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 556 |
lo, hi = uncertainty_interval(preds_all)
|
| 557 |
-
ad_label, _ = ad_check(esm_vec, ad_centroid, ad_threshold)
|
| 558 |
results.append({
|
| 559 |
'Target': name,
|
| 560 |
'pKd_pred': round(float(preds[0]), 3),
|
| 561 |
'CI_lo': round(float(lo[0]), 3),
|
| 562 |
'CI_hi': round(float(hi[0]), 3),
|
| 563 |
-
'
|
| 564 |
'model_std': round(float(preds_all.std()), 3),
|
| 565 |
-
'AD':
|
| 566 |
})
|
| 567 |
except Exception as e:
|
| 568 |
st.warning(f"Skipped {name}: {e}")
|
|
@@ -576,30 +606,35 @@ elif mode == "π― One compound vs. multiple targets":
|
|
| 576 |
)
|
| 577 |
res_df.insert(0, 'rank', range(1, len(res_df) + 1))
|
| 578 |
|
| 579 |
-
st.success(f"
|
| 580 |
fig = bar_chart(
|
| 581 |
res_df['Target'].tolist(),
|
| 582 |
res_df['pKd_pred'].values,
|
| 583 |
res_df['CI_lo'].values,
|
| 584 |
res_df['CI_hi'].values,
|
| 585 |
-
"Selectivity profile β predicted pKd by target"
|
|
|
|
| 586 |
)
|
| 587 |
st.pyplot(fig, use_container_width=True)
|
| 588 |
plt.close(fig)
|
| 589 |
|
| 590 |
st.dataframe(res_df, use_container_width=True)
|
| 591 |
st.download_button(
|
| 592 |
-
"
|
| 593 |
res_df.to_csv(index=False).encode(),
|
| 594 |
file_name="velobind_selectivity.csv",
|
| 595 |
mime="text/csv",
|
| 596 |
)
|
| 597 |
|
|
|
|
| 598 |
# ββ Footer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 599 |
st.markdown("---")
|
| 600 |
-
st.markdown("""
|
| 601 |
-
<div style="color:
|
| 602 |
-
VeloBind Β· Structure-free binding affinity Β·
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
| 604 |
</div>
|
| 605 |
""", unsafe_allow_html=True)
|
|
|
|
| 1 |
# app.py β VeloBind HF Spaces inference app
|
| 2 |
+
import os, warnings, time, base64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
import streamlit as st
|
|
|
|
| 7 |
import torch
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from pathlib import Path
|
|
|
|
| 10 |
|
| 11 |
warnings.filterwarnings("ignore")
|
| 12 |
from rdkit import RDLogger
|
|
|
|
| 18 |
SEEDS = [42, 123, 456]
|
| 19 |
MODEL_TYPES = ["lgbm", "cb", "xgb"]
|
| 20 |
N_FOLDS = 5
|
| 21 |
+
VALID_AA = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwyX")
|
| 22 |
|
|
|
|
|
|
|
| 23 |
import sys
|
| 24 |
sys.path.append(str(Path(__file__).parent))
|
| 25 |
from src.features.protein import load_esm, embed_batch, sequence_features
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
# Validation
|
| 33 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
def validate_sequence(raw: str):
|
| 35 |
+
raw = raw.strip()
|
| 36 |
+
if not raw:
|
| 37 |
+
return None, "Please enter a sequence."
|
| 38 |
+
|
| 39 |
+
# Strip FASTA header(s)
|
| 40 |
+
lines = raw.splitlines()
|
| 41 |
+
seq_lines = [l.strip() for l in lines if not l.startswith(">")]
|
| 42 |
+
seq = "".join(seq_lines).upper().replace(" ", "")
|
| 43 |
+
|
| 44 |
+
if len(seq) < 10:
|
| 45 |
+
return None, "Sequence too short (minimum 10 residues)."
|
| 46 |
+
invalid = set(seq) - VALID_AA
|
| 47 |
+
if invalid:
|
| 48 |
+
return None, f"Invalid characters: {', '.join(sorted(invalid))}. Only standard amino acid letters accepted."
|
| 49 |
+
return seq, None
|
| 50 |
+
|
| 51 |
|
| 52 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
# Model loading
|
| 54 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
@st.cache_resource(show_spinner="Downloading and loading VeloBind models (first run ~30s)...")
|
| 56 |
def load_all_models():
|
| 57 |
from huggingface_hub import hf_hub_download
|
| 58 |
MODEL_CACHE.mkdir(parents=True, exist_ok=True)
|
| 59 |
|
|
|
|
| 60 |
model_files = (
|
| 61 |
[f"fold_model_s{s}_{t}_f{f}.pkl"
|
| 62 |
for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
|
| 63 |
+ ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
|
| 64 |
)
|
| 65 |
|
| 66 |
+
bar = st.progress(0, text="Loading models...")
|
| 67 |
for i, fname in enumerate(model_files):
|
| 68 |
local = MODEL_CACHE / fname
|
| 69 |
if not local.exists():
|
| 70 |
+
hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
|
| 71 |
+
local_dir=str(MODEL_CACHE))
|
| 72 |
+
bar.progress((i + 1) / len(model_files), text=f"Loading {fname}...")
|
| 73 |
+
bar.empty()
|
|
|
|
|
|
|
|
|
|
| 74 |
|
|
|
|
| 75 |
fold_models = {}
|
| 76 |
for s in SEEDS:
|
| 77 |
fold_models[s] = {}
|
|
|
|
| 84 |
meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
|
| 85 |
scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
|
| 86 |
lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
|
|
|
|
| 87 |
return fold_models, meta, scaler, lig_sc
|
| 88 |
|
| 89 |
+
|
| 90 |
+
@st.cache_resource(show_spinner="Loading ESM-2 protein language model...")
|
| 91 |
def load_esm_model():
|
| 92 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 93 |
tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
|
| 94 |
return tokenizer, esm_model, device
|
| 95 |
|
| 96 |
+
|
| 97 |
@st.cache_resource(show_spinner=False)
|
| 98 |
def load_ad_centroid():
|
| 99 |
+
for p in [Path("output/models/deployment"), Path("output/models")]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
if (p / "ad_centroid.npy").exists():
|
| 101 |
return (np.load(p / "ad_centroid.npy"),
|
| 102 |
float(np.load(p / "ad_threshold.npy")))
|
|
|
|
| 103 |
for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
|
| 104 |
local = MODEL_CACHE / fname
|
| 105 |
if not local.exists():
|
|
|
|
| 112 |
return (np.load(MODEL_CACHE / "ad_centroid.npy"),
|
| 113 |
float(np.load(MODEL_CACHE / "ad_threshold.npy")))
|
| 114 |
|
| 115 |
+
|
| 116 |
def ad_check(esm_mean_vec, centroid, threshold):
|
| 117 |
if centroid is None:
|
| 118 |
return "UNKNOWN", float("nan")
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
+
# Feature extraction
|
| 125 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
|
|
|
|
| 127 |
return np.concatenate([
|
| 128 |
+
esm_mean[:, -480:],
|
| 129 |
+
seq_feat,
|
| 130 |
+
lig_feats["ecfp"],
|
| 131 |
+
lig_feats["ecfp2"],
|
| 132 |
+
lig_feats["ecfp6"],
|
| 133 |
+
lig_feats["fcfp"],
|
| 134 |
+
lig_feats["estate"],
|
| 135 |
+
lig_feats["maccs"],
|
| 136 |
+
lig_feats["atom_pair"],
|
| 137 |
+
lig_feats["torsion"],
|
| 138 |
+
lig_feats["phys"],
|
| 139 |
], axis=1)
|
| 140 |
|
| 141 |
|
| 142 |
+
def extract_features(sequence, smiles_list, tokenizer, esm_model, device, lig_scaler):
|
|
|
|
|
|
|
|
|
|
| 143 |
esm_mean, esm_var, esm_attn, _ = embed_batch(
|
| 144 |
[sequence], tokenizer, esm_model,
|
| 145 |
config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
|
|
|
|
| 147 |
)
|
| 148 |
seq_feat = np.array([sequence_features(sequence)])
|
| 149 |
|
|
|
|
| 150 |
lig_feats, valid_mask, _ = extract_ligand_features(
|
| 151 |
smiles_list, scaler=lig_scaler, fit_scaler=False
|
| 152 |
)
|
|
|
|
| 156 |
bool_mask[valid_mask] = True
|
| 157 |
valid_mask = bool_mask
|
| 158 |
|
|
|
|
| 159 |
n_valid = int(valid_mask.sum())
|
| 160 |
+
esm_mean_t = np.tile(esm_mean, (n_valid, 1))
|
| 161 |
+
esm_var_t = np.tile(esm_var, (n_valid, 1))
|
| 162 |
+
esm_attn_t = np.tile(esm_attn, (n_valid, 1))
|
| 163 |
+
seq_feat_t = np.tile(seq_feat, (n_valid, 1))
|
| 164 |
|
| 165 |
X = assemble_from_parts(esm_mean_t, esm_var_t, esm_attn_t, seq_feat_t, lig_feats)
|
| 166 |
return X, valid_mask, esm_mean[0]
|
| 167 |
|
| 168 |
|
| 169 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 170 |
+
# Prediction
|
| 171 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 172 |
def predict(X, fold_models, meta, scaler):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
type_avgs = []
|
| 174 |
for s in SEEDS:
|
| 175 |
for t in MODEL_TYPES:
|
| 176 |
fold_preds = np.stack([
|
| 177 |
scaler.inverse(fold_models[s][t][f].predict(X))
|
| 178 |
for f in range(N_FOLDS)
|
| 179 |
+
], axis=1)
|
| 180 |
+
type_avgs.append(fold_preds.mean(axis=1))
|
| 181 |
+
|
| 182 |
+
preds_all = np.stack(type_avgs, axis=1)
|
| 183 |
+
lgbm_avg = preds_all[:, [0, 3, 6]].mean(axis=1)
|
| 184 |
+
cb_avg = preds_all[:, [1, 4, 7]].mean(axis=1)
|
| 185 |
+
xgb_avg = preds_all[:, [2, 5, 8]].mean(axis=1)
|
| 186 |
+
preds = meta.predict(np.column_stack([lgbm_avg, cb_avg, xgb_avg]))
|
|
|
|
|
|
|
|
|
|
| 187 |
return preds, preds_all
|
| 188 |
|
| 189 |
|
|
|
|
| 192 |
return preds_all.mean(axis=1) - z * std, preds_all.mean(axis=1) + z * std
|
| 193 |
|
| 194 |
|
| 195 |
+
def format_ki(pkd):
|
| 196 |
+
"""Format Ki with appropriate unit (nM, uM, mM)."""
|
| 197 |
+
ki_nM = 10 ** (9 - pkd)
|
| 198 |
+
if ki_nM < 1000:
|
| 199 |
+
return f"{ki_nM:.1f} nM"
|
| 200 |
+
elif ki_nM < 1_000_000:
|
| 201 |
+
return f"{ki_nM/1000:.2f} uM"
|
| 202 |
+
else:
|
| 203 |
+
return f"{ki_nM/1_000_000:.2f} mM"
|
| 204 |
+
|
| 205 |
+
|
| 206 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 207 |
# Plots
|
| 208 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 209 |
+
def bar_chart(names, preds, lo, hi, title, dark=True):
|
| 210 |
+
bg = "#1e2a38" if dark else "#f8f9fa"
|
| 211 |
+
fg = "#ffffff" if dark else "#111111"
|
| 212 |
+
grid = "#2d3f55" if dark else "#cccccc"
|
| 213 |
+
|
| 214 |
+
fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4),
|
| 215 |
+
facecolor=bg)
|
| 216 |
+
ax.set_facecolor(bg)
|
| 217 |
+
x = np.arange(len(names))
|
| 218 |
+
err = [preds - lo, hi - preds]
|
| 219 |
bars = ax.bar(x, preds, color="#4C72B0", alpha=0.85, width=0.6,
|
| 220 |
+
yerr=err, capsize=5, error_kw=dict(ecolor=fg, lw=1.5))
|
| 221 |
ax.set_xticks(x)
|
| 222 |
+
ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10, color=fg)
|
| 223 |
+
ax.set_ylabel("Predicted pKd", fontsize=11, color=fg)
|
| 224 |
+
ax.set_title(title, fontsize=12, fontweight='bold', color=fg)
|
| 225 |
+
ax.tick_params(colors=fg)
|
| 226 |
+
ax.spines[:].set_color(grid)
|
| 227 |
+
ax.grid(True, axis='y', alpha=0.25, color=grid)
|
| 228 |
for bar, val in zip(bars, preds):
|
| 229 |
ax.text(bar.get_x() + bar.get_width() / 2,
|
| 230 |
bar.get_height() + 0.05, f"{val:.2f}",
|
| 231 |
+
ha='center', va='bottom', fontsize=9,
|
| 232 |
+
fontweight='bold', color=fg)
|
| 233 |
plt.tight_layout()
|
| 234 |
return fig
|
| 235 |
|
| 236 |
|
| 237 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 238 |
+
# Page setup
|
| 239 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
+
st.set_page_config(page_title="VeloBind", layout="wide")
|
| 241 |
+
|
| 242 |
+
# ββ Theme toggle ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 243 |
+
with st.sidebar:
|
| 244 |
+
st.markdown("### Display")
|
| 245 |
+
dark_mode = st.toggle("Dark mode", value=True)
|
| 246 |
+
|
| 247 |
+
if dark_mode:
|
| 248 |
+
header_bg = "linear-gradient(135deg, #1a3a5c, #1e6091, #2980b9)"
|
| 249 |
+
card_bg = "#1e2a38"
|
| 250 |
+
card_border = "#2d3f55"
|
| 251 |
+
val_color = "#4fc3f7"
|
| 252 |
+
lab_color = "#aaa"
|
| 253 |
+
page_bg = "#0e1117"
|
| 254 |
+
text_color = "#ffffff"
|
| 255 |
+
else:
|
| 256 |
+
header_bg = "linear-gradient(135deg, #2980b9, #5dade2, #85c1e9)"
|
| 257 |
+
card_bg = "#f0f4f8"
|
| 258 |
+
card_border = "#b0c4de"
|
| 259 |
+
val_color = "#1a5276"
|
| 260 |
+
lab_color = "#555"
|
| 261 |
+
page_bg = "#ffffff"
|
| 262 |
+
text_color = "#111111"
|
| 263 |
|
| 264 |
def load_svg_b64(path):
|
| 265 |
with open(path, "rb") as f:
|
|
|
|
| 269 |
|
| 270 |
st.markdown(f"""
|
| 271 |
<style>
|
| 272 |
+
.stApp {{ background-color: {page_bg}; color: {text_color}; }}
|
| 273 |
.header-wrap {{
|
| 274 |
display: flex; align-items: center; gap: 1.5rem;
|
| 275 |
margin-bottom: 1.5rem;
|
|
|
|
| 280 |
}}
|
| 281 |
.logo-box img {{ height: 130px; width: auto; display: block; }}
|
| 282 |
.header-text {{
|
| 283 |
+
background: {header_bg};
|
| 284 |
padding: 1.5rem 2rem; border-radius: 12px; flex: 1;
|
| 285 |
}}
|
| 286 |
.header-text h1 {{ color: #fff; font-size: 2.2rem; margin: 0; }}
|
| 287 |
+
.header-text p {{ color: #d6eaf8; margin: 0.3rem 0 0; font-size: 1rem; }}
|
| 288 |
.metric-card {{
|
| 289 |
+
background: {card_bg}; border: 1px solid {card_border};
|
| 290 |
border-radius: 10px; padding: 1rem; text-align: center;
|
| 291 |
}}
|
| 292 |
+
.metric-val {{ font-size: 2rem; font-weight: 700; color: {val_color}; }}
|
| 293 |
+
.metric-lab {{ font-size: 0.8rem; color: {lab_color}; margin-top: 0.2rem; }}
|
| 294 |
.ad-in {{ background:#1b4332; border:1px solid #2d6a4f; color:#40916c;
|
| 295 |
border-radius:8px; padding:0.4rem 1rem; font-weight:700; display:inline-block; }}
|
| 296 |
.ad-out {{ background:#4a1c24; border:1px solid #9b2335; color:#e74c3c;
|
|
|
|
| 304 |
</div>
|
| 305 |
<div class="header-text">
|
| 306 |
<h1>VeloBind</h1>
|
| 307 |
+
<p>Structure-free protein-ligand binding affinity prediction Β·
|
| 308 |
+
Sequence + SMILES only Β·
|
| 309 |
+
Pearson R = 0.8469 on CASF-2016 Β·
|
| 310 |
+
45-model ensemble (LGBM + CatBoost + XGBoost)</p>
|
| 311 |
</div>
|
| 312 |
</div>
|
| 313 |
""", unsafe_allow_html=True)
|
| 314 |
|
| 315 |
+
# ββ Load everything βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 316 |
fold_models, meta, target_scaler, lig_scaler = load_all_models()
|
| 317 |
tokenizer, esm_model, device = load_esm_model()
|
| 318 |
+
ad_centroid, ad_threshold = load_ad_centroid()
|
| 319 |
n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
|
| 320 |
st.success(f"β {n_loaded} fold models loaded | Device: {device.upper()}")
|
| 321 |
|
| 322 |
# ββ Mode selector βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 323 |
mode = st.radio(
|
| 324 |
"Select mode",
|
| 325 |
+
["Single query", "Batch screening (CSV)", "One compound vs. multiple targets"],
|
|
|
|
|
|
|
| 326 |
horizontal=True,
|
| 327 |
)
|
| 328 |
st.markdown("---")
|
|
|
|
| 331 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 332 |
# MODE 1 β Single query
|
| 333 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 334 |
+
if mode == "Single query":
|
| 335 |
|
| 336 |
col_p, col_l = st.columns(2)
|
| 337 |
with col_p:
|
| 338 |
st.subheader("Protein")
|
| 339 |
+
seq_raw = st.text_area(
|
| 340 |
+
"Amino acid sequence (single-letter FASTA, no header)",
|
| 341 |
+
height=150,
|
| 342 |
+
placeholder="MKTAYIAKQRQISFVK...",
|
| 343 |
+
help="Only standard amino acid letters accepted (A C D E F G H I K L M N P Q R S T V W Y)."
|
| 344 |
+
)
|
| 345 |
with col_l:
|
| 346 |
st.subheader("Ligand")
|
| 347 |
smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O")
|
|
|
|
| 355 |
if chosen != "β":
|
| 356 |
smi = examples[chosen]
|
| 357 |
|
| 358 |
+
if st.button("Predict", type="primary", use_container_width=True):
|
| 359 |
+
seq, err = validate_sequence(seq_raw)
|
| 360 |
+
if err:
|
| 361 |
+
st.error(err)
|
| 362 |
+
elif not smi.strip():
|
| 363 |
+
st.error("Please enter a SMILES string.")
|
| 364 |
else:
|
| 365 |
+
with st.spinner("Running inference..."):
|
| 366 |
t0 = time.time()
|
| 367 |
try:
|
| 368 |
X, valid, esm_vec = extract_features(
|
| 369 |
+
seq, [smi.strip()],
|
| 370 |
tokenizer, esm_model, device, lig_scaler
|
| 371 |
)
|
| 372 |
if not valid.any():
|
|
|
|
| 390 |
<div class="metric-lab">95% model interval (Β±1.96Ο, 45 models)</div>
|
| 391 |
</div>""", unsafe_allow_html=True)
|
| 392 |
with c3:
|
|
|
|
| 393 |
st.markdown(f"""<div class="metric-card">
|
| 394 |
+
<div class="metric-val">{format_ki(pkd)}</div>
|
| 395 |
+
<div class="metric-lab">Estimated Ki (pKd β pKi assumed)</div>
|
| 396 |
</div>""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
| 397 |
with c4:
|
| 398 |
+
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 399 |
+
ad_cls = ("ad-in" if ad_label == "IN DOMAIN" else
|
| 400 |
+
"ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk")
|
| 401 |
st.markdown(f"""<div class="metric-card">
|
| 402 |
<div class="{ad_cls}">{ad_label}</div>
|
| 403 |
<div class="metric-lab">Applicability domain</div>
|
|
|
|
| 405 |
|
| 406 |
if ad_label == "OUT OF DOMAIN":
|
| 407 |
st.warning("Protein is outside the training distribution. "
|
| 408 |
+
"Predictions may be unreliable.")
|
| 409 |
|
| 410 |
st.caption(
|
| 411 |
f"Inference time: {elapsed:.2f}s | "
|
| 412 |
+
f"45-model ensemble (3 seeds x 3 types x 5 folds) | "
|
| 413 |
f"Device: {device.upper()}"
|
| 414 |
)
|
| 415 |
|
| 416 |
with st.expander("Per-model breakdown"):
|
| 417 |
labels = [f"s{s}_{t}" for s in SEEDS for t in MODEL_TYPES]
|
| 418 |
fig = bar_chart(
|
| 419 |
+
labels, preds_all[0],
|
|
|
|
| 420 |
preds_all[0] - preds_all[0].std(),
|
| 421 |
preds_all[0] + preds_all[0].std(),
|
| 422 |
+
"Seed x type predictions (fold-averaged)",
|
| 423 |
+
dark=dark_mode,
|
| 424 |
)
|
| 425 |
st.pyplot(fig, use_container_width=True)
|
| 426 |
plt.close(fig)
|
|
|
|
| 433 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 434 |
# MODE 2 β Batch screening
|
| 435 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 436 |
+
elif mode == "Batch screening (CSV)":
|
| 437 |
|
| 438 |
st.subheader("Batch Screening")
|
| 439 |
st.markdown("One protein, many compounds. Upload a CSV with a `smiles` column "
|
|
|
|
| 441 |
|
| 442 |
col_seq, col_csv = st.columns(2)
|
| 443 |
with col_seq:
|
| 444 |
+
batch_seq_raw = st.text_area("Target protein sequence", height=180,
|
| 445 |
+
placeholder="Paste UniProt sequence...")
|
| 446 |
with col_csv:
|
| 447 |
uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"])
|
| 448 |
st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
|
|
|
|
| 450 |
max_cpds = st.slider("Max compounds", 10, 500, 100,
|
| 451 |
help="~1s per compound on CPU free tier.")
|
| 452 |
|
| 453 |
+
if st.button("Run batch screening", type="primary", use_container_width=True):
|
| 454 |
+
batch_seq, err = validate_sequence(batch_seq_raw)
|
| 455 |
+
if err:
|
| 456 |
+
st.error(err)
|
| 457 |
elif uploaded is None:
|
| 458 |
st.error("Please upload a CSV file.")
|
| 459 |
else:
|
|
|
|
| 467 |
names_list = (df_in['name'].tolist() if 'name' in df_in.columns
|
| 468 |
else [f"cpd_{i}" for i in range(len(df_in))])
|
| 469 |
|
| 470 |
+
with st.spinner(f"Screening {len(smiles_list)} compounds..."):
|
|
|
|
| 471 |
t0 = time.time()
|
| 472 |
X, valid, esm_vec = extract_features(
|
| 473 |
+
batch_seq, smiles_list,
|
| 474 |
tokenizer, esm_model, device, lig_scaler
|
| 475 |
)
|
| 476 |
+
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 478 |
+
lo, hi = uncertainty_interval(preds_all)
|
| 479 |
+
elapsed = time.time() - t0
|
| 480 |
|
| 481 |
valid_names = [names_list[i] for i in range(len(names_list)) if valid[i]]
|
| 482 |
valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
|
|
|
|
| 488 |
'pKd_pred': np.round(preds, 3),
|
| 489 |
'CI_lo': np.round(lo, 3),
|
| 490 |
'CI_hi': np.round(hi, 3),
|
| 491 |
+
'Ki_est': [format_ki(p) for p in preds],
|
| 492 |
'model_std': np.round(preds_all.std(axis=1), 3),
|
| 493 |
+
'AD': [ad_label] * len(valid_names),
|
| 494 |
}).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
|
| 495 |
results_df.insert(0, 'rank', range(1, len(results_df) + 1))
|
| 496 |
|
| 497 |
+
if ad_label == "OUT OF DOMAIN":
|
| 498 |
+
st.warning("Protein is outside the training distribution. "
|
| 499 |
+
"Predictions may be unreliable.")
|
| 500 |
+
|
| 501 |
st.success(
|
| 502 |
f"β {len(results_df)} compounds in {elapsed:.1f}s "
|
| 503 |
f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
|
|
|
|
| 511 |
top_df['pKd_pred'].values,
|
| 512 |
top_df['CI_lo'].values,
|
| 513 |
top_df['CI_hi'].values,
|
| 514 |
+
f"Top {top_n} hits",
|
| 515 |
+
dark=dark_mode,
|
| 516 |
)
|
| 517 |
st.pyplot(fig, use_container_width=True)
|
| 518 |
plt.close(fig)
|
|
|
|
| 522 |
use_container_width=True, height=400,
|
| 523 |
)
|
| 524 |
st.download_button(
|
| 525 |
+
"Download ranked CSV",
|
| 526 |
results_df.to_csv(index=False).encode(),
|
| 527 |
file_name="velobind_screening.csv",
|
| 528 |
mime="text/csv",
|
|
|
|
| 532 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 533 |
# MODE 3 β One compound vs. multiple targets
|
| 534 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 535 |
+
elif mode == "One compound vs. multiple targets":
|
| 536 |
|
| 537 |
st.subheader("Selectivity Profiling")
|
| 538 |
st.markdown("One SMILES, multiple proteins β ranked by predicted pKd. "
|
|
|
|
| 543 |
multi_seqs = st.text_area(
|
| 544 |
"Target proteins (one per line)",
|
| 545 |
height=250,
|
| 546 |
+
placeholder="ABL1: MGPSENDPNLFVALY...\nEGFR: MRPSGTAGAALLALL...\nCDK2: MENFQKVEKIGEGTY...",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
)
|
| 548 |
|
| 549 |
+
if st.button("Run selectivity profiling", type="primary", use_container_width=True):
|
| 550 |
if not multi_smi.strip() or not multi_seqs.strip():
|
| 551 |
st.error("Please enter a SMILES and at least one protein sequence.")
|
| 552 |
else:
|
| 553 |
targets = {}
|
| 554 |
+
parse_errors = []
|
| 555 |
for i, line in enumerate(multi_seqs.strip().splitlines()):
|
| 556 |
line = line.strip()
|
| 557 |
if not line:
|
| 558 |
continue
|
| 559 |
if ":" in line:
|
| 560 |
+
name, raw_seq = line.split(":", 1)
|
| 561 |
+
name = name.strip()
|
| 562 |
+
else:
|
| 563 |
+
name, raw_seq = f"Target_{i+1}", line
|
| 564 |
+
seq, err = validate_sequence(raw_seq)
|
| 565 |
+
if err:
|
| 566 |
+
parse_errors.append(f"{name}: {err}")
|
| 567 |
else:
|
| 568 |
+
targets[name] = seq
|
| 569 |
|
| 570 |
+
if parse_errors:
|
| 571 |
+
for e in parse_errors:
|
| 572 |
+
st.warning(f"Skipped β {e}")
|
| 573 |
if not targets:
|
| 574 |
+
st.error("No valid sequences found.")
|
| 575 |
st.stop()
|
| 576 |
|
|
|
|
| 577 |
results, progress = [], st.progress(0)
|
| 578 |
for idx, (name, seq) in enumerate(targets.items()):
|
| 579 |
try:
|
|
|
|
| 584 |
if valid.any():
|
| 585 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 586 |
lo, hi = uncertainty_interval(preds_all)
|
| 587 |
+
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 588 |
results.append({
|
| 589 |
'Target': name,
|
| 590 |
'pKd_pred': round(float(preds[0]), 3),
|
| 591 |
'CI_lo': round(float(lo[0]), 3),
|
| 592 |
'CI_hi': round(float(hi[0]), 3),
|
| 593 |
+
'Ki_est': format_ki(float(preds[0])),
|
| 594 |
'model_std': round(float(preds_all.std()), 3),
|
| 595 |
+
'AD': ad_label,
|
| 596 |
})
|
| 597 |
except Exception as e:
|
| 598 |
st.warning(f"Skipped {name}: {e}")
|
|
|
|
| 606 |
)
|
| 607 |
res_df.insert(0, 'rank', range(1, len(res_df) + 1))
|
| 608 |
|
| 609 |
+
st.success(f"Profiled {len(res_df)} targets.")
|
| 610 |
fig = bar_chart(
|
| 611 |
res_df['Target'].tolist(),
|
| 612 |
res_df['pKd_pred'].values,
|
| 613 |
res_df['CI_lo'].values,
|
| 614 |
res_df['CI_hi'].values,
|
| 615 |
+
"Selectivity profile β predicted pKd by target",
|
| 616 |
+
dark=dark_mode,
|
| 617 |
)
|
| 618 |
st.pyplot(fig, use_container_width=True)
|
| 619 |
plt.close(fig)
|
| 620 |
|
| 621 |
st.dataframe(res_df, use_container_width=True)
|
| 622 |
st.download_button(
|
| 623 |
+
"Download selectivity CSV",
|
| 624 |
res_df.to_csv(index=False).encode(),
|
| 625 |
file_name="velobind_selectivity.csv",
|
| 626 |
mime="text/csv",
|
| 627 |
)
|
| 628 |
|
| 629 |
+
|
| 630 |
# ββ Footer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 631 |
st.markdown("---")
|
| 632 |
+
st.markdown(f"""
|
| 633 |
+
<div style="color:{lab_color};font-size:0.8rem;text-align:center;padding:0.5rem">
|
| 634 |
+
VeloBind Β· Structure-free binding affinity Β·
|
| 635 |
+
ESM-2 + GBM ensemble Β·
|
| 636 |
+
Trained on LP-PDBBind Β·
|
| 637 |
+
Evaluated on CASF-2016/2013 Β·
|
| 638 |
+
<b>Not for clinical use.</b>
|
| 639 |
</div>
|
| 640 |
""", unsafe_allow_html=True)
|