Spaces:
Sleeping
Sleeping
| """ | |
| BALM-PPI Pro Β· ESM-2 + LoRA + Integrated Gradients | |
| ===================================================== | |
| Gradio port with HuggingFace ZeroGPU support. | |
| Key porting notes: | |
| - All GPU work (forward pass + IG) is wrapped in @spaces.GPU so it runs on | |
| ZeroGPU's dynamic A100s. Outside the decorator, the model lives on CPU. | |
| - IG kept at float32 with 15 Riemann steps (same fix as the Streamlit version). | |
| - NGL viewer + Plotly visualisations + theming preserved. | |
| - Per-session state (PDB, chain info, IG arrays) is held in gr.State so the | |
| app is safe under concurrent users on Spaces. | |
| """ | |
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import io, re, json, csv, traceback, tempfile, copy | |
| import io as sio | |
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import requests | |
| import plotly.graph_objects as go | |
| from Bio import PDB | |
| from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE | |
| # βββ ZeroGPU compatibility shim ββββββββββββββββββββββββββββββββββββββββββββ | |
| # On HF Spaces with ZeroGPU, `spaces` is installed and @spaces.GPU works. | |
| # Locally (or off-Spaces) we fall back to a no-op decorator so the same file | |
| # runs both places without modification. | |
| try: | |
| import spaces | |
| _HAS_SPACES = True | |
| except ImportError: | |
| _HAS_SPACES = False | |
| class _SpacesShim: | |
| def GPU(duration=60): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _SpacesShim() # type: ignore | |
| HF_REPO_ID = "Harshit494/BALM-PPI" | |
| HF_FILENAME = "best_model_fold_1.pth" | |
| UPPER_3TO1 = {k.upper(): v for k, v in THREE_TO_ONE.items()} | |
| EXAMPLES = [ | |
| { | |
| "pdb": "1YCR", "label": "1YCR", "subtitle": "MDM2 β p53", | |
| "desc": "MDM2 oncoprotein bound to the p53 transactivation peptide. Chain A = MDM2, Chain B = p53. Classic drug-target complex.", | |
| "mode": "Protein-Protein", "chain_a": "A", "chain_b": "B", | |
| }, | |
| { | |
| "pdb": "1BRS", "label": "1BRS", "subtitle": "Barnase β Barstar", | |
| "desc": "Ultra-tight complex (Kd ~10β»ΒΉβ΄ M). Chain A = Barnase, Chain D = Barstar. Classic benchmark.", | |
| "mode": "Protein-Protein", "chain_a": "A", "chain_b": "D", | |
| }, | |
| ] | |
| PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n" | |
| ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GLOBAL MODEL CACHE | |
| # Loaded lazily on CPU; moved to CUDA only inside @spaces.GPU functions. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _MODEL_STATE = { | |
| "model": None, | |
| "pkd_bounds": None, # tuple (lo, hi) β rebuilt if bounds change | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PDB UTILITIES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _PDB_CACHE: dict = {} | |
| def _fetch_pdb_cached(pdb_id: str) -> str: | |
| pdb_id = pdb_id.strip().upper() | |
| if pdb_id in _PDB_CACHE: | |
| return _PDB_CACHE[pdb_id] | |
| url = f"https://files.rcsb.org/download/{pdb_id}.pdb" | |
| r = requests.get(url, timeout=30) | |
| r.raise_for_status() | |
| _PDB_CACHE[pdb_id] = r.text | |
| # Keep the cache from growing unbounded | |
| if len(_PDB_CACHE) > 64: | |
| _PDB_CACHE.pop(next(iter(_PDB_CACHE))) | |
| return r.text | |
| def filter_pdb_to_chains(pdb_text: str, keep_chains: set) -> str: | |
| """Strip PDB to only ATOM/HETATM records for chain IDs in keep_chains.""" | |
| out = [] | |
| for line in pdb_text.splitlines(): | |
| rec = line[:6].strip() | |
| if rec in ("ATOM", "HETATM"): | |
| chain_col = line[21:22] | |
| if chain_col not in keep_chains: | |
| continue | |
| out.append(line) | |
| return "\n".join(out) | |
| def get_chains_from_pdb(pdb_id: str): | |
| try: | |
| pdb_text = _fetch_pdb_cached(pdb_id) | |
| except Exception as e: | |
| return None, {}, f"Could not fetch PDB {pdb_id}: {e}" | |
| import warnings as _w | |
| with _w.catch_warnings(): | |
| _w.simplefilter("ignore") | |
| parser = PDB.PDBParser(QUIET=True) | |
| struct = parser.get_structure(pdb_id, io.StringIO(pdb_text)) | |
| chains: dict = {} | |
| for mdl in struct: | |
| for chain in mdl: | |
| cid = chain.id | |
| if cid not in chains: | |
| chains[cid] = {"seq": "", "resnos": []} | |
| for residue in chain: | |
| if PDB.is_aa(residue, standard=True): | |
| aa = UPPER_3TO1.get(residue.get_resname().upper(), "X") | |
| chains[cid]["seq"] += aa | |
| chains[cid]["resnos"].append(residue.get_id()[1]) | |
| break | |
| chains = {k: v for k, v in chains.items() if v["seq"]} | |
| return pdb_text, chains, None | |
| def resolve_chains(chains_dict, ids_str, available): | |
| if not ids_str: | |
| return "", [], [] | |
| ids = [c.strip() for c in str(ids_str).split(",") if c.strip()] | |
| seqs, infos, missing = [], [], [] | |
| for cid in ids: | |
| if cid in chains_dict and chains_dict[cid]["seq"]: | |
| seqs.append(chains_dict[cid]["seq"]) | |
| infos.append({"id": cid, **chains_dict[cid]}) | |
| else: | |
| found = next((k for k in chains_dict | |
| if k.upper() == cid.upper() and chains_dict[k]["seq"]), None) | |
| if found: | |
| seqs.append(chains_dict[found]["seq"]) | |
| infos.append({"id": found, **chains_dict[found]}) | |
| else: | |
| missing.append(cid) | |
| return "|".join(seqs), infos, missing | |
| def clean_multi(s: str) -> str: | |
| if not s: | |
| return "" | |
| out = [] | |
| for p in s.split("|"): | |
| p = re.sub(r">.*", "", p, flags=re.MULTILINE) | |
| p = re.sub(r"[^A-Za-z]", "", p).upper() | |
| if p: | |
| out.append(p) | |
| return "|".join(out) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # IG HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def split_ig_by_chains(ig_flat, chain_seqs, n_sep=2): | |
| if not ig_flat: | |
| return [[] for _ in chain_seqs] | |
| result, offset = [], 0 | |
| for i, seq in enumerate(chain_seqs): | |
| result.append(list(ig_flat[offset:offset + len(seq)])) | |
| offset += len(seq) | |
| if i < len(chain_seqs) - 1: | |
| offset += n_sep | |
| return result | |
| def build_ig_chain_map(chain_info_a, chain_info_b, ig_a, ig_b): | |
| result = {} | |
| for info, ig_flat in [(chain_info_a, ig_a), (chain_info_b, ig_b)]: | |
| if not info: | |
| continue | |
| per_chain = split_ig_by_chains(ig_flat, [c["seq"] for c in info]) | |
| for i, ch in enumerate(info): | |
| if ch.get("resnos"): | |
| result[ch["id"]] = {"start": ch["resnos"][0], "scores": per_chain[i]} | |
| return result | |
| def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b): | |
| seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else "" | |
| seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else "" | |
| if ig_a and chain_info_a: | |
| parts = split_ig_by_chains(ig_a, [c["seq"] for c in chain_info_a]) | |
| ig_a_out = [s for sub in parts for s in sub] | |
| else: | |
| ig_a_out = list(ig_a) if ig_a else [] | |
| if ig_b and chain_info_b: | |
| parts = split_ig_by_chains(ig_b, [c["seq"] for c in chain_info_b]) | |
| ig_b_out = [s for sub in parts for s in sub] | |
| else: | |
| ig_b_out = list(ig_b) if ig_b else [] | |
| return seq_a, ig_a_out, seq_b, ig_b_out | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL BUILD (CPU-safe β no CUDA touched here) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_esm_base(): | |
| from transformers import EsmModel, EsmTokenizer | |
| base = EsmModel.from_pretrained( | |
| "facebook/esm2_t33_650M_UR50D", | |
| use_safetensors=True, | |
| ) | |
| tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| return base, tok | |
| def _download_weights_hf() -> bytes: | |
| from huggingface_hub import hf_hub_download | |
| path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True) | |
| with open(path, "rb") as f: | |
| return f.read() | |
| def build_model(weights_bytes, pkd_bounds): | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from peft import LoraConfig, get_peft_model | |
| class BALMProjectionHead(nn.Module): | |
| def __init__(self, embedding_size, projected_size=256): | |
| super().__init__() | |
| self.protein_projection = nn.Linear(embedding_size, projected_size) | |
| self.proteina_projection = nn.Linear(embedding_size, projected_size) | |
| self.dropout = nn.Dropout(0.1) | |
| def forward(self, prot_emb, prot_a_emb): | |
| p = F.normalize(self.protein_projection(self.dropout(prot_emb)), p=2, dim=1) | |
| pa = F.normalize(self.proteina_projection(self.dropout(prot_a_emb)), p=2, dim=1) | |
| return torch.clamp(F.cosine_similarity(p, pa), -0.9999, 0.9999) | |
| class BALMForLoRAFinetuning(nn.Module): | |
| def __init__(self, esm_model, esm_tokenizer, pkd_bounds): | |
| super().__init__() | |
| self.esm_model = esm_model | |
| self.esm_tokenizer = esm_tokenizer | |
| self.projection_head = BALMProjectionHead(self.esm_model.config.hidden_size) | |
| self.pkd_lower, self.pkd_upper = pkd_bounds | |
| self.cls_token = self.esm_tokenizer.cls_token | |
| def _get_esm_embeddings(self, sequences): | |
| processed = [s.replace("|", f"{self.cls_token}{self.cls_token}") for s in sequences] | |
| inputs = self.esm_tokenizer(processed, return_tensors="pt", padding=True, | |
| truncation=True, max_length=1024) | |
| inputs = {k: v.to(self.esm_model.device) for k, v in inputs.items()} | |
| h = self.esm_model(**inputs).last_hidden_state | |
| mask = inputs["attention_mask"].unsqueeze(-1).expand(h.size()).float() | |
| return (torch.sum(h * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)).float() | |
| def forward(self, seq_a, seq_b): | |
| ea = self._get_esm_embeddings([seq_a]) | |
| eb = self._get_esm_embeddings([seq_b]) | |
| cos = self.projection_head(ea, eb) | |
| pkd = ((cos + 1) / 2) * (self.pkd_upper - self.pkd_lower) + self.pkd_lower | |
| return pkd, cos | |
| base_esm, tok = _load_esm_base() | |
| lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, | |
| target_modules=["key", "query", "value"]) | |
| model = BALMForLoRAFinetuning( | |
| get_peft_model(base_esm, lora_cfg), tok, pkd_bounds) | |
| with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp: | |
| tmp.write(weights_bytes) | |
| tmp_path = tmp.name | |
| model.load_state_dict(torch.load(tmp_path, map_location="cpu"), strict=False) | |
| os.unlink(tmp_path) | |
| model.eval() | |
| return model | |
| def ensure_model(pkd_lo: float, pkd_hi: float): | |
| """Lazy load (on CPU). Rebuilds head bounds if user changes them.""" | |
| bounds = (float(pkd_lo), float(pkd_hi)) | |
| if _MODEL_STATE["model"] is None: | |
| wb = _download_weights_hf() | |
| _MODEL_STATE["model"] = build_model(wb, bounds) | |
| _MODEL_STATE["pkd_bounds"] = bounds | |
| elif _MODEL_STATE["pkd_bounds"] != bounds: | |
| # Same weights, just update the affine output bounds | |
| _MODEL_STATE["model"].pkd_lower = bounds[0] | |
| _MODEL_STATE["model"].pkd_upper = bounds[1] | |
| _MODEL_STATE["pkd_bounds"] = bounds | |
| return _MODEL_STATE["model"] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # INTEGRATED GRADIENTS (float32-safe, 15 Riemann steps) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_ig(model, seq_a: str, seq_b: str, steps: int = 15): | |
| device = next(model.parameters()).device | |
| tok = model.esm_tokenizer | |
| cls_tok = tok.cls_token | |
| esm = model.esm_model | |
| def tokenise(seq): | |
| proc = seq.replace("|", f"{cls_tok}{cls_tok}") | |
| return tok(proc, return_tensors="pt", | |
| padding=False, truncation=True, max_length=1024).to(device) | |
| word_embed = esm.base_model.model.embeddings.word_embeddings | |
| enc_a = tokenise(seq_a); mask_a = enc_a["attention_mask"] | |
| enc_b = tokenise(seq_b); mask_b = enc_b["attention_mask"] | |
| emb_a = word_embed(enc_a["input_ids"]).detach().float().clone() | |
| emb_b = word_embed(enc_b["input_ids"]).detach().float().clone() | |
| mask_a = mask_a.float() | |
| mask_b = mask_b.float() | |
| def encode(embs, mask): | |
| int_mask = mask.long() | |
| ext = esm.base_model.model.get_extended_attention_mask(int_mask, embs.shape[:2]) | |
| h = esm.base_model.model.encoder( | |
| embs.float(), attention_mask=ext | |
| ).last_hidden_state.float() | |
| m = mask.unsqueeze(-1).expand(h.size()).float() | |
| return (torch.sum(h * m, 1) / torch.clamp(m.sum(1), min=1e-9)) | |
| def fwd(e_a, e_b): | |
| return model.projection_head(encode(e_a, mask_a), encode(e_b, mask_b)) | |
| def riemann(tgt, baseline, fixed, tgt_is_b): | |
| grads = [] | |
| for alpha in torch.linspace(0, 1, steps, device=device): | |
| interp = (baseline + alpha * (tgt - baseline)).detach().requires_grad_(True) | |
| out = fwd(fixed, interp) if tgt_is_b else fwd(interp, fixed) | |
| out.sum().backward() | |
| grads.append(interp.grad.detach().float().clone()) | |
| avg_grad = torch.stack(grads).mean(0) | |
| ig = (avg_grad * (tgt - baseline)).abs().sum(-1).squeeze(0) | |
| return ig.cpu().numpy() | |
| baseline_a = torch.zeros_like(emb_a) | |
| baseline_b = torch.zeros_like(emb_b) | |
| attr_a = riemann(emb_a, baseline_a, emb_b, False) | |
| attr_b = riemann(emb_b, baseline_b, emb_a, True) | |
| def norm(a): | |
| lo, hi = a.min(), a.max() | |
| if hi - lo < 1e-9: | |
| return [0.0] * len(a) | |
| return ((a - lo) / (hi - lo)).tolist() | |
| return norm(attr_a[1:-1]), norm(attr_b[1:-1]) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GPU-WRAPPED PREDICTION | |
| # This is the only function that touches CUDA. ZeroGPU will allocate an A100 | |
| # for the duration of the call. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def gpu_run_prediction(seq_a, seq_b, run_ig, pkd_lo, pkd_hi): | |
| model = ensure_model(pkd_lo, pkd_hi) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| with torch.no_grad(): | |
| pkd_t, cos_t = model(seq_a, seq_b) | |
| pkd = float(pkd_t.item()) | |
| cos = float(cos_t.item()) | |
| ig_a, ig_b = None, None | |
| if run_ig: | |
| ig_a, ig_b = compute_ig(model, seq_a, seq_b, steps=15) | |
| # Move back to CPU so the model stays small in resident memory between calls | |
| model.to("cpu") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return pkd, cos, ig_a, ig_b, str(device) | |
| def gpu_run_batch(rows, batch_mode, run_ig, pkd_lo, pkd_hi): | |
| """Batch over rows. `rows` is a list of dicts with seq_a/seq_b or heavy/light/antigen.""" | |
| model = ensure_model(pkd_lo, pkd_hi) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| results = [] | |
| for row in rows: | |
| if batch_mode == "Protein-Protein": | |
| sa = clean_multi(str(row.get("seq_a", ""))) | |
| sb = clean_multi(str(row.get("seq_b", ""))) | |
| else: | |
| h = clean_multi(str(row.get("heavy_chain", ""))) | |
| l = clean_multi(str(row.get("light_chain", ""))) | |
| ag = clean_multi(str(row.get("antigen", ""))) | |
| sa = ag | |
| sb = f"{h}|{l}" if (h or l) else "" | |
| if not sa or not sb: | |
| results.append({"pKd": None, "Cosine": None}) | |
| continue | |
| try: | |
| with torch.no_grad(): | |
| pkd_t, cos_t = model(sa, sb) | |
| rr = {"pKd": round(float(pkd_t.item()), 4), | |
| "Cosine": round(float(cos_t.item()), 6)} | |
| if run_ig: | |
| ia, ib = compute_ig(model, sa, sb, steps=15) | |
| sap, sbp = sa.replace("|", ""), sb.replace("|", "") | |
| rr["top5_Target"] = "|".join( | |
| f"{sap[j]}{j+1}:{ia[j]:.3f}" | |
| for j in sorted(range(min(len(sap), len(ia))), key=lambda x: -ia[x])[:5]) | |
| rr["top5_proteina"] = "|".join( | |
| f"{sbp[j]}{j+1}:{ib[j]:.3f}" | |
| for j in sorted(range(min(len(sbp), len(ib))), key=lambda x: -ib[x])[:5]) | |
| except Exception as e: | |
| rr = {"pKd": None, "Cosine": None, "error": str(e)} | |
| results.append(rr) | |
| model.to("cpu") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # NGL VIEWER (theme-reactive, filtered to selected chains) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str: | |
| if not pdb_content: | |
| return ( | |
| '<div style="display:flex;flex-direction:column;align-items:center;' | |
| 'justify-content:center;height:420px;background:var(--bg2,#f1f4f9);' | |
| 'border:1px solid var(--border,#dde3ec);border-radius:12px;' | |
| 'color:var(--text2,#64748b);font-family:JetBrains Mono,monospace;' | |
| 'font-size:.8rem;text-align:center;line-height:2">' | |
| '<div style="font-size:2.2rem">π§¬</div>' | |
| '<div style="margin-top:12px;font-weight:600">Load an example or fetch a PDB</div>' | |
| '<div style="font-size:.72rem;margin-top:4px">Only selected chains will be shown</div>' | |
| '</div>' | |
| ) | |
| ig_json = json.dumps(ig_chain_map) if ig_chain_map else "null" | |
| escaped = pdb_content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$") | |
| return f"""<!DOCTYPE html><html><head> | |
| <meta charset="utf-8"> | |
| <script src="https://cdn.jsdelivr.net/npm/ngl@2.0.0-dev.37/dist/ngl.js"></script> | |
| <style> | |
| *{{box-sizing:border-box;margin:0;padding:0}} | |
| body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:background .35s}} | |
| #vp{{width:100%;height:{height}px}} | |
| #hint{{position:absolute;top:10px;left:12px;font-size:10px;pointer-events:none;letter-spacing:.04em;transition:color .3s}} | |
| #ctrl{{position:absolute;bottom:12px;right:12px;display:flex;flex-direction:column;gap:4px}} | |
| .cb{{padding:4px 11px;font-family:'JetBrains Mono',monospace;font-size:9px;font-weight:500; | |
| border-radius:5px;cursor:pointer;letter-spacing:.1em;text-transform:uppercase;transition:all .15s}} | |
| #leg{{position:absolute;bottom:12px;left:12px;font-size:9px;letter-spacing:.05em; | |
| border-radius:5px;padding:5px 10px;line-height:1.9;transition:all .3s}} | |
| </style></head><body> | |
| <div id="vp"></div> | |
| <div id="hint">𧬠Drag · Scroll · Right-drag pan</div> | |
| <div id="ctrl"> | |
| <button class="cb" onclick="setRep('cartoon')">Cartoon</button> | |
| <button class="cb" onclick="setRep('surface')">Surface</button> | |
| <button class="cb" onclick="setRep('ball+stick')">Ball+Stick</button> | |
| <button class="cb" onclick="stage.autoView()">Reset</button> | |
| </div> | |
| <div id="leg"></div> | |
| <script> | |
| var T={{ | |
| light:{{bg:'#f0f4fa',hint:'rgba(51,65,85,.5)', | |
| btn:'rgba(240,244,250,.93)',btn_b:'rgba(37,99,235,.16)',btn_c:'rgba(71,85,105,.6)', | |
| btn_hc:'#2563eb',btn_hb:'rgba(37,99,235,.55)',btn_hbg:'rgba(37,99,235,.06)', | |
| leg:'rgba(240,244,250,.93)',leg_b:'rgba(37,99,235,.13)',leg_c:'rgba(71,85,105,.65)', | |
| ig_hi:'#1d4ed8',ig_lo:'#c7d7f0',null_col:0xe8ecf4}}, | |
| dark:{{bg:'#06090f',hint:'rgba(168,184,216,.4)', | |
| btn:'rgba(6,9,15,.88)',btn_b:'rgba(96,165,250,.16)',btn_c:'rgba(168,184,216,.55)', | |
| btn_hc:'#60a5fa',btn_hb:'rgba(96,165,250,.55)',btn_hbg:'rgba(96,165,250,.06)', | |
| leg:'rgba(6,9,15,.88)',leg_b:'rgba(96,165,250,.13)',leg_c:'rgba(168,184,216,.55)', | |
| ig_hi:'#60a5fa',ig_lo:'#172135',null_col:0x172135}} | |
| }}; | |
| function isDark(){{return window.matchMedia('(prefers-color-scheme:dark)').matches;}} | |
| var theme=isDark()?T.dark:T.light,stage=null,comp=null,curRep='cartoon',igData={ig_json}; | |
| function igColor(s,dark){{ | |
| s=Math.max(0,Math.min(1,s||0)); | |
| if(dark) return(Math.round(23+s*73)<<16)|(Math.round(33+s*132)<<8)|Math.round(53+s*197); | |
| return(Math.round(199-s*170)<<16)|(Math.round(215-s*137)<<8)|Math.round(240-s*24); | |
| }} | |
| function styleUI(dark){{ | |
| var Th=dark?T.dark:T.light; theme=Th; | |
| document.body.style.background=Th.bg; | |
| document.getElementById('hint').style.color=Th.hint; | |
| document.querySelectorAll('.cb').forEach(function(b){{ | |
| b.style.background=Th.btn;b.style.border='1px solid '+Th.btn_b;b.style.color=Th.btn_c; | |
| b.onmouseenter=function(){{b.style.borderColor=Th.btn_hb;b.style.color=Th.btn_hc;b.style.background=Th.btn_hbg;}}; | |
| b.onmouseleave=function(){{b.style.borderColor=Th.btn_b;b.style.color=Th.btn_c;b.style.background=Th.btn;}}; | |
| }}); | |
| var leg=document.getElementById('leg'); | |
| leg.style.background=Th.leg;leg.style.border='1px solid '+Th.leg_b;leg.style.color=Th.leg_c; | |
| if(stage)stage.setParameters({{backgroundColor:Th.bg}}); | |
| }} | |
| function addRepr(c){{ | |
| comp=c;comp.removeAllRepresentations(); | |
| var dark=isDark(),Th=dark?T.dark:T.light,leg=document.getElementById('leg'); | |
| if(igData){{ | |
| var sid=NGL.ColormakerRegistry.addScheme(function(){{ | |
| this.atomColor=function(atom){{ | |
| var cd=igData[atom.chainname]; | |
| if(!cd||!cd.scores||!cd.scores.length) return Th.null_col; | |
| var i=atom.resno-cd.start; | |
| return(i<0||i>=cd.scores.length)?Th.null_col:igColor(cd.scores[i],dark); | |
| }}; | |
| }}); | |
| comp.addRepresentation(curRep,{{color:sid}}); | |
| leg.innerHTML='<span style="color:'+Th.ig_hi+'">β </span> High IG <span style="color:'+Th.ig_lo+'">β </span> Low IG'; | |
| }}else{{ | |
| comp.addRepresentation(curRep,{{colorScheme:'chainname'}}); | |
| leg.innerHTML='Coloured by chain'; | |
| }} | |
| stage.autoView(); | |
| }} | |
| function setRep(r){{curRep=r;if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}} | |
| function themeSwitch(dark){{styleUI(dark);if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}} | |
| window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){{themeSwitch(e.matches);}}); | |
| window.addEventListener('message',function(e){{if(e.data&&e.data.balmTheme)themeSwitch(e.data.balmTheme==='dark');}}); | |
| window.addEventListener('resize',function(){{if(stage)stage.handleResize();}}); | |
| stage=new NGL.Stage('vp',{{backgroundColor:theme.bg,quality:'medium',tooltip:true}}); | |
| styleUI(isDark()); | |
| var blob=new Blob([`{escaped}`],{{type:'text/plain'}}); | |
| stage.loadFile(URL.createObjectURL(blob),{{ext:'pdb',name:'s'}}).then(addRepr); | |
| </script></body></html>""" | |
| def wrap_iframe(html_src: str, height: int = 425) -> str: | |
| """Sandboxed iframe wrapper so NGL renders cleanly inside Gradio HTML.""" | |
| srcdoc = html_src.replace('"', '"') | |
| return ( | |
| f'<iframe srcdoc="{srcdoc}" ' | |
| f'style="width:100%;height:{height}px;border:0;border-radius:12px" ' | |
| f'sandbox="allow-scripts allow-same-origin"></iframe>' | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOTLY | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _PL = dict(paper_bgcolor="rgba(255,255,255,.97)", plot_bgcolor="rgba(241,244,249,1)") | |
| _FONT = dict(family="JetBrains Mono, monospace", color="#475569", size=9) | |
| _GRID = dict(gridcolor="rgba(37,99,235,.07)", zerolinecolor="rgba(37,99,235,.1)") | |
| def make_heatmap(seq, attr, title): | |
| W = 50; rows, hover, ylabels = [], [], [] | |
| for r in range(0, len(seq), W): | |
| rows.append([attr[r + c] if r + c < len(attr) else None for c in range(W)]) | |
| hover.append([f"{seq[r+c]}{r+c+1} IG:{attr[r+c]:.3f}" if r + c < len(seq) else "" | |
| for c in range(W)]) | |
| ylabels.append(f"{r+1}β{min(r+W, len(seq))}") | |
| fig = go.Figure(go.Heatmap( | |
| z=rows, text=hover, hoverinfo="text", | |
| colorscale=[[0, "#e8f0fe"], [.4, "#93c5fd"], [.75, "#3b82f6"], [1, "#1d4ed8"]], | |
| zmin=0, zmax=1, xgap=1, ygap=1, | |
| colorbar=dict(thickness=10, len=.9, tickfont=dict(**_FONT), | |
| title=dict(text="IG", font=dict(**_FONT))))) | |
| fig.update_layout( | |
| title=dict(text=title, font=dict(family="JetBrains Mono, monospace", size=10, color="#1d4ed8")), | |
| **_PL, height=240, margin=dict(t=30, b=24, l=55, r=45), | |
| xaxis=dict(title=dict(text="Position in window", font=dict(**_FONT)), | |
| tickfont=dict(**_FONT), **_GRID), | |
| yaxis=dict(tickvals=list(range(len(ylabels))), ticktext=ylabels, | |
| tickfont=dict(**_FONT), **_GRID)) | |
| return fig | |
| def top10_bar(seq, ig, title): | |
| indexed = sorted([{"aa": seq[i], "idx": i + 1, "sc": ig[i]} | |
| for i in range(min(len(seq), len(ig)))], key=lambda x: -x["sc"])[:10] | |
| fig = go.Figure(go.Bar( | |
| x=[x["sc"] for x in indexed], | |
| y=[f"{x['aa']}{x['idx']}" for x in indexed], | |
| orientation="h", | |
| marker=dict(color=[x["sc"] for x in indexed], | |
| colorscale=[[0, "#dbeafe"], [.5, "#3b82f6"], [1, "#1e40af"]], | |
| showscale=False, line_width=0))) | |
| fig.update_layout( | |
| title=dict(text=title, font=dict(family="JetBrains Mono, monospace", size=10, color="#1d4ed8")), | |
| **_PL, height=260, margin=dict(t=30, b=18, l=52, r=16), | |
| xaxis=dict(title=dict(text="IG Score", font=dict(**_FONT)), | |
| tickfont=dict(**_FONT), **_GRID), | |
| yaxis=dict(tickfont=dict(**_FONT), autorange="reversed")) | |
| return fig | |
| def residue_strip_html(seq, attr) -> str: | |
| cells = [] | |
| for i, aa in enumerate(seq): | |
| s = attr[i] if i < len(attr) else 0.0 | |
| r = int(248 - s * 211); g = int(250 - s * 151); b = int(252 - s * 17) | |
| fg = "#1e40af" if s > 0.5 else "#475569" | |
| cells.append( | |
| f'<span style="background:rgb({r},{g},{b});color:{fg};' | |
| f'font-family:JetBrains Mono,monospace;font-size:11px;font-weight:600;' | |
| f'padding:2px 4px;border-radius:4px;cursor:default;display:inline-block;margin:1px;' | |
| f'border:1px solid rgba(37,99,235,.09)" title="{aa}{i+1} IG:{s:.3f}">{aa}</span>') | |
| return ('<div style="font-size:9px;font-weight:700;letter-spacing:.18em;' | |
| 'text-transform:uppercase;color:#2563eb;margin-bottom:8px;' | |
| 'font-family:JetBrains Mono,monospace">' | |
| 'Residue Attribution (hover for score)</div>' | |
| '<div style="line-height:2.2;word-break:break-all">' + "".join(cells) + "</div>") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RESULT FORMATTING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def result_card_html(pkd: float, cos: float, pkd_lo: float, pkd_hi: float) -> str: | |
| pct = max(0.0, min(100.0, ((pkd - pkd_lo) / max(pkd_hi - pkd_lo, 1e-9)) * 100)) | |
| strength = "Weak" if pct < 30 else "Moderate" if pct < 55 else "Strong" if pct < 75 else "Very Strong" | |
| badge_cls = "badge-weak" if pct < 30 else "badge-moderate" if pct < 55 else "badge-strong" | |
| return f""" | |
| <div style="display:grid;grid-template-columns:1fr 1fr 2fr;gap:14px;align-items:stretch"> | |
| <div class="pkd-card"> | |
| <div class="pkd-lbl">Predicted pKd</div> | |
| <div class="pkd-val">{pkd:.3f}</div> | |
| </div> | |
| <div class="pkd-card"> | |
| <div class="pkd-lbl">Cosine Similarity</div> | |
| <div class="pkd-val" style="font-size:1.55rem">{cos:.4f}</div> | |
| </div> | |
| <div class="pkd-card"> | |
| <div class="str-labels"> | |
| <span>Weak ({pkd_lo:.0f})</span> | |
| <span style="color:var(--text0,#0f172a);font-weight:600">{strength}</span> | |
| <span>Strong ({pkd_hi:.0f})</span> | |
| </div> | |
| <div class="str-bar"><div class="str-fill" style="width:{pct:.1f}%"></div></div> | |
| <span class="pkd-badge {badge_cls}">{strength}</span> | |
| </div> | |
| </div> | |
| """ | |
| def status_badge_html(model_loaded: bool, device: str = "cpu") -> str: | |
| if model_loaded: | |
| return (f'<div class="ready-badge"><span class="ready-dot"></span>' | |
| f'READY Β· {device}</div>') | |
| return '<div class="idle-badge">β NOT LOADED</div>' | |
| def sidebar_result_html(pkd: float, cos: float) -> str: | |
| return ( | |
| f'<div class="pkd-card"><div class="pkd-lbl">Predicted pKd</div>' | |
| f'<div class="pkd-val">{pkd:.3f}</div>' | |
| f'<div class="pkd-lbl" style="margin-top:10px">Cosine Similarity</div>' | |
| f'<div style="font-family:var(--mono,monospace);font-size:1rem;font-weight:700;' | |
| f'color:var(--text0,#0f172a)">{cos:.4f}</div></div>' | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CSS (adapted for Gradio's DOM; theme vars + custom classes preserved) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@300;400;500;600;700&display=swap'); | |
| :root { | |
| --bg0:#ffffff; --bg1:#f8f9fc; --bg2:#f1f4f9; --bg3:#e4e9f2; | |
| --border:rgba(37,99,235,0.13); --shadow:rgba(37,99,235,0.07); | |
| --accent:#2563eb; --accent2:#7c3aed; | |
| --green:#16a34a; --amber:#d97706; --red:#dc2626; | |
| --text0:#0f172a; --text1:#334155; --text2:#64748b; --text3:#94a3b8; | |
| --mono:'JetBrains Mono',monospace; --sans:'Inter',sans-serif; | |
| } | |
| @media (prefers-color-scheme:dark){ | |
| :root { | |
| --bg0:#06090f; --bg1:#0b1120; --bg2:#101828; --bg3:#172135; | |
| --border:rgba(96,165,250,0.13); --shadow:rgba(0,0,0,0.35); | |
| --accent:#60a5fa; --accent2:#a78bfa; | |
| --green:#34d399; --amber:#fbbf24; --red:#f87171; | |
| --text0:#f1f5f9; --text1:#cbd5e1; --text2:#64748b; --text3:#334155; | |
| } | |
| } | |
| .gradio-container { font-family: var(--sans) !important; max-width: 1400px !important; } | |
| .gradio-container h1, .gradio-container h2, .gradio-container h3 { | |
| font-family: var(--sans) !important; color: var(--text0) !important; | |
| } | |
| code, pre { | |
| font-family: var(--mono) !important; font-size: .82rem !important; | |
| background: var(--bg2) !important; border-radius: 4px !important; | |
| color: var(--accent) !important; | |
| } | |
| .app-header{display:flex;align-items:center;gap:14px;padding:4px 0 14px;} | |
| .app-logo{width:40px;height:40px;border-radius:10px; | |
| background:linear-gradient(135deg,var(--accent),var(--accent2)); | |
| display:flex;align-items:center;justify-content:center; | |
| font-size:20px;flex-shrink:0;box-shadow:0 2px 12px rgba(37,99,235,.22);} | |
| .app-title{font-family:var(--mono)!important;font-size:1.22rem!important; | |
| font-weight:700!important;color:var(--text0)!important;letter-spacing:-.02em;} | |
| .app-subtitle{font-size:.77rem!important;color:var(--text2)!important;margin-top:1px;letter-spacing:.03em;} | |
| .sec-hdr{font-family:var(--mono);font-size:.65rem;font-weight:700;letter-spacing:.2em; | |
| text-transform:uppercase;color:var(--accent);margin:14px 0 8px; | |
| display:flex;align-items:center;gap:8px;} | |
| .sec-hdr::after{content:'';flex:1;height:1px;background:var(--border);} | |
| .ex-card{background:var(--bg1);border:1px solid var(--border);border-radius:10px; | |
| padding:12px 15px;margin-bottom:8px;position:relative;overflow:hidden; | |
| transition:box-shadow .15s,border-color .15s;} | |
| .ex-card::before{content:'';position:absolute;left:0;top:0;bottom:0;width:3px; | |
| background:linear-gradient(180deg,var(--accent),var(--accent2));border-radius:3px 0 0 3px;} | |
| .ex-card:hover{border-color:rgba(37,99,235,.3);box-shadow:0 2px 12px var(--shadow);} | |
| .ex-pdb{font-family:var(--mono);font-size:.95rem;font-weight:700;color:var(--accent);} | |
| .ex-sub{font-size:.8rem;font-weight:600;color:var(--text1);margin:2px 0;} | |
| .ex-desc{font-size:.72rem;line-height:1.5;color:var(--text2);margin-top:3px;} | |
| .pkd-card{background:linear-gradient(135deg,rgba(37,99,235,.05),rgba(124,58,237,.04)); | |
| border:1px solid var(--border);border-radius:12px;padding:14px 18px; | |
| box-shadow:0 1px 6px var(--shadow);} | |
| .pkd-lbl{font-family:var(--mono);font-size:.67rem;color:var(--text2); | |
| text-transform:uppercase;letter-spacing:.12em;margin-bottom:2px;} | |
| .pkd-val{font-family:var(--mono);font-size:2rem;font-weight:700;color:var(--accent);line-height:1.1;} | |
| .pkd-badge{display:inline-block;padding:2px 9px;border-radius:20px; | |
| font-size:.7rem;font-weight:600;font-family:var(--mono);margin-top:5px;} | |
| .badge-weak{background:rgba(220,38,38,.08);color:var(--red);border:1px solid rgba(220,38,38,.22);} | |
| .badge-moderate{background:rgba(217,119,6,.08);color:var(--amber);border:1px solid rgba(217,119,6,.22);} | |
| .badge-strong{background:rgba(22,163,74,.08);color:var(--green);border:1px solid rgba(22,163,74,.22);} | |
| .str-bar{height:5px;border-radius:3px;background:var(--bg3);overflow:hidden;margin:8px 0 3px;} | |
| .str-fill{height:100%;border-radius:3px; | |
| background:linear-gradient(90deg,var(--accent),var(--accent2)); | |
| transition:width .7s cubic-bezier(.4,0,.2,1);} | |
| .str-labels{display:flex;justify-content:space-between;font-size:.65rem; | |
| color:var(--text2);font-family:var(--mono);} | |
| .ready-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px; | |
| border-radius:20px;background:rgba(22,163,74,.07); | |
| border:1px solid rgba(22,163,74,.22); | |
| font-family:var(--mono);font-size:.72rem;color:var(--green);font-weight:600;} | |
| .ready-dot{display:inline-block;width:6px;height:6px;border-radius:50%; | |
| background:var(--green);animation:pulse 2s infinite;} | |
| .idle-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px; | |
| border-radius:20px;background:rgba(100,116,139,.07); | |
| border:1px solid rgba(100,116,139,.18); | |
| font-family:var(--mono);font-size:.72rem;color:var(--text2);font-weight:600;} | |
| @keyframes pulse{0%,100%{opacity:1;transform:scale(1)}50%{opacity:.5;transform:scale(.85)}} | |
| textarea, input[type="text"], input[type="number"] { | |
| font-family: var(--mono) !important; font-size: .82rem !important; | |
| } | |
| """ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO EVENT HANDLERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def handle_mode_change(mode): | |
| """Show the right input panel for the selected interaction mode.""" | |
| if mode == "Protein-Protein": | |
| return ( | |
| gr.update(visible=True), # ppi panel | |
| gr.update(visible=False), # ab-ag panel | |
| gr.update(visible=True), # ppi chain inputs | |
| gr.update(visible=False), # ab-ag chain inputs | |
| ) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| ) | |
| def handle_load_example(pdb_id, mode, chain_a, chain_b, chain_h=None, chain_l=None, chain_ag=None): | |
| """Fetch a PDB, populate sequences + chain info, return updates for the UI.""" | |
| pdb_content, chains, err = get_chains_from_pdb(pdb_id) | |
| if err or not chains: | |
| gr.Warning(err or f"No chains found in {pdb_id}") | |
| return ( | |
| gr.update(), # ppi_a | |
| gr.update(), # ppi_b | |
| gr.update(), # ab_h | |
| gr.update(), # ab_l | |
| gr.update(), # ab_ag | |
| [], # state_chain_info_a | |
| [], # state_chain_info_b | |
| "", # state_pdb_content | |
| None, # state_ig_chain_map | |
| None, None, # state_ig_a, state_ig_b | |
| None, # state_result | |
| gr.update(value=ngl_viewer_html(None)), # ngl_html | |
| gr.update(value=""), # result_html | |
| gr.update(value=""), # strip_a_html | |
| None, # bar_a | |
| None, # hm_a | |
| gr.update(value=""), # strip_b_html | |
| None, # bar_b | |
| None, # hm_b | |
| gr.update(value=""), # pdb_info_md | |
| ) | |
| available = sorted(chains.keys()) | |
| info_msg = f"β **{pdb_id.upper()}** β chains: **{', '.join(available)}**" | |
| selected_chain_ids: set = set() | |
| if mode == "Protein-Protein": | |
| id_a = chain_a or available[0] | |
| id_b = chain_b or (available[1] if len(available) > 1 else available[0]) | |
| seq_a, info_a, miss_a = resolve_chains(chains, id_a, available) | |
| seq_b, info_b, miss_b = resolve_chains(chains, id_b, available) | |
| for m in miss_a: gr.Warning(f"Chain {m} not found (Side A). Available: {', '.join(available)}") | |
| for m in miss_b: gr.Warning(f"Chain {m} not found (Side B). Available: {', '.join(available)}") | |
| selected_chain_ids = {c["id"] for c in info_a} | {c["id"] for c in info_b} | |
| ppi_a_val = seq_a | |
| ppi_b_val = seq_b | |
| ah_val = al_val = aag_val = gr.update() | |
| chain_info_a, chain_info_b = info_a, info_b | |
| else: | |
| id_h = chain_h or "H" | |
| id_l = chain_l or "L" | |
| id_ag_str = chain_ag or next((c for c in available if c not in (id_h, id_l)), available[0]) | |
| def _single(cid): | |
| if cid in chains and chains[cid]["seq"]: | |
| return cid, chains[cid] | |
| f = next((k for k in chains if k.upper() == cid.upper() and chains[k]["seq"]), None) | |
| return (f, chains[f]) if f else (cid, {"seq": "", "resnos": []}) | |
| real_h, ch = _single(id_h) | |
| real_l, cl = _single(id_l) | |
| seq_ag, info_ag, miss_ag = resolve_chains(chains, id_ag_str, available) | |
| if not ch["seq"]: gr.Warning(f"Heavy chain {id_h} not found. Available: {', '.join(available)}") | |
| if not cl["seq"]: gr.Warning(f"Light chain {id_l} not found. Available: {', '.join(available)}") | |
| for m in miss_ag: gr.Warning(f"Antigen chain {m} not found. Available: {', '.join(available)}") | |
| chain_info_a = info_ag if info_ag else [] | |
| chain_info_b = [{"id": real_h, **ch}, {"id": real_l, **cl}] | |
| selected_chain_ids = ({c["id"] for c in info_ag} if info_ag else set()) | |
| if ch["seq"]: selected_chain_ids.add(real_h) | |
| if cl["seq"]: selected_chain_ids.add(real_l) | |
| ppi_a_val = ppi_b_val = gr.update() | |
| ah_val = ch["seq"] | |
| al_val = cl["seq"] | |
| aag_val = seq_ag | |
| filtered_pdb = filter_pdb_to_chains(pdb_content, selected_chain_ids) if selected_chain_ids else pdb_content | |
| return ( | |
| ppi_a_val, # ppi_a | |
| ppi_b_val, # ppi_b | |
| ah_val, # ab_h | |
| al_val, # ab_l | |
| aag_val, # ab_ag | |
| chain_info_a, # state_chain_info_a | |
| chain_info_b, # state_chain_info_b | |
| filtered_pdb, # state_pdb_content | |
| None, # state_ig_chain_map (cleared) | |
| None, None, # state_ig_a, state_ig_b | |
| None, # state_result | |
| wrap_iframe(ngl_viewer_html(filtered_pdb, None)), # ngl_html | |
| "", # result_html | |
| "", # strip_a_html | |
| None, # bar_a | |
| None, # hm_a | |
| "", # strip_b_html | |
| None, # bar_b | |
| None, # hm_b | |
| info_msg, # pdb_info_md | |
| ) | |
| def handle_fetch_pdb(pdb_id, mode, ca_in, cb_in, ch_in, cl_in, cag_in): | |
| if not pdb_id or not pdb_id.strip(): | |
| gr.Warning("Enter a PDB ID.") | |
| return (gr.update(),) * 20 | |
| return handle_load_example( | |
| pdb_id.strip(), mode, | |
| chain_a=ca_in.strip() or None if ca_in else None, | |
| chain_b=cb_in.strip() or None if cb_in else None, | |
| chain_h=ch_in.strip() or None if ch_in else None, | |
| chain_l=cl_in.strip() or None if cl_in else None, | |
| chain_ag=cag_in.strip() or None if cag_in else None, | |
| ) | |
| def handle_run_prediction( | |
| mode, ppi_a, ppi_b, ab_h, ab_l, ab_ag, run_ig, pkd_lo, pkd_hi, | |
| chain_info_a, chain_info_b, pdb_content, | |
| ): | |
| """Main predict callback. Calls the @spaces.GPU function.""" | |
| if mode == "Protein-Protein": | |
| seq_a = clean_multi(ppi_a or "") | |
| seq_b = clean_multi(ppi_b or "") | |
| else: | |
| seq_a = clean_multi(ab_ag or "") | |
| h = clean_multi(ab_h or "") | |
| l = clean_multi(ab_l or "") | |
| seq_b = f"{h}|{l}" if (h or l) else "" | |
| if not seq_a or not seq_b: | |
| gr.Warning("Fill all sequence boxes before running.") | |
| return (gr.update(),) * 12 | |
| # If no chain info exists yet (user pasted sequences directly), synthesise it | |
| if not chain_info_a: | |
| chain_info_a = [{"id": ch, "seq": s, "resnos": list(range(1, len(s) + 1))} | |
| for ch, s in zip("ABCDE", seq_a.split("|"))] | |
| if not chain_info_b: | |
| if mode == "Antibody-Antigen": | |
| h_s = clean_multi(ab_h or "") | |
| l_s = clean_multi(ab_l or "") | |
| chain_info_b = [ | |
| {"id": "H", "seq": h_s, "resnos": list(range(1, len(h_s) + 1))}, | |
| {"id": "L", "seq": l_s, "resnos": list(range(1, len(l_s) + 1))}, | |
| ] | |
| else: | |
| chain_info_b = [{"id": ch, "seq": s, "resnos": list(range(1, len(s) + 1))} | |
| for ch, s in zip("FGHIJ", seq_b.split("|"))] | |
| try: | |
| pkd, cos, ig_a, ig_b, device_used = gpu_run_prediction( | |
| seq_a, seq_b, bool(run_ig), float(pkd_lo), float(pkd_hi)) | |
| except Exception as e: | |
| gr.Error(f"Prediction failed: {e}") | |
| traceback.print_exc() | |
| return (gr.update(),) * 12 | |
| result = {"pkd": pkd, "cosine": cos, "sa": seq_a, "sb": seq_b} | |
| # Display data | |
| seq_a_d, ig_a_d, seq_b_d, ig_b_d = flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b) | |
| if not seq_a_d: seq_a_d = seq_a.replace("|", "") | |
| if not seq_b_d: seq_b_d = seq_b.replace("|", "") | |
| ig_chain_map = None | |
| if ig_a is not None and ig_b is not None: | |
| ig_chain_map = build_ig_chain_map(chain_info_a, chain_info_b, ig_a, ig_b) | |
| lbl_a = "Antigen (Target)" if mode == "Antibody-Antigen" else "Target (Protein A)" | |
| lbl_b = "Antibody H+L" if mode == "Antibody-Antigen" else "Binder (proteina)" | |
| if ig_a_d and seq_a_d: | |
| strip_a = residue_strip_html(seq_a_d, ig_a_d) | |
| bar_a = top10_bar(seq_a_d, ig_a_d, f"Top 10 Β· {lbl_a}") | |
| hm_a = make_heatmap(seq_a_d, ig_a_d, f"{lbl_a} Β· IG Heatmap") | |
| else: | |
| strip_a, bar_a, hm_a = "", None, None | |
| if ig_b_d and seq_b_d: | |
| strip_b = residue_strip_html(seq_b_d, ig_b_d) | |
| bar_b = top10_bar(seq_b_d, ig_b_d, f"Top 10 Β· {lbl_b}") | |
| hm_b = make_heatmap(seq_b_d, ig_b_d, f"{lbl_b} Β· IG Heatmap") | |
| else: | |
| strip_b, bar_b, hm_b = "", None, None | |
| ngl_html = wrap_iframe(ngl_viewer_html(pdb_content, ig_chain_map)) if pdb_content else \ | |
| wrap_iframe(ngl_viewer_html(None)) | |
| result_card = result_card_html(pkd, cos, float(pkd_lo), float(pkd_hi)) | |
| sidebar_card = sidebar_result_html(pkd, cos) | |
| badge = status_badge_html(True, device_used) | |
| return ( | |
| result, # state_result | |
| ig_a, ig_b, # state_ig_a, state_ig_b | |
| ig_chain_map, # state_ig_chain_map | |
| result_card, # result_html | |
| ngl_html, # ngl_html | |
| strip_a, bar_a, hm_a, # vis A | |
| strip_b, bar_b, hm_b, # vis B | |
| sidebar_card, # sidebar_result_html | |
| badge, # status badge | |
| ) | |
| def handle_download_results(result, state_ig_a, state_ig_b, | |
| chain_info_a, chain_info_b): | |
| """Build CSV + JSON for download.""" | |
| if not result: | |
| return None, None | |
| seq_a_d, ig_a_d, seq_b_d, ig_b_d = flat_ig_for_display( | |
| state_ig_a, state_ig_b, chain_info_a, chain_info_b) | |
| if not seq_a_d: seq_a_d = result["sa"].replace("|", "") | |
| if not seq_b_d: seq_b_d = result["sb"].replace("|", "") | |
| # CSV | |
| csv_path = os.path.join(tempfile.gettempdir(), "ig_scores.csv") | |
| with open(csv_path, "w", newline="") as f: | |
| wc = csv.writer(f) | |
| wc.writerow(["chain", "position", "residue", "ig_score"]) | |
| for i, (aa, sc) in enumerate(zip(seq_a_d or "", ig_a_d or [])): | |
| wc.writerow(["Target", i + 1, aa, f"{sc:.6f}"]) | |
| for i, (aa, sc) in enumerate(zip(seq_b_d or "", ig_b_d or [])): | |
| wc.writerow(["proteina", i + 1, aa, f"{sc:.6f}"]) | |
| # JSON | |
| summary = { | |
| "pkd": result["pkd"], "cosine": result["cosine"], | |
| "Target_length": len(seq_a_d or ""), "proteina_length": len(seq_b_d or ""), | |
| "top10_Target": sorted( | |
| [{"res": f"{(seq_a_d or '')[i]}{i+1}", "ig": (ig_a_d or [])[i]} | |
| for i in range(min(len(seq_a_d or ""), len(ig_a_d or [])))], | |
| key=lambda x: -x["ig"])[:10], | |
| "top10_proteina": sorted( | |
| [{"res": f"{(seq_b_d or '')[i]}{i+1}", "ig": (ig_b_d or [])[i]} | |
| for i in range(min(len(seq_b_d or ""), len(ig_b_d or [])))], | |
| key=lambda x: -x["ig"])[:10], | |
| } | |
| json_path = os.path.join(tempfile.gettempdir(), "balm_result.json") | |
| with open(json_path, "w") as f: | |
| json.dump(summary, f, indent=2) | |
| return csv_path, json_path | |
| def handle_load_model(pkd_lo, pkd_hi): | |
| """Force-load the model on CPU (download weights). GPU is reserved for prediction time.""" | |
| try: | |
| ensure_model(float(pkd_lo), float(pkd_hi)) | |
| gr.Info("Model weights downloaded and ready. GPU will be allocated when you click Run.") | |
| return status_badge_html(True, "cpu (GPU on demand)") | |
| except Exception as e: | |
| traceback.print_exc() | |
| gr.Error(f"Load failed: {e}") | |
| return status_badge_html(False) | |
| def handle_run_batch(file_obj, batch_mode, run_ig_b, pkd_lo, pkd_hi): | |
| if file_obj is None: | |
| gr.Warning("Upload a CSV file first.") | |
| return None, None | |
| try: | |
| df = pd.read_csv(file_obj.name) | |
| except Exception as e: | |
| gr.Error(f"CSV read failed: {e}") | |
| return None, None | |
| rows = df.to_dict(orient="records") | |
| try: | |
| results = gpu_run_batch(rows, batch_mode, bool(run_ig_b), float(pkd_lo), float(pkd_hi)) | |
| except Exception as e: | |
| traceback.print_exc() | |
| gr.Error(f"Batch failed: {e}") | |
| return None, None | |
| df_res = pd.concat([df, pd.DataFrame(results)], axis=1) | |
| out_path = os.path.join(tempfile.gettempdir(), "balm_batch_results.csv") | |
| df_res.to_csv(out_path, index=False) | |
| return df_res, out_path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO LAYOUT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_ui(): | |
| with gr.Blocks(title="BALM-PPI Pro") as demo: | |
| # ββ Persistent state βββββββββββββββββββββββββββββββββββββββββββββββ | |
| state_pdb_content = gr.State("") | |
| state_chain_info_a = gr.State([]) | |
| state_chain_info_b = gr.State([]) | |
| state_ig_a = gr.State(None) | |
| state_ig_b = gr.State(None) | |
| state_ig_chain_map = gr.State(None) | |
| state_result = gr.State(None) | |
| # ββ Header βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gr.HTML(""" | |
| <div class="app-header"> | |
| <div class="app-logo">π§¬</div> | |
| <div> | |
| <div class="app-title">BALM-PPI Pro</div> | |
| <div class="app-subtitle">ESM-2 Β· LoRA Β· Integrated Gradients Β· Protein Binding Affinity Prediction</div> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 1 β SINGLE PREDICTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π― Single Prediction"): | |
| with gr.Row(): | |
| # βββ Sidebar (left column) βββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1, min_width=260): | |
| gr.Markdown("### βοΈ Model") | |
| with gr.Accordion("pKd Bounds", open=False): | |
| pkd_lo = gr.Number(value=1.0, label="Min", precision=2) | |
| pkd_hi = gr.Number(value=16.0, label="Max", precision=2) | |
| load_btn = gr.Button("β‘ Reload Model", variant="secondary") | |
| status_html = gr.HTML( | |
| status_badge_html(_MODEL_STATE["model"] is not None, | |
| "cpu (GPU on demand)")) | |
| sidebar_result = gr.HTML("") | |
| gr.Markdown( | |
| '<div style="font-size:.72rem;color:var(--text2);' | |
| 'font-family:JetBrains Mono,monospace;line-height:1.9;margin-top:14px">' | |
| 'π€ <a href="https://huggingface.co/Harshit494/BALM-PPI" ' | |
| 'style="color:var(--accent);text-decoration:none">Harshit494/BALM-PPI</a><br>' | |
| 'π» <a href="https://github.com/rgorantla04/BALM-PPI" ' | |
| 'style="color:var(--accent);text-decoration:none">rgorantla04/BALM-PPI</a>' | |
| '</div>' | |
| ) | |
| # βββ Main column βββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=4): | |
| mode = gr.Radio( | |
| ["Protein-Protein", "Antibody-Antigen"], | |
| value="Protein-Protein", label="Interaction mode", | |
| container=False, | |
| ) | |
| # PPI input panel | |
| with gr.Group(visible=True) as ppi_panel: | |
| with gr.Row(): | |
| ppi_a = gr.Textbox( | |
| label="Target Protein β Seq A", | |
| placeholder="Paste FASTA or raw sequence⦠| = chain separator", | |
| lines=4, max_lines=8, | |
| ) | |
| ppi_b = gr.Textbox( | |
| label="Binder Protein β Seq B ( | = chain separator)", | |
| placeholder="Paste FASTA or raw sequenceβ¦", | |
| lines=4, max_lines=8, | |
| ) | |
| # Ab-Ag input panel | |
| with gr.Group(visible=False) as abag_panel: | |
| with gr.Row(): | |
| ab_ag = gr.Textbox( | |
| label="Antigen / Target β Seq A", | |
| placeholder="Antigen sequenceβ¦", | |
| lines=4, max_lines=8, scale=2, | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| ab_h = gr.Textbox( | |
| label="Heavy Chain (H)", | |
| placeholder="VH sequenceβ¦", | |
| lines=4, max_lines=8, | |
| ) | |
| ab_l = gr.Textbox( | |
| label="Light Chain (L)", | |
| placeholder="VL sequenceβ¦", | |
| lines=4, max_lines=8, | |
| ) | |
| with gr.Row(): | |
| run_btn = gr.Button("π¬ Run Prediction", variant="primary", scale=3) | |
| run_ig = gr.Checkbox(value=True, label="Run Integrated Gradients", | |
| info="~45β90s CPU Β· <5s GPU", scale=2) | |
| result_html = gr.HTML("") | |
| # βββ Lower section: PDB controls + visualizations βββββ | |
| with gr.Row(): | |
| # ββ LEFT: Custom PDB Fetch βββββββββββββββββββ | |
| with gr.Column(scale=1, min_width=240): | |
| gr.HTML('<div class="sec-hdr">Custom PDB Fetch</div>') | |
| pdb_in = gr.Textbox(label="PDB ID", placeholder="1YCR, 1BRS, 2VXQβ¦") | |
| # PPI chain selectors | |
| with gr.Group(visible=True) as ppi_chains_row: | |
| ca_in = gr.Textbox(label="Side A chain(s)", placeholder="A or A,B") | |
| cb_in = gr.Textbox(label="Side B chain(s)", placeholder="B or B,C") | |
| # Ab-Ag chain selectors | |
| with gr.Group(visible=False) as abag_chains_row: | |
| ch_in = gr.Textbox(label="Heavy", placeholder="H") | |
| cl_in = gr.Textbox(label="Light", placeholder="L") | |
| cag_in = gr.Textbox(label="Antigen", placeholder="A or A,B") | |
| fetch_btn = gr.Button("π Fetch PDB", variant="secondary", size="sm") | |
| pdb_info_md = gr.Markdown("") | |
| # ββ MIDDLE: Quick Examples βββββββββββββββββββ | |
| with gr.Column(scale=1, min_width=240): | |
| gr.HTML('<div class="sec-hdr">Quick Examples</div>') | |
| example_buttons = [] | |
| for ex in EXAMPLES: | |
| gr.HTML( | |
| f'<div class="ex-card">' | |
| f'<div class="ex-pdb">{ex["label"]}</div>' | |
| f'<div class="ex-sub">{ex["subtitle"]}</div>' | |
| f'<div class="ex-desc">{ex["desc"]}</div></div>' | |
| ) | |
| b = gr.Button(f"β¬ Load {ex['pdb']}", size="sm") | |
| example_buttons.append((b, ex)) | |
| # ββ RIGHT: Visualizations ββββββββββββββββββββ | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.Tab("𧬠Structure"): | |
| ngl_html = gr.HTML(wrap_iframe(ngl_viewer_html(None))) | |
| with gr.Tab("π Side A Strip"): | |
| strip_a_html = gr.HTML("") | |
| bar_a_plot = gr.Plot() | |
| with gr.Tab("πΊ Side A Heatmap"): | |
| hm_a_plot = gr.Plot() | |
| with gr.Tab("π Side B Strip"): | |
| strip_b_html = gr.HTML("") | |
| bar_b_plot = gr.Plot() | |
| with gr.Tab("πΊ Side B Heatmap"): | |
| hm_b_plot = gr.Plot() | |
| with gr.Tab("π₯ Download"): | |
| dl_btn = gr.Button("Build downloads", variant="secondary") | |
| dl_csv = gr.File(label="IG Scores CSV") | |
| dl_json = gr.File(label="Summary JSON") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 2 β BATCH PREDICTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Batch Prediction"): | |
| gr.Markdown("### π Batch Prediction") | |
| batch_mode = gr.Radio( | |
| ["Protein-Protein", "Antibody-Antigen"], | |
| value="Protein-Protein", label="Mode", container=False, | |
| ) | |
| with gr.Row(): | |
| gr.File(label="PPI Template", | |
| value=lambda: _write_template(PPI_TEMPLATE, "ppi_template.csv"), | |
| interactive=False) | |
| gr.File(label="Ab-Ag Template", | |
| value=lambda: _write_template(ABAG_TEMPLATE, "abag_template.csv"), | |
| interactive=False) | |
| gr.Markdown( | |
| "Required columns: **seq_a, seq_b** (PPI mode) " | |
| "or **heavy_chain, light_chain, antigen** (Ab-Ag mode)." | |
| ) | |
| batch_file = gr.File(label="CSV file", file_types=[".csv"]) | |
| run_ig_b = gr.Checkbox(value=False, | |
| label="Compute IG per row (slow on CPU; fine on GPU)") | |
| batch_run = gr.Button("π Run Batch", variant="primary") | |
| batch_df = gr.Dataframe(label="Results", interactive=False) | |
| batch_dl = gr.File(label="Download results CSV") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 3 β MODEL INFO | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Model Info"): | |
| gr.Markdown(""" | |
| ### π About BALM-PPI | |
| **BALM-PPI** predicts proteinβprotein and antibodyβantigen binding affinity using ESM-2 + LoRA | |
| with Integrated Gradients explainability. | |
| | Component | Detail | | |
| |-----------|--------| | |
| | Backbone | ESM-2 650M (`facebook/esm2_t33_650M_UR50D`) | | |
| | Fine-tuning | LoRA r=8, Ξ±=16, dropout=0.1 Β· targets: key/query/value | | |
| | Column mapping | `seq_a` β **Target** Β· `seq_b` β **proteina** | | |
| | Multi-chain | `\\|` separator β `<cls><cls>` double CLS token | | |
| | Affinity output | pKd = ((cos+1)/2) Γ (pKd_max β pKd_min) + pKd_min | | |
| | IG | Integrated Gradients, **15-step** Riemann (float32-safe) | | |
| | Viewer | PDB filtered to **selected chains only** | | |
| | Hardware | HF ZeroGPU (dynamic A100); CPU fallback supported | | |
| **Examples** Β· **1YCR** β MDM2 β p53 Β· **1BRS** β Barnase β Barstar | |
| π€ [Harshit494/BALM-PPI](https://huggingface.co/Harshit494/BALM-PPI) Β· | |
| π» [rgorantla04/BALM-PPI](https://github.com/rgorantla04/BALM-PPI) | |
| """) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # EVENT WIRING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| mode.change( | |
| handle_mode_change, | |
| inputs=[mode], | |
| outputs=[ppi_panel, abag_panel, ppi_chains_row, abag_chains_row], | |
| ) | |
| load_btn.click( | |
| handle_load_model, | |
| inputs=[pkd_lo, pkd_hi], | |
| outputs=[status_html], | |
| ) | |
| # Outputs that load_example/fetch_pdb returns | |
| load_outputs = [ | |
| ppi_a, ppi_b, ab_h, ab_l, ab_ag, | |
| state_chain_info_a, state_chain_info_b, | |
| state_pdb_content, state_ig_chain_map, | |
| state_ig_a, state_ig_b, state_result, | |
| ngl_html, result_html, | |
| strip_a_html, bar_a_plot, hm_a_plot, | |
| strip_b_html, bar_b_plot, hm_b_plot, | |
| pdb_info_md, | |
| ] | |
| for btn, ex in example_buttons: | |
| btn.click( | |
| lambda mode_val=ex["mode"], pdb=ex["pdb"], | |
| ca=ex.get("chain_a"), cb=ex.get("chain_b"), | |
| ch=ex.get("chain_h"), cl=ex.get("chain_l"), | |
| cag=ex.get("chain_ag"): | |
| handle_load_example(pdb, mode_val, ca, cb, ch, cl, cag), | |
| outputs=load_outputs, | |
| ) | |
| fetch_btn.click( | |
| handle_fetch_pdb, | |
| inputs=[pdb_in, mode, ca_in, cb_in, ch_in, cl_in, cag_in], | |
| outputs=load_outputs, | |
| ) | |
| run_btn.click( | |
| handle_run_prediction, | |
| inputs=[ | |
| mode, ppi_a, ppi_b, ab_h, ab_l, ab_ag, | |
| run_ig, pkd_lo, pkd_hi, | |
| state_chain_info_a, state_chain_info_b, state_pdb_content, | |
| ], | |
| outputs=[ | |
| state_result, | |
| state_ig_a, state_ig_b, state_ig_chain_map, | |
| result_html, ngl_html, | |
| strip_a_html, bar_a_plot, hm_a_plot, | |
| strip_b_html, bar_b_plot, hm_b_plot, | |
| sidebar_result, status_html, | |
| ], | |
| ) | |
| dl_btn.click( | |
| handle_download_results, | |
| inputs=[state_result, state_ig_a, state_ig_b, | |
| state_chain_info_a, state_chain_info_b], | |
| outputs=[dl_csv, dl_json], | |
| ) | |
| batch_run.click( | |
| handle_run_batch, | |
| inputs=[batch_file, batch_mode, run_ig_b, pkd_lo, pkd_hi], | |
| outputs=[batch_df, batch_dl], | |
| ) | |
| return demo | |
| def _write_template(content: str, name: str) -> str: | |
| path = os.path.join(tempfile.gettempdir(), name) | |
| with open(path, "w") as f: | |
| f.write(content) | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| # Eager model load at startup β happens during Space's "Starting" phase, | |
| # so the user opens a ready UI instead of waiting ~5 GB of downloads on | |
| # their first click. Progress shows in the Space's runtime logs. | |
| print("[BALM-PPI] Pre-loading model on CPUβ¦", flush=True) | |
| try: | |
| ensure_model(1.0, 16.0) | |
| print("[BALM-PPI] Model ready.", flush=True) | |
| except Exception as _e: | |
| print(f"[BALM-PPI] WARNING β pre-load failed: {_e}", flush=True) | |
| traceback.print_exc() | |
| demo = build_ui() | |
| demo.queue(max_size=20).launch( | |
| css=CSS, | |
| theme=gr.themes.Soft(), | |
| ssr_mode=False, # cleaner logs; SSR is experimental in Gradio 6 | |
| ) |