ym59 commited on
Commit
b1223e2
Β·
verified Β·
1 Parent(s): 12e3a24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -245
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # app.py β€” VeloBind HF Spaces inference app
2
- import os, warnings, time, base64
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
@@ -18,7 +18,7 @@ MODEL_CACHE = Path("/tmp/velobind_models")
18
  SEEDS = [42, 123, 456]
19
  MODEL_TYPES = ["lgbm", "cb", "xgb"]
20
  N_FOLDS = 5
21
- VALID_AA = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwyX")
22
 
23
  import sys
24
  sys.path.append(str(Path(__file__).parent))
@@ -28,6 +28,13 @@ from src.models.ensemble import TargetScaler
28
  from src.config import config
29
 
30
 
 
 
 
 
 
 
 
31
  # ══════════════════════════════════════════════════════════════════════
32
  # Validation
33
  # ══════════════════════════════════════════════════════════════════════
@@ -35,12 +42,8 @@ def validate_sequence(raw: str):
35
  raw = raw.strip()
36
  if not raw:
37
  return None, "Please enter a sequence."
38
-
39
- # Strip FASTA header(s)
40
- lines = raw.splitlines()
41
- seq_lines = [l.strip() for l in lines if not l.startswith(">")]
42
- seq = "".join(seq_lines).upper().replace(" ", "")
43
-
44
  if len(seq) < 10:
45
  return None, "Sequence too short (minimum 10 residues)."
46
  invalid = set(seq) - VALID_AA
@@ -52,21 +55,18 @@ def validate_sequence(raw: str):
52
  # ══════════════════════════════════════════════════════════════════════
53
  # Model loading
54
  # ══════════════════════════════════════════════════════════════════════
55
- @st.cache_resource(show_spinner="Downloading and loading VeloBind models (first run ~30s)...")
56
  def load_all_models():
57
  from huggingface_hub import hf_hub_download
58
  MODEL_CACHE.mkdir(parents=True, exist_ok=True)
59
-
60
  model_files = (
61
  [f"fold_model_s{s}_{t}_f{f}.pkl"
62
  for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
63
  + ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
64
  )
65
-
66
  bar = st.progress(0, text="Loading models...")
67
  for i, fname in enumerate(model_files):
68
- local = MODEL_CACHE / fname
69
- if not local.exists():
70
  hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
71
  local_dir=str(MODEL_CACHE))
72
  bar.progress((i + 1) / len(model_files), text=f"Loading {fname}...")
@@ -80,14 +80,13 @@ def load_all_models():
80
  joblib.load(MODEL_CACHE / f"fold_model_s{s}_{t}_f{f}.pkl")
81
  for f in range(N_FOLDS)
82
  ]
83
-
84
  meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
85
  scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
86
  lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
87
  return fold_models, meta, scaler, lig_sc
88
 
89
 
90
- @st.cache_resource(show_spinner="Loading ESM-2 protein language model...")
91
  def load_esm_model():
92
  device = "cuda" if torch.cuda.is_available() else "cpu"
93
  tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
@@ -98,8 +97,7 @@ def load_esm_model():
98
  def load_ad_centroid():
99
  for p in [Path("output/models/deployment"), Path("output/models")]:
100
  if (p / "ad_centroid.npy").exists():
101
- return (np.load(p / "ad_centroid.npy"),
102
- float(np.load(p / "ad_threshold.npy")))
103
  for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
104
  local = MODEL_CACHE / fname
105
  if not local.exists():
@@ -109,33 +107,25 @@ def load_ad_centroid():
109
  local_dir=str(MODEL_CACHE))
110
  except Exception:
111
  return None, None
112
- return (np.load(MODEL_CACHE / "ad_centroid.npy"),
113
- float(np.load(MODEL_CACHE / "ad_threshold.npy")))
114
 
115
 
116
- def ad_check(esm_mean_vec, centroid, threshold):
117
  if centroid is None:
118
  return "UNKNOWN", float("nan")
119
- dist = float(np.linalg.norm(esm_mean_vec - centroid))
120
  return ("IN DOMAIN" if dist <= threshold else "OUT OF DOMAIN"), dist
121
 
122
 
123
  # ══════════════════════════════════════════════════════════════════════
124
- # Feature extraction
125
  # ══════════════════════════════════════════════════════════════════════
126
  def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
127
  return np.concatenate([
128
- esm_mean[:, -480:],
129
- seq_feat,
130
- lig_feats["ecfp"],
131
- lig_feats["ecfp2"],
132
- lig_feats["ecfp6"],
133
- lig_feats["fcfp"],
134
- lig_feats["estate"],
135
- lig_feats["maccs"],
136
- lig_feats["atom_pair"],
137
- lig_feats["torsion"],
138
- lig_feats["phys"],
139
  ], axis=1)
140
 
141
 
@@ -145,45 +135,34 @@ def extract_features(sequence, smiles_list, tokenizer, esm_model, device, lig_sc
145
  config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
146
  batch_size=1, device=device,
147
  )
148
- seq_feat = np.array([sequence_features(sequence)])
149
-
150
  lig_feats, valid_mask, _ = extract_ligand_features(
151
- smiles_list, scaler=lig_scaler, fit_scaler=False
152
- )
153
  valid_mask = np.array(valid_mask)
154
  if valid_mask.dtype != bool:
155
- bool_mask = np.zeros(len(smiles_list), dtype=bool)
156
- bool_mask[valid_mask] = True
157
- valid_mask = bool_mask
158
-
159
- n_valid = int(valid_mask.sum())
160
- esm_mean_t = np.tile(esm_mean, (n_valid, 1))
161
- esm_var_t = np.tile(esm_var, (n_valid, 1))
162
- esm_attn_t = np.tile(esm_attn, (n_valid, 1))
163
- seq_feat_t = np.tile(seq_feat, (n_valid, 1))
164
-
165
- X = assemble_from_parts(esm_mean_t, esm_var_t, esm_attn_t, seq_feat_t, lig_feats)
166
  return X, valid_mask, esm_mean[0]
167
 
168
 
169
- # ══════════════════════════════════════════════════════════════════════
170
- # Prediction
171
- # ══════════════════════════════════════════════════════════════════════
172
  def predict(X, fold_models, meta, scaler):
173
  type_avgs = []
174
  for s in SEEDS:
175
  for t in MODEL_TYPES:
176
- fold_preds = np.stack([
177
- scaler.inverse(fold_models[s][t][f].predict(X))
178
- for f in range(N_FOLDS)
179
- ], axis=1)
180
- type_avgs.append(fold_preds.mean(axis=1))
181
-
182
  preds_all = np.stack(type_avgs, axis=1)
