ym59 commited on
Commit
8a8b83b
Β·
verified Β·
1 Parent(s): fe7ed04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +728 -611
app.py CHANGED
@@ -1,18 +1,21 @@
1
- # app.py β€” VeloBind HF Spaces inference app
2
- import warnings, time, base64
 
 
 
3
  import numpy as np
4
  import pandas as pd
5
- import streamlit as st
6
  import joblib
7
  import torch
8
- import matplotlib.pyplot as plt
9
  from pathlib import Path
 
 
10
 
11
  warnings.filterwarnings("ignore")
12
  from rdkit import RDLogger
13
  RDLogger.DisableLog('rdApp.*')
14
 
15
- # ── Config ────────────────────────────────────────────────────────────
16
  HF_MODEL_REPO = "ym59/velobind-models"
17
  MODEL_CACHE = Path("/tmp/velobind_models")
18
  SEEDS = [42, 123, 456]
@@ -27,100 +30,94 @@ from src.features.ligand import extract_ligand_features
27
  from src.models.ensemble import TargetScaler
28
  from src.config import config
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # ══════════════════════════════════════════════════════════════════════
32
- # Session state β€” theme
33
- # ══════════════════════════════════════════════════════════════════════
34
- if "dark_mode" not in st.session_state:
35
- st.session_state.dark_mode = True
36
-
37
-
38
- # ══════════════════════════════════════════════════════════════════════
39
- # Validation
40
- # ══════════════════════════════════════════════════════════════════════
41
- def validate_sequence(raw: str):
42
- raw = raw.strip()
43
- if not raw:
44
- return None, "Please enter a sequence."
45
- lines = raw.splitlines()
46
- seq = "".join(l.strip() for l in lines if not l.startswith(">")).upper().replace(" ", "")
47
- if len(seq) < 10:
48
- return None, "Sequence too short (minimum 10 residues)."
49
- invalid = set(seq) - VALID_AA
50
- if invalid:
51
- return None, f"Invalid characters: {', '.join(sorted(invalid))}. Only standard amino acid letters accepted."
52
- return seq, None
53
-
54
-
55
- # ══════════════════════════════════════════════════════════════════════
56
- # Model loading
57
- # ══════════════════════════════════════════════════════════════════════
58
- @st.cache_resource(show_spinner="Loading VeloBind models (first run ~30s)...")
59
- def load_all_models():
60
- from huggingface_hub import hf_hub_download
61
- MODEL_CACHE.mkdir(parents=True, exist_ok=True)
62
- model_files = (
63
- [f"fold_model_s{s}_{t}_f{f}.pkl"
64
- for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
65
- + ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
66
- )
67
- bar = st.progress(0, text="Loading models...")
68
- for i, fname in enumerate(model_files):
69
- if not (MODEL_CACHE / fname).exists():
70
- hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
71
- local_dir=str(MODEL_CACHE))
72
- bar.progress((i + 1) / len(model_files), text=f"Loading {fname}...")
73
- bar.empty()
74
-
75
- fold_models = {}
76
  for s in SEEDS:
77
- fold_models[s] = {}
78
  for t in MODEL_TYPES:
79
- fold_models[s][t] = [
80
- joblib.load(MODEL_CACHE / f"fold_model_s{s}_{t}_f{f}.pkl")
81
  for f in range(N_FOLDS)
82
  ]
83
- meta = joblib.load(MODEL_CACHE / "meta_type_casf16.pkl")
84
- scaler = joblib.load(MODEL_CACHE / "target_scaler.pkl")
85
- lig_sc = joblib.load(MODEL_CACHE / "ligand_scaler.pkl")
86
- return fold_models, meta, scaler, lig_sc
87
-
88
-
89
- @st.cache_resource(show_spinner="Loading ESM-2...")
90
- def load_esm_model():
91
- device = "cuda" if torch.cuda.is_available() else "cpu"
92
- tokenizer, esm_model = load_esm(config.ESM_MODEL, device)
93
- return tokenizer, esm_model, device
94
-
95
-
96
- @st.cache_resource(show_spinner=False)
97
- def load_ad_centroid():
98
- for p in [Path("output/models/deployment"), Path("output/models")]:
 
 
 
 
99
  if (p / "ad_centroid.npy").exists():
100
- return np.load(p / "ad_centroid.npy"), float(np.load(p / "ad_threshold.npy"))
101
- for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
102
- local = MODEL_CACHE / fname
103
- if not local.exists():
104
- try:
105
- from huggingface_hub import hf_hub_download
 
106
  hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
107
  local_dir=str(MODEL_CACHE))
108
- except Exception:
109
- return None, None
110
- return np.load(MODEL_CACHE / "ad_centroid.npy"), float(np.load(MODEL_CACHE / "ad_threshold.npy"))
 
 
111
 
112
 
113
- def ad_check(esm_vec, centroid, threshold):
114
- if centroid is None:
115
- return "UNKNOWN", float("nan")
116
- dist = float(np.linalg.norm(esm_vec - centroid))
117
- return ("IN DOMAIN" if dist <= threshold else "OUT OF DOMAIN"), dist
 
 
 
 
 
118
 
119
 
