ym59 commited on
Commit
12e3a24
Β·
verified Β·
1 Parent(s): bfd6953

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -170
app.py CHANGED
@@ -1,19 +1,5 @@
1
  # app.py β€” VeloBind HF Spaces inference app
2
- #
3
- # Uses the exact 45 fold models that produced the reported R=0.8469 on CASF-2016.
4
- # No retraining required. Upload output/models/ to HF model repo and set
5
- # HF_MODEL_REPO below.
6
- #
7
- # HF model repo should contain:
8
- # fold_model_s{seed}_{type}_f{fold}.pkl β€” 45 files (3 seeds Γ— 3 types Γ— 5 folds)
9
- # meta_type_casf16.pkl β€” Ridge meta-learner (from 06_eval_both.py)
10
- # target_scaler.pkl β€” TargetScaler (from 03_train.py)
11
- # ligand_scaler.pkl β€” from output/preprocessors/
12
- #
13
- # Free tier: 16GB RAM, 2 vCPU, 50GB disk β€” all 45 models fit easily (~2-3GB total).
14
- # Cold start: ~30-40s to download + load models on first visit.
15
-
16
- import os, json, warnings, time
17
  import numpy as np
18
  import pandas as pd
19
  import streamlit as st
@@ -21,7 +7,6 @@ import joblib
21
  import torch
22
  import matplotlib.pyplot as plt
23
  from pathlib import Path
24
- from scipy.stats import pearsonr
25
 
26
  warnings.filterwarnings("ignore")
27
  from rdkit import RDLogger
@@ -33,9 +18,8 @@ MODEL_CACHE = Path("/tmp/velobind_models")
33
  SEEDS = [42, 123, 456]
34
  MODEL_TYPES = ["lgbm", "cb", "xgb"]
35
  N_FOLDS = 5
 
36
 
37
- # Best feature config β€” Step 9 winner from 03_train.py ablation
38
- # MUST match what the fold models were trained on
39
  import sys
40
  sys.path.append(str(Path(__file__).parent))
41
  from src.features.protein import load_esm, embed_batch, sequence_features
@@ -45,34 +29,49 @@ from src.config import config
45
 
46
 
47
  # ══════════════════════════════════════════════════════════════════════
48
- # Model loading
49
  # ══════════════════════════════════════════════════════════════════════
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- @st.cache_resource(show_spinner="Downloading & loading VeloBind models (~30s first run)…")
 
 
 
52
  def load_all_models():
53
  from huggingface_hub import hf_hub_download
54
  MODEL_CACHE.mkdir(parents=True, exist_ok=True)
55
 
56
- # Build list of all files to fetch
57
  model_files = (
58
  [f"fold_model_s{s}_{t}_f{f}.pkl"
59
  for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
60
  + ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
61
  )
62
 
63
- progress = st.progress(0, text="Downloading models…")
64
  for i, fname in enumerate(model_files):
65
  local = MODEL_CACHE / fname
66
  if not local.exists():
67
- hf_hub_download(
68
- repo_id=HF_MODEL_REPO, filename=fname,
69
- local_dir=str(MODEL_CACHE),
70
- )
71
- progress.progress((i + 1) / len(model_files),
72
- text=f"Loading {fname}…")
73
- progress.empty()
74
 
75
- # Load into nested dict: fold_models[seed][type][fold] = model
76
  fold_models = {}
77
  for s in SEEDS:
78
  fold_models[s] = {}
@@ -85,27 +84,22 @@ def load_all_models():
85
  meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
86
  scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
87
  lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
88
-
89
  return fold_models, meta, scaler, lig_sc
90
 
91
- @st.cache_resource(show_spinner="Loading ESM-2 protein language model…")
 
92
  def load_esm_model():
93
  device = "cuda" if torch.cuda.is_available() else "cpu"
94
  tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
95
  return tokenizer, esm_model, device
96
 
 
97
  @st.cache_resource(show_spinner=False)
98
  def load_ad_centroid():
99
- # local fallback
100
- local_paths = [
101
- Path("output/models/deployment"),
102
- Path("output/models"),
103
- ]
104
- for p in local_paths:
105
  if (p / "ad_centroid.npy").exists():
106
  return (np.load(p / "ad_centroid.npy"),
107
  float(np.load(p / "ad_threshold.npy")))
108
- # HF fallback
109
  for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
110
  local = MODEL_CACHE / fname
111
  if not local.exists():
@@ -118,6 +112,7 @@ def load_ad_centroid():
118
  return (np.load(MODEL_CACHE / "ad_centroid.npy"),
119
  float(np.load(MODEL_CACHE / "ad_threshold.npy")))
120
 
 
121
  def ad_check(esm_mean_vec, centroid, threshold):
122
  if centroid is None:
123
  return "UNKNOWN", float("nan")
@@ -126,29 +121,25 @@ def ad_check(esm_mean_vec, centroid, threshold):
126
 
127
 
128
  # ══════════════════════════════════════════════════════════════════════
129
- # Feature assembly β€” mirrors assemble() in 03_train.py exactly
130
  # ══════════════════════════════════════════════════════════════════════
131
- def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats, cfg=None):
132
- """Matches assemble() in 06_casf_eval.py exactly β€” 10,054d."""
133
  return np.concatenate([
134
- esm_mean[:, -480:], # last layer only: 480d
135
- seq_feat, # 919d
136
- lig_feats["ecfp"], # 1024d
137
- lig_feats["ecfp2"], # 1024d
138
- lig_feats["ecfp6"], # 1024d
139
- lig_feats["fcfp"], # 1024d
140
- lig_feats["estate"], # 79d
141
- lig_feats["maccs"], # 167d
142
- lig_feats["atom_pair"], # 2048d
143
- lig_feats["torsion"], # 2048d
144
- lig_feats["phys"], # 217d
145
  ], axis=1)
146
 
147
 
