BALM-PPI / app.py
Harshit494's picture
Update app.py
78b0ca2 verified
"""
BALM-PPI Pro Β· ESM-2 + LoRA + Integrated Gradients
=====================================================
Gradio port with HuggingFace ZeroGPU support.
Key porting notes:
- All GPU work (forward pass + IG) is wrapped in @spaces.GPU so it runs on
ZeroGPU's dynamic A100s. Outside the decorator, the model lives on CPU.
- IG kept at float32 with 15 Riemann steps (same fix as the Streamlit version).
- NGL viewer + Plotly visualisations + theming preserved.
- Per-session state (PDB, chain info, IG arrays) is held in gr.State so the
app is safe under concurrent users on Spaces.
"""
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import io, re, json, csv, traceback, tempfile, copy
import io as sio
import gradio as gr
import torch
import pandas as pd
import numpy as np
import requests
import plotly.graph_objects as go
from Bio import PDB
from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE
# ─── ZeroGPU compatibility shim ────────────────────────────────────────────
# On HF Spaces with ZeroGPU, `spaces` is installed and @spaces.GPU works.
# Locally (or off-Spaces) we fall back to a no-op decorator so the same file
# runs both places without modification.
try:
import spaces
_HAS_SPACES = True
except ImportError:
_HAS_SPACES = False
class _SpacesShim:
@staticmethod
def GPU(duration=60):
def decorator(fn):
return fn
return decorator
spaces = _SpacesShim() # type: ignore
HF_REPO_ID = "Harshit494/BALM-PPI"
HF_FILENAME = "best_model_fold_1.pth"
UPPER_3TO1 = {k.upper(): v for k, v in THREE_TO_ONE.items()}
EXAMPLES = [
{
"pdb": "1YCR", "label": "1YCR", "subtitle": "MDM2 ↔ p53",
"desc": "MDM2 oncoprotein bound to the p53 transactivation peptide. Chain A = MDM2, Chain B = p53. Classic drug-target complex.",
"mode": "Protein-Protein", "chain_a": "A", "chain_b": "B",
},
{
"pdb": "1BRS", "label": "1BRS", "subtitle": "Barnase ↔ Barstar",
"desc": "Ultra-tight complex (Kd ~10⁻¹⁴ M). Chain A = Barnase, Chain D = Barstar. Classic benchmark.",
"mode": "Protein-Protein", "chain_a": "A", "chain_b": "D",
},
]
PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n"
ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n"
# ═══════════════════════════════════════════════════════════════════════════
# GLOBAL MODEL CACHE
# Loaded lazily on CPU; moved to CUDA only inside @spaces.GPU functions.
# ═══════════════════════════════════════════════════════════════════════════
_MODEL_STATE = {
"model": None,
"pkd_bounds": None, # tuple (lo, hi) β€” rebuilt if bounds change
}
# ═══════════════════════════════════════════════════════════════════════════
# PDB UTILITIES
# ═══════════════════════════════════════════════════════════════════════════
_PDB_CACHE: dict = {}
def _fetch_pdb_cached(pdb_id: str) -> str:
pdb_id = pdb_id.strip().upper()
if pdb_id in _PDB_CACHE:
return _PDB_CACHE[pdb_id]
url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
r = requests.get(url, timeout=30)
r.raise_for_status()
_PDB_CACHE[pdb_id] = r.text
# Keep the cache from growing unbounded
if len(_PDB_CACHE) > 64:
_PDB_CACHE.pop(next(iter(_PDB_CACHE)))
return r.text
def filter_pdb_to_chains(pdb_text: str, keep_chains: set) -> str:
"""Strip PDB to only ATOM/HETATM records for chain IDs in keep_chains."""
out = []
for line in pdb_text.splitlines():
rec = line[:6].strip()
if rec in ("ATOM", "HETATM"):
chain_col = line[21:22]
if chain_col not in keep_chains:
continue
out.append(line)
return "\n".join(out)
def get_chains_from_pdb(pdb_id: str):
try:
pdb_text = _fetch_pdb_cached(pdb_id)
except Exception as e:
return None, {}, f"Could not fetch PDB {pdb_id}: {e}"
import warnings as _w
with _w.catch_warnings():
_w.simplefilter("ignore")
parser = PDB.PDBParser(QUIET=True)
struct = parser.get_structure(pdb_id, io.StringIO(pdb_text))
chains: dict = {}
for mdl in struct:
for chain in mdl:
cid = chain.id
if cid not in chains:
chains[cid] = {"seq": "", "resnos": []}
for residue in chain:
if PDB.is_aa(residue, standard=True):
aa = UPPER_3TO1.get(residue.get_resname().upper(), "X")
chains[cid]["seq"] += aa
chains[cid]["resnos"].append(residue.get_id()[1])
break
chains = {k: v for k, v in chains.items() if v["seq"]}
return pdb_text, chains, None
def resolve_chains(chains_dict, ids_str, available):
if not ids_str:
return "", [], []
ids = [c.strip() for c in str(ids_str).split(",") if c.strip()]
seqs, infos, missing = [], [], []
for cid in ids:
if cid in chains_dict and chains_dict[cid]["seq"]:
seqs.append(chains_dict[cid]["seq"])
infos.append({"id": cid, **chains_dict[cid]})
else:
found = next((k for k in chains_dict
if k.upper() == cid.upper() and chains_dict[k]["seq"]), None)
if found:
seqs.append(chains_dict[found]["seq"])
infos.append({"id": found, **chains_dict[found]})
else:
missing.append(cid)
return "|".join(seqs), infos, missing
def clean_multi(s: str) -> str:
if not s:
return ""
out = []
for p in s.split("|"):
p = re.sub(r">.*", "", p, flags=re.MULTILINE)
p = re.sub(r"[^A-Za-z]", "", p).upper()
if p:
out.append(p)
return "|".join(out)
# ═══════════════════════════════════════════════════════════════════════════
# IG HELPERS
# ═══════════════════════════════════════════════════════════════════════════
def split_ig_by_chains(ig_flat, chain_seqs, n_sep=2):
if not ig_flat:
return [[] for _ in chain_seqs]
result, offset = [], 0
for i, seq in enumerate(chain_seqs):
result.append(list(ig_flat[offset:offset + len(seq)]))
offset += len(seq)
if i < len(chain_seqs) - 1:
offset += n_sep
return result
def build_ig_chain_map(chain_info_a, chain_info_b, ig_a, ig_b):
result = {}
for info, ig_flat in [(chain_info_a, ig_a), (chain_info_b, ig_b)]:
if not info:
continue
per_chain = split_ig_by_chains(ig_flat, [c["seq"] for c in info])
for i, ch in enumerate(info):
if ch.get("resnos"):
result[ch["id"]] = {"start": ch["resnos"][0], "scores": per_chain[i]}
return result
def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b):
seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else ""
seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else ""
if ig_a and chain_info_a:
parts = split_ig_by_chains(ig_a, [c["seq"] for c in chain_info_a])
ig_a_out = [s for sub in parts for s in sub]
else:
ig_a_out = list(ig_a) if ig_a else []
if ig_b and chain_info_b:
parts = split_ig_by_chains(ig_b, [c["seq"] for c in chain_info_b])
ig_b_out = [s for sub in parts for s in sub]
else:
ig_b_out = list(ig_b) if ig_b else []
return seq_a, ig_a_out, seq_b, ig_b_out
# ═══════════════════════════════════════════════════════════════════════════
# MODEL BUILD (CPU-safe β€” no CUDA touched here)
# ═══════════════════════════════════════════════════════════════════════════
def _load_esm_base():
from transformers import EsmModel, EsmTokenizer
base = EsmModel.from_pretrained(
"facebook/esm2_t33_650M_UR50D",
use_safetensors=True,
)
tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
return base, tok
def _download_weights_hf() -> bytes:
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True)
with open(path, "rb") as f:
return f.read()
def build_model(weights_bytes, pkd_bounds):
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
class BALMProjectionHead(nn.Module):
def __init__(self, embedding_size, projected_size=256):
super().__init__()
self.protein_projection = nn.Linear(embedding_size, projected_size)
self.proteina_projection = nn.Linear(embedding_size, projected_size)
self.dropout = nn.Dropout(0.1)
def forward(self, prot_emb, prot_a_emb):
p = F.normalize(self.protein_projection(self.dropout(prot_emb)), p=2, dim=1)
pa = F.normalize(self.proteina_projection(self.dropout(prot_a_emb)), p=2, dim=1)
return torch.clamp(F.cosine_similarity(p, pa), -0.9999, 0.9999)
class BALMForLoRAFinetuning(nn.Module):
def __init__(self, esm_model, esm_tokenizer, pkd_bounds):
super().__init__()
self.esm_model = esm_model
self.esm_tokenizer = esm_tokenizer
self.projection_head = BALMProjectionHead(self.esm_model.config.hidden_size)
self.pkd_lower, self.pkd_upper = pkd_bounds
self.cls_token = self.esm_tokenizer.cls_token
def _get_esm_embeddings(self, sequences):
processed = [s.replace("|", f"{self.cls_token}{self.cls_token}") for s in sequences]
inputs = self.esm_tokenizer(processed, return_tensors="pt", padding=True,
truncation=True, max_length=1024)
inputs = {k: v.to(self.esm_model.device) for k, v in inputs.items()}
h = self.esm_model(**inputs).last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).expand(h.size()).float()
return (torch.sum(h * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)).float()
def forward(self, seq_a, seq_b):
ea = self._get_esm_embeddings([seq_a])
eb = self._get_esm_embeddings([seq_b])
cos = self.projection_head(ea, eb)
pkd = ((cos + 1) / 2) * (self.pkd_upper - self.pkd_lower) + self.pkd_lower
return pkd, cos
base_esm, tok = _load_esm_base()
lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1,
target_modules=["key", "query", "value"])
model = BALMForLoRAFinetuning(
get_peft_model(base_esm, lora_cfg), tok, pkd_bounds)
with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp:
tmp.write(weights_bytes)
tmp_path = tmp.name
model.load_state_dict(torch.load(tmp_path, map_location="cpu"), strict=False)
os.unlink(tmp_path)
model.eval()
return model
def ensure_model(pkd_lo: float, pkd_hi: float):
"""Lazy load (on CPU). Rebuilds head bounds if user changes them."""
bounds = (float(pkd_lo), float(pkd_hi))
if _MODEL_STATE["model"] is None:
wb = _download_weights_hf()
_MODEL_STATE["model"] = build_model(wb, bounds)
_MODEL_STATE["pkd_bounds"] = bounds
elif _MODEL_STATE["pkd_bounds"] != bounds:
# Same weights, just update the affine output bounds
_MODEL_STATE["model"].pkd_lower = bounds[0]
_MODEL_STATE["model"].pkd_upper = bounds[1]
_MODEL_STATE["pkd_bounds"] = bounds
return _MODEL_STATE["model"]
# ═══════════════════════════════════════════════════════════════════════════
# INTEGRATED GRADIENTS (float32-safe, 15 Riemann steps)
# ═══════════════════════════════════════════════════════════════════════════
def compute_ig(model, seq_a: str, seq_b: str, steps: int = 15):
device = next(model.parameters()).device
tok = model.esm_tokenizer
cls_tok = tok.cls_token
esm = model.esm_model
def tokenise(seq):
proc = seq.replace("|", f"{cls_tok}{cls_tok}")
return tok(proc, return_tensors="pt",
padding=False, truncation=True, max_length=1024).to(device)
word_embed = esm.base_model.model.embeddings.word_embeddings
enc_a = tokenise(seq_a); mask_a = enc_a["attention_mask"]
enc_b = tokenise(seq_b); mask_b = enc_b["attention_mask"]
emb_a = word_embed(enc_a["input_ids"]).detach().float().clone()
emb_b = word_embed(enc_b["input_ids"]).detach().float().clone()
mask_a = mask_a.float()
mask_b = mask_b.float()
def encode(embs, mask):
int_mask = mask.long()
ext = esm.base_model.model.get_extended_attention_mask(int_mask, embs.shape[:2])
h = esm.base_model.model.encoder(
embs.float(), attention_mask=ext
).last_hidden_state.float()
m = mask.unsqueeze(-1).expand(h.size()).float()
return (torch.sum(h * m, 1) / torch.clamp(m.sum(1), min=1e-9))
def fwd(e_a, e_b):
return model.projection_head(encode(e_a, mask_a), encode(e_b, mask_b))
def riemann(tgt, baseline, fixed, tgt_is_b):
grads = []
for alpha in torch.linspace(0, 1, steps, device=device):
interp = (baseline + alpha * (tgt - baseline)).detach().requires_grad_(True)
out = fwd(fixed, interp) if tgt_is_b else fwd(interp, fixed)
out.sum().backward()
grads.append(interp.grad.detach().float().clone())
avg_grad = torch.stack(grads).mean(0)
ig = (avg_grad * (tgt - baseline)).abs().sum(-1).squeeze(0)
return ig.cpu().numpy()
baseline_a = torch.zeros_like(emb_a)
baseline_b = torch.zeros_like(emb_b)
attr_a = riemann(emb_a, baseline_a, emb_b, False)
attr_b = riemann(emb_b, baseline_b, emb_a, True)
def norm(a):
lo, hi = a.min(), a.max()
if hi - lo < 1e-9:
return [0.0] * len(a)
return ((a - lo) / (hi - lo)).tolist()
return norm(attr_a[1:-1]), norm(attr_b[1:-1])
# ═══════════════════════════════════════════════════════════════════════════
# GPU-WRAPPED PREDICTION
# This is the only function that touches CUDA. ZeroGPU will allocate an A100
# for the duration of the call.
# ═══════════════════════════════════════════════════════════════════════════
@spaces.GPU(duration=120)
def gpu_run_prediction(seq_a, seq_b, run_ig, pkd_lo, pkd_hi):
model = ensure_model(pkd_lo, pkd_hi)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
with torch.no_grad():
pkd_t, cos_t = model(seq_a, seq_b)
pkd = float(pkd_t.item())
cos = float(cos_t.item())
ig_a, ig_b = None, None
if run_ig:
ig_a, ig_b = compute_ig(model, seq_a, seq_b, steps=15)
# Move back to CPU so the model stays small in resident memory between calls
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pkd, cos, ig_a, ig_b, str(device)
@spaces.GPU(duration=300)
def gpu_run_batch(rows, batch_mode, run_ig, pkd_lo, pkd_hi):
"""Batch over rows. `rows` is a list of dicts with seq_a/seq_b or heavy/light/antigen."""
model = ensure_model(pkd_lo, pkd_hi)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
results = []
for row in rows:
if batch_mode == "Protein-Protein":
sa = clean_multi(str(row.get("seq_a", "")))
sb = clean_multi(str(row.get("seq_b", "")))
else:
h = clean_multi(str(row.get("heavy_chain", "")))
l = clean_multi(str(row.get("light_chain", "")))
ag = clean_multi(str(row.get("antigen", "")))
sa = ag
sb = f"{h}|{l}" if (h or l) else ""
if not sa or not sb:
results.append({"pKd": None, "Cosine": None})
continue
try:
with torch.no_grad():
pkd_t, cos_t = model(sa, sb)
rr = {"pKd": round(float(pkd_t.item()), 4),
"Cosine": round(float(cos_t.item()), 6)}
if run_ig:
ia, ib = compute_ig(model, sa, sb, steps=15)
sap, sbp = sa.replace("|", ""), sb.replace("|", "")
rr["top5_Target"] = "|".join(
f"{sap[j]}{j+1}:{ia[j]:.3f}"
for j in sorted(range(min(len(sap), len(ia))), key=lambda x: -ia[x])[:5])
rr["top5_proteina"] = "|".join(
f"{sbp[j]}{j+1}:{ib[j]:.3f}"
for j in sorted(range(min(len(sbp), len(ib))), key=lambda x: -ib[x])[:5])
except Exception as e:
rr = {"pKd": None, "Cosine": None, "error": str(e)}
results.append(rr)
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
# ═══════════════════════════════════════════════════════════════════════════
# NGL VIEWER (theme-reactive, filtered to selected chains)
# ═══════════════════════════════════════════════════════════════════════════
def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
if not pdb_content:
return (
'<div style="display:flex;flex-direction:column;align-items:center;'
'justify-content:center;height:420px;background:var(--bg2,#f1f4f9);'
'border:1px solid var(--border,#dde3ec);border-radius:12px;'
'color:var(--text2,#64748b);font-family:JetBrains Mono,monospace;'
'font-size:.8rem;text-align:center;line-height:2">'
'<div style="font-size:2.2rem">🧬</div>'
'<div style="margin-top:12px;font-weight:600">Load an example or fetch a PDB</div>'
'<div style="font-size:.72rem;margin-top:4px">Only selected chains will be shown</div>'
'</div>'
)
ig_json = json.dumps(ig_chain_map) if ig_chain_map else "null"
escaped = pdb_content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$")
return f"""<!DOCTYPE html><html><head>
<meta charset="utf-8">
<script src="https://cdn.jsdelivr.net/npm/ngl@2.0.0-dev.37/dist/ngl.js"></script>
<style>
*{{box-sizing:border-box;margin:0;padding:0}}
body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:background .35s}}
#vp{{width:100%;height:{height}px}}
#hint{{position:absolute;top:10px;left:12px;font-size:10px;pointer-events:none;letter-spacing:.04em;transition:color .3s}}
#ctrl{{position:absolute;bottom:12px;right:12px;display:flex;flex-direction:column;gap:4px}}
.cb{{padding:4px 11px;font-family:'JetBrains Mono',monospace;font-size:9px;font-weight:500;
border-radius:5px;cursor:pointer;letter-spacing:.1em;text-transform:uppercase;transition:all .15s}}
#leg{{position:absolute;bottom:12px;left:12px;font-size:9px;letter-spacing:.05em;
border-radius:5px;padding:5px 10px;line-height:1.9;transition:all .3s}}
</style></head><body>
<div id="vp"></div>
<div id="hint">🧬 Drag · Scroll · Right-drag pan</div>
<div id="ctrl">
<button class="cb" onclick="setRep('cartoon')">Cartoon</button>
<button class="cb" onclick="setRep('surface')">Surface</button>
<button class="cb" onclick="setRep('ball+stick')">Ball+Stick</button>
<button class="cb" onclick="stage.autoView()">Reset</button>
</div>
<div id="leg"></div>
<script>
var T={{
light:{{bg:'#f0f4fa',hint:'rgba(51,65,85,.5)',
btn:'rgba(240,244,250,.93)',btn_b:'rgba(37,99,235,.16)',btn_c:'rgba(71,85,105,.6)',
btn_hc:'#2563eb',btn_hb:'rgba(37,99,235,.55)',btn_hbg:'rgba(37,99,235,.06)',
leg:'rgba(240,244,250,.93)',leg_b:'rgba(37,99,235,.13)',leg_c:'rgba(71,85,105,.65)',
ig_hi:'#1d4ed8',ig_lo:'#c7d7f0',null_col:0xe8ecf4}},
dark:{{bg:'#06090f',hint:'rgba(168,184,216,.4)',
btn:'rgba(6,9,15,.88)',btn_b:'rgba(96,165,250,.16)',btn_c:'rgba(168,184,216,.55)',
btn_hc:'#60a5fa',btn_hb:'rgba(96,165,250,.55)',btn_hbg:'rgba(96,165,250,.06)',
leg:'rgba(6,9,15,.88)',leg_b:'rgba(96,165,250,.13)',leg_c:'rgba(168,184,216,.55)',
ig_hi:'#60a5fa',ig_lo:'#172135',null_col:0x172135}}
}};
function isDark(){{return window.matchMedia('(prefers-color-scheme:dark)').matches;}}
var theme=isDark()?T.dark:T.light,stage=null,comp=null,curRep='cartoon',igData={ig_json};
function igColor(s,dark){{
s=Math.max(0,Math.min(1,s||0));
if(dark) return(Math.round(23+s*73)<<16)|(Math.round(33+s*132)<<8)|Math.round(53+s*197);
return(Math.round(199-s*170)<<16)|(Math.round(215-s*137)<<8)|Math.round(240-s*24);
}}
function styleUI(dark){{
var Th=dark?T.dark:T.light; theme=Th;
document.body.style.background=Th.bg;
document.getElementById('hint').style.color=Th.hint;
document.querySelectorAll('.cb').forEach(function(b){{
b.style.background=Th.btn;b.style.border='1px solid '+Th.btn_b;b.style.color=Th.btn_c;
b.onmouseenter=function(){{b.style.borderColor=Th.btn_hb;b.style.color=Th.btn_hc;b.style.background=Th.btn_hbg;}};
b.onmouseleave=function(){{b.style.borderColor=Th.btn_b;b.style.color=Th.btn_c;b.style.background=Th.btn;}};
}});
var leg=document.getElementById('leg');
leg.style.background=Th.leg;leg.style.border='1px solid '+Th.leg_b;leg.style.color=Th.leg_c;
if(stage)stage.setParameters({{backgroundColor:Th.bg}});
}}
function addRepr(c){{
comp=c;comp.removeAllRepresentations();
var dark=isDark(),Th=dark?T.dark:T.light,leg=document.getElementById('leg');
if(igData){{
var sid=NGL.ColormakerRegistry.addScheme(function(){{
this.atomColor=function(atom){{
var cd=igData[atom.chainname];
if(!cd||!cd.scores||!cd.scores.length) return Th.null_col;
var i=atom.resno-cd.start;
return(i<0||i>=cd.scores.length)?Th.null_col:igColor(cd.scores[i],dark);
}};
}});
comp.addRepresentation(curRep,{{color:sid}});
leg.innerHTML='<span style="color:'+Th.ig_hi+'">β– </span> High IG &nbsp;&nbsp;<span style="color:'+Th.ig_lo+'">β– </span> Low IG';
}}else{{
comp.addRepresentation(curRep,{{colorScheme:'chainname'}});
leg.innerHTML='Coloured by chain';
}}
stage.autoView();
}}
function setRep(r){{curRep=r;if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
function themeSwitch(dark){{styleUI(dark);if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){{themeSwitch(e.matches);}});
window.addEventListener('message',function(e){{if(e.data&&e.data.balmTheme)themeSwitch(e.data.balmTheme==='dark');}});
window.addEventListener('resize',function(){{if(stage)stage.handleResize();}});
stage=new NGL.Stage('vp',{{backgroundColor:theme.bg,quality:'medium',tooltip:true}});
styleUI(isDark());
var blob=new Blob([`{escaped}`],{{type:'text/plain'}});
stage.loadFile(URL.createObjectURL(blob),{{ext:'pdb',name:'s'}}).then(addRepr);
</script></body></html>"""
def wrap_iframe(html_src: str, height: int = 425) -> str:
"""Sandboxed iframe wrapper so NGL renders cleanly inside Gradio HTML."""
srcdoc = html_src.replace('"', '&quot;')
return (
f'<iframe srcdoc="{srcdoc}" '
f'style="width:100%;height:{height}px;border:0;border-radius:12px" '
f'sandbox="allow-scripts allow-same-origin"></iframe>'
)
# ═══════════════════════════════════════════════════════════════════════════
# PLOTLY
# ═══════════════════════════════════════════════════════════════════════════
_PL = dict(paper_bgcolor="rgba(255,255,255,.97)", plot_bgcolor="rgba(241,244,249,1)")
_FONT = dict(family="JetBrains Mono, monospace", color="#475569", size=9)
_GRID = dict(gridcolor="rgba(37,99,235,.07)", zerolinecolor="rgba(37,99,235,.1)")
def make_heatmap(seq, attr, title):
W = 50; rows, hover, ylabels = [], [], []
for r in range(0, len(seq), W):
rows.append([attr[r + c] if r + c < len(attr) else None for c in range(W)])
hover.append([f"{seq[r+c]}{r+c+1} IG:{attr[r+c]:.3f}" if r + c < len(seq) else ""
for c in range(W)])
ylabels.append(f"{r+1}–{min(r+W, len(seq))}")
fig = go.Figure(go.Heatmap(
z=rows, text=hover, hoverinfo="text",
colorscale=[[0, "#e8f0fe"], [.4, "#93c5fd"], [.75, "#3b82f6"], [1, "#1d4ed8"]],
zmin=0, zmax=1, xgap=1, ygap=1,
colorbar=dict(thickness=10, len=.9, tickfont=dict(**_FONT),
title=dict(text="IG", font=dict(**_FONT)))))
fig.update_layout(
title=dict(text=title, font=dict(family="JetBrains Mono, monospace", size=10, color="#1d4ed8")),
**_PL, height=240, margin=dict(t=30, b=24, l=55, r=45),
xaxis=dict(title=dict(text="Position in window", font=dict(**_FONT)),
tickfont=dict(**_FONT), **_GRID),
yaxis=dict(tickvals=list(range(len(ylabels))), ticktext=ylabels,
tickfont=dict(**_FONT), **_GRID))
return fig
def top10_bar(seq, ig, title):
indexed = sorted([{"aa": seq[i], "idx": i + 1, "sc": ig[i]}
for i in range(min(len(seq), len(ig)))], key=lambda x: -x["sc"])[:10]
fig = go.Figure(go.Bar(
x=[x["sc"] for x in indexed],
y=[f"{x['aa']}{x['idx']}" for x in indexed],
orientation="h",
marker=dict(color=[x["sc"] for x in indexed],
colorscale=[[0, "#dbeafe"], [.5, "#3b82f6"], [1, "#1e40af"]],
showscale=False, line_width=0)))
fig.update_layout(
title=dict(text=title, font=dict(family="JetBrains Mono, monospace", size=10, color="#1d4ed8")),
**_PL, height=260, margin=dict(t=30, b=18, l=52, r=16),
xaxis=dict(title=dict(text="IG Score", font=dict(**_FONT)),
tickfont=dict(**_FONT), **_GRID),
yaxis=dict(tickfont=dict(**_FONT), autorange="reversed"))
return fig
def residue_strip_html(seq, attr) -> str:
cells = []
for i, aa in enumerate(seq):
s = attr[i] if i < len(attr) else 0.0
r = int(248 - s * 211); g = int(250 - s * 151); b = int(252 - s * 17)
fg = "#1e40af" if s > 0.5 else "#475569"
cells.append(
f'<span style="background:rgb({r},{g},{b});color:{fg};'
f'font-family:JetBrains Mono,monospace;font-size:11px;font-weight:600;'
f'padding:2px 4px;border-radius:4px;cursor:default;display:inline-block;margin:1px;'
f'border:1px solid rgba(37,99,235,.09)" title="{aa}{i+1} IG:{s:.3f}">{aa}</span>')
return ('<div style="font-size:9px;font-weight:700;letter-spacing:.18em;'
'text-transform:uppercase;color:#2563eb;margin-bottom:8px;'
'font-family:JetBrains Mono,monospace">'
'Residue Attribution (hover for score)</div>'
'<div style="line-height:2.2;word-break:break-all">' + "".join(cells) + "</div>")
# ═══════════════════════════════════════════════════════════════════════════
# RESULT FORMATTING
# ═══════════════════════════════════════════════════════════════════════════
def result_card_html(pkd: float, cos: float, pkd_lo: float, pkd_hi: float) -> str:
pct = max(0.0, min(100.0, ((pkd - pkd_lo) / max(pkd_hi - pkd_lo, 1e-9)) * 100))
strength = "Weak" if pct < 30 else "Moderate" if pct < 55 else "Strong" if pct < 75 else "Very Strong"
badge_cls = "badge-weak" if pct < 30 else "badge-moderate" if pct < 55 else "badge-strong"
return f"""
<div style="display:grid;grid-template-columns:1fr 1fr 2fr;gap:14px;align-items:stretch">
<div class="pkd-card">
<div class="pkd-lbl">Predicted pKd</div>
<div class="pkd-val">{pkd:.3f}</div>
</div>
<div class="pkd-card">
<div class="pkd-lbl">Cosine Similarity</div>
<div class="pkd-val" style="font-size:1.55rem">{cos:.4f}</div>
</div>
<div class="pkd-card">
<div class="str-labels">
<span>Weak ({pkd_lo:.0f})</span>
<span style="color:var(--text0,#0f172a);font-weight:600">{strength}</span>
<span>Strong ({pkd_hi:.0f})</span>
</div>
<div class="str-bar"><div class="str-fill" style="width:{pct:.1f}%"></div></div>
<span class="pkd-badge {badge_cls}">{strength}</span>
</div>
</div>
"""
def status_badge_html(model_loaded: bool, device: str = "cpu") -> str:
if model_loaded:
return (f'<div class="ready-badge"><span class="ready-dot"></span>'
f'READY &nbsp;Β·&nbsp; {device}</div>')
return '<div class="idle-badge">β—‹ &nbsp;NOT LOADED</div>'
def sidebar_result_html(pkd: float, cos: float) -> str:
return (
f'<div class="pkd-card"><div class="pkd-lbl">Predicted pKd</div>'
f'<div class="pkd-val">{pkd:.3f}</div>'
f'<div class="pkd-lbl" style="margin-top:10px">Cosine Similarity</div>'
f'<div style="font-family:var(--mono,monospace);font-size:1rem;font-weight:700;'
f'color:var(--text0,#0f172a)">{cos:.4f}</div></div>'
)
# ═══════════════════════════════════════════════════════════════════════════
# CSS (adapted for Gradio's DOM; theme vars + custom classes preserved)
# ═══════════════════════════════════════════════════════════════════════════
CSS = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@300;400;500;600;700&display=swap');
:root {
--bg0:#ffffff; --bg1:#f8f9fc; --bg2:#f1f4f9; --bg3:#e4e9f2;
--border:rgba(37,99,235,0.13); --shadow:rgba(37,99,235,0.07);
--accent:#2563eb; --accent2:#7c3aed;
--green:#16a34a; --amber:#d97706; --red:#dc2626;
--text0:#0f172a; --text1:#334155; --text2:#64748b; --text3:#94a3b8;
--mono:'JetBrains Mono',monospace; --sans:'Inter',sans-serif;
}
@media (prefers-color-scheme:dark){
:root {
--bg0:#06090f; --bg1:#0b1120; --bg2:#101828; --bg3:#172135;
--border:rgba(96,165,250,0.13); --shadow:rgba(0,0,0,0.35);
--accent:#60a5fa; --accent2:#a78bfa;
--green:#34d399; --amber:#fbbf24; --red:#f87171;
--text0:#f1f5f9; --text1:#cbd5e1; --text2:#64748b; --text3:#334155;
}
}
.gradio-container { font-family: var(--sans) !important; max-width: 1400px !important; }
.gradio-container h1, .gradio-container h2, .gradio-container h3 {
font-family: var(--sans) !important; color: var(--text0) !important;
}
code, pre {
font-family: var(--mono) !important; font-size: .82rem !important;
background: var(--bg2) !important; border-radius: 4px !important;
color: var(--accent) !important;
}
.app-header{display:flex;align-items:center;gap:14px;padding:4px 0 14px;}
.app-logo{width:40px;height:40px;border-radius:10px;
background:linear-gradient(135deg,var(--accent),var(--accent2));
display:flex;align-items:center;justify-content:center;
font-size:20px;flex-shrink:0;box-shadow:0 2px 12px rgba(37,99,235,.22);}
.app-title{font-family:var(--mono)!important;font-size:1.22rem!important;
font-weight:700!important;color:var(--text0)!important;letter-spacing:-.02em;}
.app-subtitle{font-size:.77rem!important;color:var(--text2)!important;margin-top:1px;letter-spacing:.03em;}
.sec-hdr{font-family:var(--mono);font-size:.65rem;font-weight:700;letter-spacing:.2em;
text-transform:uppercase;color:var(--accent);margin:14px 0 8px;
display:flex;align-items:center;gap:8px;}
.sec-hdr::after{content:'';flex:1;height:1px;background:var(--border);}
.ex-card{background:var(--bg1);border:1px solid var(--border);border-radius:10px;
padding:12px 15px;margin-bottom:8px;position:relative;overflow:hidden;
transition:box-shadow .15s,border-color .15s;}
.ex-card::before{content:'';position:absolute;left:0;top:0;bottom:0;width:3px;
background:linear-gradient(180deg,var(--accent),var(--accent2));border-radius:3px 0 0 3px;}
.ex-card:hover{border-color:rgba(37,99,235,.3);box-shadow:0 2px 12px var(--shadow);}
.ex-pdb{font-family:var(--mono);font-size:.95rem;font-weight:700;color:var(--accent);}
.ex-sub{font-size:.8rem;font-weight:600;color:var(--text1);margin:2px 0;}
.ex-desc{font-size:.72rem;line-height:1.5;color:var(--text2);margin-top:3px;}
.pkd-card{background:linear-gradient(135deg,rgba(37,99,235,.05),rgba(124,58,237,.04));
border:1px solid var(--border);border-radius:12px;padding:14px 18px;
box-shadow:0 1px 6px var(--shadow);}
.pkd-lbl{font-family:var(--mono);font-size:.67rem;color:var(--text2);
text-transform:uppercase;letter-spacing:.12em;margin-bottom:2px;}
.pkd-val{font-family:var(--mono);font-size:2rem;font-weight:700;color:var(--accent);line-height:1.1;}
.pkd-badge{display:inline-block;padding:2px 9px;border-radius:20px;
font-size:.7rem;font-weight:600;font-family:var(--mono);margin-top:5px;}
.badge-weak{background:rgba(220,38,38,.08);color:var(--red);border:1px solid rgba(220,38,38,.22);}
.badge-moderate{background:rgba(217,119,6,.08);color:var(--amber);border:1px solid rgba(217,119,6,.22);}
.badge-strong{background:rgba(22,163,74,.08);color:var(--green);border:1px solid rgba(22,163,74,.22);}
.str-bar{height:5px;border-radius:3px;background:var(--bg3);overflow:hidden;margin:8px 0 3px;}
.str-fill{height:100%;border-radius:3px;
background:linear-gradient(90deg,var(--accent),var(--accent2));
transition:width .7s cubic-bezier(.4,0,.2,1);}
.str-labels{display:flex;justify-content:space-between;font-size:.65rem;
color:var(--text2);font-family:var(--mono);}
.ready-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px;
border-radius:20px;background:rgba(22,163,74,.07);
border:1px solid rgba(22,163,74,.22);
font-family:var(--mono);font-size:.72rem;color:var(--green);font-weight:600;}
.ready-dot{display:inline-block;width:6px;height:6px;border-radius:50%;
background:var(--green);animation:pulse 2s infinite;}
.idle-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px;
border-radius:20px;background:rgba(100,116,139,.07);
border:1px solid rgba(100,116,139,.18);
font-family:var(--mono);font-size:.72rem;color:var(--text2);font-weight:600;}
@keyframes pulse{0%,100%{opacity:1;transform:scale(1)}50%{opacity:.5;transform:scale(.85)}}
textarea, input[type="text"], input[type="number"] {
font-family: var(--mono) !important; font-size: .82rem !important;
}
"""
# ═══════════════════════════════════════════════════════════════════════════
# GRADIO EVENT HANDLERS
# ═══════════════════════════════════════════════════════════════════════════
def handle_mode_change(mode):
"""Show the right input panel for the selected interaction mode."""
if mode == "Protein-Protein":
return (
gr.update(visible=True), # ppi panel
gr.update(visible=False), # ab-ag panel
gr.update(visible=True), # ppi chain inputs
gr.update(visible=False), # ab-ag chain inputs
)
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=True),
)
def handle_load_example(pdb_id, mode, chain_a, chain_b, chain_h=None, chain_l=None, chain_ag=None):
"""Fetch a PDB, populate sequences + chain info, return updates for the UI."""
pdb_content, chains, err = get_chains_from_pdb(pdb_id)
if err or not chains:
gr.Warning(err or f"No chains found in {pdb_id}")
return (
gr.update(), # ppi_a
gr.update(), # ppi_b
gr.update(), # ab_h
gr.update(), # ab_l
gr.update(), # ab_ag
[], # state_chain_info_a
[], # state_chain_info_b
"", # state_pdb_content
None, # state_ig_chain_map
None, None, # state_ig_a, state_ig_b
None, # state_result
gr.update(value=ngl_viewer_html(None)), # ngl_html
gr.update(value=""), # result_html
gr.update(value=""), # strip_a_html
None, # bar_a
None, # hm_a
gr.update(value=""), # strip_b_html
None, # bar_b
None, # hm_b
gr.update(value=""), # pdb_info_md
)
available = sorted(chains.keys())
info_msg = f"βœ… **{pdb_id.upper()}** β€” chains: **{', '.join(available)}**"
selected_chain_ids: set = set()
if mode == "Protein-Protein":
id_a = chain_a or available[0]
id_b = chain_b or (available[1] if len(available) > 1 else available[0])
seq_a, info_a, miss_a = resolve_chains(chains, id_a, available)
seq_b, info_b, miss_b = resolve_chains(chains, id_b, available)
for m in miss_a: gr.Warning(f"Chain {m} not found (Side A). Available: {', '.join(available)}")
for m in miss_b: gr.Warning(f"Chain {m} not found (Side B). Available: {', '.join(available)}")
selected_chain_ids = {c["id"] for c in info_a} | {c["id"] for c in info_b}
ppi_a_val = seq_a
ppi_b_val = seq_b
ah_val = al_val = aag_val = gr.update()
chain_info_a, chain_info_b = info_a, info_b
else:
id_h = chain_h or "H"
id_l = chain_l or "L"
id_ag_str = chain_ag or next((c for c in available if c not in (id_h, id_l)), available[0])
def _single(cid):
if cid in chains and chains[cid]["seq"]:
return cid, chains[cid]
f = next((k for k in chains if k.upper() == cid.upper() and chains[k]["seq"]), None)
return (f, chains[f]) if f else (cid, {"seq": "", "resnos": []})
real_h, ch = _single(id_h)
real_l, cl = _single(id_l)
seq_ag, info_ag, miss_ag = resolve_chains(chains, id_ag_str, available)
if not ch["seq"]: gr.Warning(f"Heavy chain {id_h} not found. Available: {', '.join(available)}")
if not cl["seq"]: gr.Warning(f"Light chain {id_l} not found. Available: {', '.join(available)}")
for m in miss_ag: gr.Warning(f"Antigen chain {m} not found. Available: {', '.join(available)}")
chain_info_a = info_ag if info_ag else []
chain_info_b = [{"id": real_h, **ch}, {"id": real_l, **cl}]
selected_chain_ids = ({c["id"] for c in info_ag} if info_ag else set())
if ch["seq"]: selected_chain_ids.add(real_h)
if cl["seq"]: selected_chain_ids.add(real_l)
ppi_a_val = ppi_b_val = gr.update()
ah_val = ch["seq"]
al_val = cl["seq"]
aag_val = seq_ag
filtered_pdb = filter_pdb_to_chains(pdb_content, selected_chain_ids) if selected_chain_ids else pdb_content
return (
ppi_a_val, # ppi_a
ppi_b_val, # ppi_b
ah_val, # ab_h
al_val, # ab_l
aag_val, # ab_ag
chain_info_a, # state_chain_info_a
chain_info_b, # state_chain_info_b
filtered_pdb, # state_pdb_content
None, # state_ig_chain_map (cleared)
None, None, # state_ig_a, state_ig_b
None, # state_result
wrap_iframe(ngl_viewer_html(filtered_pdb, None)), # ngl_html
"", # result_html
"", # strip_a_html
None, # bar_a
None, # hm_a
"", # strip_b_html
None, # bar_b
None, # hm_b
info_msg, # pdb_info_md
)
def handle_fetch_pdb(pdb_id, mode, ca_in, cb_in, ch_in, cl_in, cag_in):
if not pdb_id or not pdb_id.strip():
gr.Warning("Enter a PDB ID.")
return (gr.update(),) * 20
return handle_load_example(
pdb_id.strip(), mode,
chain_a=ca_in.strip() or None if ca_in else None,
chain_b=cb_in.strip() or None if cb_in else None,
chain_h=ch_in.strip() or None if ch_in else None,
chain_l=cl_in.strip() or None if cl_in else None,
chain_ag=cag_in.strip() or None if cag_in else None,
)
def handle_run_prediction(
mode, ppi_a, ppi_b, ab_h, ab_l, ab_ag, run_ig, pkd_lo, pkd_hi,
chain_info_a, chain_info_b, pdb_content,
):
"""Main predict callback. Calls the @spaces.GPU function."""
if mode == "Protein-Protein":
seq_a = clean_multi(ppi_a or "")
seq_b = clean_multi(ppi_b or "")
else:
seq_a = clean_multi(ab_ag or "")
h = clean_multi(ab_h or "")
l = clean_multi(ab_l or "")
seq_b = f"{h}|{l}" if (h or l) else ""
if not seq_a or not seq_b:
gr.Warning("Fill all sequence boxes before running.")
return (gr.update(),) * 12
# If no chain info exists yet (user pasted sequences directly), synthesise it
if not chain_info_a:
chain_info_a = [{"id": ch, "seq": s, "resnos": list(range(1, len(s) + 1))}
for ch, s in zip("ABCDE", seq_a.split("|"))]
if not chain_info_b:
if mode == "Antibody-Antigen":
h_s = clean_multi(ab_h or "")
l_s = clean_multi(ab_l or "")
chain_info_b = [
{"id": "H", "seq": h_s, "resnos": list(range(1, len(h_s) + 1))},
{"id": "L", "seq": l_s, "resnos": list(range(1, len(l_s) + 1))},
]
else:
chain_info_b = [{"id": ch, "seq": s, "resnos": list(range(1, len(s) + 1))}
for ch, s in zip("FGHIJ", seq_b.split("|"))]
try:
pkd, cos, ig_a, ig_b, device_used = gpu_run_prediction(
seq_a, seq_b, bool(run_ig), float(pkd_lo), float(pkd_hi))
except Exception as e:
gr.Error(f"Prediction failed: {e}")
traceback.print_exc()
return (gr.update(),) * 12
result = {"pkd": pkd, "cosine": cos, "sa": seq_a, "sb": seq_b}
# Display data
seq_a_d, ig_a_d, seq_b_d, ig_b_d = flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b)
if not seq_a_d: seq_a_d = seq_a.replace("|", "")
if not seq_b_d: seq_b_d = seq_b.replace("|", "")
ig_chain_map = None
if ig_a is not None and ig_b is not None:
ig_chain_map = build_ig_chain_map(chain_info_a, chain_info_b, ig_a, ig_b)
lbl_a = "Antigen (Target)" if mode == "Antibody-Antigen" else "Target (Protein A)"
lbl_b = "Antibody H+L" if mode == "Antibody-Antigen" else "Binder (proteina)"
if ig_a_d and seq_a_d:
strip_a = residue_strip_html(seq_a_d, ig_a_d)
bar_a = top10_bar(seq_a_d, ig_a_d, f"Top 10 Β· {lbl_a}")
hm_a = make_heatmap(seq_a_d, ig_a_d, f"{lbl_a} Β· IG Heatmap")
else:
strip_a, bar_a, hm_a = "", None, None
if ig_b_d and seq_b_d:
strip_b = residue_strip_html(seq_b_d, ig_b_d)
bar_b = top10_bar(seq_b_d, ig_b_d, f"Top 10 Β· {lbl_b}")
hm_b = make_heatmap(seq_b_d, ig_b_d, f"{lbl_b} Β· IG Heatmap")
else:
strip_b, bar_b, hm_b = "", None, None
ngl_html = wrap_iframe(ngl_viewer_html(pdb_content, ig_chain_map)) if pdb_content else \
wrap_iframe(ngl_viewer_html(None))
result_card = result_card_html(pkd, cos, float(pkd_lo), float(pkd_hi))
sidebar_card = sidebar_result_html(pkd, cos)
badge = status_badge_html(True, device_used)
return (
result, # state_result
ig_a, ig_b, # state_ig_a, state_ig_b
ig_chain_map, # state_ig_chain_map
result_card, # result_html
ngl_html, # ngl_html
strip_a, bar_a, hm_a, # vis A
strip_b, bar_b, hm_b, # vis B
sidebar_card, # sidebar_result_html
badge, # status badge
)
def handle_download_results(result, state_ig_a, state_ig_b,
chain_info_a, chain_info_b):
"""Build CSV + JSON for download."""
if not result:
return None, None
seq_a_d, ig_a_d, seq_b_d, ig_b_d = flat_ig_for_display(
state_ig_a, state_ig_b, chain_info_a, chain_info_b)
if not seq_a_d: seq_a_d = result["sa"].replace("|", "")
if not seq_b_d: seq_b_d = result["sb"].replace("|", "")
# CSV
csv_path = os.path.join(tempfile.gettempdir(), "ig_scores.csv")
with open(csv_path, "w", newline="") as f:
wc = csv.writer(f)
wc.writerow(["chain", "position", "residue", "ig_score"])
for i, (aa, sc) in enumerate(zip(seq_a_d or "", ig_a_d or [])):
wc.writerow(["Target", i + 1, aa, f"{sc:.6f}"])
for i, (aa, sc) in enumerate(zip(seq_b_d or "", ig_b_d or [])):
wc.writerow(["proteina", i + 1, aa, f"{sc:.6f}"])
# JSON
summary = {
"pkd": result["pkd"], "cosine": result["cosine"],
"Target_length": len(seq_a_d or ""), "proteina_length": len(seq_b_d or ""),
"top10_Target": sorted(
[{"res": f"{(seq_a_d or '')[i]}{i+1}", "ig": (ig_a_d or [])[i]}
for i in range(min(len(seq_a_d or ""), len(ig_a_d or [])))],
key=lambda x: -x["ig"])[:10],
"top10_proteina": sorted(
[{"res": f"{(seq_b_d or '')[i]}{i+1}", "ig": (ig_b_d or [])[i]}
for i in range(min(len(seq_b_d or ""), len(ig_b_d or [])))],
key=lambda x: -x["ig"])[:10],
}
json_path = os.path.join(tempfile.gettempdir(), "balm_result.json")
with open(json_path, "w") as f:
json.dump(summary, f, indent=2)
return csv_path, json_path
def handle_load_model(pkd_lo, pkd_hi):
"""Force-load the model on CPU (download weights). GPU is reserved for prediction time."""
try:
ensure_model(float(pkd_lo), float(pkd_hi))
gr.Info("Model weights downloaded and ready. GPU will be allocated when you click Run.")
return status_badge_html(True, "cpu (GPU on demand)")
except Exception as e:
traceback.print_exc()
gr.Error(f"Load failed: {e}")
return status_badge_html(False)
def handle_run_batch(file_obj, batch_mode, run_ig_b, pkd_lo, pkd_hi):
if file_obj is None:
gr.Warning("Upload a CSV file first.")
return None, None
try:
df = pd.read_csv(file_obj.name)
except Exception as e:
gr.Error(f"CSV read failed: {e}")
return None, None
rows = df.to_dict(orient="records")
try:
results = gpu_run_batch(rows, batch_mode, bool(run_ig_b), float(pkd_lo), float(pkd_hi))
except Exception as e:
traceback.print_exc()
gr.Error(f"Batch failed: {e}")
return None, None
df_res = pd.concat([df, pd.DataFrame(results)], axis=1)
out_path = os.path.join(tempfile.gettempdir(), "balm_batch_results.csv")
df_res.to_csv(out_path, index=False)
return df_res, out_path
# ═══════════════════════════════════════════════════════════════════════════
# GRADIO LAYOUT
# ═══════════════════════════════════════════════════════════════════════════
def build_ui():
with gr.Blocks(title="BALM-PPI Pro") as demo:
# ── Persistent state ───────────────────────────────────────────────
state_pdb_content = gr.State("")
state_chain_info_a = gr.State([])
state_chain_info_b = gr.State([])
state_ig_a = gr.State(None)
state_ig_b = gr.State(None)
state_ig_chain_map = gr.State(None)
state_result = gr.State(None)
# ── Header ─────────────────────────────────────────────────────────
gr.HTML("""
<div class="app-header">
<div class="app-logo">🧬</div>
<div>
<div class="app-title">BALM-PPI Pro</div>
<div class="app-subtitle">ESM-2 Β· LoRA Β· Integrated Gradients &nbsp;Β·&nbsp; Protein Binding Affinity Prediction</div>
</div>
</div>
""")
with gr.Tabs():
# ═══════════════════════════════════════════════════════════════
# TAB 1 β€” SINGLE PREDICTION
# ═══════════════════════════════════════════════════════════════
with gr.Tab("🎯 Single Prediction"):
with gr.Row():
# ─── Sidebar (left column) ─────────────────────────────
with gr.Column(scale=1, min_width=260):
gr.Markdown("### βš™οΈ Model")
with gr.Accordion("pKd Bounds", open=False):
pkd_lo = gr.Number(value=1.0, label="Min", precision=2)
pkd_hi = gr.Number(value=16.0, label="Max", precision=2)
load_btn = gr.Button("⚑ Reload Model", variant="secondary")
status_html = gr.HTML(
status_badge_html(_MODEL_STATE["model"] is not None,
"cpu (GPU on demand)"))
sidebar_result = gr.HTML("")
gr.Markdown(
'<div style="font-size:.72rem;color:var(--text2);'
'font-family:JetBrains Mono,monospace;line-height:1.9;margin-top:14px">'
'πŸ€— <a href="https://huggingface.co/Harshit494/BALM-PPI" '
'style="color:var(--accent);text-decoration:none">Harshit494/BALM-PPI</a><br>'
'πŸ’» <a href="https://github.com/rgorantla04/BALM-PPI" '
'style="color:var(--accent);text-decoration:none">rgorantla04/BALM-PPI</a>'
'</div>'
)
# ─── Main column ───────────────────────────────────────
with gr.Column(scale=4):
mode = gr.Radio(
["Protein-Protein", "Antibody-Antigen"],
value="Protein-Protein", label="Interaction mode",
container=False,
)
# PPI input panel
with gr.Group(visible=True) as ppi_panel:
with gr.Row():
ppi_a = gr.Textbox(
label="Target Protein β€” Seq A",
placeholder="Paste FASTA or raw sequence… | = chain separator",
lines=4, max_lines=8,
)
ppi_b = gr.Textbox(
label="Binder Protein β€” Seq B ( | = chain separator)",
placeholder="Paste FASTA or raw sequence…",
lines=4, max_lines=8,
)
# Ab-Ag input panel
with gr.Group(visible=False) as abag_panel:
with gr.Row():
ab_ag = gr.Textbox(
label="Antigen / Target β€” Seq A",
placeholder="Antigen sequence…",
lines=4, max_lines=8, scale=2,
)
with gr.Column(scale=3):
with gr.Row():
ab_h = gr.Textbox(
label="Heavy Chain (H)",
placeholder="VH sequence…",
lines=4, max_lines=8,
)
ab_l = gr.Textbox(
label="Light Chain (L)",
placeholder="VL sequence…",
lines=4, max_lines=8,
)
with gr.Row():
run_btn = gr.Button("πŸ”¬ Run Prediction", variant="primary", scale=3)
run_ig = gr.Checkbox(value=True, label="Run Integrated Gradients",
info="~45–90s CPU Β· <5s GPU", scale=2)
result_html = gr.HTML("")
# ─── Lower section: PDB controls + visualizations ─────
with gr.Row():
# ── LEFT: Custom PDB Fetch ───────────────────
with gr.Column(scale=1, min_width=240):
gr.HTML('<div class="sec-hdr">Custom PDB Fetch</div>')
pdb_in = gr.Textbox(label="PDB ID", placeholder="1YCR, 1BRS, 2VXQ…")
# PPI chain selectors
with gr.Group(visible=True) as ppi_chains_row:
ca_in = gr.Textbox(label="Side A chain(s)", placeholder="A or A,B")
cb_in = gr.Textbox(label="Side B chain(s)", placeholder="B or B,C")
# Ab-Ag chain selectors
with gr.Group(visible=False) as abag_chains_row:
ch_in = gr.Textbox(label="Heavy", placeholder="H")
cl_in = gr.Textbox(label="Light", placeholder="L")
cag_in = gr.Textbox(label="Antigen", placeholder="A or A,B")
fetch_btn = gr.Button("πŸ” Fetch PDB", variant="secondary", size="sm")
pdb_info_md = gr.Markdown("")
# ── MIDDLE: Quick Examples ───────────────────
with gr.Column(scale=1, min_width=240):
gr.HTML('<div class="sec-hdr">Quick Examples</div>')
example_buttons = []
for ex in EXAMPLES:
gr.HTML(
f'<div class="ex-card">'
f'<div class="ex-pdb">{ex["label"]}</div>'
f'<div class="ex-sub">{ex["subtitle"]}</div>'
f'<div class="ex-desc">{ex["desc"]}</div></div>'
)
b = gr.Button(f"⬇ Load {ex['pdb']}", size="sm")
example_buttons.append((b, ex))
# ── RIGHT: Visualizations ────────────────────
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("🧬 Structure"):
ngl_html = gr.HTML(wrap_iframe(ngl_viewer_html(None)))
with gr.Tab("πŸ“Š Side A Strip"):
strip_a_html = gr.HTML("")
bar_a_plot = gr.Plot()
with gr.Tab("πŸ—Ί Side A Heatmap"):
hm_a_plot = gr.Plot()
with gr.Tab("πŸ“Š Side B Strip"):
strip_b_html = gr.HTML("")
bar_b_plot = gr.Plot()
with gr.Tab("πŸ—Ί Side B Heatmap"):
hm_b_plot = gr.Plot()
with gr.Tab("πŸ“₯ Download"):
dl_btn = gr.Button("Build downloads", variant="secondary")
dl_csv = gr.File(label="IG Scores CSV")
dl_json = gr.File(label="Summary JSON")
# ═══════════════════════════════════════════════════════════════
# TAB 2 β€” BATCH PREDICTION
# ═══════════════════════════════════════════════════════════════
with gr.Tab("πŸ“‚ Batch Prediction"):
gr.Markdown("### πŸ“‚ Batch Prediction")
batch_mode = gr.Radio(
["Protein-Protein", "Antibody-Antigen"],
value="Protein-Protein", label="Mode", container=False,
)
with gr.Row():
gr.File(label="PPI Template",
value=lambda: _write_template(PPI_TEMPLATE, "ppi_template.csv"),
interactive=False)
gr.File(label="Ab-Ag Template",
value=lambda: _write_template(ABAG_TEMPLATE, "abag_template.csv"),
interactive=False)
gr.Markdown(
"Required columns: **seq_a, seq_b** (PPI mode) "
"or **heavy_chain, light_chain, antigen** (Ab-Ag mode)."
)
batch_file = gr.File(label="CSV file", file_types=[".csv"])
run_ig_b = gr.Checkbox(value=False,
label="Compute IG per row (slow on CPU; fine on GPU)")
batch_run = gr.Button("πŸƒ Run Batch", variant="primary")
batch_df = gr.Dataframe(label="Results", interactive=False)
batch_dl = gr.File(label="Download results CSV")
# ═══════════════════════════════════════════════════════════════
# TAB 3 β€” MODEL INFO
# ═══════════════════════════════════════════════════════════════
with gr.Tab("πŸ“š Model Info"):
gr.Markdown("""
### πŸ“š About BALM-PPI
**BALM-PPI** predicts protein–protein and antibody–antigen binding affinity using ESM-2 + LoRA
with Integrated Gradients explainability.
| Component | Detail |
|-----------|--------|
| Backbone | ESM-2 650M (`facebook/esm2_t33_650M_UR50D`) |
| Fine-tuning | LoRA r=8, Ξ±=16, dropout=0.1 Β· targets: key/query/value |
| Column mapping | `seq_a` β†’ **Target** Β· `seq_b` β†’ **proteina** |
| Multi-chain | `\\|` separator β†’ `<cls><cls>` double CLS token |
| Affinity output | pKd = ((cos+1)/2) Γ— (pKd_max βˆ’ pKd_min) + pKd_min |
| IG | Integrated Gradients, **15-step** Riemann (float32-safe) |
| Viewer | PDB filtered to **selected chains only** |
| Hardware | HF ZeroGPU (dynamic A100); CPU fallback supported |
**Examples** Β· **1YCR** β€” MDM2 ↔ p53 &nbsp;Β·&nbsp; **1BRS** β€” Barnase ↔ Barstar
πŸ€— [Harshit494/BALM-PPI](https://huggingface.co/Harshit494/BALM-PPI) &nbsp;Β·&nbsp;
πŸ’» [rgorantla04/BALM-PPI](https://github.com/rgorantla04/BALM-PPI)
""")
# ───────────────────────────────────────────────────────────────────
# EVENT WIRING
# ───────────────────────────────────────────────────────────────────
mode.change(
handle_mode_change,
inputs=[mode],
outputs=[ppi_panel, abag_panel, ppi_chains_row, abag_chains_row],
)
load_btn.click(
handle_load_model,
inputs=[pkd_lo, pkd_hi],
outputs=[status_html],
)
# Outputs that load_example/fetch_pdb returns
load_outputs = [
ppi_a, ppi_b, ab_h, ab_l, ab_ag,
state_chain_info_a, state_chain_info_b,
state_pdb_content, state_ig_chain_map,
state_ig_a, state_ig_b, state_result,
ngl_html, result_html,
strip_a_html, bar_a_plot, hm_a_plot,
strip_b_html, bar_b_plot, hm_b_plot,
pdb_info_md,
]
for btn, ex in example_buttons:
btn.click(
lambda mode_val=ex["mode"], pdb=ex["pdb"],
ca=ex.get("chain_a"), cb=ex.get("chain_b"),
ch=ex.get("chain_h"), cl=ex.get("chain_l"),
cag=ex.get("chain_ag"):
handle_load_example(pdb, mode_val, ca, cb, ch, cl, cag),
outputs=load_outputs,
)
fetch_btn.click(
handle_fetch_pdb,
inputs=[pdb_in, mode, ca_in, cb_in, ch_in, cl_in, cag_in],
outputs=load_outputs,
)
run_btn.click(
handle_run_prediction,
inputs=[
mode, ppi_a, ppi_b, ab_h, ab_l, ab_ag,
run_ig, pkd_lo, pkd_hi,
state_chain_info_a, state_chain_info_b, state_pdb_content,
],
outputs=[
state_result,
state_ig_a, state_ig_b, state_ig_chain_map,
result_html, ngl_html,
strip_a_html, bar_a_plot, hm_a_plot,
strip_b_html, bar_b_plot, hm_b_plot,
sidebar_result, status_html,
],
)
dl_btn.click(
handle_download_results,
inputs=[state_result, state_ig_a, state_ig_b,
state_chain_info_a, state_chain_info_b],
outputs=[dl_csv, dl_json],
)
batch_run.click(
handle_run_batch,
inputs=[batch_file, batch_mode, run_ig_b, pkd_lo, pkd_hi],
outputs=[batch_df, batch_dl],
)
return demo
def _write_template(content: str, name: str) -> str:
path = os.path.join(tempfile.gettempdir(), name)
with open(path, "w") as f:
f.write(content)
return path
# ═══════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
# Eager model load at startup β€” happens during Space's "Starting" phase,
# so the user opens a ready UI instead of waiting ~5 GB of downloads on
# their first click. Progress shows in the Space's runtime logs.
print("[BALM-PPI] Pre-loading model on CPU…", flush=True)
try:
ensure_model(1.0, 16.0)
print("[BALM-PPI] Model ready.", flush=True)
except Exception as _e:
print(f"[BALM-PPI] WARNING β€” pre-load failed: {_e}", flush=True)
traceback.print_exc()
demo = build_ui()
demo.queue(max_size=20).launch(
css=CSS,
theme=gr.themes.Soft(),
ssr_mode=False, # cleaner logs; SSR is experimental in Gradio 6
)