120
- # ══════════════════════════════════════════════════════════════════════
121
- # Features + prediction
122
- # ══════════════════════════════════════════════════════════════════════
123
- def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
124
  return np.concatenate([
125
  esm_mean[:, -480:], seq_feat,
126
  lig_feats["ecfp"], lig_feats["ecfp2"], lig_feats["ecfp6"], lig_feats["fcfp"],
@@ -129,13 +126,13 @@ def assemble_from_parts(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
129
  ], axis=1)
130
 
131
 
132
- def extract_features(sequence, smiles_list, tokenizer, esm_model, device, lig_scaler):
 
133
  esm_mean, esm_var, esm_attn, _ = embed_batch(
134
- [sequence], tokenizer, esm_model,
135
  config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
136
- batch_size=1, device=device,
137
- )
138
- seq_feat = np.array([sequence_features(sequence)])
139
  lig_feats, valid_mask, _ = extract_ligand_features(
140
  smiles_list, scaler=lig_scaler, fit_scaler=False)
141
  valid_mask = np.array(valid_mask)
@@ -144,13 +141,13 @@ def extract_features(sequence, smiles_list, tokenizer, esm_model, device, lig_sc
144
  bm[valid_mask] = True
145
  valid_mask = bm
146
  n = int(valid_mask.sum())
147
- X = assemble_from_parts(
148
- np.tile(esm_mean, (n, 1)), np.tile(esm_var, (n, 1)),
149
- np.tile(esm_attn, (n, 1)), np.tile(seq_feat, (n, 1)), lig_feats)
150
  return X, valid_mask, esm_mean[0]
151
 
152
 
153
- def predict(X, fold_models, meta, scaler):
154
  type_avgs = []
155
  for s in SEEDS:
156
  for t in MODEL_TYPES:
@@ -159,530 +156,650 @@ def predict(X, fold_models, meta, scaler):
159
  type_avgs.append(fp.mean(axis=1))
160
  preds_all = np.stack(type_avgs, axis=1)
161
  preds = meta.predict(np.column_stack([
162
- preds_all[:, [0,3,6]].mean(1),
163
- preds_all[:, [1,4,7]].mean(1),
164
- preds_all[:, [2,5,8]].mean(1),
165
  ]))
166
- return preds, preds_all
167
-
168
-
169
- def uncertainty_interval(preds_all, z=1.96):
170
  std = preds_all.std(axis=1)
171
- return preds_all.mean(axis=1) - z * std, preds_all.mean(axis=1) + z * std
172
 
173
 
174
  def format_ki(pkd):
175
- ki_nM = 10 ** (9 - pkd)
176
- if ki_nM < 1000: return f"{ki_nM:.1f} nM"
177
- elif ki_nM < 1_000_000: return f"{ki_nM/1000:.2f} uM"
178
- else: return f"{ki_nM/1_000_000:.2f} mM"
179
-
180
-
181
- # ══════════════════════════════════════════════════════════════════════
182
- # Plot
183
- # ══════════════════════════════════════════════════════════════════════
184
- def bar_chart(names, preds, lo, hi, title, dark=True):
185
- bg, fg, gc = ("#1a2332", "#e8edf2", "#2d3f55") if dark else ("#f8fafc", "#1a202c", "#cbd5e0")
186
- fig, ax = plt.subplots(figsize=(max(6, len(names) * 0.9), 4), facecolor=bg)
187
- ax.set_facecolor(bg)
188
- x = np.arange(len(names))
189
- err = [preds - lo, hi - preds]
190
- bars = ax.bar(x, preds, color="#3b82f6", alpha=0.9, width=0.6,
191
- yerr=err, capsize=5, error_kw=dict(ecolor=fg, lw=1.5))
192
- ax.set_xticks(x)
193
- ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10, color=fg)
194
- ax.set_ylabel("Predicted pKd", fontsize=11, color=fg)
195
- ax.set_title(title, fontsize=12, fontweight='bold', color=fg)
196
- ax.tick_params(colors=fg)
197
- for spine in ax.spines.values():
198
- spine.set_color(gc)
199
- ax.grid(True, axis='y', alpha=0.3, color=gc)
200
- for bar, val in zip(bars, preds):
201
- ax.text(bar.get_x() + bar.get_width() / 2,
202
- bar.get_height() + 0.05, f"{val:.2f}",
203
- ha='center', va='bottom', fontsize=9, fontweight='bold', color=fg)
204
- plt.tight_layout()
205
- return fig
206
-
207
-
208
- # ══════════════════════════════════════════════════════════════════════
209
- # Page layout
210
- # ══════════════════════════════════════════════════════════════════════
211
- st.set_page_config(page_title="VeloBind", layout="wide")
212
-
213
- dark = st.session_state.dark_mode
214
-
215
- # ── Theme-aware CSS (only custom elements, never .stApp) ──────────────
216
- if dark:
217
- card_bg, card_border = "#1a2332", "#2d4a6b"
218
- val_col, lab_col = "#60a5fa", "#94a3b8"
219
- banner_grad = "linear-gradient(135deg, #0f172a 0%, #1e3a5f 50%, #1e40af 100%)"
220
- banner_sub = "#93c5fd"
221
- logo_bg = "rgba(255,255,255,0.12)"
222
- logo_border = "rgba(255,255,255,0.2)"
223
- toggle_bg = "#1e3a5f"
224
- toggle_knob = "#60a5fa"
225
- toggle_label = "#93c5fd"
226
- else:
227
- card_bg, card_border = "#f0f7ff", "#bfdbfe"
228
- val_col, lab_col = "#1d4ed8", "#475569"
229
- banner_grad = "linear-gradient(135deg, #1d4ed8 0%, #2563eb 50%, #3b82f6 100%)"
230
- banner_sub = "#dbeafe"
231
- logo_bg = "rgba(255,255,255,0.85)"
232
- logo_border = "rgba(255,255,255,0.9)"
233
- toggle_bg = "#93c5fd"
234
- toggle_knob = "#1d4ed8"
235
- toggle_label = "#dbeafe"
236
-
237
- ad_css = """
238
- .ad-in { background:#064e3b; border:1px solid #059669; color:#34d399;
239
- border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
240
- .ad-out { background:#450a0a; border:1px solid #dc2626; color:#f87171;
241
- border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
242
- .ad-unk { background:#1e293b; border:1px solid #475569; color:#94a3b8;
243
- border-radius:20px; padding:0.3rem 1rem; font-weight:700; display:inline-block; font-size:0.9rem; }
244
- """
245
-
246
- def load_img_b64(path):
247
- with open(path, "rb") as f:
248
- return base64.b64encode(f.read()).decode()
249
-
250
- logo_b64 = load_img_b64("logo.png")
251
-
252
- st.markdown(f"""
253
  <style>
254
- {ad_css}
255
- .vb-banner {{
256
- background: {banner_grad};
257
- border-radius: 16px;
258
- padding: 1.2rem 1.8rem;
259
- display: flex;
260
- align-items: center;
261
- gap: 1.5rem;
262
- margin-bottom: 0.5rem;
263
- box-shadow: 0 4px 24px rgba(0,0,0,0.18);
264
- position: relative;
265
- }}
266
- .vb-logo-wrap {{
267
- background: {logo_bg};
268
- border: 1px solid {logo_border};
269
- border-radius: 14px;
270
- padding: 0.6rem;
271
- backdrop-filter: blur(8px);
272
- flex-shrink: 0;
273
- }}
274
- .vb-logo-wrap img {{
275
- height: 110px;
276
- width: auto;
277
- display: block;
278
- }}
279
- .vb-title-wrap {{
280
- flex: 1;
281
- }}
282
- .vb-title-wrap h1 {{
283
- color: #ffffff;
284
- font-size: 2.4rem;
285
- font-weight: 800;
286
- margin: 0 0 0.3rem 0;
287
- letter-spacing: -0.5px;
288
- }}
289
- .vb-title-wrap p {{
290
- color: {banner_sub};
291
- font-size: 0.92rem;
292
- margin: 0;
293
- line-height: 1.6;
294
- }}
295
- .vb-toggle-wrap {{
296
- position: absolute;
297
- top: 1rem;
298
- right: 1.2rem;
299
- display: flex;
300
- align-items: center;
301
- gap: 0.5rem;
302
- }}
303
- .vb-toggle-label {{
304
- color: {toggle_label};
305
- font-size: 0.78rem;
306
- font-weight: 600;
307
- letter-spacing: 0.03em;
308
- }}
309
- .metric-card {{
310
- background: {card_bg};
311
- border: 1px solid {card_border};
312
- border-radius: 12px;
313
- padding: 1.1rem;
314
- text-align: center;
315
- transition: box-shadow 0.2s;
316
- }}
317
- .metric-card:hover {{
318
- box-shadow: 0 4px 16px rgba(59,130,246,0.15);
319
- }}
320
- .metric-val {{
321
- font-size: 1.9rem;
322
- font-weight: 700;
323
- color: {val_col};
324
- line-height: 1.2;
325
- }}
326
- .metric-lab {{
327
- font-size: 0.75rem;
328
- color: {lab_col};
329
- margin-top: 0.35rem;
330
- line-height: 1.4;
331
- }}
332
- @media (max-width: 640px) {{
333
- .vb-banner {{
334
- flex-direction: column;
335
- text-align: center;
336
- padding: 1rem;
337
- }}
338
- .vb-logo-wrap img {{
339
- height: 80px;
340
- }}
341
- .vb-title-wrap h1 {{
342
- font-size: 1.6rem;
343
- }}
344
- .vb-title-wrap p {{
345
- font-size: 0.78rem;
346
- }}
347
- .vb-toggle-wrap {{
348
- position: static;
349
- justify-content: center;
350
- margin-top: 0.5rem;
351
- }}
352
- }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  </style>
354
- """, unsafe_allow_html=True)
 
355
 
356
- # ── Banner ────────────────────────────────────────────────────────────
357
- toggle_icon = "β˜€" if dark else "☾"
358
- toggle_text = "Light mode" if dark else "Dark mode"
 
359
 
360
- st.markdown(f"""
361
- <div class="vb-banner">
362
- <div class="vb-logo-wrap">
363
- <img src="data:image/png;base64,{logo_b64}" alt="VeloBind"/>=
 
 
364
  </div>
365
- <div class="vb-title-wrap">
366
- <h1>VeloBind</h1>
367
- <p>
368
- Structure-free protein-ligand binding affinity prediction &nbsp;&middot;&nbsp;
369
- Sequence + SMILES &nbsp;&middot;&nbsp;
370
- Pearson R = 0.8469 on CASF-2016 &nbsp;&middot;&nbsp;
371
- 45-model ensemble (LGBM + CatBoost + XGBoost)
372
- </p>
373
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  </div>
375
- """, unsafe_allow_html=True)
376
-
377
- # Theme toggle β€” just below banner, right-aligned
378
- _, tcol = st.columns([6, 1])
379
- with tcol:
380
- if st.button(f"{toggle_icon} {toggle_text}", use_container_width=True):
381
- st.session_state.dark_mode = not st.session_state.dark_mode
382
- st.rerun()
383
-
384
- # ── Load models ───────────────────────────────────────────────────────
385
- fold_models, meta, target_scaler, lig_scaler = load_all_models()
386
- tokenizer, esm_model, device = load_esm_model()
387
- ad_centroid, ad_threshold = load_ad_centroid()
388
- n_loaded = sum(len(fold_models[s][t]) for s in SEEDS for t in MODEL_TYPES)
389
- st.success(f"{n_loaded} fold models loaded | Device: {device.upper()}")
390
-
391
- # ── Mode tabs ─────────────────────────────────────────────────────────
392
- tab1, tab2, tab3 = st.tabs([
393
- "Single query",
394
- "Batch screening (CSV)",
395
- "One compound vs. multiple targets",
396
- ])
397
-
398
-
399
- # ══════════════════════════════════════════════════════════════════════
400
- # TAB 1 β€” Single query
401
- # ══════════════════════════════════════════════════════════════════════
402
- with tab1:
403
- col_p, col_l = st.columns(2)
404
- with col_p:
405
- st.subheader("Protein")
406
- seq_raw = st.text_area(
407
- "Amino acid sequence (plain or FASTA format)",
408
- height=160,
409
- placeholder=">ProteinName\nMKTAYIAKQRQISFVK...",
410
- help="Plain sequence or FASTA with >header line. Only standard amino acid letters (A-Z subset).",
411
- key="sq_seq"
412
- )
413
- with col_l:
414
- st.subheader("Ligand")
415
- smi = st.text_input("SMILES", placeholder="CC(=O)Oc1ccccc1C(=O)O", key="sq_smi")
416
- examples = {
417
- "Aspirin": "CC(=O)Oc1ccccc1C(=O)O",
418
- "Imatinib": "Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1",
419
- "Gefitinib": "COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1",
420
- "Staurosporine": "C[C@@H]1CCCN2C(=O)c3[nH]c4ccccc4c3C2=N1",
421
- }
422
- chosen = st.selectbox("Load example SMILES", ["β€”"] + list(examples), key="sq_ex")
423
- if chosen != "β€”":
424
- smi = examples[chosen]
425
-
426
- if st.button("Predict", type="primary", use_container_width=True, key="sq_btn"):
427
- seq, err = validate_sequence(seq_raw)
428
- if err:
429
- st.error(err)
430
- elif not smi.strip():
431
- st.error("Please enter a SMILES string.")
432
- else:
433
- with st.spinner("Running inference..."):
434
- t0 = time.time()
435
- try:
436
- X, valid, esm_vec = extract_features(
437
- seq, [smi.strip()], tokenizer, esm_model, device, lig_scaler)
438
- if not valid.any():
439
- st.error("RDKit could not parse this SMILES.")
440
- else:
441
- preds, preds_all = predict(X, fold_models, meta, target_scaler)
442
- lo, hi = uncertainty_interval(preds_all)
443
- elapsed = time.time() - t0
444
- pkd = float(preds[0])
445
- ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
446
-
447
- st.markdown("#### Results")
448
- c1, c2, c3, c4 = st.columns(4)
449
- with c1:
450
- st.markdown(f"""<div class="metric-card">
451
- <div class="metric-val">{pkd:.2f}</div>
452
- <div class="metric-lab">Predicted pKd</div>
453
- </div>""", unsafe_allow_html=True)
454
- with c2:
455
- st.markdown(f"""<div class="metric-card">
456
- <div class="metric-val">[{lo[0]:.2f}, {hi[0]:.2f}]</div>
457
- <div class="metric-lab">95% model interval<br>(Β±1.96Οƒ Β· 45 models)</div>
458
- </div>""", unsafe_allow_html=True)
459
- with c3:
460
- st.markdown(f"""<div class="metric-card">
461
- <div class="metric-val">{format_ki(pkd)}</div>
462
- <div class="metric-lab">Estimated Ki<br>(pKd β‰ˆ pKi assumed)</div>
463
- </div>""", unsafe_allow_html=True)
464
- with c4:
465
- ad_cls = ("ad-in" if ad_label == "IN DOMAIN" else
466
- "ad-out" if ad_label == "OUT OF DOMAIN" else "ad-unk")
467
- st.markdown(f"""<div class="metric-card">
468
- <div style="padding-top:0.4rem">
469
- <span class="{ad_cls}">{ad_label}</span>
470
- </div>
471
- <div class="metric-lab" style="margin-top:0.6rem">Applicability domain</div>
472
- </div>""", unsafe_allow_html=True)
473
-
474
- if ad_label == "OUT OF DOMAIN":
475
- st.warning("Protein is outside the training distribution. Predictions may be unreliable.")
476
-
477
- st.caption(
478
- f"Inference time: {elapsed:.2f}s | "
479
- f"45-model ensemble (3 seeds x 3 types x 5 folds) | "
480
- f"Device: {device.upper()}"
481
- )
482
-
483
- with st.expander("Per-model breakdown"):
484
- labels = [f"s{s}_{t}" for s in SEEDS for t in MODEL_TYPES]
485
- fig = bar_chart(
486
- labels, preds_all[0],
487
- preds_all[0] - preds_all[0].std(),
488
- preds_all[0] + preds_all[0].std(),
489
- "Per-seed and model-type predictions (fold-averaged)",
490
- dark=dark,
491
- )
492
- st.pyplot(fig, use_container_width=True)
493
- plt.close(fig)
494
-
495
- except Exception as e:
496
- st.error(f"Inference error: {e}")
497
- st.exception(e)
498
-
499
-
500
- # ══════════════════════════════════════════════════════════════════════
501
- # TAB 2 β€” Batch screening
502
- # ══════════════════════════════════════════════════════════════════════
503
- with tab2:
504
- st.subheader("Batch Screening")
505
- st.markdown("Screen a library of compounds against one target. "
506
- "Upload a CSV with a `smiles` column (and optionally `name`). "
507
- "Results are ranked by predicted pKd.")
508
-
509
- col_seq, col_csv = st.columns(2)
510
- with col_seq:
511
- batch_seq_raw = st.text_area("Target protein sequence (plain or FASTA)", height=180,
512
- placeholder=">Target\nMKTAYIAKQRQISFVK...", key="bs_seq")
513
- with col_csv:
514
- uploaded = st.file_uploader("Compound CSV (smiles, name)", type=["csv"], key="bs_up")
515
- st.code("smiles,name\nCC(=O)Oc1ccccc1C(=O)O,Aspirin", language="csv")
516
-
517
- max_cpds = st.slider("Max compounds", 10, 500, 100, key="bs_max",
518
- help="~1s per compound on CPU free tier.")
519
-
520
- if st.button("Run batch screening", type="primary", use_container_width=True, key="bs_btn"):
521
- batch_seq, err = validate_sequence(batch_seq_raw)
522
- if err:
523
- st.error(err)
524
- elif uploaded is None:
525
- st.error("Please upload a CSV file.")
526
- else:
527
- df_in = pd.read_csv(uploaded)
528
- if 'smiles' not in df_in.columns:
529
- st.error("CSV must have a 'smiles' column.")
530
- st.stop()
531
-
532
- df_in = df_in.head(max_cpds)
533
- smiles_list = df_in['smiles'].tolist()
534
- names_list = (df_in['name'].tolist() if 'name' in df_in.columns
535
- else [f"cpd_{i}" for i in range(len(df_in))])
536
-
537
- with st.spinner(f"Screening {len(smiles_list)} compounds..."):
538
- t0 = time.time()
539
- X, valid, esm_vec = extract_features(
540
- batch_seq, smiles_list, tokenizer, esm_model, device, lig_scaler)
541
- ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
542
- preds, preds_all = predict(X, fold_models, meta, target_scaler)
543
- lo, hi = uncertainty_interval(preds_all)
544
- elapsed = time.time() - t0
545
-
546
- valid_names = [names_list[i] for i in range(len(names_list)) if valid[i]]
547
- valid_smiles = [smiles_list[i] for i in range(len(smiles_list)) if valid[i]]
548
- n_invalid = int((~valid).sum())
549
-
550
- if ad_label == "OUT OF DOMAIN":
551
- st.warning("Protein is outside the training distribution. Predictions may be unreliable.")
552
-
553
- results_df = pd.DataFrame({
554
- 'name': valid_names,
555
- 'smiles': valid_smiles,
556
- 'pKd_pred': np.round(preds, 3),
557
- 'CI_lo': np.round(lo, 3),
558
- 'CI_hi': np.round(hi, 3),
559
- 'Ki_est': [format_ki(p) for p in preds],
560
- 'model_std': np.round(preds_all.std(axis=1), 3),
561
- 'AD': [ad_label] * len(valid_names),
562
- }).sort_values('pKd_pred', ascending=False).reset_index(drop=True)
563
- results_df.insert(0, 'rank', range(1, len(results_df) + 1))
564
-
565
- st.success(
566
- f"{len(results_df)} compounds screened in {elapsed:.1f}s "
567
- f"({elapsed / max(len(results_df), 1):.2f}s/compound)"
568
- + (f" | {n_invalid} invalid SMILES skipped" if n_invalid else "")
569
- )
570
-
571
- top_n = min(20, len(results_df))
572
- fig = bar_chart(
573
- results_df.head(top_n)['name'].tolist(),
574
- results_df.head(top_n)['pKd_pred'].values,
575
- results_df.head(top_n)['CI_lo'].values,
576
- results_df.head(top_n)['CI_hi'].values,
577
- f"Top {top_n} hits by predicted pKd", dark=dark,
578
- )
579
- st.pyplot(fig, use_container_width=True)
580
- plt.close(fig)
581
-
582
- st.dataframe(
583
- results_df.style.background_gradient(subset=['pKd_pred'], cmap='Blues'),
584
- use_container_width=True, height=400,
585
- )
586
- st.download_button(
587
- "Download ranked CSV",
588
- results_df.to_csv(index=False).encode(),
589
- file_name="velobind_screening.csv", mime="text/csv",
590
- )
591
-
592
-
593
- # ══════════════════════════════════════════════════════════════════════
594
- # TAB 3 β€” Selectivity profiling
595
- # ══════════════════════════════════════════════════════════════════════
596
- with tab3:
597
- st.subheader("Selectivity Profiling")
598
- st.markdown("One compound, multiple targets β€” ranked by predicted pKd. "
599
- "Format: `TargetName: SEQUENCE` (name optional). "
600
- "Accepts plain sequence or FASTA per line.")
601
-
602
- multi_smi = st.text_input("Compound SMILES", placeholder="Cc1ccc(...)cc1Nc1nccc(...)n1", key="sp_smi")
603
- multi_seqs = st.text_area(
604
- "Target proteins (one per line)",
605
- height=250,
606
- placeholder="ABL1: MGPSENDPNLFVALY...\nEGFR: MRPSGTAGAALLALL...\nCDK2: MENFQKVEKIGEGTY...",
607
- key="sp_seqs",
608
- )
609
-
610
- if st.button("Run selectivity profiling", type="primary", use_container_width=True, key="sp_btn"):
611
- if not multi_smi.strip() or not multi_seqs.strip():
612
- st.error("Please enter a SMILES and at least one protein sequence.")
613
- else:
614
- targets, parse_errors = {}, []
615
- for i, line in enumerate(multi_seqs.strip().splitlines()):
616
- line = line.strip()
617
- if not line:
618
- continue
619
- name, raw_seq = (line.split(":", 1) if ":" in line
620
- else (f"Target_{i+1}", line))
621
- seq, err = validate_sequence(raw_seq if isinstance(raw_seq, str) else raw_seq)
622
- if err:
623
- parse_errors.append(f"{name.strip()}: {err}")
624
- else:
625
- targets[name.strip()] = seq
626
-
627
- for e in parse_errors:
628
- st.warning(f"Skipped β€” {e}")
629
- if not targets:
630
- st.error("No valid sequences found.")
631
- st.stop()
632
-
633
- results, progress = [], st.progress(0)
634
- for idx, (name, seq) in enumerate(targets.items()):
635
- try:
636
- X, valid, esm_vec = extract_features(
637
- seq, [multi_smi.strip()], tokenizer, esm_model, device, lig_scaler)
638
- if valid.any():
639
- preds, preds_all = predict(X, fold_models, meta, target_scaler)
640
- lo, hi = uncertainty_interval(preds_all)
641
- ad_label, _ = ad_check(esm_vec[-480:], ad_centroid, ad_threshold)
642
- results.append({
643
- 'Target': name,
644
- 'pKd_pred': round(float(preds[0]), 3),
645
- 'CI_lo': round(float(lo[0]), 3),
646
- 'CI_hi': round(float(hi[0]), 3),
647
- 'Ki_est': format_ki(float(preds[0])),
648
- 'model_std': round(float(preds_all.std()), 3),
649
- 'AD': ad_label,
650
- })
651
- except Exception as e:
652
- st.warning(f"Skipped {name}: {e}")
653
- progress.progress((idx + 1) / len(targets))
654
-
655
- progress.empty()
656
- res_df = (pd.DataFrame(results)
657
- .sort_values('pKd_pred', ascending=False)
658
- .reset_index(drop=True))
659
- res_df.insert(0, 'rank', range(1, len(res_df) + 1))
660
-
661
- st.success(f"Profiled {len(res_df)} targets.")
662
- fig = bar_chart(
663
- res_df['Target'].tolist(), res_df['pKd_pred'].values,
664
- res_df['CI_lo'].values, res_df['CI_hi'].values,
665
- "Selectivity profile β€” predicted pKd by target", dark=dark,
666
- )
667
- st.pyplot(fig, use_container_width=True)
668
- plt.close(fig)
669
-
670
- st.dataframe(res_df, use_container_width=True)
671
- st.download_button(
672
- "Download selectivity CSV",
673
- res_df.to_csv(index=False).encode(),
674
- file_name="velobind_selectivity.csv", mime="text/csv",
675
- )
676
-
677
-
678
- # ── Footer ────────────────────────────────────────────────────────────
679
- st.markdown("---")
680
- st.markdown(f"""
681
- <div style="color:{lab_col};font-size:0.78rem;text-align:center;padding:0.4rem 0 0.8rem">
682
- VeloBind &nbsp;&middot;&nbsp; Structure-free binding affinity &nbsp;&middot;&nbsp;
683
- ESM-2 + gradient-boosted ensemble &nbsp;&middot;&nbsp;
684
- Trained on LP-PDBBind &nbsp;&middot;&nbsp;
685
- Evaluated on CASF-2016 and CASF-2013 &nbsp;&middot;&nbsp;
686
- <b>Not for clinical use</b>
687
- </div>
688
- """, unsafe_allow_html=True)
 
1
+ # app.py β€” VeloBind Flask inference app
2
+ # Run locally: python app.py
3
+ # HF Spaces: add a Dockerfile or use the gradio wrapper below
4
+
5
+ import os, warnings, time, base64, json
6
  import numpy as np
7
  import pandas as pd
 
8
  import joblib
9
  import torch
 
10
  from pathlib import Path
11
+ from io import StringIO
12
+ from flask import Flask, request, jsonify, render_template_string
13
 
14
  warnings.filterwarnings("ignore")
15
  from rdkit import RDLogger
16
  RDLogger.DisableLog('rdApp.*')
17
 
18
+ # ── Config ────────────────────────────────────────────────────
19
  HF_MODEL_REPO = "ym59/velobind-models"
20
  MODEL_CACHE = Path("/tmp/velobind_models")
21
  SEEDS = [42, 123, 456]
 
30
  from src.models.ensemble import TargetScaler
31
  from src.config import config
32
 
33
+ # ── Lazy-loaded globals ───────────────────────────────────────
34
+ _fold_models = _meta = _scaler = _lig_scaler = None
35
+ _tokenizer = _esm_model = _device = None
36
+ _ad_centroid = _ad_threshold = None
37
+
38
+
39
+ def get_models():
40
+ global _fold_models, _meta, _scaler, _lig_scaler
41
+ if _fold_models is not None:
42
+ return _fold_models, _meta, _scaler, _lig_scaler
43
+
44
+ local_dir = Path("output/models")
45
+ local_pre = Path("output/preprocessors")
46
+
47
+ if (local_dir / "meta_type_casf16.pkl").exists():
48
+ model_dir, preproc_dir = local_dir, local_pre
49
+ else:
50
+ from huggingface_hub import hf_hub_download
51
+ MODEL_CACHE.mkdir(parents=True, exist_ok=True)
52
+ files = (
53
+ [f"fold_model_s{s}_{t}_f{f}.pkl"
54
+ for s in SEEDS for t in MODEL_TYPES for f in range(N_FOLDS)]
55
+ + ["meta_type_casf16.pkl", "target_scaler.pkl", "ligand_scaler.pkl"]
56
+ )
57
+ for fname in files:
58
+ if not (MODEL_CACHE / fname).exists():
59
+ hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
60
+ local_dir=str(MODEL_CACHE))
61
+ model_dir = preproc_dir = MODEL_CACHE
62
 
