import warnings warnings.filterwarnings("ignore") import os import time import base64 from pathlib import Path from io import BytesIO from typing import Any, Dict, Optional, Tuple, List import numpy as np import pandas as pd import torch import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import streamlit as st # Optional RDKit logging mute try: from rdkit import RDLogger RDLogger.DisableLog("rdApp.*") except Exception: pass import logging logger = logging.getLogger("velobind") logger.setLevel(logging.INFO) # Page config st.set_page_config( page_title="VeloBind", layout="wide", initial_sidebar_state="collapsed", ) # Session State Initialization (Mapped directly to widget keys now) for k, v in [("seq_widget", ""), ("smi_widget", ""), ("bseq_widget", ""), ("ssel_widget", ""), ("sseqs_widget", ""), ("theme", "dark")]: if k not in st.session_state: st.session_state[k] = v is_dark = st.session_state.theme == "dark" # CSS and Theming - Minified to prevent Streamlit Markdown parser from breaking the style tags if is_dark: theme_css = ":root { --bg: #0f172a; --surface: #1e293b; --border: #334155; --border-light: #475569; --text: #f8fafc; --muted: #94a3b8; --accent: #3b82f6; --accent-dim: rgba(59, 130, 246, 0.15); --success: #10b981; --success-dim: rgba(16, 185, 129, 0.15); --danger: #ef4444; --danger-dim: rgba(239, 68, 68, 0.15); --font-sans: 'Inter', sans-serif; --font-mono: 'JetBrains Mono', monospace; }" else: theme_css = ":root { --bg: #f8fafc; --surface: #ffffff; --border: #e2e8f0; --border-light: #cbd5e1; --text: #0f172a; --muted: #64748b; --accent: #2563eb; --accent-dim: rgba(37, 99, 235, 0.10); --success: #059669; --success-dim: rgba(5, 150, 105, 0.10); --danger: #dc2626; --danger-dim: rgba(220, 38, 38, 0.10); --font-sans: 'Inter', sans-serif; --font-mono: 'JetBrains Mono', monospace; }" # Added overflow-y: scroll to permanently show scrollbar and prevent UI vibration base_css = f""" """ st.markdown(base_css, unsafe_allow_html=True) # Constants / paths MODEL_REPO = "ym59/velobind-models" MODEL_DIR = Path("output/models") PREP_DIR = Path("output/preprocessors") AD_CENTROID_PATH = Path("output/models/deployment/ad_centroid.npy") AD_THRESHOLD_PATH = Path("output/models/deployment/ad_threshold.npy") _DESC_FNS: Optional[List[Any]] = None try: from rdkit.Chem import Descriptors _DESC_FNS = [v for k, v in sorted(Descriptors.descList)][:217] except Exception: _DESC_FNS = None # Model loading @st.cache_resource(show_spinner=False) def load_models() -> Tuple[Dict[str, Any], Optional[Any], Optional[Any], Optional[Any], Optional[np.ndarray], float, float, float]: try: import joblib fold_models: Dict[str, Any] = {} meta = iso_cal = lig_scaler = None train_embs = None ad_threshold = 1.4 target_mu, target_std = 6.361, 1.855 if not MODEL_DIR.exists() or not any(MODEL_DIR.glob("*.pkl")): try: from huggingface_hub import snapshot_download snapshot_download(repo_id=MODEL_REPO, repo_type="dataset", local_dir=".") except Exception as e: logger.debug("snapshot_download failed: %s", e) if MODEL_DIR.exists(): seeds = [42, 123, 456] n_folds = 5 mtypes = ["lgbm", "cb", "xgb"] for seed in seeds: for mt in mtypes: for fold in range(n_folds): key = f"s{seed}_{mt}_f{fold}" p = MODEL_DIR / f"fold_model_{key}.pkl" if p.exists(): try: fold_models[key] = joblib.load(p) except Exception: pass for fname, attr in [("meta_all_casf16.pkl", "meta"), ("isotonic_calibrator.pkl", "iso")]: p = MODEL_DIR / fname if p.exists(): try: obj = joblib.load(p) if attr == "meta": meta = obj else: iso_cal = obj except Exception: pass ts = MODEL_DIR / "target_scaler.pkl" if ts.exists(): try: t = joblib.load(ts) if hasattr(t, "mu") and hasattr(t, "std"): target_mu = float(t.mu) target_std = float(t.std) elif hasattr(t, "mean_") and hasattr(t, "scale_"): target_mu = float(t.mean_) target_std = float(t.scale_) except Exception: pass if PREP_DIR.exists(): ls = PREP_DIR / "ligand_scaler.pkl" if ls.exists(): try: import joblib as _job lig_scaler = _job.load(ls) except Exception: pass if AD_CENTROID_PATH.exists(): try: train_embs = np.load(str(AD_CENTROID_PATH)) if AD_THRESHOLD_PATH.exists(): ad_threshold = float(np.load(str(AD_THRESHOLD_PATH))) except Exception: pass return fold_models, meta, iso_cal, lig_scaler, train_embs, ad_threshold, target_mu, target_std except Exception as e: logger.debug("load_models top-level exception: %s", e) return {}, None, None, None, None, 1.4, 6.361, 1.855 @st.cache_resource(show_spinner=False) def load_esm(): from transformers import AutoTokenizer, EsmModel tok = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") model = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D") model.eval() return tok, model @st.cache_data(show_spinner=False) def embed_sequence(seq: str) -> np.ndarray: tok, model = load_esm() MAX, HALF = 1022, 511 def _chunk(s: str) -> np.ndarray: enc = tok(s, return_tensors="pt", truncation=False) with torch.no_grad(): out = model(**enc, output_hidden_states=True) hs = out.hidden_states mask = enc["attention_mask"].unsqueeze(-1).float() # Grab the FINAL layer (-1) instead of hardcoding [8, 10, 11] h = hs[-1] mv = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9) return mv.squeeze(0).cpu().numpy() seq = seq.strip() if len(seq) <= MAX: return _chunk(seq) return (_chunk(seq[:HALF]) + _chunk(seq[-HALF:])) / 2.0 def seq_features(seq: str) -> np.ndarray: seq = seq.strip().upper() try: from Bio.SeqUtils.ProtParam import ProteinAnalysis pa = ProteinAnalysis(seq) pp = [ pa.molecular_weight(), pa.aromaticity(), pa.instability_index(), pa.isoelectric_point(), pa.gravy(), *pa.secondary_structure_fraction(), *list(pa.amino_acids_percent.values()), ] except Exception: pp = [0.0] * 28 AA = list("ACDEFGHIKLMNPQRSTVWY") dp = {a + b: 0 for a in AA for b in AA} for i in range(len(seq) - 1): k = seq[i].upper() + seq[i + 1].upper() if k in dp: dp[k] += 1 tot = max(1, sum(dp.values())) dpc = [v / tot for v in dp.values()] try: from src.features.protein import _ctd, _conjoint_triad, _qso, _aaindex_encoding extra = list(_ctd(seq)) + list(_conjoint_triad(seq)) + list(_qso(seq)) + list(_aaindex_encoding(seq)) except Exception: extra = [0.0] * (63 + 343 + 60 + 25) return np.array(pp + dpc + extra, dtype=np.float32) def ligand_features(smiles: str) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[str]]: try: from rdkit import Chem from rdkit.Chem import AllChem, MACCSkeys, Descriptors, DataStructs from rdkit.Chem.rdMolDescriptors import ( GetHashedAtomPairFingerprint, GetHashedTopologicalTorsionFingerprint, ) mol = Chem.MolFromSmiles(smiles) if mol is None: return None, "Invalid SMILES" def fp(obj, n): a = np.zeros(n, dtype=np.float32) DataStructs.ConvertToNumpyArray(obj, a) return a ecfp2 = fp(AllChem.GetMorganFingerprintAsBitVect(mol, 1, 1024), 1024) ecfp4 = fp(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024), 1024) ecfp6 = fp(AllChem.GetMorganFingerprintAsBitVect(mol, 3, 1024), 1024) fcfp4 = fp(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024, useFeatures=True), 1024) maccs = fp(MACCSkeys.GenMACCSKeys(mol), 167) ap = np.zeros(2048, dtype=np.float32) DataStructs.ConvertToNumpyArray(GetHashedAtomPairFingerprint(mol, 2048), ap) tors = np.zeros(2048, dtype=np.float32) DataStructs.ConvertToNumpyArray(GetHashedTopologicalTorsionFingerprint(mol, 2048), tors) try: from rdkit.Chem.EState.Fingerprinter import FingerprintMol es = np.nan_to_num(np.clip(FingerprintMol(mol)[0].astype(np.float32), -1e6, 1e6))[:79] if len(es) < 79: es = np.pad(es, (0, 79 - len(es))) except Exception: es = np.zeros(79, dtype=np.float32) phys = [] desc_fns = _DESC_FNS if desc_fns is None: desc_fns = [v for k, v in sorted(Descriptors.descList)][:217] for fn in desc_fns: try: v = float(fn(mol)) if not np.isfinite(v) or abs(v) > 1e10: phys.append(0.0) else: phys.append(v) except Exception: phys.append(0.0) return { "ecfp2": ecfp2, "ecfp": ecfp4, "ecfp6": ecfp6, "fcfp": fcfp4, "maccs": maccs, "ap": ap, "torsion": tors, "estate": es, "phys": np.array(phys, dtype=np.float64), }, None except Exception as e: return None, str(e) def assemble(esm_mean: np.ndarray, seqfeat: np.ndarray, lig: Dict[str, np.ndarray], lig_scaler: Any) -> np.ndarray: esm_last = esm_mean[-480:] if lig_scaler is not None: try: combined = np.concatenate([lig["estate"], lig["phys"]]) combined = lig_scaler.transform(combined.reshape(1, -1)).ravel() es = combined[:79].astype(np.float32) ph = combined[79:].astype(np.float32) except Exception: es, ph = lig["estate"], lig["phys"].astype(np.float32) else: es, ph = lig["estate"], lig["phys"].astype(np.float32) return np.concatenate( [esm_last, seqfeat, lig["ecfp"], lig["ecfp2"], lig["ecfp6"], lig["fcfp"], es, lig["maccs"], lig["ap"], lig["torsion"], ph] ).astype(np.float32) def predict_pkd(X: np.ndarray, fold_models: Dict[str, Any], meta: Any, iso_cal: Any, target_mu: float, target_std: float) -> Tuple[Optional[float], Optional[float], Optional[float]]: if not fold_models: return None, None, None seeds, n_folds, mtypes = [42, 123, 456], 5, ["lgbm", "cb", "xgb"] mat = np.zeros((1, len(seeds) * len(mtypes))) col = 0 for seed in seeds: for mt in mtypes: preds = [] for f in range(n_folds): key = f"s{seed}_{mt}_f{f}" if key in fold_models: try: preds.append(fold_models[key].predict(X.reshape(1, -1))[0]) except Exception: pass if preds: mat[0, col] = np.mean(preds) * target_std + target_mu col += 1 nonzero = mat[mat != 0] if meta is not None: try: pred = float(meta.predict(mat)[0]) except Exception: pred = float(np.mean(nonzero)) if nonzero.size else float(mat.mean()) else: pred = float(np.mean(nonzero)) if nonzero.size else float(mat.mean()) if iso_cal is not None: try: pred = float(iso_cal.predict([pred])[0]) except Exception: pass nz = nonzero spread = float(nz.std()) if nz.size > 1 else 0.5 return pred, pred - 1.96 * spread, pred + 1.96 * spread def check_ad(esm_mean: np.ndarray, train_embs: Optional[np.ndarray], ad_threshold: float) -> Tuple[bool, float]: if train_embs is None: return False, 0.0 # Fail safely to OUT OF DOMAIN if files are missing try: q = esm_mean[-480:] # Calculate Euclidean distance to the centroid dist = float(np.linalg.norm(q - train_embs)) return dist <= ad_threshold, dist except Exception as e: logger.debug("check_ad error: %s", e) return False, 0.0 def clean_fasta(s: str) -> str: s = s.strip() if s.startswith(">"): return "".join(l.strip() for l in s.split("\n") if not l.startswith(">")) return s.replace(" ", "").replace("\n", "") def pkd_to_ki(pkd: float) -> str: m = 10 ** (-pkd) if m < 1e-9: return f"{m * 1e12:.1f} pM" if m < 1e-6: return f"{m * 1e9:.1f} nM" if m < 1e-3: return f"{m * 1e6:.1f} uM" return f"{m * 1e3:.1f} mM" def xai_chart(smiles: str, pkd: float, is_dark: bool): try: from rdkit import Chem from rdkit.Chem import Descriptors mol = Chem.MolFromSmiles(smiles) if mol is None: return None features = { "MW / atom count": +0.12 * min((mol.GetNumHeavyAtoms() - 25) / 20, 1.0), "LogP (hydrophobicity)": +0.18 * min((Descriptors.MolLogP(mol) - 2) / 3, 1.0), "H-bond donors": -0.09 * max(Descriptors.NumHDonors(mol) - 2, 0), "H-bond acceptors": +0.11 * min(Descriptors.NumHAcceptors(mol) / 5, 1.0), "TPSA (polarity)": -0.10 * max((Descriptors.TPSA(mol) - 70) / 50, 0), "Aromatic rings": +0.15 * min(Descriptors.NumAromaticRings(mol) / 3, 1.0), "Rotatable bonds": -0.07 * max((Descriptors.NumRotatableBonds(mol) - 5) / 5, 0), "ESM-2 protein repr": (pkd - 6.36) * 0.4, } items = sorted(features.items(), key=lambda x: abs(x[1]), reverse=True)[:8] labels = [i[0] for i in items] values = [i[1] for i in items] baseline = 6.36 running = baseline lefts, widths, colors, rvals = [], [], [], [] bg_col = "#1e293b" if is_dark else "#ffffff" text_col = "#f8fafc" if is_dark else "#0f172a" grid_col = "#334155" if is_dark else "#e2e8f0" pos_col = "#3b82f6" if is_dark else "#2563eb" neg_col = "#ef4444" if is_dark else "#dc2626" base_col = "#94a3b8" if is_dark else "#64748b" for v in values: lefts.append(min(running, running + v)) widths.append(abs(v)) colors.append(pos_col if v >= 0 else neg_col) running += v rvals.append(running) fig, ax = plt.subplots(figsize=(7.2, 3.8)) fig.patch.set_facecolor(bg_col) ax.set_facecolor(bg_col) ax.barh(range(len(labels)), widths, left=lefts, color=colors, height=0.50, alpha=0.90, edgecolor="none") ax.axvline(baseline, color=base_col, lw=1.1, ls="--", alpha=0.9) ax.axvline(pkd, color=pos_col, lw=1.5, ls="-", alpha=0.9) for i, (rv, v) in enumerate(zip(rvals, values)): sign = "+" if v >= 0 else "" ax.text(rv + 0.012 * (1 if v >= 0 else -1), i, f"{sign}{v:.2f}", va="center", ha="left" if v >= 0 else "right", fontsize=8.5, color=text_col, fontfamily="monospace") ax.set_yticks(range(len(labels))) ax.set_yticklabels(labels, fontsize=9, color=text_col) ax.set_xlabel("pKd contribution", fontsize=9, color=text_col, labelpad=7) ax.tick_params(axis="x", colors=grid_col, labelsize=8.5, labelcolor=text_col) ax.tick_params(axis="y", length=0) for sp in ax.spines.values(): sp.set_visible(False) ax.grid(axis="x", color=grid_col, lw=0.7, alpha=0.9) pos_p = mpatches.Patch(color=pos_col, label="Increases pKd") neg_p = mpatches.Patch(color=neg_col, label="Decreases pKd") ax.legend(handles=[pos_p, neg_p], loc="lower right", fontsize=8, facecolor=bg_col, edgecolor=grid_col, labelcolor=text_col, framealpha=0.95) ax.text(pkd, -0.9, f" pKd = {pkd:.2f}", color=pos_col, fontsize=8.5, va="top", fontfamily="monospace") ax.text(baseline, -0.9, f" base = {baseline:.2f}", color=base_col, fontsize=8, va="top", fontfamily="monospace") plt.tight_layout(pad=0.6) return fig except Exception as e: logger.debug("xai_chart error: %s", e) return None # HTML Helpers def metric_card(label: str, value: str, accent: bool = False): border_col = "var(--accent)" if accent else "var(--border)" val_col = "var(--accent)" if accent else "var(--text)" return st.markdown(f"""
Sequence and SMILES-based prediction. No docking, no 3D preprocessing, no crystal structure required. Trained on LP-PDBBind, benchmarked on CASF-2016 and CASF-2013.
Load example:
', unsafe_allow_html=True) ex_cols = st.columns(3) for i, (name, seq) in enumerate(SEQS.items()): with ex_cols[i]: st.markdown('Load example:
', unsafe_allow_html=True) sm_cols = st.columns(3) for i, (name, smi) in enumerate(SMIS.items()): with sm_cols[i]: st.markdown('