148
- def extract_features(sequence: str, smiles_list: list,
149
- tokenizer, esm_model, device, lig_scaler):
150
- """Returns X [N_valid, D], valid_mask [N_smiles]."""
151
- # Protein (embed once, tile)
152
  esm_mean, esm_var, esm_attn, _ = embed_batch(
153
  [sequence], tokenizer, esm_model,
154
  config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
@@ -156,7 +147,6 @@ def extract_features(sequence: str, smiles_list: list,
156
  )
157
  seq_feat = np.array([sequence_features(sequence)])
158
 
159
- # Ligands
160
  lig_feats, valid_mask, _ = extract_ligand_features(
161
  smiles_list, scaler=lig_scaler, fit_scaler=False
162
  )
@@ -166,44 +156,34 @@ def extract_features(sequence: str, smiles_list: list,
166
  bool_mask[valid_mask] = True
167
  valid_mask = bool_mask
168
 
169
- # Tile protein over valid ligands only
170
  n_valid = int(valid_mask.sum())
171
- esm_mean_t = np.tile(esm_mean, (n_valid, 1))
172
- esm_var_t = np.tile(esm_var, (n_valid, 1))
173
- esm_attn_t = np.tile(esm_attn, (n_valid, 1))
174
- seq_feat_t = np.tile(seq_feat, (n_valid, 1))
175
 
176
  X = assemble_from_parts(esm_mean_t, esm_var_t, esm_attn_t, seq_feat_t, lig_feats)
177
  return X, valid_mask, esm_mean[0]
178
 
179
 
180
  # ══════════════════════════════════════════════════════════════════════
181
- # Prediction β€” mirrors build_test_matrix + blend from 06_eval_both.py
182
  # ══════════════════════════════════════════════════════════════════════
183
  def predict(X, fold_models, meta, scaler):
184
- """
185
- Returns:
186
- preds [N] final ensemble pKd
187
- preds_all [N, 9] per-(seed,type) predictions for uncertainty
188
- """
189
- # Each entry: average over 5 folds for one (seed, type) combo
190
  type_avgs = []
191
  for s in SEEDS:
192
  for t in MODEL_TYPES:
193
  fold_preds = np.stack([
194
  scaler.inverse(fold_models[s][t][f].predict(X))
195
  for f in range(N_FOLDS)
196
- ], axis=1) # [N, 5]
197
- type_avgs.append(fold_preds.mean(axis=1)) # [N]
198
-
199
- preds_all = np.stack(type_avgs, axis=1) # [N, 9]
200
-
201
- # Per-model-type average β†’ Ridge meta (matches blend() in 06_eval_both.py)
202
- lgbm_avg = preds_all[:, [0, 3, 6]].mean(axis=1)
203
- cb_avg = preds_all[:, [1, 4, 7]].mean(axis=1)
204
- xgb_avg = preds_all[:, [2, 5, 8]].mean(axis=1)
205
- preds = meta.predict(np.column_stack([lgbm_avg, cb_avg, xgb_avg]))
206
-
207
  return preds, preds_all
208
 
209
 
@@ -212,34 +192,74 @@ def uncertainty_interval(preds_all, z=1.96):
212
  return preds_all.mean(axis=1) - z * std, preds_all.mean(axis=1) + z * std
213
 
214
 
 
 
 
 
 
 
 
 
 
 
 
215
  # ══════════════════════════════════════════════════════════════════════
216
  # Plots
217
  # ══════════════════════════════════════════════════════════════════════
218
- def bar_chart(names, preds, lo, hi, title):
219
- fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4))
220
- x = np.arange(len(names))
221
- err = [preds - lo, hi - preds]
 
 
 
 
 
 
222
  bars = ax.bar(x, preds, color="#4C72B0", alpha=0.85, width=0.6,
223
- yerr=err, capsize=5, error_kw=dict(ecolor="#333", lw=1.5))
224
  ax.set_xticks(x)
225
- ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10)
226
- ax.set_ylabel("Predicted pKd", fontsize=11)
227
- ax.set_title(title, fontsize=12, fontweight='bold')
228
- ax.grid(True, axis='y', alpha=0.25)
 
 
229
  for bar, val in zip(bars, preds):
230
  ax.text(bar.get_x() + bar.get_width() / 2,
231
  bar.get_height() + 0.05, f"{val:.2f}",
232
- ha='center', va='bottom', fontsize=9, fontweight='bold')
 
233
  plt.tight_layout()
234
  return fig
235
 
236
 
237
  # ══════════════════════════════════════════════════════════════════════
238
- # App layout
239
  # ══════════════════════════════════════════════════════════════════════
240
- st.set_page_config(page_title="VeloBind", page_icon="⚑", layout="wide")
241
-
242
- import base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  def load_svg_b64(path):
245
  with open(path, "rb") as f:
@@ -249,6 +269,7 @@ logo_b64 = load_svg_b64("logo.svg")
249
 
250
  st.markdown(f"""
251
  <style>
 
252
  .header-wrap {{
253
  display: flex; align-items: center; gap: 1.5rem;
254
  margin-bottom: 1.5rem;
@@ -259,17 +280,17 @@ st.markdown(f"""
259
  }}
260
  .logo-box img {{ height: 130px; width: auto; display: block; }}