63
+ _fold_models = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  for s in SEEDS:
65
+ _fold_models[s] = {}
66
  for t in MODEL_TYPES:
67
+ _fold_models[s][t] = [
68
+ joblib.load(model_dir / f"fold_model_s{s}_{t}_f{f}.pkl")
69
  for f in range(N_FOLDS)
70
  ]
71
+ _meta = joblib.load(model_dir / "meta_type_casf16.pkl")
72
+ _scaler = joblib.load(model_dir / "target_scaler.pkl")
73
+ _lig_scaler = joblib.load(preproc_dir / "ligand_scaler.pkl")
74
+ return _fold_models, _meta, _scaler, _lig_scaler
75
+
76
+
77
+ def get_esm():
78
+ global _tokenizer, _esm_model, _device
79
+ if _tokenizer is not None:
80
+ return _tokenizer, _esm_model, _device
81
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
82
+ _tokenizer, _esm_model = load_esm(config.ESM_MODEL, _device)
83
+ return _tokenizer, _esm_model, _device
84
+
85
+
86
+ def get_ad():
87
+ global _ad_centroid, _ad_threshold
88
+ if _ad_centroid is not None:
89
+ return _ad_centroid, _ad_threshold
90
+ for p in [Path("output/models/deployment"), Path("output/models"), MODEL_CACHE]:
91
  if (p / "ad_centroid.npy").exists():