183
- lgbm_avg = preds_all[:, [0, 3, 6]].mean(axis=1)
184
- cb_avg = preds_all[:, [1, 4, 7]].mean(axis=1)
185
- xgb_avg = preds_all[:, [2, 5, 8]].mean(axis=1)
186
- preds = meta.predict(np.column_stack([lgbm_avg, cb_avg, xgb_avg]))
 
187
  return preds, preds_all
188
 
189
 
@@ -193,73 +172,76 @@ def uncertainty_interval(preds_all, z=1.96):
193
 
194
 
195
  def format_ki(pkd):
196
- """Format Ki with appropriate unit (nM, uM, mM)."""
197
  ki_nM = 10 ** (9 - pkd)
198
- if ki_nM < 1000:
199
- return f"{ki_nM:.1f} nM"
200
- elif ki_nM < 1_000_000:
201
- return f"{ki_nM/1000:.2f} uM"
202
- else:
203
- return f"{ki_nM/1_000_000:.2f} mM"
204
 
205
 
206
  # ══════════════════════════════════���═══════════════════════════════════
207
- # Plots
208
  # ══════════════════════════════════════════════════════════════════════
209
  def bar_chart(names, preds, lo, hi, title, dark=True):
210
- bg = "#1e2a38" if dark else "#f8f9fa"
211
- fg = "#ffffff" if dark else "#111111"
212
- grid = "#2d3f55" if dark else "#cccccc"
213
-
214
- fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4),
215
- facecolor=bg)
216
  ax.set_facecolor(bg)
217
  x = np.arange(len(names))
218
  err = [preds - lo, hi - preds]
219
- bars = ax.bar(x, preds, color="#4C72B0", alpha=0.85, width=0.6,
220
  yerr=err, capsize=5, error_kw=dict(ecolor=fg, lw=1.5))
221
  ax.set_xticks(x)
222
  ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10, color=fg)
223
  ax.set_ylabel("Predicted pKd", fontsize=11, color=fg)
224
  ax.set_title(title, fontsize=12, fontweight='bold', color=fg)
225
  ax.tick_params(colors=fg)
226
- ax.spines[:].set_color(grid)
227
- ax.grid(True, axis='y', alpha=0.25, color=grid)
 
228
  for bar, val in zip(bars, preds):
229
  ax.text(bar.get_x() + bar.get_width() / 2,
230
  bar.get_height() + 0.05, f"{val:.2f}",
231
- ha='center', va='bottom', fontsize=9,
232
- fontweight='bold', color=fg)
233
  plt.tight_layout()
234
  return fig
235
 
236
 
237
  # ══════════════════════════════════════════════════════════════════════
238
- # Page setup
239
  # ══════════════════════════════════════════════════════════════════════
240
  st.set_page_config(page_title="VeloBind", layout="wide")
241
 
242
- # ── Theme toggle ──────────────────────────────────────────────────────
243
- with st.sidebar:
244
- st.markdown("### Display")
245
- dark_mode = st.toggle("Dark mode", value=True)
246
-
247
- if dark_mode:
248
- header_bg = "linear-gradient(135deg, #1a3a5c, #1e6091, #2980b9)"
249
- card_bg = "#1e2a38"
250
- card_border = "#2d3f55"
251
- val_color = "#4fc3f7"
252
- lab_color = "#aaa"
253
- page_bg = "#0e1117"
254
- text_color = "#ffffff"
255
  else:
256
- header_bg = "linear-gradient(135deg, #2980b9, #5dade2, #85c1e9)"
257
- card_bg = "#f0f4f8"
258
- card_border = "#b0c4de"
259
- val_color = "#1a5276"
260
- lab_color = "#555"
261
- page_bg = "#ffffff"
262
- text_color = "#111111"
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  def load_svg_b64(path):
265
  with open(path, "rb") as f:
@@ -269,93 +251,158 @@ logo_b64 = load_svg_b64("logo.svg")
269
 
270
  st.markdown(f"""
271
  <style>
272
- .stApp {{ background-color: {page_bg}; color: {text_color}; }}
273
- .header-wrap {{
274
- display: flex; align-items: center; gap: 1.5rem;
275
- margin-bottom: 1.5rem;
276
- }}
277
- .logo-box {{
278
- background: #ffffff; border-radius: 12px;
279
- padding: 0.75rem; flex-shrink: 0;
280
- }}
281
- .logo-box img {{ height: 130px; width: auto; display: block; }}
282
- .header-text {{
283
- background: {header_bg};
284
- padding: 1.5rem 2rem; border-radius: 12px; flex: 1;
285
- }}
286
- .header-text h1 {{ color: #fff; font-size: 2.2rem; margin: 0; }}
287
- .header-text p {{ color: #d6eaf8; margin: 0.3rem 0 0; font-size: 1rem; }}
288
- .metric-card {{
289
- background: {card_bg}; border: 1px solid {card_border};
290
- border-radius: 10px; padding: 1rem; text-align: center;
291
- }}
292
- .metric-val {{ font-size: 2rem; font-weight: 700; color: {val_color}; }}
293
- .metric-lab {{ font-size: 0.8rem; color: {lab_color}; margin-top: 0.2rem; }}
294
- .ad-in {{ background:#1b4332; border:1px solid #2d6a4f; color:#40916c;
295
- border-radius:8px; padding:0.4rem 1rem; font-weight:700; display:inline-block; }}
296
- .ad-out {{ background:#4a1c24; border:1px solid #9b2335; color:#e74c3c;
297
- border-radius:8px; padding:0.4rem 1rem; font-weight:700; display:inline-block; }}
298
- .ad-unk {{ background:#2d2d2d; border:1px solid #555; color:#aaa;
299
- border-radius:8px; padding:0.4rem 1rem; font-weight:700; display:inline-block; }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  </style>
301
- <div class="header-wrap">
302
- <div class="logo-box">
303
- <img src="data:image/svg+xml;base64,{logo_b64}" alt="VeloBind logo"/>
 
 
 
 
 
 
 
304
  </div>
305
- <div class="header-text">
306
  <h1>VeloBind</h1>
307
- <p>Structure-free protein-ligand binding affinity prediction &nbsp;Β·&nbsp;
308
- Sequence + SMILES only &nbsp;Β·&nbsp;
309
- Pearson R = 0.8469 on CASF-2016 &nbsp;Β·&nbsp;
310
- 45-model ensemble (LGBM + CatBoost + XGBoost)</p>
 
 
311
  </div>
312
  </div>
313
  """, unsafe_allow_html=True)
