"""
BALM-PPI Pro · ESM-2 + LoRA + Integrated Gradients
=====================================================
Gradio port with HuggingFace ZeroGPU support.
Key porting notes:
- All GPU work (forward pass + IG) is wrapped in @spaces.GPU so it runs on
ZeroGPU's dynamic A100s. Outside the decorator, the model lives on CPU.
- IG kept at float32 with 15 Riemann steps (same fix as the Streamlit version).
- NGL viewer + Plotly visualisations + theming preserved.
- Per-session state (PDB, chain info, IG arrays) is held in gr.State so the
app is safe under concurrent users on Spaces.
"""
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import io, re, json, csv, traceback, tempfile, copy
import io as sio
import gradio as gr
import torch
import pandas as pd
import numpy as np
import requests
import plotly.graph_objects as go
from Bio import PDB
from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE
# ─── ZeroGPU compatibility shim ────────────────────────────────────────────
# On HF Spaces with ZeroGPU, `spaces` is installed and @spaces.GPU works.
# Locally (or off-Spaces) we fall back to a no-op decorator so the same file
# runs both places without modification.
try:
import spaces
_HAS_SPACES = True
except ImportError:
_HAS_SPACES = False
class _SpacesShim:
@staticmethod
def GPU(duration=60):
def decorator(fn):
return fn
return decorator
spaces = _SpacesShim() # type: ignore
HF_REPO_ID = "Harshit494/BALM-PPI"
HF_FILENAME = "best_model_fold_1.pth"
UPPER_3TO1 = {k.upper(): v for k, v in THREE_TO_ONE.items()}
EXAMPLES = [
{
"pdb": "1YCR", "label": "1YCR", "subtitle": "MDM2 ↔ p53",
"desc": "MDM2 oncoprotein bound to the p53 transactivation peptide. Chain A = MDM2, Chain B = p53. Classic drug-target complex.",
"mode": "Protein-Protein", "chain_a": "A", "chain_b": "B",
},
{
"pdb": "1BRS", "label": "1BRS", "subtitle": "Barnase ↔ Barstar",
"desc": "Ultra-tight complex (Kd ~10⁻¹⁴ M). Chain A = Barnase, Chain D = Barstar. Classic benchmark.",
"mode": "Protein-Protein", "chain_a": "A", "chain_b": "D",
},
]
PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n"
ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n"
# ═══════════════════════════════════════════════════════════════════════════
# GLOBAL MODEL CACHE
# Loaded lazily on CPU; moved to CUDA only inside @spaces.GPU functions.
# ═══════════════════════════════════════════════════════════════════════════
_MODEL_STATE = {
"model": None,
"pkd_bounds": None, # tuple (lo, hi) — rebuilt if bounds change
}
# ═══════════════════════════════════════════════════════════════════════════
# PDB UTILITIES
# ═══════════════════════════════════════════════════════════════════════════
_PDB_CACHE: dict = {}
def _fetch_pdb_cached(pdb_id: str) -> str:
pdb_id = pdb_id.strip().upper()
if pdb_id in _PDB_CACHE:
return _PDB_CACHE[pdb_id]
url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
r = requests.get(url, timeout=30)
r.raise_for_status()
_PDB_CACHE[pdb_id] = r.text
# Keep the cache from growing unbounded
if len(_PDB_CACHE) > 64:
_PDB_CACHE.pop(next(iter(_PDB_CACHE)))
return r.text
def filter_pdb_to_chains(pdb_text: str, keep_chains: set) -> str:
"""Strip PDB to only ATOM/HETATM records for chain IDs in keep_chains."""
out = []
for line in pdb_text.splitlines():
rec = line[:6].strip()
if rec in ("ATOM", "HETATM"):
chain_col = line[21:22]
if chain_col not in keep_chains:
continue
out.append(line)
return "\n".join(out)
def get_chains_from_pdb(pdb_id: str):
try:
pdb_text = _fetch_pdb_cached(pdb_id)
except Exception as e:
return None, {}, f"Could not fetch PDB {pdb_id}: {e}"
import warnings as _w
with _w.catch_warnings():
_w.simplefilter("ignore")
parser = PDB.PDBParser(QUIET=True)
struct = parser.get_structure(pdb_id, io.StringIO(pdb_text))
chains: dict = {}
for mdl in struct:
for chain in mdl:
cid = chain.id
if cid not in chains:
chains[cid] = {"seq": "", "resnos": []}
for residue in chain:
if PDB.is_aa(residue, standard=True):
aa = UPPER_3TO1.get(residue.get_resname().upper(), "X")
chains[cid]["seq"] += aa
chains[cid]["resnos"].append(residue.get_id()[1])
break
chains = {k: v for k, v in chains.items() if v["seq"]}
return pdb_text, chains, None
def resolve_chains(chains_dict, ids_str, available):
if not ids_str:
return "", [], []
ids = [c.strip() for c in str(ids_str).split(",") if c.strip()]
seqs, infos, missing = [], [], []
for cid in ids:
if cid in chains_dict and chains_dict[cid]["seq"]:
seqs.append(chains_dict[cid]["seq"])
infos.append({"id": cid, **chains_dict[cid]})
else:
found = next((k for k in chains_dict
if k.upper() == cid.upper() and chains_dict[k]["seq"]), None)
if found:
seqs.append(chains_dict[found]["seq"])
infos.append({"id": found, **chains_dict[found]})
else:
missing.append(cid)
return "|".join(seqs), infos, missing
def clean_multi(s: str) -> str:
if not s:
return ""
out = []
for p in s.split("|"):
p = re.sub(r">.*", "", p, flags=re.MULTILINE)
p = re.sub(r"[^A-Za-z]", "", p).upper()
if p:
out.append(p)
return "|".join(out)
# ═══════════════════════════════════════════════════════════════════════════
# IG HELPERS
# ═══════════════════════════════════════════════════════════════════════════
def split_ig_by_chains(ig_flat, chain_seqs, n_sep=2):
if not ig_flat:
return [[] for _ in chain_seqs]
result, offset = [], 0
for i, seq in enumerate(chain_seqs):
result.append(list(ig_flat[offset:offset + len(seq)]))
offset += len(seq)
if i < len(chain_seqs) - 1:
offset += n_sep
return result
def build_ig_chain_map(chain_info_a, chain_info_b, ig_a, ig_b):
result = {}
for info, ig_flat in [(chain_info_a, ig_a), (chain_info_b, ig_b)]:
if not info:
continue
per_chain = split_ig_by_chains(ig_flat, [c["seq"] for c in info])
for i, ch in enumerate(info):
if ch.get("resnos"):
result[ch["id"]] = {"start": ch["resnos"][0], "scores": per_chain[i]}
return result
def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b):
seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else ""
seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else ""
if ig_a and chain_info_a:
parts = split_ig_by_chains(ig_a, [c["seq"] for c in chain_info_a])
ig_a_out = [s for sub in parts for s in sub]
else:
ig_a_out = list(ig_a) if ig_a else []
if ig_b and chain_info_b:
parts = split_ig_by_chains(ig_b, [c["seq"] for c in chain_info_b])
ig_b_out = [s for sub in parts for s in sub]
else:
ig_b_out = list(ig_b) if ig_b else []
return seq_a, ig_a_out, seq_b, ig_b_out
# ═══════════════════════════════════════════════════════════════════════════
# MODEL BUILD (CPU-safe — no CUDA touched here)
# ═══════════════════════════════════════════════════════════════════════════
def _load_esm_base():
from transformers import EsmModel, EsmTokenizer
base = EsmModel.from_pretrained(
"facebook/esm2_t33_650M_UR50D",
use_safetensors=True,
)
tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
return base, tok
def _download_weights_hf() -> bytes:
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True)
with open(path, "rb") as f:
return f.read()
def build_model(weights_bytes, pkd_bounds):
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
class BALMProjectionHead(nn.Module):
def __init__(self, embedding_size, projected_size=256):
super().__init__()
self.protein_projection = nn.Linear(embedding_size, projected_size)
self.proteina_projection = nn.Linear(embedding_size, projected_size)
self.dropout = nn.Dropout(0.1)
def forward(self, prot_emb, prot_a_emb):
p = F.normalize(self.protein_projection(self.dropout(prot_emb)), p=2, dim=1)
pa = F.normalize(self.proteina_projection(self.dropout(prot_a_emb)), p=2, dim=1)
return torch.clamp(F.cosine_similarity(p, pa), -0.9999, 0.9999)
class BALMForLoRAFinetuning(nn.Module):
def __init__(self, esm_model, esm_tokenizer, pkd_bounds):
super().__init__()
self.esm_model = esm_model
self.esm_tokenizer = esm_tokenizer
self.projection_head = BALMProjectionHead(self.esm_model.config.hidden_size)
self.pkd_lower, self.pkd_upper = pkd_bounds
self.cls_token = self.esm_tokenizer.cls_token
def _get_esm_embeddings(self, sequences):
processed = [s.replace("|", f"{self.cls_token}{self.cls_token}") for s in sequences]
inputs = self.esm_tokenizer(processed, return_tensors="pt", padding=True,
truncation=True, max_length=1024)
inputs = {k: v.to(self.esm_model.device) for k, v in inputs.items()}
h = self.esm_model(**inputs).last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).expand(h.size()).float()
return (torch.sum(h * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)).float()
def forward(self, seq_a, seq_b):
ea = self._get_esm_embeddings([seq_a])
eb = self._get_esm_embeddings([seq_b])
cos = self.projection_head(ea, eb)
pkd = ((cos + 1) / 2) * (self.pkd_upper - self.pkd_lower) + self.pkd_lower
return pkd, cos
base_esm, tok = _load_esm_base()
lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1,
target_modules=["key", "query", "value"])
model = BALMForLoRAFinetuning(
get_peft_model(base_esm, lora_cfg), tok, pkd_bounds)
with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp:
tmp.write(weights_bytes)
tmp_path = tmp.name
model.load_state_dict(torch.load(tmp_path, map_location="cpu"), strict=False)
os.unlink(tmp_path)
model.eval()
return model
def ensure_model(pkd_lo: float, pkd_hi: float):
"""Lazy load (on CPU). Rebuilds head bounds if user changes them."""
bounds = (float(pkd_lo), float(pkd_hi))
if _MODEL_STATE["model"] is None:
wb = _download_weights_hf()
_MODEL_STATE["model"] = build_model(wb, bounds)
_MODEL_STATE["pkd_bounds"] = bounds
elif _MODEL_STATE["pkd_bounds"] != bounds:
# Same weights, just update the affine output bounds
_MODEL_STATE["model"].pkd_lower = bounds[0]
_MODEL_STATE["model"].pkd_upper = bounds[1]
_MODEL_STATE["pkd_bounds"] = bounds
return _MODEL_STATE["model"]
# ═══════════════════════════════════════════════════════════════════════════
# INTEGRATED GRADIENTS (float32-safe, 15 Riemann steps)
# ═══════════════════════════════════════════════════════════════════════════
def compute_ig(model, seq_a: str, seq_b: str, steps: int = 15):
device = next(model.parameters()).device
tok = model.esm_tokenizer
cls_tok = tok.cls_token
esm = model.esm_model
def tokenise(seq):
proc = seq.replace("|", f"{cls_tok}{cls_tok}")
return tok(proc, return_tensors="pt",
padding=False, truncation=True, max_length=1024).to(device)
word_embed = esm.base_model.model.embeddings.word_embeddings
enc_a = tokenise(seq_a); mask_a = enc_a["attention_mask"]
enc_b = tokenise(seq_b); mask_b = enc_b["attention_mask"]
emb_a = word_embed(enc_a["input_ids"]).detach().float().clone()
emb_b = word_embed(enc_b["input_ids"]).detach().float().clone()
mask_a = mask_a.float()
mask_b = mask_b.float()
def encode(embs, mask):
int_mask = mask.long()
ext = esm.base_model.model.get_extended_attention_mask(int_mask, embs.shape[:2])
h = esm.base_model.model.encoder(
embs.float(), attention_mask=ext
).last_hidden_state.float()
m = mask.unsqueeze(-1).expand(h.size()).float()
return (torch.sum(h * m, 1) / torch.clamp(m.sum(1), min=1e-9))
def fwd(e_a, e_b):
return model.projection_head(encode(e_a, mask_a), encode(e_b, mask_b))
def riemann(tgt, baseline, fixed, tgt_is_b):
grads = []
for alpha in torch.linspace(0, 1, steps, device=device):
interp = (baseline + alpha * (tgt - baseline)).detach().requires_grad_(True)
out = fwd(fixed, interp) if tgt_is_b else fwd(interp, fixed)
out.sum().backward()
grads.append(interp.grad.detach().float().clone())
avg_grad = torch.stack(grads).mean(0)
ig = (avg_grad * (tgt - baseline)).abs().sum(-1).squeeze(0)
return ig.cpu().numpy()
baseline_a = torch.zeros_like(emb_a)
baseline_b = torch.zeros_like(emb_b)
attr_a = riemann(emb_a, baseline_a, emb_b, False)
attr_b = riemann(emb_b, baseline_b, emb_a, True)
def norm(a):
lo, hi = a.min(), a.max()
if hi - lo < 1e-9:
return [0.0] * len(a)
return ((a - lo) / (hi - lo)).tolist()
return norm(attr_a[1:-1]), norm(attr_b[1:-1])
# ═══════════════════════════════════════════════════════════════════════════
# GPU-WRAPPED PREDICTION
# This is the only function that touches CUDA. ZeroGPU will allocate an A100
# for the duration of the call.
# ═══════════════════════════════════════════════════════════════════════════
@spaces.GPU(duration=120)
def gpu_run_prediction(seq_a, seq_b, run_ig, pkd_lo, pkd_hi):
model = ensure_model(pkd_lo, pkd_hi)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
with torch.no_grad():
pkd_t, cos_t = model(seq_a, seq_b)
pkd = float(pkd_t.item())
cos = float(cos_t.item())
ig_a, ig_b = None, None
if run_ig:
ig_a, ig_b = compute_ig(model, seq_a, seq_b, steps=15)
# Move back to CPU so the model stays small in resident memory between calls
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pkd, cos, ig_a, ig_b, str(device)
@spaces.GPU(duration=300)
def gpu_run_batch(rows, batch_mode, run_ig, pkd_lo, pkd_hi):
"""Batch over rows. `rows` is a list of dicts with seq_a/seq_b or heavy/light/antigen."""
model = ensure_model(pkd_lo, pkd_hi)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
results = []
for row in rows:
if batch_mode == "Protein-Protein":
sa = clean_multi(str(row.get("seq_a", "")))
sb = clean_multi(str(row.get("seq_b", "")))
else:
h = clean_multi(str(row.get("heavy_chain", "")))
l = clean_multi(str(row.get("light_chain", "")))
ag = clean_multi(str(row.get("antigen", "")))
sa = ag
sb = f"{h}|{l}" if (h or l) else ""
if not sa or not sb:
results.append({"pKd": None, "Cosine": None})
continue
try:
with torch.no_grad():
pkd_t, cos_t = model(sa, sb)
rr = {"pKd": round(float(pkd_t.item()), 4),
"Cosine": round(float(cos_t.item()), 6)}
if run_ig:
ia, ib = compute_ig(model, sa, sb, steps=15)
sap, sbp = sa.replace("|", ""), sb.replace("|", "")
rr["top5_Target"] = "|".join(
f"{sap[j]}{j+1}:{ia[j]:.3f}"
for j in sorted(range(min(len(sap), len(ia))), key=lambda x: -ia[x])[:5])
rr["top5_proteina"] = "|".join(
f"{sbp[j]}{j+1}:{ib[j]:.3f}"
for j in sorted(range(min(len(sbp), len(ib))), key=lambda x: -ib[x])[:5])
except Exception as e:
rr = {"pKd": None, "Cosine": None, "error": str(e)}
results.append(rr)
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
# ═══════════════════════════════════════════════════════════════════════════
# NGL VIEWER (theme-reactive, filtered to selected chains)
# ═══════════════════════════════════════════════════════════════════════════
def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
if not pdb_content:
return (
'