261
  .header-text {{
262
- background: linear-gradient(135deg, #1a3a5c, #1e6091, #2980b9);
263
  padding: 1.5rem 2rem; border-radius: 12px; flex: 1;
264
  }}
265
  .header-text h1 {{ color: #fff; font-size: 2.2rem; margin: 0; }}
266
- .header-text p {{ color: #aad4f5; margin: 0.3rem 0 0; font-size: 1rem; }}
267
  .metric-card {{
268
- background: #1e2a38; border: 1px solid #2d3f55;
269
  border-radius: 10px; padding: 1rem; text-align: center;
270
  }}
271
- .metric-val {{ font-size: 2rem; font-weight: 700; color: #4fc3f7; }}
272
- .metric-lab {{ font-size: 0.8rem; color: #aaa; margin-top: 0.2rem; }}
273
  .ad-in {{ background:#1b4332; border:1px solid #2d6a4f; color:#40916c;
274
  border-radius:8px; padding:0.4rem 1rem; font-weight:700; display:inline-block; }}
275
  .ad-out {{ background:#4a1c24; border:1px solid #9b2335; color:#e74c3c;
@@ -283,24 +304,25 @@ st.markdown(f"""
283
  </div>
284
  <div class="header-text">
285
  <h1>VeloBind</h1>
286
- <p>Structure-free protein–ligand binding affinity Β· sequence + SMILES only Β·
287
- Pearson R = 0.8469 on CASF-2016 Β· 45-model ensemble (LGBM + CatBoost + XGBoost)</p>
 
 
288
  </div>
289
  </div>
290
  """, unsafe_allow_html=True)
291
 
292
- # ── Load models (cached) ──────────────────────────────────────────────
293
  fold_models, meta, target_scaler, lig_scaler = load_all_models()
294
  tokenizer, esm_model, device = load_esm_model()
 
295
  n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
296
  st.success(f"βœ“ {n_loaded} fold models loaded | Device: {device.upper()}")
297
 
298
  # ── Mode selector ─────────────────────────────────────────────────────
299
  mode = st.radio(
300
  "Select mode",
301
- ["πŸ”¬ Single query",
302
- "πŸ“‹ Batch screening (CSV)",
303
- "🎯 One compound vs. multiple targets"],
304
  horizontal=True,
305
  )
306
  st.markdown("---")
@@ -309,13 +331,17 @@ st.markdown("---")
309
  # ══════════════════════════════════════════════════════════════════════
310
  # MODE 1 β€” Single query
311
  # ══════════════════════════════════════════════════════════════════════
312
- if mode == "πŸ”¬ Single query":
313
 
314
  col_p, col_l = st.columns(2)
315
  with col_p:
316
  st.subheader("Protein")
317
- seq = st.text_area("Amino acid sequence (single-letter)", height=150,
318
- placeholder="MKTAYIAKQRQISFVK…")
 
 
 
 
319
  with col_l:
320
  st.subheader("Ligand")
321
  smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O")
@@ -329,15 +355,18 @@ if mode == "πŸ”¬ Single query":
329
  if chosen != "β€”":
330
  smi = examples[chosen]
331
 
332
- if st.button("Predict ⚑", type="primary", use_container_width=True):
333
- if not seq.strip() or not smi.strip():
334
- st.error("Please enter both a sequence and a SMILES string.")
 
 
 
335
  else:
336
- with st.spinner("Running inference…"):
337
  t0 = time.time()
338
  try:
339
  X, valid, esm_vec = extract_features(
340
- seq.strip(), [smi.strip()],
341
  tokenizer, esm_model, device, lig_scaler
342
  )
343
  if not valid.any():
@@ -361,17 +390,14 @@ if mode == "πŸ”¬ Single query":
361
  <div class="metric-lab">95% model interval (Β±1.96Οƒ, 45 models)</div>
362
  </div>""", unsafe_allow_html=True)
363
  with c3:
364
- Ki = 10 ** (9 - pkd)
365
  st.markdown(f"""<div class="metric-card">
366
- <div class="metric-val">{Ki:.1f} nM</div>
367
- <div class="metric-lab">Estimated Kα΅’ (pKd β‰ˆ pKα΅’ assumed)</div>
368
  </div>""", unsafe_allow_html=True)
369
- ad_centroid, ad_threshold = load_ad_centroid()
370
- ad_label, ad_dist = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
371
-
372
  with c4:
373
- ad_cls = "ad-in" if ad_label == "IN DOMAIN" else \
374
- "ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk"
 
375
  st.markdown(f"""<div class="metric-card">
376
  <div class="{ad_cls}">{ad_label}</div>
377
  <div class="metric-lab">Applicability domain</div>
@@ -379,22 +405,22 @@ if mode == "πŸ”¬ Single query":
379
 
380
  if ad_label == "OUT OF DOMAIN":
381
  st.warning("Protein is outside the training distribution. "
382
- "Predictions may be unreliable.", icon="⚠️")
383
 
384
  st.caption(
385
  f"Inference time: {elapsed:.2f}s | "
386
- f"45-model ensemble (3 seeds Γ— 3 types Γ— 5 folds) | "
387
  f"Device: {device.upper()}"
388
  )
389
 
390
  with st.expander("Per-model breakdown"):
391
  labels = [f"s{s}_{t}" for s in SEEDS for t in MODEL_TYPES]
392
  fig = bar_chart(
393
- labels,
394
- preds_all[0],
395
  preds_all[0] - preds_all[0].std(),
396
  preds_all[0] + preds_all[0].std(),
397
- "Seed Γ— type predictions (fold-averaged)"
 
398
  )
399
  st.pyplot(fig, use_container_width=True)
400
  plt.close(fig)
@@ -407,7 +433,7 @@ if mode == "πŸ”¬ Single query":
407
  # ══════════════════════════════════════════════════════════════════════
408
  # MODE 2 β€” Batch screening
409
  # ══════════════════════════════════════════════════════════════════════
410
- elif mode == "πŸ“‹ Batch screening (CSV)":
411
 
412
  st.subheader("Batch Screening")
413
  st.markdown("One protein, many compounds. Upload a CSV with a `smiles` column "
@@ -415,8 +441,8 @@ elif mode == "πŸ“‹ Batch screening (CSV)":
415
 
416
  col_seq, col_csv = st.columns(2)
417
  with col_seq:
418
- batch_seq = st.text_area("Target protein sequence", height=180,
419
- placeholder="Paste UniProt sequence…")
420
  with col_csv:
421
  uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"])
422
  st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
@@ -424,9 +450,10 @@ elif mode == "πŸ“‹ Batch screening (CSV)":
424
  max_cpds = st.slider("Max compounds", 10, 500, 100,
425
  help="~1s per compound on CPU free tier.")
426
 
427
- if st.button("Run batch screening ⚑", type="primary", use_container_width=True):
428
- if not batch_seq.strip():
429
- st.error("Please enter a protein sequence.")
 
430
  elif uploaded is None:
431
  st.error("Please upload a CSV file.")
432
  else:
@@ -440,22 +467,16 @@ elif mode == "πŸ“‹ Batch screening (CSV)":
440
  names_list = (df_in['name'].tolist() if 'name' in df_in.columns
441
  else [f"cpd_{i}" for i in range(len(df_in))])
442
 
443
- ad_centroid, ad_threshold = load_ad_centroid()
444
- with st.spinner(f"Screening {len(smiles_list)} compounds…"):
445
  t0 = time.time()
446
  X, valid, esm_vec = extract_features(
447
- batch_seq.strip(), smiles_list,
448
  tokenizer, esm_model, device, lig_scaler
449
  )
450
- ad_labels = []
451
- for i, smiles in enumerate(smiles_list):
452
- if valid[i]:
453
- label, _ = ad_check(esm_vec, ad_centroid, ad_threshold)
454
- ad_labels.append(label)
455
-
456
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
457
- lo, hi = uncertainty_interval(preds_all)
458
- elapsed = time.time() - t0
459
 
460
  valid_names = [names_list[i] for i in range(len(names_list)) if valid[i]]
461
  valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
@@ -467,12 +488,16 @@ elif mode == "πŸ“‹ Batch screening (CSV)":
467
  'pKd_pred': np.round(preds, 3),
468
  'CI_lo': np.round(lo, 3),
469
  'CI_hi': np.round(hi, 3),
470
- 'Ki_nM_est': np.round(10 ** (9 - preds), 1),
471
  'model_std': np.round(preds_all.std(axis=1), 3),
472
- 'AD' : ad_labels
473
  }).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
474
  results_df.insert(0, 'rank', range(1, len(results_df) + 1))
475
 
 
 
 
 
476
  st.success(
477
  f"βœ“ {len(results_df)} compounds in {elapsed:.1f}s "
478
  f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
@@ -486,7 +511,8 @@ elif mode == "πŸ“‹ Batch screening (CSV)":
486
  top_df['pKd_pred'].values,
487
  top_df['CI_lo'].values,
488
  top_df['CI_hi'].values,
489
- f"Top {top_n} hits"
 
490
  )
491
  st.pyplot(fig, use_container_width=True)
492
  plt.close(fig)
@@ -496,7 +522,7 @@ elif mode == "πŸ“‹ Batch screening (CSV)":
496
  use_container_width=True, height=400,
497
  )
498
  st.download_button(
499
- "⬇ Download ranked CSV",
500
  results_df.to_csv(index=False).encode(),
501
  file_name="velobind_screening.csv",
502
  mime="text/csv",
@@ -506,7 +532,7 @@ elif mode == "πŸ“‹ Batch screening (CSV)":
506
  # ══════════════════════════════════════════════════════════════════════
507
  # MODE 3 β€” One compound vs. multiple targets
508
  # ══════════════════════════════════════════════════════════════════════
509
- elif mode == "🎯 One compound vs. multiple targets":
510
 
511
  st.subheader("Selectivity Profiling")
512
  st.markdown("One SMILES, multiple proteins β€” ranked by predicted pKd. "
@@ -517,33 +543,37 @@ elif mode == "🎯 One compound vs. multiple targets":
517
  multi_seqs = st.text_area(
518
  "Target proteins (one per line)",
519
  height=250,
520
- placeholder=(
521
- "ABL1: MGPSENDPNLFVALY...\n"
522
- "EGFR: MRPSGTAGAALLALL...\n"
523
- "CDK2: MENFQKVEKIGEGTY..."
524
- ),
525
  )
526
 
527
- if st.button("Run selectivity profiling ⚑", type="primary", use_container_width=True):
528
  if not multi_smi.strip() or not multi_seqs.strip():
529
  st.error("Please enter a SMILES and at least one protein sequence.")
530
  else:
531
  targets = {}
 
532
  for i, line in enumerate(multi_seqs.strip().splitlines()):
533
  line = line.strip()
534
  if not line:
535
  continue
536
  if ":" in line:
537
- name, seq = line.split(":", 1)
538
- targets[name.strip()] = seq.strip()
 
 
 
 
 
539
  else:
540
- targets[f"Target_{i+1}"] = line
541
 
 
 
 
542
  if not targets:
543
- st.error("Could not parse any sequences.")
544
  st.stop()
545
 
546
- ad_centroid, ad_threshold = load_ad_centroid()
547
  results, progress = [], st.progress(0)
548
  for idx, (name, seq) in enumerate(targets.items()):
549
  try:
@@ -554,15 +584,15 @@ elif mode == "🎯 One compound vs. multiple targets":
554
  if valid.any():
555
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
556
  lo, hi = uncertainty_interval(preds_all)
557
- ad_label, _ = ad_check(esm_vec, ad_centroid, ad_threshold)
558
  results.append({
559
  'Target': name,
560
  'pKd_pred': round(float(preds[0]), 3),
561
  'CI_lo': round(float(lo[0]), 3),
562
  'CI_hi': round(float(hi[0]), 3),
563
- 'Ki_nM_est': round(10 ** (9 - float(preds[0])), 1),
564
  'model_std': round(float(preds_all.std()), 3),
565
- 'AD': ad_label,
566
  })
567
  except Exception as e:
568
  st.warning(f"Skipped {name}: {e}")
@@ -576,30 +606,35 @@ elif mode == "🎯 One compound vs. multiple targets":
576
  )
577
  res_df.insert(0, 'rank', range(1, len(res_df) + 1))
578
 
579
- st.success(f"βœ“ Profiled {len(res_df)} targets.")
580
  fig = bar_chart(
581
  res_df['Target'].tolist(),
582
  res_df['pKd_pred'].values,
583
  res_df['CI_lo'].values,
584
  res_df['CI_hi'].values,
585
- "Selectivity profile β€” predicted pKd by target"
 
586
  )
587
  st.pyplot(fig, use_container_width=True)
588
  plt.close(fig)
589
 
590
  st.dataframe(res_df, use_container_width=True)
591
  st.download_button(
592
- "⬇ Download selectivity CSV",
593
  res_df.to_csv(index=False).encode(),
594
  file_name="velobind_selectivity.csv",
595
  mime="text/csv",
596
  )
597
 
 
598
  # ── Footer ────────────────────────────────────────────────────────────
599
  st.markdown("---")
600
- st.markdown("""
601
- <div style="color:#666;font-size:0.8rem;text-align:center;padding:0.5rem">
602
- VeloBind Β· Structure-free binding affinity Β· ESM-2 + GBM ensemble Β·
603
- Trained on LP-PDBBind Β· Evaluated on CASF-2016/2013 Β· <b>Not for clinical use.</b>
 
 
 
604
  </div>
605
  """, unsafe_allow_html=True)
 
1
  # app.py β€” VeloBind HF Spaces inference app
2
+ import os, warnings, time, base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
 
7
  import torch
8
  import matplotlib.pyplot as plt
9
  from pathlib import Path
 
10
 
11
  warnings.filterwarnings("ignore")
12
  from rdkit import RDLogger
 
18
  SEEDS = [42, 123, 456]
19
  MODEL_TYPES = ["lgbm", "cb", "xgb"]
20
  N_FOLDS = 5
21
+ VALID_AA = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwyX")
22
 
 
 
23
  import sys
24
  sys.path.append(str(Path(__file__).parent))
25
  from src.features.protein import load_esm, embed_batch, sequence_features
 
29
 
30
 
31
  # ══════════════════════════════════════════════════════════════════════
32
+ # Validation
33
  # ══════════════════════════════════════════════════════════════════════
34
+ def validate_sequence(raw: str):
35
+ raw = raw.strip()
36
+ if not raw:
37
+ return None, "Please enter a sequence."
38
+
39
+ # Strip FASTA header(s)
40
+ lines = raw.splitlines()
41
+ seq_lines = [l.strip() for l in lines if not l.startswith(">")]
42
+ seq = "".join(seq_lines).upper().replace(" ", "")
43
+
44
+ if len(seq) < 10:
45
+ return None, "Sequence too short (minimum 10 residues)."
46
+ invalid = set(seq) - VALID_AA
47
+ if invalid:
48
+ return None, f"Invalid characters: {', '.join(sorted(invalid))}. Only standard amino acid letters accepted."
49
+ return seq, None
50
+
51
 
52
+ # ══════════════════════════════════════════════════════════════════════
53
+ # Model loading
54
+ # ══════════════════════════════════════════════════════════════════════
55
+ @st.cache_resource(show_spinner="Downloading and loading VeloBind models (first run ~30s)...")
56
  def load_all_models():
57
  from huggingface_hub import hf_hub_download
58
  MODEL_CACHE.mkdir(parents=True, exist_ok=True)
59
 
 
60
  model_files = (
61
  [f"fold_model_s{s}_{t}_f{f}.pkl"
62
  for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
63
  + ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
64
  )
65
 
66
+ bar = st.progress(0, text="Loading models...")
67
  for i, fname in enumerate(model_files):
68
  local = MODEL_CACHE / fname
69
  if not local.exists():
70
+ hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
71
+ local_dir=str(MODEL_CACHE))
72
+ bar.progress((i + 1) / len(model_files), text=f"Loading {fname}...")
73
+ bar.empty()
 
 
 
74
 
 
75
  fold_models = {}
76
  for s in SEEDS:
77
  fold_models[s] = {}
 
84
  meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
85
  scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
86
  lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
 
87
  return fold_models, meta, scaler, lig_sc
88
 
89
+
90
+ @st.cache_resource(show_spinner="Loading ESM-2 protein language model...")
91
  def load_esm_model():
92
  device = "cuda" if torch.cuda.is_available() else "cpu"
93
  tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
94
  return tokenizer, esm_model, device
95
 
96
+
97
  @st.cache_resource(show_spinner=False)
98
  def load_ad_centroid():
99
+ for p in [Path("output/models/deployment"), Path("output/models")]:
 
 
 
 
 
100
  if (p / "ad_centroid.npy").exists():
101
  return (np.load(p / "ad_centroid.npy"),
102
  float(np.load(p / "ad_threshold.npy")))
 
103
  for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
104
  local = MODEL_CACHE / fname
105
  if not local.exists():
 
112
  return (np.load(MODEL_CACHE / "ad_centroid.npy"),
113
  float(np.load(MODEL_CACHE / "ad_threshold.npy")))
114
 
115
+
116
  def ad_check(esm_mean_vec, centroid, threshold):
117
  if centroid is None:
118
  return "UNKNOWN", float("nan")
 
121
 
122
 
123
  # ══════════════════════════════════════════════════════════════════════
124
+ # Feature extraction
125
  # ══════════════════════════════════════════════════════════════════════
126
+ def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
 
127
  return np.concatenate([
128
+ esm_mean[:, -480:],
129
+ seq_feat,
130
+ lig_feats["ecfp"],
131
+ lig_feats["ecfp2"],
132
+ lig_feats["ecfp6"],
133
+ lig_feats["fcfp"],
134
+ lig_feats["estate"],
135
+ lig_feats["maccs"],
136
+ lig_feats["atom_pair"],
137
+ lig_feats["torsion"],
138
+ lig_feats["phys"],
139
  ], axis=1)
140
 
141
 
142
+ def extract_features(sequence, smiles_list, tokenizer, esm_model, device, lig_scaler):
 
 
 
143
  esm_mean, esm_var, esm_attn, _ = embed_batch(
144
  [sequence], tokenizer, esm_model,
145
  config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
 
147
  )
148
  seq_feat = np.array([sequence_features(sequence)])
149
 
 
150
  lig_feats, valid_mask, _ = extract_ligand_features(
151
  smiles_list, scaler=lig_scaler, fit_scaler=False
152
  )
 
156
  bool_mask[valid_mask] = True
157
  valid_mask = bool_mask
158
 
 
159
  n_valid = int(valid_mask.sum())
160
+ esm_mean_t = np.tile(esm_mean, (n_valid, 1))
161
+ esm_var_t = np.tile(esm_var, (n_valid, 1))
162
+ esm_attn_t = np.tile(esm_attn, (n_valid, 1))
163
+ seq_feat_t = np.tile(seq_feat, (n_valid, 1))
164
 
165
  X = assemble_from_parts(esm_mean_t, esm_var_t, esm_attn_t, seq_feat_t, lig_feats)
166
  return X, valid_mask, esm_mean[0]
167
 
168
 
169
  # ══════════════════════════════════════════════════════════════════════
170
+ # Prediction
171
  # ══════════════════════════════════════════════════════════════════════
172
  def predict(X, fold_models, meta, scaler):
 
 
 
 
 
 
173
  type_avgs = []
174
  for s in SEEDS:
175
  for t in MODEL_TYPES:
176
  fold_preds = np.stack([
177
  scaler.inverse(fold_models[s][t][f].predict(X))
178
  for f in range(N_FOLDS)
179
+ ], axis=1)
180
+ type_avgs.append(fold_preds.mean(axis=1))
181
+
182
+ preds_all = np.stack(type_avgs, axis=1)
183
+ lgbm_avg = preds_all[:, [0, 3, 6]].mean(axis=1)
184
+ cb_avg = preds_all[:, [1, 4, 7]].mean(axis=1)
185
+ xgb_avg = preds_all[:, [2, 5, 8]].mean(axis=1)
186
+ preds = meta.predict(np.column_stack([lgbm_avg, cb_avg, xgb_avg]))
 
 
 
187
  return preds, preds_all
188
 
189
 
 
192
  return preds_all.mean(axis=1) - z * std, preds_all.mean(axis=1) + z * std
193
 
194
 
195
+ def format_ki(pkd):
196
+ """Format Ki with appropriate unit (nM, uM, mM)."""
197
+ ki_nM = 10 ** (9 - pkd)
198
+ if ki_nM < 1000:
199
+ return f"{ki_nM:.1f} nM"
200
+ elif ki_nM < 1_000_000:
201
+ return f"{ki_nM/1000:.2f} uM"
202
+ else:
203
+ return f"{ki_nM/1_000_000:.2f} mM"
204
+
205
+
206
  # ══════════════════════════════════════════════════════════════════════
207
  # Plots
208
  # ══════════════════════════════════════════════════════════════════════
209
+ def bar_chart(names, preds, lo, hi, title, dark=True):
210
+ bg = "#1e2a38" if dark else "#f8f9fa"
211
+ fg = "#ffffff" if dark else "#111111"
212
+ grid = "#2d3f55" if dark else "#cccccc"
213
+
214
+ fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4),
215
+ facecolor=bg)
216
+ ax.set_facecolor(bg)
217
+ x = np.arange(len(names))
218
+ err = [preds - lo, hi - preds]
219
  bars = ax.bar(x, preds, color="#4C72B0", alpha=0.85, width=0.6,
220
+ yerr=err, capsize=5, error_kw=dict(ecolor=fg, lw=1.5))
221
  ax.set_xticks(x)
222
+ ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10, color=fg)
223
+ ax.set_ylabel("Predicted pKd", fontsize=11, color=fg)
224
+ ax.set_title(title, fontsize=12, fontweight='bold', color=fg)
225
+ ax.tick_params(colors=fg)
226
+ ax.spines[:].set_color(grid)
227
+ ax.grid(True, axis='y', alpha=0.25, color=grid)
228
  for bar, val in zip(bars, preds):
229
  ax.text(bar.get_x() + bar.get_width() / 2,
230
  bar.get_height() + 0.05, f"{val:.2f}",
231
+ ha='center', va='bottom', fontsize=9,
232
+ fontweight='bold', color=fg)
233
  plt.tight_layout()
234
  return fig
235
 
236
 
237
  # ══════════════════════════════════════════════════════════════════════
238
+ # Page setup
239
  # ══════════════════════════════════════════════════════════════════════
240
+ st.set_page_config(page_title="VeloBind", layout="wide")
241
+
242
+ # ── Theme toggle ──────────────────────────────────────────────────────
243
+ with st.sidebar:
244
+ st.markdown("### Display")
245
+ dark_mode = st.toggle("Dark mode", value=True)
246
+
247
+ if dark_mode:
248
+ header_bg = "linear-gradient(135deg, #1a3a5c, #1e6091, #2980b9)"
249
+ card_bg = "#1e2a38"
250
+ card_border = "#2d3f55"
251
+ val_color = "#4fc3f7"
252
+ lab_color = "#aaa"
253
+ page_bg = "#0e1117"
254
+ text_color = "#ffffff"
255
+ else:
256
+ header_bg = "linear-gradient(135deg, #2980b9, #5dade2, #85c1e9)"
257
+ card_bg = "#f0f4f8"
258
+ card_border = "#b0c4de"
259
+ val_color = "#1a5276"
260
+ lab_color = "#555"
261
+ page_bg = "#ffffff"
262
+ text_color = "#111111"
263
 
264
  def load_svg_b64(path):
265
  with open(path, "rb") as f:
 
269
 
270
  st.markdown(f"""
271
  <style>
272
+ .stApp {{ background-color: {page_bg}; color: {text_color}; }}
273
  .header-wrap {{
274
  display: flex; align-items: center; gap: 1.5rem;
275
  margin-bottom: 1.5rem;
 
280
  }}
281
  .logo-box img {{ height: 130px; width: auto; display: block; }}
282
  .header-text {{
283
+ background: {header_bg};
284
  padding: 1.5rem 2rem; border-radius: 12px; flex: 1;
285
  }}
286
  .header-text h1 {{ color: #fff; font-size: 2.2rem; margin: 0; }}
287
+ .header-text p {{ color: #d6eaf8; margin: 0.3rem 0 0; font-size: 1rem; }}
288
  .metric-card {{
289
+ background: {card_bg}; border: 1px solid {card_border};
290
  border-radius: 10px; padding: 1rem; text-align: center;
291
  }}
292
+ .metric-val {{ font-size: 2rem; font-weight: 700; color: {val_color}; }}
293
+ .metric-lab {{ font-size: 0.8rem; color: {lab_color}; margin-top: 0.2rem; }}
294
  .ad-in {{ background:#1b4332; border:1px solid #2d6a4f; color:#40916c;
295
  border-radius:8px; padding:0.4rem 1rem; font-weight:700; display:inline-block; }}
296
  .ad-out {{ background:#4a1c24; border:1px solid #9b2335; color:#e74c3c;
 
304
  </div>
305
  <div class="header-text">
306
  <h1>VeloBind</h1>
307
+ <p>Structure-free protein-ligand binding affinity prediction &nbsp;Β·&nbsp;
308
+ Sequence + SMILES only &nbsp;Β·&nbsp;
309
+ Pearson R = 0.8469 on CASF-2016 &nbsp;Β·&nbsp;
310
+ 45-model ensemble (LGBM + CatBoost + XGBoost)</p>
311
  </div>
312
  </div>
313
  """, unsafe_allow_html=True)
314
 
315
+ # ── Load everything ───────────────────────────────────────────────────
316
  fold_models, meta, target_scaler, lig_scaler = load_all_models()
317
  tokenizer, esm_model, device = load_esm_model()
318
+ ad_centroid, ad_threshold = load_ad_centroid()
319
  n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
320
  st.success(f"βœ“ {n_loaded} fold models loaded | Device: {device.upper()}")
321
 
322
  # ── Mode selector ─────────────────────────────────────────────────────
323
  mode = st.radio(
324
  "Select mode",
325
+ ["Single query", "Batch screening (CSV)", "One compound vs. multiple targets"],
 
 
326
  horizontal=True,
327
  )
328
  st.markdown("---")
 
331
  # ══════════════════════════════════════════════════════════════════════
332
  # MODE 1 β€” Single query
333
  # ══════════════════════════════════════════════════════════════════════
334
+ if mode == "Single query":
335
 
336
  col_p, col_l = st.columns(2)
337
  with col_p:
338
  st.subheader("Protein")
339
+ seq_raw = st.text_area(
340
+ "Amino acid sequence (single-letter FASTA, no header)",
341
+ height=150,
342
+ placeholder="MKTAYIAKQRQISFVK...",
343
+ help="Only standard amino acid letters accepted (A C D E F G H I K L M N P Q R S T V W Y)."
344
+ )
345
  with col_l:
346
  st.subheader("Ligand")
347
  smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O")
 
355
  if chosen != "β€”":
356
  smi = examples[chosen]
357
 
358
+ if st.button("Predict", type="primary", use_container_width=True):
359
+ seq, err = validate_sequence(seq_raw)
360
+ if err:
361
+ st.error(err)
362
+ elif not smi.strip():
363
+ st.error("Please enter a SMILES string.")
364
  else:
365
+ with st.spinner("Running inference..."):
366
  t0 = time.time()
367
  try:
368
  X, valid, esm_vec = extract_features(
369
+ seq, [smi.strip()],
370
  tokenizer, esm_model, device, lig_scaler
371
  )
372
  if not valid.any():
 
390
  <div class="metric-lab">95% model interval (Β±1.96Οƒ, 45 models)</div>
391
  </div>""", unsafe_allow_html=True)
392
  with c3:
 
393
  st.markdown(f"""<div class="metric-card">
394
+ <div class="metric-val">{format_ki(pkd)}</div>
395
+ <div class="metric-lab">Estimated Ki (pKd β‰ˆ pKi assumed)</div>
396
  </div>""", unsafe_allow_html=True)
 
 
 
397
  with c4:
398
+ ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
399
+ ad_cls = ("ad-in" if ad_label == "IN DOMAIN" else
400
+ "ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk")
401
  st.markdown(f"""<div class="metric-card">
402
  <div class="{ad_cls}">{ad_label}</div>
403
  <div class="metric-lab">Applicability domain</div>
 
405
 
406
  if ad_label == "OUT OF DOMAIN":
407
  st.warning("Protein is outside the training distribution. "
408
+ "Predictions may be unreliable.")
409
 
410
  st.caption(
411
  f"Inference time: {elapsed:.2f}s | "
412
+ f"45-model ensemble (3 seeds x 3 types x 5 folds) | "
413
  f"Device: {device.upper()}"
414
  )
415
 
416
  with st.expander("Per-model breakdown"):
417
  labels = [f"s{s}_{t}" for s in SEEDS for t in MODEL_TYPES]
418
  fig = bar_chart(
419
+ labels, preds_all[0],
 
420
  preds_all[0] - preds_all[0].std(),
421
  preds_all[0] + preds_all[0].std(),
422
+ "Seed x type predictions (fold-averaged)",
423
+ dark=dark_mode,
424
  )
425
  st.pyplot(fig, use_container_width=True)
426
  plt.close(fig)
 
433
  # ══════════════════════════════════════════════════════════════════════
434
  # MODE 2 β€” Batch screening
435
  # ══════════════════════════════════════════════════════════════════════
436
+ elif mode == "Batch screening (CSV)":
437
 
438
  st.subheader("Batch Screening")
439
  st.markdown("One protein, many compounds. Upload a CSV with a `smiles` column "
 
441
 
442
  col_seq, col_csv = st.columns(2)
443
  with col_seq:
444
+ batch_seq_raw = st.text_area("Target protein sequence", height=180,
445
+ placeholder="Paste UniProt sequence...")
446
  with col_csv:
447
  uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"])
448
  st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
 
450
  max_cpds = st.slider("Max compounds", 10, 500, 100,
451
  help="~1s per compound on CPU free tier.")
452
 
453
+ if st.button("Run batch screening", type="primary", use_container_width=True):
454
+ batch_seq, err = validate_sequence(batch_seq_raw)
455
+ if err:
456
+ st.error(err)
457
  elif uploaded is None:
458
  st.error("Please upload a CSV file.")
459
  else:
 
467
  names_list = (df_in['name'].tolist() if 'name' in df_in.columns
468
  else [f"cpd_{i}" for i in range(len(df_in))])
469
 
470
+ with st.spinner(f"Screening {len(smiles_list)} compounds..."):
 
471
  t0 = time.time()
472
  X, valid, esm_vec = extract_features(
473
+ batch_seq, smiles_list,
474
  tokenizer, esm_model, device, lig_scaler
475
  )
476
+ ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
 
 
 
 
 
477
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
478
+ lo, hi = uncertainty_interval(preds_all)
479
+ elapsed = time.time() - t0
480
 
481
  valid_names = [names_list[i] for i in range(len(names_list)) if valid[i]]
482
  valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
 
488
  'pKd_pred': np.round(preds, 3),
489
  'CI_lo': np.round(lo, 3),
490
  'CI_hi': np.round(hi, 3),
491
+ 'Ki_est': [format_ki(p) for p in preds],
492
  'model_std': np.round(preds_all.std(axis=1), 3),
493
+ 'AD': [ad_label] * len(valid_names),
494
  }).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
495
  results_df.insert(0, 'rank', range(1, len(results_df) + 1))
496
 
497
+ if ad_label == "OUT OF DOMAIN":
498
+ st.warning("Protein is outside the training distribution. "
499
+ "Predictions may be unreliable.")
500
+
501
  st.success(
502
  f"βœ“ {len(results_df)} compounds in {elapsed:.1f}s "
503
  f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
 
511
  top_df['pKd_pred'].values,
512
  top_df['CI_lo'].values,
513
  top_df['CI_hi'].values,
514
+ f"Top {top_n} hits",
515
+ dark=dark_mode,
516
  )
517
  st.pyplot(fig, use_container_width=True)
518
  plt.close(fig)
 
522
  use_container_width=True, height=400,
523
  )
524
  st.download_button(
525
+ "Download ranked CSV",
526
  results_df.to_csv(index=False).encode(),
527
  file_name="velobind_screening.csv",
528
  mime="text/csv",
 
532
  # ══════════════════════════════════════════════════════════════════════
533
  # MODE 3 β€” One compound vs. multiple targets
534
  # ══════════════════════════════════════════════════════════════════════
535
+ elif mode == "One compound vs. multiple targets":
536
 
537
  st.subheader("Selectivity Profiling")
538
  st.markdown("One SMILES, multiple proteins β€” ranked by predicted pKd. "
 
543
  multi_seqs = st.text_area(
544
  "Target proteins (one per line)",
545
  height=250,
546
+ placeholder="ABL1: MGPSENDPNLFVALY...\nEGFR: MRPSGTAGAALLALL...\nCDK2: MENFQKVEKIGEGTY...",
 
 
 
 
547
  )
548
 
549
+ if st.button("Run selectivity profiling", type="primary", use_container_width=True):
550
  if not multi_smi.strip() or not multi_seqs.strip():
551
  st.error("Please enter a SMILES and at least one protein sequence.")
552
  else:
553
  targets = {}
554
+ parse_errors = []
555
  for i, line in enumerate(multi_seqs.strip().splitlines()):
556
  line = line.strip()
557
  if not line:
558
  continue
559
  if ":" in line:
560
+ name, raw_seq = line.split(":", 1)
561
+ name = name.strip()
562
+ else:
563
+ name, raw_seq = f"Target_{i+1}", line
564
+ seq, err = validate_sequence(raw_seq)
565
+ if err:
566
+ parse_errors.append(f"{name}: {err}")
567
  else:
568
+ targets[name] = seq
569
 
570
+ if parse_errors:
571
+ for e in parse_errors:
572
+ st.warning(f"Skipped β€” {e}")
573
  if not targets:
574
+ st.error("No valid sequences found.")
575
  st.stop()
576
 
 
577
  results, progress = [], st.progress(0)
578
  for idx, (name, seq) in enumerate(targets.items()):
579
  try:
 
584
  if valid.any():
585
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
586
  lo, hi = uncertainty_interval(preds_all)
587
+ ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
588
  results.append({
589
  'Target': name,
590
  'pKd_pred': round(float(preds[0]), 3),
591
  'CI_lo': round(float(lo[0]), 3),
592
  'CI_hi': round(float(hi[0]), 3),
593
+ 'Ki_est': format_ki(float(preds[0])),
594
  'model_std': round(float(preds_all.std()), 3),
595
+ 'AD': ad_label,
596
  })
597
  except Exception as e:
598
  st.warning(f"Skipped {name}: {e}")
 
606
  )
607
  res_df.insert(0, 'rank', range(1, len(res_df) + 1))
608
 
609
+ st.success(f"Profiled {len(res_df)} targets.")
610
  fig = bar_chart(
611
  res_df['Target'].tolist(),
612
  res_df['pKd_pred'].values,
613
  res_df['CI_lo'].values,
614
  res_df['CI_hi'].values,
615
+ "Selectivity profile β€” predicted pKd by target",
616
+ dark=dark_mode,
617
  )
618
  st.pyplot(fig, use_container_width=True)
619
  plt.close(fig)
620
 
621
  st.dataframe(res_df, use_container_width=True)
622
  st.download_button(
623
+ "Download selectivity CSV",
624
  res_df.to_csv(index=False).encode(),
625
  file_name="velobind_selectivity.csv",
626
  mime="text/csv",
627
  )
628
 
629
+
630
  # ── Footer ────────────────────────────────────────────────────────────
631
  st.markdown("---")
632
+ st.markdown(f"""
633
+ <div style="color:{lab_color};font-size:0.8rem;text-align:center;padding:0.5rem">
634
+ VeloBind &nbsp;Β·&nbsp; Structure-free binding affinity &nbsp;Β·&nbsp;
635
+ ESM-2 + GBM ensemble &nbsp;Β·&nbsp;
636
+ Trained on LP-PDBBind &nbsp;Β·&nbsp;
637
+ Evaluated on CASF-2016/2013 &nbsp;Β·&nbsp;
638
+ <b>Not for clinical use.</b>
639
  </div>
640
  """, unsafe_allow_html=True)