314
 
315
- # ── Load everything ───────────────────────────────────────────────────
 
 
 
 
 
 
 
316
  fold_models, meta, target_scaler, lig_scaler = load_all_models()
317
  tokenizer, esm_model, device = load_esm_model()
318
  ad_centroid, ad_threshold = load_ad_centroid()
319
  n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
320
- st.success(f"βœ“ {n_loaded} fold models loaded | Device: {device.upper()}")
321
-
322
- # ── Mode selector ─────────────────────────────────────────────────────
323
- mode = st.radio(
324
- "Select mode",
325
- ["Single query", "Batch screening (CSV)", "One compound vs. multiple targets"],
326
- horizontal=True,
327
- )
328
- st.markdown("---")
329
 
330
 
331
  # ══════════════════════════════════════════════════════════════════════
332
- # MODE 1 β€” Single query
333
  # ══════════════════════════════════════════════════════════════════════
334
- if mode == "Single query":
335
-
336
  col_p, col_l = st.columns(2)
337
  with col_p:
338
  st.subheader("Protein")
339
  seq_raw = st.text_area(
340
- "Amino acid sequence (single-letter FASTA, no header)",
341
- height=150,
342
- placeholder="MKTAYIAKQRQISFVK...",
343
- help="Only standard amino acid letters accepted (A C D E F G H I K L M N P Q R S T V W Y)."
 
344
  )
345
  with col_l:
346
  st.subheader("Ligand")
347
- smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O")
348
  examples = {
349
  "Aspirin": "CC(=O)Oc1ccccc1C(=O)O",
350
  "Imatinib": "Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1",
351
  "Gefitinib": "COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1",
352
  "Staurosporine": "C[C@@H]1CCCN2C(=O)c3[nH]c4ccccc4c3C2=N1",
353
  }
354
- chosen = st.selectbox("Load example SMILES", ["β€”"] + list(examples))
355
  if chosen != "β€”":
356
  smi = examples[chosen]
357
 
358
- if st.button("Predict", type="primary", use_container_width=True):
359
  seq, err = validate_sequence(seq_raw)
360
  if err:
361
  st.error(err)
@@ -366,18 +413,17 @@ if mode == "Single query":
366
  t0 = time.time()
367
  try:
368
  X, valid, esm_vec = extract_features(
369
- seq, [smi.strip()],
370
- tokenizer, esm_model, device, lig_scaler
371
- )
372
  if not valid.any():
373
- st.error("RDKit could not parse this SMILES. Please check the input.")
374
  else:
375
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
376
  lo, hi = uncertainty_interval(preds_all)
377
  elapsed = time.time() - t0
378
  pkd = float(preds[0])
 
379
 
380
- st.markdown("### Results")
381
  c1, c2, c3, c4 = st.columns(4)
382
  with c1:
