Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# app.py β VeloBind HF Spaces inference app
|
| 2 |
-
import
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
import streamlit as st
|
|
@@ -18,7 +18,7 @@ MODEL_CACHE = Path("/tmp/velobind_models")
|
|
| 18 |
SEEDS = [42, 123, 456]
|
| 19 |
MODEL_TYPES = ["lgbm", "cb", "xgb"]
|
| 20 |
N_FOLDS = 5
|
| 21 |
-
VALID_AA = set("
|
| 22 |
|
| 23 |
import sys
|
| 24 |
sys.path.append(str(Path(__file__).parent))
|
|
@@ -28,6 +28,13 @@ from src.models.ensemble import TargetScaler
|
|
| 28 |
from src.config import config
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
# Validation
|
| 33 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -35,12 +42,8 @@ def validate_sequence(raw: str):
|
|
| 35 |
raw = raw.strip()
|
| 36 |
if not raw:
|
| 37 |
return None, "Please enter a sequence."
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
lines = raw.splitlines()
|
| 41 |
-
seq_lines = [l.strip() for l in lines if not l.startswith(">")]
|
| 42 |
-
seq = "".join(seq_lines).upper().replace(" ", "")
|
| 43 |
-
|
| 44 |
if len(seq) < 10:
|
| 45 |
return None, "Sequence too short (minimum 10 residues)."
|
| 46 |
invalid = set(seq) - VALID_AA
|
|
@@ -52,21 +55,18 @@ def validate_sequence(raw: str):
|
|
| 52 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
# Model loading
|
| 54 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
-
@st.cache_resource(show_spinner="
|
| 56 |
def load_all_models():
|
| 57 |
from huggingface_hub import hf_hub_download
|
| 58 |
MODEL_CACHE.mkdir(parents=True, exist_ok=True)
|
| 59 |
-
|
| 60 |
model_files = (
|
| 61 |
[f"fold_model_s{s}_{t}_f{f}.pkl"
|
| 62 |
for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
|
| 63 |
+ ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
|
| 64 |
)
|
| 65 |
-
|
| 66 |
bar = st.progress(0, text="Loading models...")
|
| 67 |
for i, fname in enumerate(model_files):
|
| 68 |
-
|
| 69 |
-
if not local.exists():
|
| 70 |
hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
|
| 71 |
local_dir=str(MODEL_CACHE))
|
| 72 |
bar.progress((i + 1) / len(model_files), text=f"Loading {fname}...")
|
|
@@ -80,14 +80,13 @@ def load_all_models():
|
|
| 80 |
joblib.load(MODEL_CACHE / f"fold_model_s{s}_{t}_f{f}.pkl")
|
| 81 |
for f in range(N_FOLDS)
|
| 82 |
]
|
| 83 |
-
|
| 84 |
meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
|
| 85 |
scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
|
| 86 |
lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
|
| 87 |
return fold_models, meta, scaler, lig_sc
|
| 88 |
|
| 89 |
|
| 90 |
-
@st.cache_resource(show_spinner="Loading ESM-2
|
| 91 |
def load_esm_model():
|
| 92 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 93 |
tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
|
|
@@ -98,8 +97,7 @@ def load_esm_model():
|
|
| 98 |
def load_ad_centroid():
|
| 99 |
for p in [Path("output/models/deployment"), Path("output/models")]:
|
| 100 |
if (p / "ad_centroid.npy").exists():
|
| 101 |
-
return
|
| 102 |
-
float(np.load(p / "ad_threshold.npy")))
|
| 103 |
for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
|
| 104 |
local = MODEL_CACHE / fname
|
| 105 |
if not local.exists():
|
|
@@ -109,33 +107,25 @@ def load_ad_centroid():
|
|
| 109 |
local_dir=str(MODEL_CACHE))
|
| 110 |
except Exception:
|
| 111 |
return None, None
|
| 112 |
-
return
|
| 113 |
-
float(np.load(MODEL_CACHE / "ad_threshold.npy")))
|
| 114 |
|
| 115 |
|
| 116 |
-
def ad_check(
|
| 117 |
if centroid is None:
|
| 118 |
return "UNKNOWN", float("nan")
|
| 119 |
-
dist = float(np.linalg.norm(
|
| 120 |
return ("IN DOMAIN" if dist <= threshold else "OUT OF DOMAIN"), dist
|
| 121 |
|
| 122 |
|
| 123 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
-
#
|
| 125 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
|
| 127 |
return np.concatenate([
|
| 128 |
-
esm_mean[:, -480:],
|
| 129 |
-
|
| 130 |
-
lig_feats["
|
| 131 |
-
lig_feats["
|
| 132 |
-
lig_feats["ecfp6"],
|
| 133 |
-
lig_feats["fcfp"],
|
| 134 |
-
lig_feats["estate"],
|
| 135 |
-
lig_feats["maccs"],
|
| 136 |
-
lig_feats["atom_pair"],
|
| 137 |
-
lig_feats["torsion"],
|
| 138 |
-
lig_feats["phys"],
|
| 139 |
], axis=1)
|
| 140 |
|
| 141 |
|
|
@@ -145,45 +135,34 @@ def extract_features(sequence, smiles_list, tokenizer, esm_model, device, lig_sc
|
|
| 145 |
config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
|
| 146 |
batch_size=1, device=device,
|
| 147 |
)
|
| 148 |
-
seq_feat
|
| 149 |
-
|
| 150 |
lig_feats, valid_mask, _ = extract_ligand_features(
|
| 151 |
-
smiles_list, scaler=lig_scaler, fit_scaler=False
|
| 152 |
-
)
|
| 153 |
valid_mask = np.array(valid_mask)
|
| 154 |
if valid_mask.dtype != bool:
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
valid_mask =
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
esm_attn_t = np.tile(esm_attn, (n_valid, 1))
|
| 163 |
-
seq_feat_t = np.tile(seq_feat, (n_valid, 1))
|
| 164 |
-
|
| 165 |
-
X = assemble_from_parts(esm_mean_t, esm_var_t, esm_attn_t, seq_feat_t, lig_feats)
|
| 166 |
return X, valid_mask, esm_mean[0]
|
| 167 |
|
| 168 |
|
| 169 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 170 |
-
# Prediction
|
| 171 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 172 |
def predict(X, fold_models, meta, scaler):
|
| 173 |
type_avgs = []
|
| 174 |
for s in SEEDS:
|
| 175 |
for t in MODEL_TYPES:
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
], axis=1)
|
| 180 |
-
type_avgs.append(fold_preds.mean(axis=1))
|
| 181 |
-
|
| 182 |
preds_all = np.stack(type_avgs, axis=1)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
| 187 |
return preds, preds_all
|
| 188 |
|
| 189 |
|
|
@@ -193,73 +172,76 @@ def uncertainty_interval(preds_all, z=1.96):
|
|
| 193 |
|
| 194 |
|
| 195 |
def format_ki(pkd):
|
| 196 |
-
"""Format Ki with appropriate unit (nM, uM, mM)."""
|
| 197 |
ki_nM = 10 ** (9 - pkd)
|
| 198 |
-
if ki_nM < 1000:
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
return f"{ki_nM/1000:.2f} uM"
|
| 202 |
-
else:
|
| 203 |
-
return f"{ki_nM/1_000_000:.2f} mM"
|
| 204 |
|
| 205 |
|
| 206 |
# ββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββββ
|
| 207 |
-
#
|
| 208 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 209 |
def bar_chart(names, preds, lo, hi, title, dark=True):
|
| 210 |
-
bg
|
| 211 |
-
|
| 212 |
-
grid = "#2d3f55" if dark else "#cccccc"
|
| 213 |
-
|
| 214 |
-
fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4),
|
| 215 |
-
facecolor=bg)
|
| 216 |
ax.set_facecolor(bg)
|
| 217 |
x = np.arange(len(names))
|
| 218 |
err = [preds - lo, hi - preds]
|
| 219 |
-
bars = ax.bar(x, preds, color="#
|
| 220 |
yerr=err, capsize=5, error_kw=dict(ecolor=fg, lw=1.5))
|
| 221 |
ax.set_xticks(x)
|
| 222 |
ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10, color=fg)
|
| 223 |
ax.set_ylabel("Predicted pKd", fontsize=11, color=fg)
|
| 224 |
ax.set_title(title, fontsize=12, fontweight='bold', color=fg)
|
| 225 |
ax.tick_params(colors=fg)
|
| 226 |
-
ax.spines
|
| 227 |
-
|
|
|
|
| 228 |
for bar, val in zip(bars, preds):
|
| 229 |
ax.text(bar.get_x() + bar.get_width() / 2,
|
| 230 |
bar.get_height() + 0.05, f"{val:.2f}",
|
| 231 |
-
ha='center', va='bottom', fontsize=9,
|
| 232 |
-
fontweight='bold', color=fg)
|
| 233 |
plt.tight_layout()
|
| 234 |
return fig
|
| 235 |
|
| 236 |
|
| 237 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 238 |
-
# Page
|
| 239 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
st.set_page_config(page_title="VeloBind", layout="wide")
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
else:
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
def load_svg_b64(path):
|
| 265 |
with open(path, "rb") as f:
|
|
@@ -269,93 +251,158 @@ logo_b64 = load_svg_b64("logo.svg")
|
|
| 269 |
|
| 270 |
st.markdown(f"""
|
| 271 |
<style>
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
</style>
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
</div>
|
| 305 |
-
<div class="
|
| 306 |
<h1>VeloBind</h1>
|
| 307 |
-
<p>
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
| 311 |
</div>
|
| 312 |
</div>
|
| 313 |
""", unsafe_allow_html=True)
|
| 314 |
|
| 315 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
fold_models, meta, target_scaler, lig_scaler = load_all_models()
|
| 317 |
tokenizer, esm_model, device = load_esm_model()
|
| 318 |
ad_centroid, ad_threshold = load_ad_centroid()
|
| 319 |
n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
|
| 320 |
-
st.success(f"
|
| 321 |
-
|
| 322 |
-
# ββ Mode
|
| 323 |
-
|
| 324 |
-
"
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
)
|
| 328 |
-
st.markdown("---")
|
| 329 |
|
| 330 |
|
| 331 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 332 |
-
#
|
| 333 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 334 |
-
|
| 335 |
-
|
| 336 |
col_p, col_l = st.columns(2)
|
| 337 |
with col_p:
|
| 338 |
st.subheader("Protein")
|
| 339 |
seq_raw = st.text_area(
|
| 340 |
-
"Amino acid sequence (
|
| 341 |
-
height=
|
| 342 |
-
placeholder="
|
| 343 |
-
help="Only standard amino acid letters
|
|
|
|
| 344 |
)
|
| 345 |
with col_l:
|
| 346 |
st.subheader("Ligand")
|
| 347 |
-
smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O")
|
| 348 |
examples = {
|
| 349 |
"Aspirin": "CC(=O)Oc1ccccc1C(=O)O",
|
| 350 |
"Imatinib": "Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1",
|
| 351 |
"Gefitinib": "COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1",
|
| 352 |
"Staurosporine": "C[C@@H]1CCCN2C(=O)c3[nH]c4ccccc4c3C2=N1",
|
| 353 |
}
|
| 354 |
-
chosen = st.selectbox("Load example SMILES", ["β"] + list(examples))
|
| 355 |
if chosen != "β":
|
| 356 |
smi = examples[chosen]
|
| 357 |
|
| 358 |
-
if st.button("Predict", type="primary", use_container_width=True):
|
| 359 |
seq, err = validate_sequence(seq_raw)
|
| 360 |
if err:
|
| 361 |
st.error(err)
|
|
@@ -366,18 +413,17 @@ if mode == "Single query":
|
|
| 366 |
t0 = time.time()
|
| 367 |
try:
|
| 368 |
X, valid, esm_vec = extract_features(
|
| 369 |
-
seq, [smi.strip()],
|
| 370 |
-
tokenizer, esm_model, device, lig_scaler
|
| 371 |
-
)
|
| 372 |
if not valid.any():
|
| 373 |
-
st.error("RDKit could not parse this SMILES.
|
| 374 |
else:
|
| 375 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 376 |
lo, hi = uncertainty_interval(preds_all)
|
| 377 |
elapsed = time.time() - t0
|
| 378 |
pkd = float(preds[0])
|
|
|
|
| 379 |
|
| 380 |
-
st.markdown("### Results")
|
| 381 |
c1, c2, c3, c4 = st.columns(4)
|
| 382 |
with c1:
|
| 383 |
st.markdown(f"""<div class="metric-card">
|
|
@@ -387,25 +433,25 @@ if mode == "Single query":
|
|
| 387 |
with c2:
|
| 388 |
st.markdown(f"""<div class="metric-card">
|
| 389 |
<div class="metric-val">[{lo[0]:.2f}, {hi[0]:.2f}]</div>
|
| 390 |
-
<div class="metric-lab">95% model interval
|
| 391 |
</div>""", unsafe_allow_html=True)
|
| 392 |
with c3:
|
| 393 |
st.markdown(f"""<div class="metric-card">
|
| 394 |
<div class="metric-val">{format_ki(pkd)}</div>
|
| 395 |
-
<div class="metric-lab">Estimated Ki
|
| 396 |
</div>""", unsafe_allow_html=True)
|
| 397 |
with c4:
|
| 398 |
-
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 399 |
ad_cls = ("ad-in" if ad_label == "IN DOMAIN" else
|
| 400 |
"ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk")
|
| 401 |
st.markdown(f"""<div class="metric-card">
|
| 402 |
-
<div
|
| 403 |
-
|
|
|
|
|
|
|
| 404 |
</div>""", unsafe_allow_html=True)
|
| 405 |
|
| 406 |
if ad_label == "OUT OF DOMAIN":
|
| 407 |
-
st.warning("Protein is outside the training distribution. "
|
| 408 |
-
"Predictions may be unreliable.")
|
| 409 |
|
| 410 |
st.caption(
|
| 411 |
f"Inference time: {elapsed:.2f}s | "
|
|
@@ -419,8 +465,8 @@ if mode == "Single query":
|
|
| 419 |
labels, preds_all[0],
|
| 420 |
preds_all[0] - preds_all[0].std(),
|
| 421 |
preds_all[0] + preds_all[0].std(),
|
| 422 |
-
"
|
| 423 |
-
dark=
|
| 424 |
)
|
| 425 |
st.pyplot(fig, use_container_width=True)
|
| 426 |
plt.close(fig)
|
|
@@ -431,26 +477,26 @@ if mode == "Single query":
|
|
| 431 |
|
| 432 |
|
| 433 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 434 |
-
#
|
| 435 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 436 |
-
|
| 437 |
-
|
| 438 |
st.subheader("Batch Screening")
|
| 439 |
-
st.markdown("
|
| 440 |
-
"(and optionally `name`).
|
|
|
|
| 441 |
|
| 442 |
col_seq, col_csv = st.columns(2)
|
| 443 |
with col_seq:
|
| 444 |
-
batch_seq_raw = st.text_area("Target protein sequence", height=180,
|
| 445 |
-
placeholder="
|
| 446 |
with col_csv:
|
| 447 |
-
uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"])
|
| 448 |
st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
|
| 449 |
|
| 450 |
-
max_cpds = st.slider("Max compounds", 10, 500, 100,
|
| 451 |
help="~1s per compound on CPU free tier.")
|
| 452 |
|
| 453 |
-
if st.button("Run batch screening", type="primary", use_container_width=True):
|
| 454 |
batch_seq, err = validate_sequence(batch_seq_raw)
|
| 455 |
if err:
|
| 456 |
st.error(err)
|
|
@@ -470,9 +516,7 @@ elif mode == "Batch screening (CSV)":
|
|
| 470 |
with st.spinner(f"Screening {len(smiles_list)} compounds..."):
|
| 471 |
t0 = time.time()
|
| 472 |
X, valid, esm_vec = extract_features(
|
| 473 |
-
batch_seq, smiles_list,
|
| 474 |
-
tokenizer, esm_model, device, lig_scaler
|
| 475 |
-
)
|
| 476 |
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 477 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 478 |
lo, hi = uncertainty_interval(preds_all)
|
|
@@ -482,6 +526,9 @@ elif mode == "Batch screening (CSV)":
|
|
| 482 |
valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
|
| 483 |
n_invalid = int((~valid).sum())
|
| 484 |
|
|
|
|
|
|
|
|
|
|
| 485 |
results_df = pd.DataFrame({
|
| 486 |
'name': valid_names,
|
| 487 |
'smiles': valid_smiles,
|
|
@@ -494,25 +541,19 @@ elif mode == "Batch screening (CSV)":
|
|
| 494 |
}).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
|
| 495 |
results_df.insert(0, 'rank', range(1, len(results_df) + 1))
|
| 496 |
|
| 497 |
-
if ad_label == "OUT OF DOMAIN":
|
| 498 |
-
st.warning("Protein is outside the training distribution. "
|
| 499 |
-
"Predictions may be unreliable.")
|
| 500 |
-
|
| 501 |
st.success(
|
| 502 |
-
f"
|
| 503 |
f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
|
| 504 |
+ (f" | {n_invalid} invalid SMILES skipped" if n_invalid else "")
|
| 505 |
)
|
| 506 |
|
| 507 |
top_n = min(20, len(results_df))
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
f"Top {top_n} hits",
|
| 515 |
-
dark=dark_mode,
|
| 516 |
)
|
| 517 |
st.pyplot(fig, use_container_width=True)
|
| 518 |
plt.close(fig)
|
|
@@ -524,52 +565,46 @@ elif mode == "Batch screening (CSV)":
|
|
| 524 |
st.download_button(
|
| 525 |
"Download ranked CSV",
|
| 526 |
results_df.to_csv(index=False).encode(),
|
| 527 |
-
file_name="velobind_screening.csv",
|
| 528 |
-
mime="text/csv",
|
| 529 |
)
|
| 530 |
|
| 531 |
|
| 532 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 533 |
-
#
|
| 534 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 535 |
-
|
| 536 |
-
|
| 537 |
st.subheader("Selectivity Profiling")
|
| 538 |
-
st.markdown("One
|
| 539 |
-
"Format: `TargetName: SEQUENCE` (name optional)."
|
|
|
|
| 540 |
|
| 541 |
-
multi_smi = st.text_input("Compound SMILES",
|
| 542 |
-
placeholder="Cc1ccc(NC(=O)...)cc1Nc1nccc(...)n1")
|
| 543 |
multi_seqs = st.text_area(
|
| 544 |
"Target proteins (one per line)",
|
| 545 |
height=250,
|
| 546 |
placeholder="ABL1: MGPSENDPNLFVALY...\nEGFR: MRPSGTAGAALLALL...\nCDK2: MENFQKVEKIGEGTY...",
|
|
|
|
| 547 |
)
|
| 548 |
|
| 549 |
-
if st.button("Run selectivity profiling", type="primary", use_container_width=True):
|
| 550 |
if not multi_smi.strip() or not multi_seqs.strip():
|
| 551 |
st.error("Please enter a SMILES and at least one protein sequence.")
|
| 552 |
else:
|
| 553 |
-
targets = {}
|
| 554 |
-
parse_errors = []
|
| 555 |
for i, line in enumerate(multi_seqs.strip().splitlines()):
|
| 556 |
line = line.strip()
|
| 557 |
if not line:
|
| 558 |
continue
|
| 559 |
-
if ":" in line
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
else:
|
| 563 |
-
name, raw_seq = f"Target_{i+1}", line
|
| 564 |
-
seq, err = validate_sequence(raw_seq)
|
| 565 |
if err:
|
| 566 |
-
parse_errors.append(f"{name}: {err}")
|
| 567 |
else:
|
| 568 |
-
targets[name] = seq
|
| 569 |
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
st.warning(f"Skipped β {e}")
|
| 573 |
if not targets:
|
| 574 |
st.error("No valid sequences found.")
|
| 575 |
st.stop()
|
|
@@ -578,18 +613,16 @@ elif mode == "One compound vs. multiple targets":
|
|
| 578 |
for idx, (name, seq) in enumerate(targets.items()):
|
| 579 |
try:
|
| 580 |
X, valid, esm_vec = extract_features(
|
| 581 |
-
seq, [multi_smi.strip()],
|
| 582 |
-
tokenizer, esm_model, device, lig_scaler
|
| 583 |
-
)
|
| 584 |
if valid.any():
|
| 585 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 586 |
lo, hi = uncertainty_interval(preds_all)
|
| 587 |
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 588 |
results.append({
|
| 589 |
'Target': name,
|
| 590 |
-
'pKd_pred': round(float(preds[0]),
|
| 591 |
-
'CI_lo': round(float(lo[0]),
|
| 592 |
-
'CI_hi': round(float(hi[0]),
|
| 593 |
'Ki_est': format_ki(float(preds[0])),
|
| 594 |
'model_std': round(float(preds_all.std()), 3),
|
| 595 |
'AD': ad_label,
|
|
@@ -599,21 +632,16 @@ elif mode == "One compound vs. multiple targets":
|
|
| 599 |
progress.progress((idx + 1) / len(targets))
|
| 600 |
|
| 601 |
progress.empty()
|
| 602 |
-
res_df = (
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
.reset_index(drop=True)
|
| 606 |
-
)
|
| 607 |
res_df.insert(0, 'rank', range(1, len(res_df) + 1))
|
| 608 |
|
| 609 |
st.success(f"Profiled {len(res_df)} targets.")
|
| 610 |
fig = bar_chart(
|
| 611 |
-
res_df['Target'].tolist(),
|
| 612 |
-
res_df['
|
| 613 |
-
|
| 614 |
-
res_df['CI_hi'].values,
|
| 615 |
-
"Selectivity profile β predicted pKd by target",
|
| 616 |
-
dark=dark_mode,
|
| 617 |
)
|
| 618 |
st.pyplot(fig, use_container_width=True)
|
| 619 |
plt.close(fig)
|
|
@@ -622,19 +650,18 @@ elif mode == "One compound vs. multiple targets":
|
|
| 622 |
st.download_button(
|
| 623 |
"Download selectivity CSV",
|
| 624 |
res_df.to_csv(index=False).encode(),
|
| 625 |
-
file_name="velobind_selectivity.csv",
|
| 626 |
-
mime="text/csv",
|
| 627 |
)
|
| 628 |
|
| 629 |
|
| 630 |
# ββ Footer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 631 |
st.markdown("---")
|
| 632 |
st.markdown(f"""
|
| 633 |
-
<div style="color:{
|
| 634 |
-
VeloBind
|
| 635 |
-
ESM-2 +
|
| 636 |
-
Trained on LP-PDBBind
|
| 637 |
-
Evaluated on CASF-2016
|
| 638 |
-
<b>Not for clinical use
|
| 639 |
</div>
|
| 640 |
""", unsafe_allow_html=True)
|
|
|
|
| 1 |
# app.py β VeloBind HF Spaces inference app
|
| 2 |
+
import warnings, time, base64
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
import streamlit as st
|
|
|
|
| 18 |
SEEDS = [42, 123, 456]
|
| 19 |
MODEL_TYPES = ["lgbm", "cb", "xgb"]
|
| 20 |
N_FOLDS = 5
|
| 21 |
+
VALID_AA = set("ACDEFGHIKLMNPQRSTVWYX")
|
| 22 |
|
| 23 |
import sys
|
| 24 |
sys.path.append(str(Path(__file__).parent))
|
|
|
|
| 28 |
from src.config import config
|
| 29 |
|
| 30 |
|
| 31 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
# Session state β theme
|
| 33 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
if "dark_mode" not in st.session_state:
|
| 35 |
+
st.session_state.dark_mode = True
|
| 36 |
+
|
| 37 |
+
|
| 38 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
# Validation
|
| 40 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 42 |
raw = raw.strip()
|
| 43 |
if not raw:
|
| 44 |
return None, "Please enter a sequence."
|
| 45 |
+
lines = raw.splitlines()
|
| 46 |
+
seq = "".join(l.strip() for l in lines if not l.startswith(">")).upper().replace(" ", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
if len(seq) < 10:
|
| 48 |
return None, "Sequence too short (minimum 10 residues)."
|
| 49 |
invalid = set(seq) - VALID_AA
|
|
|
|
| 55 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
# Model loading
|
| 57 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
@st.cache_resource(show_spinner="Loading VeloBind models (first run ~30s)...")
|
| 59 |
def load_all_models():
|
| 60 |
from huggingface_hub import hf_hub_download
|
| 61 |
MODEL_CACHE.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 62 |
model_files = (
|
| 63 |
[f"fold_model_s{s}_{t}_f{f}.pkl"
|
| 64 |
for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
|
| 65 |
+ ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
|
| 66 |
)
|
|
|
|
| 67 |
bar = st.progress(0, text="Loading models...")
|
| 68 |
for i, fname in enumerate(model_files):
|
| 69 |
+
if not (MODEL_CACHE / fname).exists():
|
|
|
|
| 70 |
hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
|
| 71 |
local_dir=str(MODEL_CACHE))
|
| 72 |
bar.progress((i + 1) / len(model_files), text=f"Loading {fname}...")
|
|
|
|
| 80 |
joblib.load(MODEL_CACHE / f"fold_model_s{s}_{t}_f{f}.pkl")
|
| 81 |
for f in range(N_FOLDS)
|
| 82 |
]
|
|
|
|
| 83 |
meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
|
| 84 |
scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
|
| 85 |
lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
|
| 86 |
return fold_models, meta, scaler, lig_sc
|
| 87 |
|
| 88 |
|
| 89 |
+
@st.cache_resource(show_spinner="Loading ESM-2...")
|
| 90 |
def load_esm_model():
|
| 91 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 92 |
tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
|
|
|
|
| 97 |
def load_ad_centroid():
|
| 98 |
for p in [Path("output/models/deployment"), Path("output/models")]:
|
| 99 |
if (p / "ad_centroid.npy").exists():
|
| 100 |
+
return np.load(p / "ad_centroid.npy"), float(np.load(p / "ad_threshold.npy"))
|
|
|
|
| 101 |
for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
|
| 102 |
local = MODEL_CACHE / fname
|
| 103 |
if not local.exists():
|
|
|
|
| 107 |
local_dir=str(MODEL_CACHE))
|
| 108 |
except Exception:
|
| 109 |
return None, None
|
| 110 |
+
return np.load(MODEL_CACHE / "ad_centroid.npy"), float(np.load(MODEL_CACHE / "ad_threshold.npy"))
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
+
def ad_check(esm_vec, centroid, threshold):
|
| 114 |
if centroid is None:
|
| 115 |
return "UNKNOWN", float("nan")
|
| 116 |
+
dist = float(np.linalg.norm(esm_vec - centroid))
|
| 117 |
return ("IN DOMAIN" if dist <= threshold else "OUT OF DOMAIN"), dist
|
| 118 |
|
| 119 |
|
| 120 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
+
# Features + prediction
|
| 122 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
|
| 124 |
return np.concatenate([
|
| 125 |
+
esm_mean[:, -480:], seq_feat,
|
| 126 |
+
lig_feats["ecfp"], lig_feats["ecfp2"], lig_feats["ecfp6"], lig_feats["fcfp"],
|
| 127 |
+
lig_feats["estate"], lig_feats["maccs"], lig_feats["atom_pair"],
|
| 128 |
+
lig_feats["torsion"], lig_feats["phys"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
], axis=1)
|
| 130 |
|
| 131 |
|
|
|
|
| 135 |
config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
|
| 136 |
batch_size=1, device=device,
|
| 137 |
)
|
| 138 |
+
seq_feat = np.array([sequence_features(sequence)])
|
|
|
|
| 139 |
lig_feats, valid_mask, _ = extract_ligand_features(
|
| 140 |
+
smiles_list, scaler=lig_scaler, fit_scaler=False)
|
|
|
|
| 141 |
valid_mask = np.array(valid_mask)
|
| 142 |
if valid_mask.dtype != bool:
|
| 143 |
+
bm = np.zeros(len(smiles_list), dtype=bool)
|
| 144 |
+
bm[valid_mask] = True
|
| 145 |
+
valid_mask = bm
|
| 146 |
+
n = int(valid_mask.sum())
|
| 147 |
+
X = assemble_from_parts(
|
| 148 |
+
np.tile(esm_mean, (n, 1)), np.tile(esm_var, (n, 1)),
|
| 149 |
+
np.tile(esm_attn, (n, 1)), np.tile(seq_feat, (n, 1)), lig_feats)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
return X, valid_mask, esm_mean[0]
|
| 151 |
|
| 152 |
|
|
|
|
|
|
|
|
|
|
| 153 |
def predict(X, fold_models, meta, scaler):
|
| 154 |
type_avgs = []
|
| 155 |
for s in SEEDS:
|
| 156 |
for t in MODEL_TYPES:
|
| 157 |
+
fp = np.stack([scaler.inverse(fold_models[s][t][f].predict(X))
|
| 158 |
+
for f in range(N_FOLDS)], axis=1)
|
| 159 |
+
type_avgs.append(fp.mean(axis=1))
|
|
|
|
|
|
|
|
|
|
| 160 |
preds_all = np.stack(type_avgs, axis=1)
|
| 161 |
+
preds = meta.predict(np.column_stack([
|
| 162 |
+
preds_all[:, [0,3,6]].mean(1),
|
| 163 |
+
preds_all[:, [1,4,7]].mean(1),
|
| 164 |
+
preds_all[:, [2,5,8]].mean(1),
|
| 165 |
+
]))
|
| 166 |
return preds, preds_all
|
| 167 |
|
| 168 |
|
|
|
|
| 172 |
|
| 173 |
|
| 174 |
def format_ki(pkd):
|
|
|
|
| 175 |
ki_nM = 10 ** (9 - pkd)
|
| 176 |
+
if ki_nM < 1000: return f"{ki_nM:.1f} nM"
|
| 177 |
+
elif ki_nM < 1_000_000: return f"{ki_nM/1000:.2f} uM"
|
| 178 |
+
else: return f"{ki_nM/1_000_000:.2f} mM"
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
|
| 181 |
# ββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββββ
|
| 182 |
+
# Plot
|
| 183 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 184 |
def bar_chart(names, preds, lo, hi, title, dark=True):
|
| 185 |
+
bg, fg, gc = ("#1a2332", "#e8edf2", "#2d3f55") if dark else ("#f8fafc", "#1a202c", "#cbd5e0")
|
| 186 |
+
fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4), facecolor=bg)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
ax.set_facecolor(bg)
|
| 188 |
x = np.arange(len(names))
|
| 189 |
err = [preds - lo, hi - preds]
|
| 190 |
+
bars = ax.bar(x, preds, color="#3b82f6", alpha=0.9, width=0.6,
|
| 191 |
yerr=err, capsize=5, error_kw=dict(ecolor=fg, lw=1.5))
|
| 192 |
ax.set_xticks(x)
|
| 193 |
ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10, color=fg)
|
| 194 |
ax.set_ylabel("Predicted pKd", fontsize=11, color=fg)
|
| 195 |
ax.set_title(title, fontsize=12, fontweight='bold', color=fg)
|
| 196 |
ax.tick_params(colors=fg)
|
| 197 |
+
for spine in ax.spines.values():
|
| 198 |
+
spine.set_color(gc)
|
| 199 |
+
ax.grid(True, axis='y', alpha=0.3, color=gc)
|
| 200 |
for bar, val in zip(bars, preds):
|
| 201 |
ax.text(bar.get_x() + bar.get_width() / 2,
|
| 202 |
bar.get_height() + 0.05, f"{val:.2f}",
|
| 203 |
+
ha='center', va='bottom', fontsize=9, fontweight='bold', color=fg)
|
|
|
|
| 204 |
plt.tight_layout()
|
| 205 |
return fig
|
| 206 |
|
| 207 |
|
| 208 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 209 |
+
# Page layout
|
| 210 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 211 |
st.set_page_config(page_title="VeloBind", layout="wide")
|
| 212 |
|
| 213 |
+
dark = st.session_state.dark_mode
|
| 214 |
+
|
| 215 |
+
# ββ Theme-aware CSS (only custom elements, never .stApp) ββββββββββββββ
|
| 216 |
+
if dark:
|
| 217 |
+
card_bg, card_border = "#1a2332", "#2d4a6b"
|
| 218 |
+
val_col, lab_col = "#60a5fa", "#94a3b8"
|
| 219 |
+
banner_grad = "linear-gradient(135deg, #0f172a 0%, #1e3a5f 50%, #1e40af 100%)"
|
| 220 |
+
banner_sub = "#93c5fd"
|
| 221 |
+
logo_bg = "rgba(255,255,255,0.12)"
|
| 222 |
+
logo_border = "rgba(255,255,255,0.2)"
|
| 223 |
+
toggle_bg = "#1e3a5f"
|
| 224 |
+
toggle_knob = "#60a5fa"
|
| 225 |
+
toggle_label = "#93c5fd"
|
| 226 |
else:
|
| 227 |
+
card_bg, card_border = "#f0f7ff", "#bfdbfe"
|
| 228 |
+
val_col, lab_col = "#1d4ed8", "#475569"
|
| 229 |
+
banner_grad = "linear-gradient(135deg, #1d4ed8 0%, #2563eb 50%, #3b82f6 100%)"
|
| 230 |
+
banner_sub = "#dbeafe"
|
| 231 |
+
logo_bg = "rgba(255,255,255,0.85)"
|
| 232 |
+
logo_border = "rgba(255,255,255,0.9)"
|
| 233 |
+
toggle_bg = "#93c5fd"
|
| 234 |
+
toggle_knob = "#1d4ed8"
|
| 235 |
+
toggle_label = "#dbeafe"
|
| 236 |
+
|
| 237 |
+
ad_css = """
|
| 238 |
+
.ad-in { background:#064e3b; border:1px solid #059669; color:#34d399;
|
| 239 |
+
border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
|
| 240 |
+
.ad-out { background:#450a0a; border:1px solid #dc2626; color:#f87171;
|
| 241 |
+
border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
|
| 242 |
+
.ad-unk { background:#1e293b; border:1px solid #475569; color:#94a3b8;
|
| 243 |
+
border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
|
| 244 |
+
"""
|
| 245 |
|
| 246 |
def load_svg_b64(path):
|
| 247 |
with open(path, "rb") as f:
|
|
|
|
| 251 |
|
| 252 |
st.markdown(f"""
|
| 253 |
<style>
|
| 254 |
+
{ad_css}
|
| 255 |
+
.vb-banner {{
|
| 256 |
+
background: {banner_grad};
|
| 257 |
+
border-radius: 16px;
|
| 258 |
+
padding: 1.2rem 1.8rem;
|
| 259 |
+
display: flex;
|
| 260 |
+
align-items: center;
|
| 261 |
+
gap: 1.5rem;
|
| 262 |
+
margin-bottom: 0.5rem;
|
| 263 |
+
box-shadow: 0 4px 24px rgba(0,0,0,0.18);
|
| 264 |
+
position: relative;
|
| 265 |
+
}}
|
| 266 |
+
.vb-logo-wrap {{
|
| 267 |
+
background: {logo_bg};
|
| 268 |
+
border: 1px solid {logo_border};
|
| 269 |
+
border-radius: 14px;
|
| 270 |
+
padding: 0.6rem;
|
| 271 |
+
backdrop-filter: blur(8px);
|
| 272 |
+
flex-shrink: 0;
|
| 273 |
+
}}
|
| 274 |
+
.vb-logo-wrap img {{
|
| 275 |
+
height: 110px;
|
| 276 |
+
width: auto;
|
| 277 |
+
display: block;
|
| 278 |
+
}}
|
| 279 |
+
.vb-title-wrap {{
|
| 280 |
+
flex: 1;
|
| 281 |
+
}}
|
| 282 |
+
.vb-title-wrap h1 {{
|
| 283 |
+
color: #ffffff;
|
| 284 |
+
font-size: 2.4rem;
|
| 285 |
+
font-weight: 800;
|
| 286 |
+
margin: 0 0 0.3rem 0;
|
| 287 |
+
letter-spacing: -0.5px;
|
| 288 |
+
}}
|
| 289 |
+
.vb-title-wrap p {{
|
| 290 |
+
color: {banner_sub};
|
| 291 |
+
font-size: 0.92rem;
|
| 292 |
+
margin: 0;
|
| 293 |
+
line-height: 1.6;
|
| 294 |
+
}}
|
| 295 |
+
.vb-toggle-wrap {{
|
| 296 |
+
position: absolute;
|
| 297 |
+
top: 1rem;
|
| 298 |
+
right: 1.2rem;
|
| 299 |
+
display: flex;
|
| 300 |
+
align-items: center;
|
| 301 |
+
gap: 0.5rem;
|
| 302 |
+
}}
|
| 303 |
+
.vb-toggle-label {{
|
| 304 |
+
color: {toggle_label};
|
| 305 |
+
font-size: 0.78rem;
|
| 306 |
+
font-weight: 600;
|
| 307 |
+
letter-spacing: 0.03em;
|
| 308 |
+
}}
|
| 309 |
+
.metric-card {{
|
| 310 |
+
background: {card_bg};
|
| 311 |
+
border: 1px solid {card_border};
|
| 312 |
+
border-radius: 12px;
|
| 313 |
+
padding: 1.1rem;
|
| 314 |
+
text-align: center;
|
| 315 |
+
transition: box-shadow 0.2s;
|
| 316 |
+
}}
|
| 317 |
+
.metric-card:hover {{
|
| 318 |
+
box-shadow: 0 4px 16px rgba(59,130,246,0.15);
|
| 319 |
+
}}
|
| 320 |
+
.metric-val {{
|
| 321 |
+
font-size: 1.9rem;
|
| 322 |
+
font-weight: 700;
|
| 323 |
+
color: {val_col};
|
| 324 |
+
line-height: 1.2;
|
| 325 |
+
}}
|
| 326 |
+
.metric-lab {{
|
| 327 |
+
font-size: 0.75rem;
|
| 328 |
+
color: {lab_col};
|
| 329 |
+
margin-top: 0.35rem;
|
| 330 |
+
line-height: 1.4;
|
| 331 |
+
}}
|
| 332 |
</style>
|
| 333 |
+
""", unsafe_allow_html=True)
|
| 334 |
+
|
| 335 |
+
# ββ Banner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 336 |
+
toggle_icon = "β" if dark else "βΎ"
|
| 337 |
+
toggle_text = "Light mode" if dark else "Dark mode"
|
| 338 |
+
|
| 339 |
+
st.markdown(f"""
|
| 340 |
+
<div class="vb-banner">
|
| 341 |
+
<div class="vb-logo-wrap">
|
| 342 |
+
<img src="data:image/svg+xml;base64,{logo_b64}" alt="VeloBind"/>
|
| 343 |
</div>
|
| 344 |
+
<div class="vb-title-wrap">
|
| 345 |
<h1>VeloBind</h1>
|
| 346 |
+
<p>
|
| 347 |
+
Structure-free protein-ligand binding affinity prediction ·
|
| 348 |
+
Sequence + SMILES ·
|
| 349 |
+
Pearson R = 0.8469 on CASF-2016 ·
|
| 350 |
+
45-model ensemble (LGBM + CatBoost + XGBoost)
|
| 351 |
+
</p>
|
| 352 |
</div>
|
| 353 |
</div>
|
| 354 |
""", unsafe_allow_html=True)
|
| 355 |
|
| 356 |
+
# Theme toggle β just below banner, right-aligned
|
| 357 |
+
_, tcol = st.columns([6, 1])
|
| 358 |
+
with tcol:
|
| 359 |
+
if st.button(f"{toggle_icon} {toggle_text}", use_container_width=True):
|
| 360 |
+
st.session_state.dark_mode = not st.session_state.dark_mode
|
| 361 |
+
st.rerun()
|
| 362 |
+
|
| 363 |
+
# ββ Load models βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 364 |
fold_models, meta, target_scaler, lig_scaler = load_all_models()
|
| 365 |
tokenizer, esm_model, device = load_esm_model()
|
| 366 |
ad_centroid, ad_threshold = load_ad_centroid()
|
| 367 |
n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
|
| 368 |
+
st.success(f"{n_loaded} fold models loaded | Device: {device.upper()}")
|
| 369 |
+
|
| 370 |
+
# ββ Mode tabs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 371 |
+
tab1, tab2, tab3 = st.tabs([
|
| 372 |
+
"Single query",
|
| 373 |
+
"Batch screening (CSV)",
|
| 374 |
+
"One compound vs. multiple targets",
|
| 375 |
+
])
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 379 |
+
# TAB 1 β Single query
|
| 380 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 381 |
+
with tab1:
|
|
|
|
| 382 |
col_p, col_l = st.columns(2)
|
| 383 |
with col_p:
|
| 384 |
st.subheader("Protein")
|
| 385 |
seq_raw = st.text_area(
|
| 386 |
+
"Amino acid sequence (plain or FASTA format)",
|
| 387 |
+
height=160,
|
| 388 |
+
placeholder=">ProteinName\nMKTAYIAKQRQISFVK...",
|
| 389 |
+
help="Plain sequence or FASTA with >header line. Only standard amino acid letters (A-Z subset).",
|
| 390 |
+
key="sq_seq"
|
| 391 |
)
|
| 392 |
with col_l:
|
| 393 |
st.subheader("Ligand")
|
| 394 |
+
smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O", key="sq_smi")
|
| 395 |
examples = {
|
| 396 |
"Aspirin": "CC(=O)Oc1ccccc1C(=O)O",
|
| 397 |
"Imatinib": "Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1",
|
| 398 |
"Gefitinib": "COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1",
|
| 399 |
"Staurosporine": "C[C@@H]1CCCN2C(=O)c3[nH]c4ccccc4c3C2=N1",
|
| 400 |
}
|
| 401 |
+
chosen = st.selectbox("Load example SMILES", ["β"] + list(examples), key="sq_ex")
|
| 402 |
if chosen != "β":
|
| 403 |
smi = examples[chosen]
|
| 404 |
|
| 405 |
+
if st.button("Predict", type="primary", use_container_width=True, key="sq_btn"):
|
| 406 |
seq, err = validate_sequence(seq_raw)
|
| 407 |
if err:
|
| 408 |
st.error(err)
|
|
|
|
| 413 |
t0 = time.time()
|
| 414 |
try:
|
| 415 |
X, valid, esm_vec = extract_features(
|
| 416 |
+
seq, [smi.strip()], tokenizer, esm_model, device, lig_scaler)
|
|
|
|
|
|
|
| 417 |
if not valid.any():
|
| 418 |
+
st.error("RDKit could not parse this SMILES.")
|
| 419 |
else:
|
| 420 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 421 |
lo, hi = uncertainty_interval(preds_all)
|
| 422 |
elapsed = time.time() - t0
|
| 423 |
pkd = float(preds[0])
|
| 424 |
+
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 425 |
|
| 426 |
+
st.markdown("#### Results")
|
| 427 |
c1, c2, c3, c4 = st.columns(4)
|
| 428 |
with c1:
|
| 429 |
st.markdown(f"""<div class="metric-card">
|
|
|
|
| 433 |
with c2:
|
| 434 |
st.markdown(f"""<div class="metric-card">
|
| 435 |
<div class="metric-val">[{lo[0]:.2f}, {hi[0]:.2f}]</div>
|
| 436 |
+
<div class="metric-lab">95% model interval<br>(Β±1.96Ο Β· 45 models)</div>
|
| 437 |
</div>""", unsafe_allow_html=True)
|
| 438 |
with c3:
|
| 439 |
st.markdown(f"""<div class="metric-card">
|
| 440 |
<div class="metric-val">{format_ki(pkd)}</div>
|
| 441 |
+
<div class="metric-lab">Estimated Ki<br>(pKd β pKi assumed)</div>
|
| 442 |
</div>""", unsafe_allow_html=True)
|
| 443 |
with c4:
|
|
|
|
| 444 |
ad_cls = ("ad-in" if ad_label == "IN DOMAIN" else
|
| 445 |
"ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk")
|
| 446 |
st.markdown(f"""<div class="metric-card">
|
| 447 |
+
<div style="padding-top:0.4rem">
|
| 448 |
+
<span class="{ad_cls}">{ad_label}</span>
|
| 449 |
+
</div>
|
| 450 |
+
<div class="metric-lab" style="margin-top:0.6rem">Applicability domain</div>
|
| 451 |
</div>""", unsafe_allow_html=True)
|
| 452 |
|
| 453 |
if ad_label == "OUT OF DOMAIN":
|
| 454 |
+
st.warning("Protein is outside the training distribution. Predictions may be unreliable.")
|
|
|
|
| 455 |
|
| 456 |
st.caption(
|
| 457 |
f"Inference time: {elapsed:.2f}s | "
|
|
|
|
| 465 |
labels, preds_all[0],
|
| 466 |
preds_all[0] - preds_all[0].std(),
|
| 467 |
preds_all[0] + preds_all[0].std(),
|
| 468 |
+
"Per-seed and model-type predictions (fold-averaged)",
|
| 469 |
+
dark=dark,
|
| 470 |
)
|
| 471 |
st.pyplot(fig, use_container_width=True)
|
| 472 |
plt.close(fig)
|
|
|
|
| 477 |
|
| 478 |
|
| 479 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 480 |
+
# TAB 2 β Batch screening
|
| 481 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 482 |
+
with tab2:
|
|
|
|
| 483 |
st.subheader("Batch Screening")
|
| 484 |
+
st.markdown("Screen a library of compounds against one target. "
|
| 485 |
+
"Upload a CSV with a `smiles` column (and optionally `name`). "
|
| 486 |
+
"Results are ranked by predicted pKd.")
|
| 487 |
|
| 488 |
col_seq, col_csv = st.columns(2)
|
| 489 |
with col_seq:
|
| 490 |
+
batch_seq_raw = st.text_area("Target protein sequence (plain or FASTA)", height=180,
|
| 491 |
+
placeholder=">Target\nMKTAYIAKQRQISFVK...", key="bs_seq")
|
| 492 |
with col_csv:
|
| 493 |
+
uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"], key="bs_up")
|
| 494 |
st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
|
| 495 |
|
| 496 |
+
max_cpds = st.slider("Max compounds", 10, 500, 100, key="bs_max",
|
| 497 |
help="~1s per compound on CPU free tier.")
|
| 498 |
|
| 499 |
+
if st.button("Run batch screening", type="primary", use_container_width=True, key="bs_btn"):
|
| 500 |
batch_seq, err = validate_sequence(batch_seq_raw)
|
| 501 |
if err:
|
| 502 |
st.error(err)
|
|
|
|
| 516 |
with st.spinner(f"Screening {len(smiles_list)} compounds..."):
|
| 517 |
t0 = time.time()
|
| 518 |
X, valid, esm_vec = extract_features(
|
| 519 |
+
batch_seq, smiles_list, tokenizer, esm_model, device, lig_scaler)
|
|
|
|
|
|
|
| 520 |
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 521 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 522 |
lo, hi = uncertainty_interval(preds_all)
|
|
|
|
| 526 |
valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
|
| 527 |
n_invalid = int((~valid).sum())
|
| 528 |
|
| 529 |
+
if ad_label == "OUT OF DOMAIN":
|
| 530 |
+
st.warning("Protein is outside the training distribution. Predictions may be unreliable.")
|
| 531 |
+
|
| 532 |
results_df = pd.DataFrame({
|
| 533 |
'name': valid_names,
|
| 534 |
'smiles': valid_smiles,
|
|
|
|
| 541 |
}).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
|
| 542 |
results_df.insert(0, 'rank', range(1, len(results_df) + 1))
|
| 543 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
st.success(
|
| 545 |
+
f"{len(results_df)} compounds screened in {elapsed:.1f}s "
|
| 546 |
f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
|
| 547 |
+ (f" | {n_invalid} invalid SMILES skipped" if n_invalid else "")
|
| 548 |
)
|
| 549 |
|
| 550 |
top_n = min(20, len(results_df))
|
| 551 |
+
fig = bar_chart(
|
| 552 |
+
results_df.head(top_n)['name'].tolist(),
|
| 553 |
+
results_df.head(top_n)['pKd_pred'].values,
|
| 554 |
+
results_df.head(top_n)['CI_lo'].values,
|
| 555 |
+
results_df.head(top_n)['CI_hi'].values,
|
| 556 |
+
f"Top {top_n} hits by predicted pKd", dark=dark,
|
|
|
|
|
|
|
| 557 |
)
|
| 558 |
st.pyplot(fig, use_container_width=True)
|
| 559 |
plt.close(fig)
|
|
|
|
| 565 |
st.download_button(
|
| 566 |
"Download ranked CSV",
|
| 567 |
results_df.to_csv(index=False).encode(),
|
| 568 |
+
file_name="velobind_screening.csv", mime="text/csv",
|
|
|
|
| 569 |
)
|
| 570 |
|
| 571 |
|
| 572 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 573 |
+
# TAB 3 β Selectivity profiling
|
| 574 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 575 |
+
with tab3:
|
|
|
|
| 576 |
st.subheader("Selectivity Profiling")
|
| 577 |
+
st.markdown("One compound, multiple targets β ranked by predicted pKd. "
|
| 578 |
+
"Format: `TargetName: SEQUENCE` (name optional). "
|
| 579 |
+
"Accepts plain sequence or FASTA per line.")
|
| 580 |
|
| 581 |
+
multi_smi = st.text_input("Compound SMILES", placeholder="Cc1ccc(...)cc1Nc1nccc(...)n1", key="sp_smi")
|
|
|
|
| 582 |
multi_seqs = st.text_area(
|
| 583 |
"Target proteins (one per line)",
|
| 584 |
height=250,
|
| 585 |
placeholder="ABL1: MGPSENDPNLFVALY...\nEGFR: MRPSGTAGAALLALL...\nCDK2: MENFQKVEKIGEGTY...",
|
| 586 |
+
key="sp_seqs",
|
| 587 |
)
|
| 588 |
|
| 589 |
+
if st.button("Run selectivity profiling", type="primary", use_container_width=True, key="sp_btn"):
|
| 590 |
if not multi_smi.strip() or not multi_seqs.strip():
|
| 591 |
st.error("Please enter a SMILES and at least one protein sequence.")
|
| 592 |
else:
|
| 593 |
+
targets, parse_errors = {}, []
|
|
|
|
| 594 |
for i, line in enumerate(multi_seqs.strip().splitlines()):
|
| 595 |
line = line.strip()
|
| 596 |
if not line:
|
| 597 |
continue
|
| 598 |
+
name, raw_seq = (line.split(":", 1) if ":" in line
|
| 599 |
+
else (f"Target_{i+1}", line))
|
| 600 |
+
seq, err = validate_sequence(raw_seq if isinstance(raw_seq, str) else raw_seq)
|
|
|
|
|
|
|
|
|
|
| 601 |
if err:
|
| 602 |
+
parse_errors.append(f"{name.strip()}: {err}")
|
| 603 |
else:
|
| 604 |
+
targets[name.strip()] = seq
|
| 605 |
|
| 606 |
+
for e in parse_errors:
|
| 607 |
+
st.warning(f"Skipped β {e}")
|
|
|
|
| 608 |
if not targets:
|
| 609 |
st.error("No valid sequences found.")
|
| 610 |
st.stop()
|
|
|
|
| 613 |
for idx, (name, seq) in enumerate(targets.items()):
|
| 614 |
try:
|
| 615 |
X, valid, esm_vec = extract_features(
|
| 616 |
+
seq, [multi_smi.strip()], tokenizer, esm_model, device, lig_scaler)
|
|
|
|
|
|
|
| 617 |
if valid.any():
|
| 618 |
preds, preds_all = predict(X, fold_models, meta, target_scaler)
|
| 619 |
lo, hi = uncertainty_interval(preds_all)
|
| 620 |
ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
|
| 621 |
results.append({
|
| 622 |
'Target': name,
|
| 623 |
+
'pKd_pred': round(float(preds[0]), 3),
|
| 624 |
+
'CI_lo': round(float(lo[0]), 3),
|
| 625 |
+
'CI_hi': round(float(hi[0]), 3),
|
| 626 |
'Ki_est': format_ki(float(preds[0])),
|
| 627 |
'model_std': round(float(preds_all.std()), 3),
|
| 628 |
'AD': ad_label,
|
|
|
|
| 632 |
progress.progress((idx + 1) / len(targets))
|
| 633 |
|
| 634 |
progress.empty()
|
| 635 |
+
res_df = (pd.DataFrame(results)
|
| 636 |
+
.sort_values('pKd_pred', ascending=False)
|
| 637 |
+
.reset_index(drop=True))
|
|
|
|
|
|
|
| 638 |
res_df.insert(0, 'rank', range(1, len(res_df) + 1))
|
| 639 |
|
| 640 |
st.success(f"Profiled {len(res_df)} targets.")
|
| 641 |
fig = bar_chart(
|
| 642 |
+
res_df['Target'].tolist(), res_df['pKd_pred'].values,
|
| 643 |
+
res_df['CI_lo'].values, res_df['CI_hi'].values,
|
| 644 |
+
"Selectivity profile β predicted pKd by target", dark=dark,
|
|
|
|
|
|
|
|
|
|
| 645 |
)
|
| 646 |
st.pyplot(fig, use_container_width=True)
|
| 647 |
plt.close(fig)
|
|
|
|
| 650 |
st.download_button(
|
| 651 |
"Download selectivity CSV",
|
| 652 |
res_df.to_csv(index=False).encode(),
|
| 653 |
+
file_name="velobind_selectivity.csv", mime="text/csv",
|
|
|
|
| 654 |
)
|
| 655 |
|
| 656 |
|
| 657 |
# ββ Footer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 658 |
st.markdown("---")
|
| 659 |
st.markdown(f"""
|
| 660 |
+
<div style="color:{lab_col};font-size:0.78rem;text-align:center;padding:0.4rem 0 0.8rem">
|
| 661 |
+
VeloBind · Structure-free binding affinity ·
|
| 662 |
+
ESM-2 + gradient-boosted ensemble ·
|
| 663 |
+
Trained on LP-PDBBind ·
|
| 664 |
+
Evaluated on CASF-2016 and CASF-2013 ·
|
| 665 |
+
<b>Not for clinical use</b>
|
| 666 |
</div>
|
| 667 |
""", unsafe_allow_html=True)
|