92
+ _ad_centroid = np.load(p / "ad_centroid.npy")
93
+ _ad_threshold = float(np.load(p / "ad_threshold.npy"))
94
+ return _ad_centroid, _ad_threshold
95
+ try:
96
+ from huggingface_hub import hf_hub_download
97
+ for fname in ["ad_centroid.npy", "ad_threshold.npy"]:
98
+ if not (MODEL_CACHE / fname).exists():
99
  hf_hub_download(repo_id=HF_MODEL_REPO, filename=fname,
100
  local_dir=str(MODEL_CACHE))
101
+ _ad_centroid = np.load(MODEL_CACHE / "ad_centroid.npy")
102
+ _ad_threshold = float(np.load(MODEL_CACHE / "ad_threshold.npy"))
103
+ except Exception:
104
+ _ad_centroid = _ad_threshold = None
105
+ return _ad_centroid, _ad_threshold
106
 
107
 
108
+ # ── Helpers ───────────────────────────────────────────────────
109
+ def validate_sequence(raw):
110
+ lines = raw.strip().splitlines()
111
+ seq = "".join(l.strip() for l in lines if not l.startswith(">")).upper().replace(" ","")
112
+ if len(seq) < 10:
113
+ return None, "Sequence too short (minimum 10 residues)."
114
+ bad = set(seq) - VALID_AA
115
+ if bad:
116
+ return None, f"Invalid characters: {', '.join(sorted(bad))}."
117
+ return seq, None
118
 
119
 
120
+ def assemble(esm_mean, esm_var, esm_attn, seq_feat, lig_feats):
 
 
 
121
  return np.concatenate([
122
  esm_mean[:, -480:], seq_feat,
123
  lig_feats["ecfp"], lig_feats["ecfp2"], lig_feats["ecfp6"], lig_feats["fcfp"],
 
126
  ], axis=1)
127
 
128
 
129
+ def extract_features(sequence, smiles_list, lig_scaler):
130
+ tok, esm, dev = get_esm()
131
  esm_mean, esm_var, esm_attn, _ = embed_batch(
132
+ [sequence], tok, esm,
133
  config.ESM_LAYERS, config.MAX_SEQ_LEN, config.HALF_SEQ_LEN,
134
+ batch_size=1, device=dev)
135
+ seq_feat = np.array([sequence_features(sequence)])
 
136
  lig_feats, valid_mask, _ = extract_ligand_features(
137
  smiles_list, scaler=lig_scaler, fit_scaler=False)
138
  valid_mask = np.array(valid_mask)
 
141
  bm[valid_mask] = True
142
  valid_mask = bm
143
  n = int(valid_mask.sum())
144
+ X = assemble(
145
+ np.tile(esm_mean,(n,1)), np.tile(esm_var,(n,1)),
146
+ np.tile(esm_attn,(n,1)), np.tile(seq_feat,(n,1)), lig_feats)
147
  return X, valid_mask, esm_mean[0]
148
 
149
 
150
+ def run_predict(X, fold_models, meta, scaler):
151
  type_avgs = []
152
  for s in SEEDS:
153
  for t in MODEL_TYPES:
 
156
  type_avgs.append(fp.mean(axis=1))
157
  preds_all = np.stack(type_avgs, axis=1)
158
  preds = meta.predict(np.column_stack([
159
+ preds_all[:,[0,3,6]].mean(1),
160
+ preds_all[:,[1,4,7]].mean(1),
161
+ preds_all[:,[2,5,8]].mean(1),
162
  ]))
 
 
 
 