383
  st.markdown(f"""<div class="metric-card">
@@ -387,25 +433,25 @@ if mode == "Single query":
387
  with c2:
388
  st.markdown(f"""<div class="metric-card">
389
  <div class="metric-val">[{lo[0]:.2f}, {hi[0]:.2f}]</div>
390
- <div class="metric-lab">95% model interval (Β±1.96Οƒ, 45 models)</div>
391
  </div>""", unsafe_allow_html=True)
392
  with c3:
393
  st.markdown(f"""<div class="metric-card">
394
  <div class="metric-val">{format_ki(pkd)}</div>
395
- <div class="metric-lab">Estimated Ki (pKd β‰ˆ pKi assumed)</div>
396
  </div>""", unsafe_allow_html=True)
397
  with c4:
398
- ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
399
  ad_cls = ("ad-in" if ad_label == "IN DOMAIN" else
400
  "ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk")
401
  st.markdown(f"""<div class="metric-card">
402
- <div class="{ad_cls}">{ad_label}</div>
403
- <div class="metric-lab">Applicability domain</div>
 
 
404
  </div>""", unsafe_allow_html=True)
405
 
406
  if ad_label == "OUT OF DOMAIN":
407
- st.warning("Protein is outside the training distribution. "
408
- "Predictions may be unreliable.")
409
 
410
  st.caption(
411
  f"Inference time: {elapsed:.2f}s | "
@@ -419,8 +465,8 @@ if mode == "Single query":
419
  labels, preds_all[0],
420
  preds_all[0] - preds_all[0].std(),
421
  preds_all[0] + preds_all[0].std(),
422
- "Seed x type predictions (fold-averaged)",
423
- dark=dark_mode,
424
  )
425
  st.pyplot(fig, use_container_width=True)
426
  plt.close(fig)
@@ -431,26 +477,26 @@ if mode == "Single query":
431
 
432
 
433
  # ══════════════════════════════════════════════════════════════════════
434
- # MODE 2 β€” Batch screening
435
  # ══════════════════════════════════════════════════════════════════════
436
- elif mode == "Batch screening (CSV)":
437
-
438
  st.subheader("Batch Screening")
439
- st.markdown("One protein, many compounds. Upload a CSV with a `smiles` column "
440
- "(and optionally `name`). Results are ranked by predicted pKd.")
 
441
 
442
  col_seq, col_csv = st.columns(2)
443
  with col_seq:
444
- batch_seq_raw = st.text_area("Target protein sequence", height=180,
445
- placeholder="Paste UniProt sequence...")
446
  with col_csv:
447
- uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"])
448
  st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
449
 
450
- max_cpds = st.slider("Max compounds", 10, 500, 100,
451
  help="~1s per compound on CPU free tier.")
452
 
453
- if st.button("Run batch screening", type="primary", use_container_width=True):
454
  batch_seq, err = validate_sequence(batch_seq_raw)
455
  if err:
456
  st.error(err)
@@ -470,9 +516,7 @@ elif mode == "Batch screening (CSV)":
470
  with st.spinner(f"Screening {len(smiles_list)} compounds..."):
471
  t0 = time.time()
472
  X, valid, esm_vec = extract_features(
473
- batch_seq, smiles_list,
474
- tokenizer, esm_model, device, lig_scaler
475
- )
476
  ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
477
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
478
  lo, hi = uncertainty_interval(preds_all)
@@ -482,6 +526,9 @@ elif mode == "Batch screening (CSV)":
482
  valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
483
  n_invalid = int((~valid).sum())
484
 
 
 
 
485
  results_df = pd.DataFrame({
486
  'name': valid_names,
487
  'smiles': valid_smiles,
@@ -494,25 +541,19 @@ elif mode == "Batch screening (CSV)":
494
  }).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
495
  results_df.insert(0, 'rank', range(1, len(results_df) + 1))
496
 
497
- if ad_label == "OUT OF DOMAIN":
498
- st.warning("Protein is outside the training distribution. "
499
- "Predictions may be unreliable.")
500
-
501
  st.success(
502
- f"βœ“ {len(results_df)} compounds in {elapsed:.1f}s "
503
  f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
504
  + (f" | {n_invalid} invalid SMILES skipped" if n_invalid else "")
505
  )
506
 
507
  top_n = min(20, len(results_df))
508
- top_df = results_df.head(top_n)
509
- fig = bar_chart(
510
- top_df['name'].tolist(),
511
- top_df['pKd_pred'].values,
512
- top_df['CI_lo'].values,
513
- top_df['CI_hi'].values,
514
- f"Top {top_n} hits",
515
- dark=dark_mode,
516
  )
517
  st.pyplot(fig, use_container_width=True)
518
  plt.close(fig)
@@ -524,52 +565,46 @@ elif mode == "Batch screening (CSV)":
524
  st.download_button(
525
  "Download ranked CSV",
526
  results_df.to_csv(index=False).encode(),
527
- file_name="velobind_screening.csv",
528
- mime="text/csv",
529
  )
530
 
531
 
532
  # ══════════════════════════════════════════════════════════════════════
533
- # MODE 3 β€” One compound vs. multiple targets
534
  # ══════════════════════════════════════════════════════════════════════
535
- elif mode == "One compound vs. multiple targets":
536
-
537
  st.subheader("Selectivity Profiling")
538
- st.markdown("One SMILES, multiple proteins β€” ranked by predicted pKd. "
539
- "Format: `TargetName: SEQUENCE` (name optional).")
 
540
 
541
- multi_smi = st.text_input("Compound SMILES",
542
- placeholder="Cc1ccc(NC(=O)...)cc1Nc1nccc(...)n1")
543
  multi_seqs = st.text_area(
544
  "Target proteins (one per line)",
545
  height=250,
546
  placeholder="ABL1: MGPSENDPNLFVALY...\nEGFR: MRPSGTAGAALLALL...\nCDK2: MENFQKVEKIGEGTY...",
 
547
  )
548
 
549
- if st.button("Run selectivity profiling", type="primary", use_container_width=True):
550
  if not multi_smi.strip() or not multi_seqs.strip():
551
  st.error("Please enter a SMILES and at least one protein sequence.")
552
  else:
553
- targets = {}
554
- parse_errors = []
555
  for i, line in enumerate(multi_seqs.strip().splitlines()):
556
  line = line.strip()
557
  if not line:
558
  continue
559
- if ":" in line:
560
- name, raw_seq = line.split(":", 1)
561
- name = name.strip()
562
- else:
563
- name, raw_seq = f"Target_{i+1}", line
564
- seq, err = validate_sequence(raw_seq)
565
  if err:
566
- parse_errors.append(f"{name}: {err}")
567
  else:
568
- targets[name] = seq
569
 
570
- if parse_errors:
571
- for e in parse_errors:
572
- st.warning(f"Skipped β€” {e}")
573
  if not targets:
574
  st.error("No valid sequences found.")
575
  st.stop()
@@ -578,18 +613,16 @@ elif mode == "One compound vs. multiple targets":
578
  for idx, (name, seq) in enumerate(targets.items()):
579
  try:
580
  X, valid, esm_vec = extract_features(
581
- seq, [multi_smi.strip()],
582
- tokenizer, esm_model, device, lig_scaler
583
- )
584
  if valid.any():
585
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
586
  lo, hi = uncertainty_interval(preds_all)
587
  ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
588
  results.append({
589
  'Target': name,
590
- 'pKd_pred': round(float(preds[0]), 3),
591
- 'CI_lo': round(float(lo[0]), 3),
592
- 'CI_hi': round(float(hi[0]), 3),
593
  'Ki_est': format_ki(float(preds[0])),
594
  'model_std': round(float(preds_all.std()), 3),
595
  'AD': ad_label,
@@ -599,21 +632,16 @@ elif mode == "One compound vs. multiple targets":
599
  progress.progress((idx + 1) / len(targets))
600
 
601
  progress.empty()
602
- res_df = (
603
- pd.DataFrame(results)
604
- .sort_values('pKd_pred', ascending=False)
605
- .reset_index(drop=True)
606
- )
607
  res_df.insert(0, 'rank', range(1, len(res_df) + 1))
608
 
609
  st.success(f"Profiled {len(res_df)} targets.")
610
  fig = bar_chart(
611
- res_df['Target'].tolist(),
612
- res_df['pKd_pred'].values,
613
- res_df['CI_lo'].values,
614
- res_df['CI_hi'].values,
615
- "Selectivity profile β€” predicted pKd by target",
616
- dark=dark_mode,
617
  )
618
  st.pyplot(fig, use_container_width=True)
619
  plt.close(fig)
@@ -622,19 +650,18 @@ elif mode == "One compound vs. multiple targets":
622
  st.download_button(
623
  "Download selectivity CSV",
624
  res_df.to_csv(index=False).encode(),
625
- file_name="velobind_selectivity.csv",
626
- mime="text/csv",
627
  )
628
 
629
 
630
  # ── Footer ────────────────────────────────────────────────────────────
631
  st.markdown("---")
632
  st.markdown(f"""
633
- <div style="color:{lab_color};font-size:0.8rem;text-align:center;padding:0.5rem">
634
- VeloBind &nbsp;Β·&nbsp; Structure-free binding affinity &nbsp;Β·&nbsp;
635
- ESM-2 + GBM ensemble &nbsp;Β·&nbsp;
636
- Trained on LP-PDBBind &nbsp;Β·&nbsp;
637
- Evaluated on CASF-2016/2013 &nbsp;Β·&nbsp;
638
- <b>Not for clinical use.</b>
639
  </div>
640
  """, unsafe_allow_html=True)
 
1
  # app.py β€” VeloBind HF Spaces inference app
2
+ import warnings, time, base64
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
 
18
  SEEDS = [42, 123, 456]
19
  MODEL_TYPES = ["lgbm", "cb", "xgb"]
20
  N_FOLDS = 5
21
+ VALID_AA = set("ACDEFGHIKLMNPQRSTVWYX")
22
 
23
  import sys
24
  sys.path.append(str(Path(__file__).parent))
 
28
  from src.config import config
29
 
30
 
31
+ # ══════════════════════════════════════════════════════════════════════
32
+ # Session state β€” theme
33
+ # ══════════════════════════════════════════════════════════════════════
34
+ if "dark_mode" not in st.session_state:
35
+ st.session_state.dark_mode = True
36
+
37
+
38
  # ══════════════════════════════════════════════════════════════════════
39
  # Validation
40
  # ══════════════════════════════════════════════════════════════════════
 
42
  raw = raw.strip()
43
  if not raw:
44
  return None, "Please enter a sequence."
45
+ lines = raw.splitlines()
46
+ seq = "".join(l.strip() for l in lines if not l.startswith(">")).upper().replace(" ", "")
 
 
 
 
47
  if len(seq) < 10:
48
  return None, "Sequence too short (minimum 10 residues)."
49
  invalid = set(seq) - VALID_AA
 
55
  # ══════════════════════════════════════════════════════════════════════
56
  # Model loading
57
  # ══════════════════════════════════════════════════════════════════════
58
+ @st.cache_resource(show_spinner="Loading VeloBind models (first run ~30s)...")
59
  def load_all_models():
60
  from huggingface_hub import hf_hub_download
61
  MODEL_CACHE.mkdir(parents=True, exist_ok=True)
 
62
  model_files = (
63
  [f"fold_model_s{s}_{t}_f{f}.pkl"
64
  for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
65
  + ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
66
  )
 
67
  bar = st.progress(0, text="Loading models...")
68
  for i, fname in enumerate(model_files):
69
+ if not (MODEL_CACHE / fname).exists():
 
70
  hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
71
  local_dir=str(MODEL_CACHE))
72
  bar.progress((i + 1) / len(model_files), text=f"Loading {fname}...")
 
80
  joblib.load(MODEL_CACHE / f"fold_model_s{s}_{t}_f{f}.pkl")
81
  for f in range(N_FOLDS)
82
  ]
 
83
  meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
84
  scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
85
  lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
86
  return fold_models, meta, scaler, lig_sc
87
 
88
 
89
+ @st.cache_resource(show_spinner="Loading ESM-2...")
90
  def load_esm_model():
91
  device = "cuda" if torch.cuda.is_available() else "cpu"
92
  tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
 
97
  def load_ad_centroid():
98
  for p in [Path("output/models/deployment"), Path("output/models")]:
99
  if (p / "ad_centroid.npy").exists():
100
+ return np.load(p / "ad_centroid.npy"), float(np.load(p / "ad_threshold.npy"))
 
101
  for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
102
  local = MODEL_CACHE / fname
103
  if not local.exists():
 
107
  local_dir=str(MODEL_CACHE))
108
  except Exception:
109
  return None, None
110
+ return np.load(MODEL_CACHE / "ad_centroid.npy"), float(np.load(MODEL_CACHE / "ad_threshold.npy"))
 
111
 
112
 
113
+ def ad_check(esm_vec, centroid, threshold):
114
  if centroid is None:
115
  return "UNKNOWN", float("nan")
116
+ dist = float(np.linalg.norm(esm_vec - centroid))
117
  return ("IN DOMAIN" if dist <= threshold else "OUT OF DOMAIN"), dist
118
 
119
 
120
  # ══════════════════════════════════════════════════════════════════════
121
+ # Features + prediction
122
  # ══════════════════════════════════════════════════════════════════════
123
  def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
124
  return np.concatenate([
125
+ esm_mean[:, -480:], seq_feat,
126
+ lig_feats["ecfp"], lig_feats["ecfp2"], lig_feats["ecfp6"], lig_feats["fcfp"],
127
+ lig_feats["estate"], lig_feats["maccs"], lig_feats["atom_pair"],
128
+ lig_feats["torsion"], lig_feats["phys"],
 
 
 
 
 
 
 
129
  ], axis=1)
130
 
131
 
 
135
  config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
136
  batch_size=1, device=device,
137
  )
138
+ seq_feat = np.array([sequence_features(sequence)])
 
139
  lig_feats, valid_mask, _ = extract_ligand_features(
140
+ smiles_list, scaler=lig_scaler, fit_scaler=False)
 
141
  valid_mask = np.array(valid_mask)
142
  if valid_mask.dtype != bool:
143
+ bm = np.zeros(len(smiles_list), dtype=bool)
144
+ bm[valid_mask] = True
145
+ valid_mask = bm
146
+ n = int(valid_mask.sum())
147
+ X = assemble_from_parts(
148
+ np.tile(esm_mean, (n, 1)), np.tile(esm_var, (n, 1)),
149
+ np.tile(esm_attn, (n, 1)), np.tile(seq_feat, (n, 1)), lig_feats)
 
 
 
 
150
  return X, valid_mask, esm_mean[0]
151
 
152
 
 
 
 
153
  def predict(X, fold_models, meta, scaler):
154
  type_avgs = []
155
  for s in SEEDS:
156
  for t in MODEL_TYPES:
157
+ fp = np.stack([scaler.inverse(fold_models[s][t][f].predict(X))
158
+ for f in range(N_FOLDS)], axis=1)
159
+ type_avgs.append(fp.mean(axis=1))
 
 
 
160
  preds_all = np.stack(type_avgs, axis=1)
161
+ preds = meta.predict(np.column_stack([
162
+ preds_all[:, [0,3,6]].mean(1),
163
+ preds_all[:, [1,4,7]].mean(1),
164
+ preds_all[:, [2,5,8]].mean(1),
165
+ ]))
166
  return preds, preds_all
167
 
168
 
 
172
 
173
 
174
  def format_ki(pkd):
 
175
  ki_nM = 10 ** (9 - pkd)
176
+ if ki_nM < 1000: return f"{ki_nM:.1f} nM"
177
+ elif ki_nM < 1_000_000: return f"{ki_nM/1000:.2f} uM"
178
+ else: return f"{ki_nM/1_000_000:.2f} mM"
 
 
 
179
 
180
 
181
  # ══════════════════════════════════���═══════════════════════════════════
182
+ # Plot
183
  # ══════════════════════════════════════════════════════════════════════
184
  def bar_chart(names, preds, lo, hi, title, dark=True):
185
+ bg, fg, gc = ("#1a2332", "#e8edf2", "#2d3f55") if dark else ("#f8fafc", "#1a202c", "#cbd5e0")
186
+ fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4), facecolor=bg)
 
 
 
 
187
  ax.set_facecolor(bg)
188
  x = np.arange(len(names))
189
  err = [preds - lo, hi - preds]
190
+ bars = ax.bar(x, preds, color="#3b82f6", alpha=0.9, width=0.6,
191
  yerr=err, capsize=5, error_kw=dict(ecolor=fg, lw=1.5))
192
  ax.set_xticks(x)
193
  ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10, color=fg)
194
  ax.set_ylabel("Predicted pKd", fontsize=11, color=fg)
195
  ax.set_title(title, fontsize=12, fontweight='bold', color=fg)
196
  ax.tick_params(colors=fg)
197
+ for spine in ax.spines.values():
198
+ spine.set_color(gc)
199
+ ax.grid(True, axis='y', alpha=0.3, color=gc)
200
  for bar, val in zip(bars, preds):
201
  ax.text(bar.get_x() + bar.get_width() / 2,
202
  bar.get_height() + 0.05, f"{val:.2f}",
203
+ ha='center', va='bottom', fontsize=9, fontweight='bold', color=fg)
 
204
  plt.tight_layout()
205
  return fig
206
 
207
 
208
  # ══════════════════════════════════════════════════════════════════════
209
+ # Page layout
210
  # ══════════════════════════════════════════════════════════════════════
211
  st.set_page_config(page_title="VeloBind", layout="wide")
212
 
213
+ dark = st.session_state.dark_mode
214
+
215
+ # ── Theme-aware CSS (only custom elements, never .stApp) ──────────────
216
+ if dark:
217
+ card_bg, card_border = "#1a2332", "#2d4a6b"
218
+ val_col, lab_col = "#60a5fa", "#94a3b8"
219
+ banner_grad = "linear-gradient(135deg, #0f172a 0%, #1e3a5f 50%, #1e40af 100%)"
220
+ banner_sub = "#93c5fd"
221
+ logo_bg = "rgba(255,255,255,0.12)"
222
+ logo_border = "rgba(255,255,255,0.2)"
223
+ toggle_bg = "#1e3a5f"
224
+ toggle_knob = "#60a5fa"
225
+ toggle_label = "#93c5fd"
226
  else:
227
+ card_bg, card_border = "#f0f7ff", "#bfdbfe"
228
+ val_col, lab_col = "#1d4ed8", "#475569"
229
+ banner_grad = "linear-gradient(135deg, #1d4ed8 0%, #2563eb 50%, #3b82f6 100%)"
230
+ banner_sub = "#dbeafe"
231
+ logo_bg = "rgba(255,255,255,0.85)"
232
+ logo_border = "rgba(255,255,255,0.9)"
233
+ toggle_bg = "#93c5fd"
234
+ toggle_knob = "#1d4ed8"
235
+ toggle_label = "#dbeafe"
236
+
237
+ ad_css = """
238
+ .ad-in { background:#064e3b; border:1px solid #059669; color:#34d399;
239
+ border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
240
+ .ad-out { background:#450a0a; border:1px solid #dc2626; color:#f87171;
241
+ border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
242
+ .ad-unk { background:#1e293b; border:1px solid #475569; color:#94a3b8;
243
+ border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
244
+ """
245
 
246
  def load_svg_b64(path):
247
  with open(path, "rb") as f:
 
251
 
252
  st.markdown(f"""
253
  <style>
254
+ {ad_css}
255
+ .vb-banner {{
256
+ background: {banner_grad};
257
+ border-radius: 16px;
258
+ padding: 1.2rem 1.8rem;
259
+ display: flex;
260
+ align-items: center;
261
+ gap: 1.5rem;
262
+ margin-bottom: 0.5rem;
263
+ box-shadow: 0 4px 24px rgba(0,0,0,0.18);
264
+ position: relative;
265
+ }}
266
+ .vb-logo-wrap {{
267
+ background: {logo_bg};
268
+ border: 1px solid {logo_border};
269
+ border-radius: 14px;
270
+ padding: 0.6rem;
271
+ backdrop-filter: blur(8px);
272
+ flex-shrink: 0;
273
+ }}
274
+ .vb-logo-wrap img {{
275
+ height: 110px;
276
+ width: auto;
277
+ display: block;
278
+ }}
279
+ .vb-title-wrap {{
280
+ flex: 1;
281
+ }}
282
+ .vb-title-wrap h1 {{
283
+ color: #ffffff;
284
+ font-size: 2.4rem;
285
+ font-weight: 800;
286
+ margin: 0 0 0.3rem 0;
287
+ letter-spacing: -0.5px;
288
+ }}
289
+ .vb-title-wrap p {{
290
+ color: {banner_sub};
291
+ font-size: 0.92rem;
292
+ margin: 0;
293
+ line-height: 1.6;
294
+ }}
295
+ .vb-toggle-wrap {{
296
+ position: absolute;
297
+ top: 1rem;
298
+ right: 1.2rem;
299
+ display: flex;
300
+ align-items: center;
301
+ gap: 0.5rem;
302
+ }}
303
+ .vb-toggle-label {{
304
+ color: {toggle_label};
305
+ font-size: 0.78rem;
306
+ font-weight: 600;
307
+ letter-spacing: 0.03em;
308
+ }}
309
+ .metric-card {{
310
+ background: {card_bg};
311
+ border: 1px solid {card_border};
312
+ border-radius: 12px;
313
+ padding: 1.1rem;
314
+ text-align: center;
315
+ transition: box-shadow 0.2s;
316
+ }}
317
+ .metric-card:hover {{
318
+ box-shadow: 0 4px 16px rgba(59,130,246,0.15);
319
+ }}
320
+ .metric-val {{
321
+ font-size: 1.9rem;
322
+ font-weight: 700;
323
+ color: {val_col};
324
+ line-height: 1.2;
325
+ }}
326
+ .metric-lab {{
327
+ font-size: 0.75rem;
328
+ color: {lab_col};
329
+ margin-top: 0.35rem;
330
+ line-height: 1.4;
331
+ }}
332
  </style>
333
+ """, unsafe_allow_html=True)
334
+
335
+ # ── Banner ────────────────────────────────────────────────────────────
336
+ toggle_icon = "β˜€" if dark else "☾"
337
+ toggle_text = "Light mode" if dark else "Dark mode"
338
+
339
+ st.markdown(f"""
340
+ <div class="vb-banner">
341
+ <div class="vb-logo-wrap">
342
+ <img src="data:image/svg+xml;base64,{logo_b64}" alt="VeloBind"/>
343
  </div>
344
+ <div class="vb-title-wrap">
345
  <h1>VeloBind</h1>
346
+ <p>
347
+ Structure-free protein-ligand binding affinity prediction &nbsp;&middot;&nbsp;
348
+ Sequence + SMILES &nbsp;&middot;&nbsp;
349
+ Pearson R = 0.8469 on CASF-2016 &nbsp;&middot;&nbsp;
350
+ 45-model ensemble (LGBM + CatBoost + XGBoost)
351
+ </p>
352
  </div>
353
  </div>
354
  """, unsafe_allow_html=True)
355
 
356
+ # Theme toggle β€” just below banner, right-aligned
357
+ _, tcol = st.columns([6, 1])
358
+ with tcol:
359
+ if st.button(f"{toggle_icon} {toggle_text}", use_container_width=True):
360
+ st.session_state.dark_mode = not st.session_state.dark_mode
361
+ st.rerun()
362
+
363
+ # ── Load models ───────────────────────────────────────────────────────
364
  fold_models, meta, target_scaler, lig_scaler = load_all_models()
365
  tokenizer, esm_model, device = load_esm_model()
366
  ad_centroid, ad_threshold = load_ad_centroid()
367
  n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
368
+ st.success(f"{n_loaded} fold models loaded | Device: {device.upper()}")
369
+
370
+ # ── Mode tabs ─────────────────────────────────────────────────────────
371
+ tab1, tab2, tab3 = st.tabs([
372
+ "Single query",
373
+ "Batch screening (CSV)",
374
+ "One compound vs. multiple targets",
375
+ ])
 
376
 
377
 
378
  # ══════════════════════════════════════════════════════════════════════
379
+ # TAB 1 β€” Single query
380
  # ══════════════════════════════════════════════════════════════════════
381
+ with tab1:
 
382
  col_p, col_l = st.columns(2)
383
  with col_p:
384
  st.subheader("Protein")
385
  seq_raw = st.text_area(
386
+ "Amino acid sequence (plain or FASTA format)",
387
+ height=160,
388
+ placeholder=">ProteinName\nMKTAYIAKQRQISFVK...",
389
+ help="Plain sequence or FASTA with >header line. Only standard amino acid letters (A-Z subset).",
390
+ key="sq_seq"
391
  )
392
  with col_l:
393
  st.subheader("Ligand")
394
+ smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O", key="sq_smi")
395
  examples = {
396
  "Aspirin": "CC(=O)Oc1ccccc1C(=O)O",
397
  "Imatinib": "Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1",
398
  "Gefitinib": "COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1",
399
  "Staurosporine": "C[C@@H]1CCCN2C(=O)c3[nH]c4ccccc4c3C2=N1",
400
  }
401
+ chosen = st.selectbox("Load example SMILES", ["β€”"] + list(examples), key="sq_ex")
402
  if chosen != "β€”":
403
  smi = examples[chosen]
404
 
405
+ if st.button("Predict", type="primary", use_container_width=True, key="sq_btn"):
406
  seq, err = validate_sequence(seq_raw)
407
  if err:
408
  st.error(err)
 
413
  t0 = time.time()
414
  try:
415
  X, valid, esm_vec = extract_features(
416
+ seq, [smi.strip()], tokenizer, esm_model, device, lig_scaler)
 
 
417
  if not valid.any():
418
+ st.error("RDKit could not parse this SMILES.")
419
  else:
420
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
421
  lo, hi = uncertainty_interval(preds_all)
422
  elapsed = time.time() - t0
423
  pkd = float(preds[0])
424
+ ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
425
 
426
+ st.markdown("#### Results")
427
  c1, c2, c3, c4 = st.columns(4)
428
  with c1:
429
  st.markdown(f"""<div class="metric-card">
 
433
  with c2:
434
  st.markdown(f"""<div class="metric-card">
435
  <div class="metric-val">[{lo[0]:.2f}, {hi[0]:.2f}]</div>
436
+ <div class="metric-lab">95% model interval<br>(Β±1.96Οƒ Β· 45 models)</div>
437
  </div>""", unsafe_allow_html=True)
438
  with c3:
439
  st.markdown(f"""<div class="metric-card">
440
  <div class="metric-val">{format_ki(pkd)}</div>
441
+ <div class="metric-lab">Estimated Ki<br>(pKd β‰ˆ pKi assumed)</div>
442
  </div>""", unsafe_allow_html=True)
443
  with c4:
 
444
  ad_cls = ("ad-in" if ad_label == "IN DOMAIN" else
445
  "ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk")
446
  st.markdown(f"""<div class="metric-card">
447
+ <div style="padding-top:0.4rem">
448
+ <span class="{ad_cls}">{ad_label}</span>
449
+ </div>
450
+ <div class="metric-lab" style="margin-top:0.6rem">Applicability domain</div>
451
  </div>""", unsafe_allow_html=True)
452
 
453
  if ad_label == "OUT OF DOMAIN":
454
+ st.warning("Protein is outside the training distribution. Predictions may be unreliable.")
 
455
 
456
  st.caption(
457
  f"Inference time: {elapsed:.2f}s | "
 
465
  labels, preds_all[0],
466
  preds_all[0] - preds_all[0].std(),
467
  preds_all[0] + preds_all[0].std(),
468
+ "Per-seed and model-type predictions (fold-averaged)",
469
+ dark=dark,
470
  )
471
  st.pyplot(fig, use_container_width=True)
472
  plt.close(fig)
 
477
 
478
 
479
  # ══════════════════════════════════════════════════════════════════════
480
+ # TAB 2 β€” Batch screening
481
  # ══════════════════════════════════════════════════════════════════════
482
+ with tab2:
 
483
  st.subheader("Batch Screening")
484
+ st.markdown("Screen a library of compounds against one target. "
485
+ "Upload a CSV with a `smiles` column (and optionally `name`). "
486
+ "Results are ranked by predicted pKd.")
487
 
488
  col_seq, col_csv = st.columns(2)
489
  with col_seq:
490
+ batch_seq_raw = st.text_area("Target protein sequence (plain or FASTA)", height=180,
491
+ placeholder=">Target\nMKTAYIAKQRQISFVK...", key="bs_seq")
492
  with col_csv:
493
+ uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"], key="bs_up")
494
  st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
495
 
496
+ max_cpds = st.slider("Max compounds", 10, 500, 100, key="bs_max",
497
  help="~1s per compound on CPU free tier.")
498
 
499
+ if st.button("Run batch screening", type="primary", use_container_width=True, key="bs_btn"):
500
  batch_seq, err = validate_sequence(batch_seq_raw)
501
  if err:
502
  st.error(err)
 
516
  with st.spinner(f"Screening {len(smiles_list)} compounds..."):
517
  t0 = time.time()
518
  X, valid, esm_vec = extract_features(
519
+ batch_seq, smiles_list, tokenizer, esm_model, device, lig_scaler)
 
 
520
  ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
521
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
522
  lo, hi = uncertainty_interval(preds_all)
 
526
  valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
527
  n_invalid = int((~valid).sum())
528
 
529
+ if ad_label == "OUT OF DOMAIN":
530
+ st.warning("Protein is outside the training distribution. Predictions may be unreliable.")
531
+
532
  results_df = pd.DataFrame({
533
  'name': valid_names,
534
  'smiles': valid_smiles,
 
541
  }).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
542
  results_df.insert(0, 'rank', range(1, len(results_df) + 1))
543
 
 
 
 
 
544
  st.success(
545
+ f"{len(results_df)} compounds screened in {elapsed:.1f}s "
546
  f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
547
  + (f" | {n_invalid} invalid SMILES skipped" if n_invalid else "")
548
  )
549
 
550
  top_n = min(20, len(results_df))
551
+ fig = bar_chart(
552
+ results_df.head(top_n)['name'].tolist(),
553
+ results_df.head(top_n)['pKd_pred'].values,
554
+ results_df.head(top_n)['CI_lo'].values,
555
+ results_df.head(top_n)['CI_hi'].values,
556
+ f"Top {top_n} hits by predicted pKd", dark=dark,
 
 
557
  )
558
  st.pyplot(fig, use_container_width=True)
559
  plt.close(fig)
 
565
  st.download_button(
566
  "Download ranked CSV",
567
  results_df.to_csv(index=False).encode(),
568
+ file_name="velobind_screening.csv", mime="text/csv",
 
569
  )
570
 
571
 
572
  # ══════════════════════════════════════════════════════════════════════
573
+ # TAB 3 β€” Selectivity profiling
574
  # ══════════════════════════════════════════════════════════════════════
575
+ with tab3:
 
576
  st.subheader("Selectivity Profiling")
577
+ st.markdown("One compound, multiple targets β€” ranked by predicted pKd. "
578
+ "Format: `TargetName: SEQUENCE` (name optional). "
579
+ "Accepts plain sequence or FASTA per line.")
580
 
581
+ multi_smi = st.text_input("Compound SMILES", placeholder="Cc1ccc(...)cc1Nc1nccc(...)n1", key="sp_smi")
 
582
  multi_seqs = st.text_area(
583
  "Target proteins (one per line)",
584
  height=250,
585
  placeholder="ABL1: MGPSENDPNLFVALY...\nEGFR: MRPSGTAGAALLALL...\nCDK2: MENFQKVEKIGEGTY...",
586
+ key="sp_seqs",
587
  )
588
 
589
+ if st.button("Run selectivity profiling", type="primary", use_container_width=True, key="sp_btn"):
590
  if not multi_smi.strip() or not multi_seqs.strip():
591
  st.error("Please enter a SMILES and at least one protein sequence.")
592
  else:
593
+ targets, parse_errors = {}, []
 
594
  for i, line in enumerate(multi_seqs.strip().splitlines()):
595
  line = line.strip()
596
  if not line:
597
  continue
598
+ name, raw_seq = (line.split(":", 1) if ":" in line
599
+ else (f"Target_{i+1}", line))
600
+ seq, err = validate_sequence(raw_seq if isinstance(raw_seq, str) else raw_seq)
 
 
 
601
  if err:
602
+ parse_errors.append(f"{name.strip()}: {err}")
603
  else:
604
+ targets[name.strip()] = seq
605
 
606
+ for e in parse_errors:
607
+ st.warning(f"Skipped β€” {e}")
 
608
  if not targets:
609
  st.error("No valid sequences found.")
610
  st.stop()
 
613
  for idx, (name, seq) in enumerate(targets.items()):
614
  try:
615
  X, valid, esm_vec = extract_features(
616
+ seq, [multi_smi.strip()], tokenizer, esm_model, device, lig_scaler)
 
 
617
  if valid.any():
618
  preds, preds_all = predict(X, fold_models, meta, target_scaler)
619
  lo, hi = uncertainty_interval(preds_all)
620
  ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
621
  results.append({
622
  'Target': name,
623
+ 'pKd_pred': round(float(preds[0]), 3),
624
+ 'CI_lo': round(float(lo[0]), 3),
625
+ 'CI_hi': round(float(hi[0]), 3),
626
  'Ki_est': format_ki(float(preds[0])),
627
  'model_std': round(float(preds_all.std()), 3),
628
  'AD': ad_label,
 
632
  progress.progress((idx + 1) / len(targets))
633
 
634
  progress.empty()
635
+ res_df = (pd.DataFrame(results)
636
+ .sort_values('pKd_pred', ascending=False)
637
+ .reset_index(drop=True))
 
 
638
  res_df.insert(0, 'rank', range(1, len(res_df) + 1))
639
 
640
  st.success(f"Profiled {len(res_df)} targets.")
641
  fig = bar_chart(
642
+ res_df['Target'].tolist(), res_df['pKd_pred'].values,
643
+ res_df['CI_lo'].values, res_df['CI_hi'].values,
644
+ "Selectivity profile β€” predicted pKd by target", dark=dark,
 
 
 
645
  )
646
  st.pyplot(fig, use_container_width=True)
647
  plt.close(fig)
 
650
  st.download_button(
651
  "Download selectivity CSV",
652
  res_df.to_csv(index=False).encode(),
653
+ file_name="velobind_selectivity.csv", mime="text/csv",
 
654
  )
655
 
656
 
657
  # ── Footer ────────────────────────────────────────────────────────────
658
  st.markdown("---")
659
  st.markdown(f"""
660
+ <div style="color:{lab_col};font-size:0.78rem;text-align:center;padding:0.4rem 0 0.8rem">
661
+ VeloBind &nbsp;&middot;&nbsp; Structure-free binding affinity &nbsp;&middot;&nbsp;
662
+ ESM-2 + gradient-boosted ensemble &nbsp;&middot;&nbsp;
663
+ Trained on LP-PDBBind &nbsp;&middot;&nbsp;
664
+ Evaluated on CASF-2016 and CASF-2013 &nbsp;&middot;&nbsp;
665
+ <b>Not for clinical use</b>
666
  </div>
667
  """, unsafe_allow_html=True)