163
  std = preds_all.std(axis=1)
164
+ return preds, preds_all.mean(1) - 1.96*std, preds_all.mean(1) + 1.96*std
165
 
166
 
167
  def format_ki(pkd):
168
+ ki = 10 ** (9 - pkd)
169
+ if ki < 1000: return f"{ki:.1f} nM"
170
+ elif ki < 1_000_000: return f"{ki/1000:.2f} Β΅M"
171
+ else: return f"{ki/1_000_000:.2f} mM"
172
+
173
+
174
+ def ad_label(esm_vec):
175
+ c, t = get_ad()
176
+ if c is None: return "UNKNOWN"
177
+ dist = float(np.linalg.norm(esm_vec[-480:] - c))
178
+ return "IN DOMAIN" if dist <= t else "OUT OF DOMAIN"
179
+
180
+
181
+ def load_logo_b64():
182
+ for ext in ["png", "svg"]:
183
+ p = Path(f"logo.{ext}")
184
+ if p.exists():
185
+ with open(p,"rb") as f:
186
+ b64 = base64.b64encode(f.read()).decode()
187
+ mime = "image/png" if ext=="png" else "image/svg+xml"
188
+ return f"data:{mime};base64,{b64}"
189
+ return ""
190
+
191
+ LOGO_URI = load_logo_b64()
192
+
193
+ # ── HTML template ─────────────────────────────────────────────
194
+ HTML = r"""<!DOCTYPE html>
195
+ <html lang="en" data-theme="dark">
196
+ <head>
197
+ <meta charset="UTF-8">
198
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0">
199
+ <title>VeloBind</title>
200
+ <link rel="preconnect" href="https://fonts.googleapis.com">
201
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
202
+ <link href="https://fonts.googleapis.com/css2?family=DM+Mono:wght@400;500&family=Syne:wght@600;700;800&family=DM+Sans:wght@400;500;600&display=swap" rel="stylesheet">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  <style>
204
+ :root[data-theme="dark"] {
205
+ --bg:#080c14; --surface:#0f1623; --surface2:#162030; --surface3:#1c2a3f;
206
+ --border:#1e3050; --border2:#243a56;
207
+ --text:#e8edf5; --text2:#8ba4c0; --text3:#4d6a87;
208
+ --accent:#3b82f6; --accent2:#60a5fa; --accent-glow:rgba(59,130,246,0.18);
209
+ --green:#22c55e; --red:#ef4444; --amber:#f59e0b;
210
+ --card-shadow:0 4px 24px rgba(0,0,0,0.4); --input-bg:#0f1623;
211
+ }
212
+ :root[data-theme="light"] {
213
+ --bg:#f0f4fa; --surface:#fff; --surface2:#f5f8ff; --surface3:#eaf0fb;
214
+ --border:#d0dce8; --border2:#b8ccd8;
215
+ --text:#0d1b2a; --text2:#4a6480; --text3:#8aa3bc;
216
+ --accent:#1d5cbf; --accent2:#2563eb; --accent-glow:rgba(29,92,191,0.12);
217
+ --green:#16a34a; --red:#dc2626; --amber:#d97706;
218
+ --card-shadow:0 4px 24px rgba(0,0,0,0.08); --input-bg:#fff;
219
+ }
220
+ *,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
221
+ html{scroll-behavior:smooth}
222
+ body{background:var(--bg);color:var(--text);font-family:'DM Sans',sans-serif;
223
+ font-size:15px;line-height:1.6;min-height:100vh;transition:background .3s,color .3s}
224
+ ::selection{background:var(--accent);color:#fff}
225
+
226
+ .wrap{max-width:1100px;margin:0 auto;padding:0 20px 60px}
227
+
228
+ /* HEADER */
229
+ .header{padding:24px 0 20px;display:flex;align-items:center;gap:18px;
230
+ border-bottom:1px solid var(--border);margin-bottom:28px;flex-wrap:wrap}
231
+ .logo-wrap{background:#fff;border-radius:12px;padding:7px;flex-shrink:0;
232
+ box-shadow:0 2px 12px rgba(0,0,0,0.15)}
233
+ .logo-wrap img{height:64px;width:auto;display:block}
234
+ .header-text{flex:1;min-width:180px}
235
+ .header-text h1{font-family:'Syne',sans-serif;font-size:clamp(1.5rem,3.5vw,2.2rem);
236
+ font-weight:800;letter-spacing:-1px;color:var(--text);line-height:1.1}
237
+ .header-text h1 span{color:var(--accent2)}
238
+ .header-text p{font-size:.78rem;color:var(--text2);margin-top:5px;
239
+ font-family:'DM Mono',monospace;letter-spacing:.02em}
240
+ .header-right{display:flex;align-items:center;gap:8px;flex-wrap:wrap}
241
+ .badge{font-family:'DM Mono',monospace;font-size:.68rem;font-weight:500;
242
+ padding:3px 9px;border-radius:5px;border:1px solid var(--border2);
243
+ color:var(--text2);background:var(--surface2);white-space:nowrap}
244
+ .badge.green{border-color:var(--green);color:var(--green);background:rgba(34,197,94,.08)}
245
+
246
+ /* THEME BUTTON */
247
+ .theme-btn{background:var(--surface2);border:1px solid var(--border2);color:var(--text);
248
+ border-radius:999px;padding:6px 14px;font-size:.8rem;font-weight:600;
249
+ font-family:'DM Sans',sans-serif;cursor:pointer;display:flex;align-items:center;
250
+ gap:6px;transition:all .2s;white-space:nowrap}
251
+ .theme-btn:hover{border-color:var(--accent);color:var(--accent)}
252
+ .theme-btn svg{width:14px;height:14px}
253
+
254
+ /* TABS */
255
+ .tabs{display:flex;gap:4px;background:var(--surface);border:1px solid var(--border);
256
+ border-radius:13px;padding:4px;margin-bottom:24px;overflow-x:auto;
257
+ -webkit-overflow-scrolling:touch}
258
+ .tab-btn{flex:1;min-width:110px;padding:9px 14px;border:none;border-radius:9px;
259
+ background:transparent;color:var(--text2);font-size:.8rem;font-weight:600;
260
+ font-family:'DM Sans',sans-serif;cursor:pointer;transition:all .2s;
261
+ white-space:nowrap;text-align:center}
262
+ .tab-btn:hover{color:var(--text);background:var(--surface2)}
263
+ .tab-btn.active{background:var(--accent);color:#fff;box-shadow:0 2px 10px var(--accent-glow)}
264
+ .tab-panel{display:none}
265
+ .tab-panel.active{display:block;animation:fadeUp .25s ease}
266
+ @keyframes fadeUp{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
267
+
268
+ /* CARD */
269
+ .card{background:var(--surface);border:1px solid var(--border);border-radius:16px;
270
+ padding:20px;box-shadow:var(--card-shadow);transition:background .3s,border-color .3s}
271
+ .card+.card{margin-top:14px}
272
+ .card-title{font-family:'Syne',sans-serif;font-size:.68rem;font-weight:700;
273
+ letter-spacing:.12em;text-transform:uppercase;color:var(--text3);margin-bottom:12px}
274
+
275
+ /* GRID */
276
+ .grid-2{display:grid;grid-template-columns:1fr 1fr;gap:14px}
277
+ @media(max-width:640px){.grid-2{grid-template-columns:1fr}}
278
+ .metric-grid{display:grid;grid-template-columns:repeat(4,1fr);gap:10px;margin-top:20px}
279
+ @media(max-width:700px){.metric-grid{grid-template-columns:repeat(2,1fr)}}
280
+
281
+ /* FORM */
282
+ label{display:block;font-size:.74rem;font-weight:600;color:var(--text2);
283
+ margin-bottom:7px;letter-spacing:.04em;text-transform:uppercase}
284
+ textarea,input[type=text],input[type=number]{width:100%;background:var(--input-bg);
285
+ border:1.5px solid var(--border);border-radius:11px;color:var(--text);
286
+ font-family:'DM Mono',monospace;font-size:.82rem;padding:11px 13px;
287
+ resize:vertical;transition:border-color .2s,box-shadow .2s;outline:none}
288
+ textarea{min-height:120px}
289
+ textarea:focus,input:focus{border-color:var(--accent);box-shadow:0 0 0 3px var(--accent-glow)}
290
+
291
+ /* BUTTONS */
292
+ .btn{display:inline-flex;align-items:center;justify-content:center;gap:7px;
293
+ padding:11px 24px;border-radius:11px;border:none;font-family:'DM Sans',sans-serif;
294
+ font-size:.88rem;font-weight:700;cursor:pointer;transition:all .2s}
295
+ .btn-primary{background:var(--accent);color:#fff;width:100%;margin-top:16px;
296
+ padding:13px;font-size:.92rem;box-shadow:0 4px 14px var(--accent-glow)}
297
+ .btn-primary:hover{background:var(--accent2);transform:translateY(-1px)}
298
+ .btn-primary:disabled{opacity:.55;cursor:not-allowed;transform:none}
299
+ .btn-outline{background:var(--surface2);border:1px solid var(--border2);
300
+ color:var(--text);padding:7px 14px;font-size:.78rem}
301
+ .btn-outline:hover{border-color:var(--accent);color:var(--accent)}
302
+
303
+ /* EXAMPLE PILLS */
304
+ .examples{display:flex;flex-wrap:wrap;gap:5px;margin-top:8px}
305
+ .ex-pill{font-size:.72rem;font-family:'DM Mono',monospace;padding:3px 9px;
306
+ border-radius:5px;border:1px solid var(--border2);background:var(--surface2);
307
+ color:var(--text2);cursor:pointer;transition:all .15s}
308
+ .ex-pill:hover{border-color:var(--accent);color:var(--accent)}
309
+
310
+ /* FILE DROP */
311
+ .file-drop{border:2px dashed var(--border2);border-radius:11px;padding:20px;
312
+ text-align:center;cursor:pointer;transition:all .2s;background:var(--surface2);
313
+ color:var(--text2);font-size:.82rem}
314
+ .file-drop:hover,.file-drop.drag{border-color:var(--accent);color:var(--accent);
315
+ background:var(--accent-glow)}
316
+ .file-drop input{display:none}
317
+
318
+ /* SPINNER */
319
+ .spinner{display:inline-block;width:15px;height:15px;border:2px solid rgba(255,255,255,.3);
320
+ border-top-color:#fff;border-radius:50%;animation:spin .7s linear infinite}
321
+ @keyframes spin{to{transform:rotate(360deg)}}
322
+
323
+ /* RESULTS */
324
+ .results{display:none;margin-top:20px}
325
+ .results.show{display:block;animation:fadeUp .3s ease}
326
+
327
+ /* METRIC CARDS */
328
+ .metric{background:var(--surface2);border:1px solid var(--border);
329
+ border-radius:13px;padding:14px;text-align:center}
330
+ .metric-val{font-family:'Syne',sans-serif;font-size:1.5rem;font-weight:800;
331
+ color:var(--accent2);line-height:1.1;margin-bottom:3px}
332
+ .metric-lab{font-size:.67rem;color:var(--text3);text-transform:uppercase;
333
+ letter-spacing:.06em;font-weight:600}
334
+
335
+ /* AD BADGES */
336
+ .ad-badge{display:inline-block;padding:4px 12px;border-radius:999px;
337
+ font-size:.7rem;font-weight:700;font-family:'DM Mono',monospace;
338
+ letter-spacing:.04em;text-transform:uppercase}
339
+ .ad-in {background:rgba(34,197,94,.1); border:1px solid var(--green);color:var(--green)}
340
+ .ad-out{background:rgba(239,68,68,.1); border:1px solid var(--red); color:var(--red)}
341
+ .ad-unk{background:rgba(77,106,135,.1);border:1px solid var(--text3);color:var(--text3)}
342
+
343
+ .inf-caption{font-family:'DM Mono',monospace;font-size:.68rem;color:var(--text3);margin-top:10px}
344
+
345
+ .warn-box{background:rgba(245,158,11,.08);border:1px solid var(--amber);border-radius:11px;
346
+ padding:10px 14px;color:var(--amber);font-size:.8rem;margin-top:10px;display:none}
347
+
348
+ /* TABLE */
349
+ .tbl-wrap{overflow-x:auto;margin-top:14px;border-radius:11px;border:1px solid var(--border)}
350
+ table{width:100%;border-collapse:collapse;font-size:.8rem}
351
+ thead{background:var(--surface2)}
352
+ th{text-align:left;padding:9px 13px;font-size:.67rem;font-weight:700;
353
+ text-transform:uppercase;letter-spacing:.06em;color:var(--text3);white-space:nowrap}
354
+ td{padding:9px 13px;border-top:1px solid var(--border);font-family:'DM Mono',monospace;
355
+ font-size:.78rem;color:var(--text)}
356
+ tr:hover td{background:var(--surface2)}
357
+ .pkd-val{color:var(--accent2);font-weight:600}
358
+
359
+ .download-wrap{margin-top:12px;display:flex;gap:8px;flex-wrap:wrap}
360
+
361
+ /* LOADING OVERLAY */
362
+ .overlay{display:none;position:fixed;inset:0;background:rgba(8,12,20,.75);
363
+ backdrop-filter:blur(4px);z-index:100;align-items:center;justify-content:center;
364
+ flex-direction:column;gap:14px}
365
+ .overlay.show{display:flex}
366
+ .ov-spinner{width:42px;height:42px;border:3px solid rgba(59,130,246,.2);
367
+ border-top-color:var(--accent);border-radius:50%;animation:spin .8s linear infinite}
368
+ .ov-text{font-family:'DM Mono',monospace;font-size:.82rem;color:var(--text2)}
369
+
370
+ /* FOOTER */
371
+ footer{border-top:1px solid var(--border);padding:20px 0 0;margin-top:44px;
372
+ text-align:center;font-size:.72rem;color:var(--text3);font-family:'DM Mono',monospace}
373
+ footer span{color:var(--text2)}
374
  </style>
375
+ </head>
376
+ <body>
377
 
378
+ <div id="overlay" class="overlay">
379
+ <div class="ov-spinner"></div>
380
+ <div class="ov-text" id="ov-text">Running inference...</div>
381
+ </div>
382
 
383
+ <div class="wrap">
384
+ <header class="header">
385
+ <div class="logo-wrap"><img src="{{ logo_uri }}" alt="VeloBind"></div>
386
+ <div class="header-text">
387
+ <h1>Velo<span>Bind</span></h1>
388
+ <p>R = 0.8469 on CASF-2016 &nbsp;Β·&nbsp; 45-model ensemble &nbsp;Β·&nbsp; sequence + SMILES only</p>
389
  </div>
390
+ <div class="header-right">
391
+ <span class="badge green">No 3D structure</span>
392
+ <span class="badge" id="dev-badge">CPU</span>
393
+ <button class="theme-btn" onclick="toggleTheme()">
394
+ <svg id="t-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"></svg>
395
+ <span id="t-label">Light mode</span>
396
+ </button>
 
397
  </div>
398
+ </header>
399
+
400
+ <div class="tabs">
401
+ <button class="tab-btn active" onclick="switchTab(0)">Single Query</button>
402
+ <button class="tab-btn" onclick="switchTab(1)">Batch Screening</button>
403
+ <button class="tab-btn" onclick="switchTab(2)">Selectivity Profile</button>
404
+ </div>
405
+
406
+ <!-- ═══ TAB 0 ═══════════════════════════════════════════════ -->
407
+ <div class="tab-panel active" id="tab-0">
408
+ <div class="grid-2">
409
+ <div class="card">
410
+ <div class="card-title">Protein</div>
411
+ <label>Sequence β€” plain or FASTA</label>
412
+ <textarea id="sq-seq" placeholder=">ProteinName&#10;MKTAYIAKQRQISFVK..."></textarea>
413
+ </div>
414
+ <div class="card">
415
+ <div class="card-title">Ligand</div>
416
+ <label>SMILES</label>
417
+ <input type="text" id="sq-smi" placeholder="CC(=O)Oc1ccccc1C(=O)O">
418
+ <div style="margin-top:10px">
419
+ <div class="card-title">Quick examples</div>
420
+ <div class="examples">
421
+ <span class="ex-pill" onclick="setSmi('CC(=O)Oc1ccccc1C(=O)O')">Aspirin</span>
422
+ <span class="ex-pill" onclick="setSmi('Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1')">Imatinib</span>
423
+ <span class="ex-pill" onclick="setSmi('COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1')">Gefitinib</span>
424
+ <span class="ex-pill" onclick="setSmi('C[C@@H]1CCCN2C(=O)c3[nH]c4ccccc4c3C2=N1')">Staurosporine</span>
425
+ </div>
426
+ </div>
427
+ </div>
428
+ </div>
429
+ <button class="btn btn-primary" id="sq-btn" onclick="runSingle()">
430
+ <span id="sq-txt">Predict binding affinity</span>
431
+ </button>
432
+ <div class="warn-box" id="sq-warn"></div>
433
+ <div class="results" id="sq-res">
434
+ <div class="metric-grid">
435
+ <div class="metric"><div class="metric-val" id="sq-pkd">β€”</div><div class="metric-lab">Predicted pKd</div></div>
436
+ <div class="metric"><div class="metric-val" id="sq-ci" style="font-size:1.1rem">β€”</div><div class="metric-lab">95% interval</div></div>
437
+ <div class="metric"><div class="metric-val" id="sq-ki">β€”</div><div class="metric-lab">Estimated Ki</div></div>
438
+ <div class="metric">
439
+ <div class="metric-val" style="font-size:0;padding-top:8px">
440
+ <span class="ad-badge ad-unk" id="sq-ad">β€”</span>
441
+ </div>
442
+ <div class="metric-lab" style="margin-top:7px">Applicability domain</div>
443
+ </div>
444
+ </div>
445
+ <div class="inf-caption" id="sq-cap"></div>
446
+ </div>
447
+ </div>
448
+
449
+ <!-- ═══ TAB 1 ═══════════════════════════════════════════════ -->
450
+ <div class="tab-panel" id="tab-1">
451
+ <div class="grid-2">
452
+ <div class="card">
453
+ <div class="card-title">Target protein</div>
454
+ <label>Sequence β€” plain or FASTA</label>
455
+ <textarea id="bs-seq" placeholder=">Target&#10;MKTAYIAKQRQISFVK..."></textarea>
456
+ </div>
457
+ <div class="card">
458
+ <div class="card-title">Compound library</div>
459
+ <label>CSV with <code style="color:var(--accent2)">smiles</code> column (+ optional <code style="color:var(--accent2)">name</code>)</label>
460
+ <div class="file-drop" id="bs-drop" onclick="document.getElementById('bs-file').click()"
461
+ ondragover="event.preventDefault();this.classList.add('drag')"
462
+ ondragleave="this.classList.remove('drag')"
463
+ ondrop="handleDrop(event,'bs-file','bs-lbl')">
464
+ <input type="file" id="bs-file" accept=".csv" onchange="fileSel(this,'bs-lbl')">
465
+ <div id="bs-lbl">Drop CSV here or click to upload</div>
466
+ </div>
467
+ <div style="margin-top:10px">
468
+ <label>Max compounds</label>
469
+ <input type="number" id="bs-max" value="100" min="1" max="500">
470
+ </div>
471
+ </div>
472
+ </div>
473
+ <button class="btn btn-primary" id="bs-btn" onclick="runBatch()">
474
+ <span id="bs-txt">Run batch screening</span>
475
+ </button>
476
+ <div class="warn-box" id="bs-warn"></div>
477
+ <div class="results" id="bs-res">
478
+ <div class="card" style="margin-top:0">
479
+ <div class="card-title">Ranked results</div>
480
+ <div class="tbl-wrap">
481
+ <table>
482
+ <thead><tr><th>#</th><th>Name</th><th>pKd</th><th>CI lo</th><th>CI hi</th><th>Ki</th><th>Std</th><th>AD</th></tr></thead>
483
+ <tbody id="bs-body"></tbody>
484
+ </table>
485
+ </div>
486
+ <div class="download-wrap">
487
+ <button class="btn btn-outline" onclick="dlCSV('bs')">Download CSV</button>
488
+ </div>
489
+ </div>
490
+ </div>
491
+ </div>
492
+
493
+ <!-- ═══ TAB 2 ═══════════════════════════════════════════════ -->
494
+ <div class="tab-panel" id="tab-2">
495
+ <div class="grid-2">
496
+ <div class="card">
497
+ <div class="card-title">Compound</div>
498
+ <label>SMILES</label>
499
+ <input type="text" id="sp-smi" placeholder="Cc1ccc(NC(=O)...)cc1...">
500
+ <div style="margin-top:10px">
501
+ <div class="card-title">Examples</div>
502
+ <div class="examples">
503
+ <span class="ex-pill" onclick="document.getElementById('sp-smi').value='Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1'">Imatinib</span>
504
+ <span class="ex-pill" onclick="document.getElementById('sp-smi').value='COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1'">Gefitinib</span>
505
+ </div>
506
+ </div>
507
+ </div>
508
+ <div class="card">
509
+ <div class="card-title">Target proteins</div>
510
+ <label>One per line β€” <code style="color:var(--accent2)">Name: SEQUENCE</code></label>
511
+ <textarea id="sp-seqs" style="min-height:160px"
512
+ placeholder="ABL1: MGPSENDPNLFVALY...&#10;EGFR: MRPSGTAGAALLALL...&#10;CDK2: MENFQKVEKIGEGTY..."></textarea>
513
+ </div>
514
+ </div>
515
+ <button class="btn btn-primary" id="sp-btn" onclick="runSelectivity()">
516
+ <span id="sp-txt">Run selectivity profiling</span>
517
+ </button>
518
+ <div class="warn-box" id="sp-warn"></div>
519
+ <div class="results" id="sp-res">
520
+ <div class="card" style="margin-top:0">
521
+ <div class="card-title">Selectivity profile</div>
522
+ <div class="tbl-wrap">
523
+ <table>
524
+ <thead><tr><th>#</th><th>Target</th><th>pKd</th><th>CI lo</th><th>CI hi</th><th>Ki</th><th>Std</th><th>AD</th></tr></thead>
525
+ <tbody id="sp-body"></tbody>
526
+ </table>
527
+ </div>
528
+ <div class="download-wrap">
529
+ <button class="btn btn-outline" onclick="dlCSV('sp')">Download CSV</button>
530
+ </div>
531
+ </div>
532
+ </div>
533
+ </div>
534
+
535
+ <footer>
536
+ <p><span>VeloBind</span> &nbsp;Β·&nbsp; ESM-2 + GBM ensemble &nbsp;Β·&nbsp;
537
+ LP-PDBBind training &nbsp;Β·&nbsp; CASF-2016/2013 &nbsp;Β·&nbsp;
538
+ <span>Not for clinical use</span></p>
539
+ </footer>
540
  </div>
541
+
542
+ <script>
543
+ // ── THEME ────────────────────────────────────────────────────
544
+ const html = document.documentElement;
545
+ const MOON = `<path d="M21 12.79A9 9 0 1 1 11.21 3 7 7 0 0 0 21 12.79z"/>`;
546
+ const SUN = `<circle cx="12" cy="12" r="5"/>
547
+ <line x1="12" y1="1" x2="12" y2="3"/><line x1="12" y1="21" x2="12" y2="23"/>
548
+ <line x1="4.22" y1="4.22" x2="5.64" y2="5.64"/>
549
+ <line x1="18.36" y1="18.36" x2="19.78" y2="19.78"/>
550
+ <line x1="1" y1="12" x2="3" y2="12"/><line x1="21" y1="12" x2="23" y2="12"/>
551
+ <line x1="4.22" y1="19.78" x2="5.64" y2="18.36"/>
552
+ <line x1="18.36" y1="5.64" x2="19.78" y2="4.22"/>`;
553
+
554
+ function applyTheme(t) {
555
+ html.setAttribute('data-theme', t);
556
+ document.getElementById('t-icon').innerHTML = t==='dark' ? SUN : MOON;
557
+ document.getElementById('t-label').textContent = t==='dark' ? 'Light mode' : 'Dark mode';
558
+ localStorage.setItem('vb-theme', t);
559
+ }
560
+ function toggleTheme() {
561
+ applyTheme(html.getAttribute('data-theme')==='dark' ? 'light' : 'dark');
562
+ }
563
+ applyTheme(localStorage.getItem('vb-theme') || 'dark');
564
+
565
+ // ── TABS ─────────────────────────────────────────────────────
566
+ function switchTab(i) {
567
+ document.querySelectorAll('.tab-btn').forEach((b,j) => b.classList.toggle('active', i===j));
568
+ document.querySelectorAll('.tab-panel').forEach((p,j) => p.classList.toggle('active', i===j));
569
+ }
570
+
571
+ // ── UTILS ─────────────────────────────────────────────────────
572
+ function setSmi(s) { document.getElementById('sq-smi').value = s; }
573
+ function overlay(on, msg='Running inference...') {
574
+ document.getElementById('overlay').classList.toggle('show', on);
575
+ document.getElementById('ov-text').textContent = msg;
576
+ }
577
+ function setBtn(id, loading) {
578
+ const labels = {sq:'Predict binding affinity', bs:'Run batch screening', sp:'Run selectivity profiling'};
579
+ document.getElementById(id+'-btn').disabled = loading;
580
+ document.getElementById(id+'-txt').innerHTML = loading
581
+ ? '<div class="spinner"></div> Processing...' : labels[id];
582
+ }
583
+ function warn(id, msg) {
584
+ const el = document.getElementById(id+'-warn');
585
+ el.textContent = msg; el.style.display = msg ? 'block' : 'none';
586
+ }
587
+ function adCls(l) {
588
+ return l==='IN DOMAIN' ? 'ad-in' : l==='OUT OF DOMAIN' ? 'ad-out' : 'ad-unk';
589
+ }
590
+ function fileSel(inp, lid) {
591
+ document.getElementById(lid).textContent = inp.files[0]?.name || 'No file';
592
+ }
593
+ function handleDrop(ev, fid, lid) {
594
+ ev.preventDefault(); ev.currentTarget.classList.remove('drag');
595
+ const f = ev.dataTransfer.files[0]; if (!f) return;
596
+ const dt = new DataTransfer(); dt.items.add(f);
597
+ document.getElementById(fid).files = dt.files;
598
+ document.getElementById(lid).textContent = f.name;
599
+ }
600
+
601
+ // CSV data store
602
+ const csvStore = {bs:[], sp:[]};
603
+ function dlCSV(p) {
604
+ const rows = csvStore[p]; if (!rows.length) return;
605
+ const keys = Object.keys(rows[0]);
606
+ const csv = [keys.join(','),
607
+ ...rows.map(r => keys.map(k => JSON.stringify(r[k]??'')).join(','))
608
+ ].join('\n');
609
+ const a = document.createElement('a');
610
+ a.href = URL.createObjectURL(new Blob([csv],{type:'text/csv'}));
611
+ a.download = `velobind_${p}_${Date.now()}.csv`;
612
+ a.click();
613
+ }
614
+
615
+ // Device badge
616
+ fetch('/device').then(r=>r.json()).then(d=>{
617
+ document.getElementById('dev-badge').textContent = d.device.toUpperCase();
618
+ });
619
+
620
+ // ── SINGLE QUERY ─────────────────────────────────────────────
621
+ async function runSingle() {
622
+ const seq = document.getElementById('sq-seq').value.trim();
623
+ const smi = document.getElementById('sq-smi').value.trim();
624
+ warn('sq','');
625
+ if (!seq||!smi) { warn('sq','Please enter both a sequence and a SMILES.'); return; }
626
+ setBtn('sq',true); overlay(true,'Embedding protein and ligand...');
627
+ try {
628
+ const r = await fetch('/predict_single',{method:'POST',
629
+ headers:{'Content-Type':'application/json'},body:JSON.stringify({seq,smi})});
630
+ const d = await r.json();
631
+ if (!r.ok) { warn('sq', d.error||'Server error'); return; }
632
+ document.getElementById('sq-pkd').textContent = d.pkd.toFixed(2);
633
+ document.getElementById('sq-ci').textContent = `[${d.lo.toFixed(2)}, ${d.hi.toFixed(2)}]`;
634
+ document.getElementById('sq-ki').textContent = d.ki;
635
+ const adEl = document.getElementById('sq-ad');
636
+ adEl.textContent = d.ad; adEl.className = 'ad-badge '+adCls(d.ad);
637
+ document.getElementById('sq-cap').textContent =
638
+ `${d.elapsed}s Β· 45 models (3 seeds Γ— 3 types Γ— 5 folds) Β· ${d.device.toUpperCase()}`;
639
+ if (d.ad==='OUT OF DOMAIN')
640
+ warn('sq','Protein outside training distribution β€” predictions may be unreliable.');
641
+ document.getElementById('sq-res').classList.add('show');
642
+ } catch(e) { warn('sq','Request failed: '+e.message); }
643
+ finally { setBtn('sq',false); overlay(false); }
644
+ }
645
+
646
+ // ── BATCH ─────────────────────────────────────────────────────
647
+ async function runBatch() {
648
+ const seq = document.getElementById('bs-seq').value.trim();
649
+ const file = document.getElementById('bs-file').files[0];
650
+ const maxC = parseInt(document.getElementById('bs-max').value)||100;
651
+ warn('bs','');
652
+ if (!seq) { warn('bs','Please enter a protein sequence.'); return; }
653
+ if (!file) { warn('bs','Please upload a CSV file.'); return; }
654
+ const csv = await file.text();
655
+ setBtn('bs',true); overlay(true,`Screening up to ${maxC} compounds...`);
656
+ try {
657
+ const r = await fetch('/predict_batch',{method:'POST',
658
+ headers:{'Content-Type':'application/json'},body:JSON.stringify({seq,csv,max_cpds:maxC})});
659
+ const d = await r.json();
660
+ if (!r.ok) { warn('bs', d.error||'Server error'); return; }
661
+ csvStore.bs = d.rows;
662
+ document.getElementById('bs-body').innerHTML = d.rows.map((row,i)=>`
663
+ <tr>
664
+ <td>${i+1}</td><td>${row.name}</td>
665
+ <td class="pkd-val">${row.pKd}</td>
666
+ <td>${row.lo}</td><td>${row.hi}</td>
667
+ <td>${row.ki}</td><td>${row.std}</td>
668
+ <td><span class="ad-badge ${adCls(row.ad)}">${row.ad}</span></td>
669
+ </tr>`).join('');
670
+ if (d.n_invalid) warn('bs',`${d.n_invalid} invalid SMILES skipped.`);
671
+ document.getElementById('bs-res').classList.add('show');
672
+ } catch(e) { warn('bs','Request failed: '+e.message); }
673
+ finally { setBtn('bs',false); overlay(false); }
674
+ }
675
+
676
+ // ── SELECTIVITY ───────────────────────────────────────────────
677
+ async function runSelectivity() {
678
+ const smi = document.getElementById('sp-smi').value.trim();
679
+ const seqs = document.getElementById('sp-seqs').value.trim();
680
+ warn('sp','');
681
+ if (!smi||!seqs) { warn('sp','Please enter a SMILES and at least one sequence.'); return; }
682
+ setBtn('sp',true); overlay(true,'Profiling targets...');
683
+ try {
684
+ const r = await fetch('/predict_selectivity',{method:'POST',
685
+ headers:{'Content-Type':'application/json'},body:JSON.stringify({smi,seqs})});
686
+ const d = await r.json();
687
+ if (!r.ok) { warn('sp', d.error||'Server error'); return; }
688
+ csvStore.sp = d.rows;
689
+ document.getElementById('sp-body').innerHTML = d.rows.map((row,i)=>`
690
+ <tr>
691
+ <td>${i+1}</td><td>${row.target}</td>
692
+ <td class="pkd-val">${row.pKd}</td>
693
+ <td>${row.lo}</td><td>${row.hi}</td>
694
+ <td>${row.ki}</td><td>${row.std}</td>
695
+ <td><span class="ad-badge ${adCls(row.ad)}">${row.ad}</span></td>
696
+ </tr>`).join('');
697
+ if (d.skipped?.length) warn('sp','Skipped: '+d.skipped.join(', '));
698
+ document.getElementById('sp-res').classList.add('show');
699
+ } catch(e) { warn('sp','Request failed: '+e.message); }
700
+ finally { setBtn('sp',false); overlay(false); }
701
+ }
702
+ </script>
703
+ </body>
704
+ </html>"""
705
+
706
+ # ── Flask routes ──────────────────────────────────────────────
707
+ app = Flask(__name__)
708
+
709
+ @app.route("/")
710
+ def index():
711
+ return render_template_string(HTML, logo_uri=LOGO_URI)
712
+
713
+ @app.route("/device")
714
+ def device_info():
715
+ _, _, dev = get_esm()
716
+ return jsonify({"device": dev})
717
+
718
+ @app.route("/predict_single", methods=["POST"])
719
+ def predict_single():
720
+ data = request.get_json()
721
+ seq, err = validate_sequence(data.get("seq",""))
722
+ if err: return jsonify({"error": err}), 400
723
+ smi = data.get("smi","").strip()
724
+ if not smi: return jsonify({"error": "Missing SMILES"}), 400
725
+ try:
726
+ fm, meta, sc, ls = get_models()
727
+ t0 = time.time()
728
+ X, valid, esm_vec = extract_features(seq, [smi], ls)
729
+ if not valid.any():
730
+ return jsonify({"error": "RDKit could not parse this SMILES."}), 400
731
+ preds, lo, hi = run_predict(X, fm, meta, sc)
732
+ elapsed = round(time.time()-t0, 2)
733
+ pkd = float(preds[0])
734
+ _, _, dev = get_esm()
735
+ return jsonify({"pkd": pkd, "lo": float(lo[0]), "hi": float(hi[0]),
736
+ "ki": format_ki(pkd), "ad": ad_label(esm_vec),
737
+ "elapsed": elapsed, "device": dev})
738
+ except Exception as e:
739
+ return jsonify({"error": str(e)}), 500
740
+
741
+ @app.route("/predict_batch", methods=["POST"])
742
+ def predict_batch():
743
+ data = request.get_json()
744
+ seq, err = validate_sequence(data.get("seq",""))
745
+ if err: return jsonify({"error": err}), 400
746
+ try:
747
+ df = pd.read_csv(StringIO(data.get("csv",""))).head(int(data.get("max_cpds",100)))
748
+ if "smiles" not in df.columns:
749
+ return jsonify({"error": "CSV must have a 'smiles' column."}), 400
750
+ smiles_list = df["smiles"].tolist()
751
+ names_list = df["name"].tolist() if "name" in df.columns else [f"cpd_{i}" for i in range(len(df))]
752
+ fm, meta, sc, ls = get_models()
753
+ X, valid, esm_vec = extract_features(seq, smiles_list, ls)
754
+ preds, lo, hi = run_predict(X, fm, meta, sc)
755
+ ad = ad_label(esm_vec)
756
+ std = (hi - lo) / (2*1.96)
757
+ valid_names = [names_list[i] for i in range(len(names_list)) if valid[i]]
758
+ rows = [{"name": n, "pKd": round(float(p),3), "lo": round(float(l),3),
759
+ "hi": round(float(h),3), "ki": format_ki(float(p)),
760
+ "std": round(float(s),3), "ad": ad}
761
+ for n,p,l,h,s in zip(valid_names, preds, lo, hi, std)]
762
+ rows.sort(key=lambda r: r["pKd"], reverse=True)
763
+ return jsonify({"rows": rows, "n_invalid": int((~valid).sum())})
764
+ except Exception as e:
765
+ return jsonify({"error": str(e)}), 500
766
+
767
+ @app.route("/predict_selectivity", methods=["POST"])
768
+ def predict_selectivity():
769
+ data = request.get_json()
770
+ smi = data.get("smi","").strip()
771
+ if not smi: return jsonify({"error": "Missing SMILES"}), 400
772
+ targets, skipped = {}, []
773
+ for i, line in enumerate(data.get("seqs","").strip().splitlines()):
774
+ line = line.strip()
775
+ if not line: continue
776
+ name, raw = (line.split(":",1) if ":" in line else (f"Target_{i+1}", line))
777
+ seq, err = validate_sequence(raw)
778
+ if err: skipped.append(name.strip())
779
+ else: targets[name.strip()] = seq
780
+ if not targets: return jsonify({"error": "No valid sequences found."}), 400
781
+ try:
782
+ fm, meta, sc, ls = get_models()
783
+ rows = []
784
+ for name, seq in targets.items():
785
+ try:
786
+ X, valid, esm_vec = extract_features(seq, [smi], ls)
787
+ if not valid.any(): continue
788
+ preds, lo, hi = run_predict(X, fm, meta, sc)
789
+ pkd = float(preds[0])
790
+ std = (float(hi[0])-float(lo[0]))/(2*1.96)
791
+ rows.append({"target": name, "pKd": round(pkd,3),
792
+ "lo": round(float(lo[0]),3), "hi": round(float(hi[0]),3),
793
+ "ki": format_ki(pkd), "std": round(std,3),
794
+ "ad": ad_label(esm_vec)})
795
+ except Exception: skipped.append(name)
796
+ rows.sort(key=lambda r: r["pKd"], reverse=True)
797
+ return jsonify({"rows": rows, "skipped": skipped})
798
+ except Exception as e:
799
+ return jsonify({"error": str(e)}), 500
800
+
801
+ if __name__ == "__main__":
802
+ print("Loading models..."); get_models()
803
+ print("Loading ESM-2..."); get_esm()
804
+ print("Ready β€” http://localhost:7860")
805
+ app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)