""" BALM-PPI Pro · ESM-2 + LoRA + Integrated Gradients ===================================================== Gradio port with HuggingFace ZeroGPU support. Key porting notes: - All GPU work (forward pass + IG) is wrapped in @spaces.GPU so it runs on ZeroGPU's dynamic A100s. Outside the decorator, the model lives on CPU. - IG kept at float32 with 15 Riemann steps (same fix as the Streamlit version). - NGL viewer + Plotly visualisations + theming preserved. - Per-session state (PDB, chain info, IG arrays) is held in gr.State so the app is safe under concurrent users on Spaces. """ import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" import io, re, json, csv, traceback, tempfile, copy import io as sio import gradio as gr import torch import pandas as pd import numpy as np import requests import plotly.graph_objects as go from Bio import PDB from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE # ─── ZeroGPU compatibility shim ──────────────────────────────────────────── # On HF Spaces with ZeroGPU, `spaces` is installed and @spaces.GPU works. # Locally (or off-Spaces) we fall back to a no-op decorator so the same file # runs both places without modification. try: import spaces _HAS_SPACES = True except ImportError: _HAS_SPACES = False class _SpacesShim: @staticmethod def GPU(duration=60): def decorator(fn): return fn return decorator spaces = _SpacesShim() # type: ignore HF_REPO_ID = "Harshit494/BALM-PPI" HF_FILENAME = "best_model_fold_1.pth" UPPER_3TO1 = {k.upper(): v for k, v in THREE_TO_ONE.items()} EXAMPLES = [ { "pdb": "1YCR", "label": "1YCR", "subtitle": "MDM2 ↔ p53", "desc": "MDM2 oncoprotein bound to the p53 transactivation peptide. Chain A = MDM2, Chain B = p53. Classic drug-target complex.", "mode": "Protein-Protein", "chain_a": "A", "chain_b": "B", }, { "pdb": "1BRS", "label": "1BRS", "subtitle": "Barnase ↔ Barstar", "desc": "Ultra-tight complex (Kd ~10⁻¹⁴ M). Chain A = Barnase, Chain D = Barstar. Classic benchmark.", "mode": "Protein-Protein", "chain_a": "A", "chain_b": "D", }, ] PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n" ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n" # ═══════════════════════════════════════════════════════════════════════════ # GLOBAL MODEL CACHE # Loaded lazily on CPU; moved to CUDA only inside @spaces.GPU functions. # ═══════════════════════════════════════════════════════════════════════════ _MODEL_STATE = { "model": None, "pkd_bounds": None, # tuple (lo, hi) — rebuilt if bounds change } # ═══════════════════════════════════════════════════════════════════════════ # PDB UTILITIES # ═══════════════════════════════════════════════════════════════════════════ _PDB_CACHE: dict = {} def _fetch_pdb_cached(pdb_id: str) -> str: pdb_id = pdb_id.strip().upper() if pdb_id in _PDB_CACHE: return _PDB_CACHE[pdb_id] url = f"https://files.rcsb.org/download/{pdb_id}.pdb" r = requests.get(url, timeout=30) r.raise_for_status() _PDB_CACHE[pdb_id] = r.text # Keep the cache from growing unbounded if len(_PDB_CACHE) > 64: _PDB_CACHE.pop(next(iter(_PDB_CACHE))) return r.text def filter_pdb_to_chains(pdb_text: str, keep_chains: set) -> str: """Strip PDB to only ATOM/HETATM records for chain IDs in keep_chains.""" out = [] for line in pdb_text.splitlines(): rec = line[:6].strip() if rec in ("ATOM", "HETATM"): chain_col = line[21:22] if chain_col not in keep_chains: continue out.append(line) return "\n".join(out) def get_chains_from_pdb(pdb_id: str): try: pdb_text = _fetch_pdb_cached(pdb_id) except Exception as e: return None, {}, f"Could not fetch PDB {pdb_id}: {e}" import warnings as _w with _w.catch_warnings(): _w.simplefilter("ignore") parser = PDB.PDBParser(QUIET=True) struct = parser.get_structure(pdb_id, io.StringIO(pdb_text)) chains: dict = {} for mdl in struct: for chain in mdl: cid = chain.id if cid not in chains: chains[cid] = {"seq": "", "resnos": []} for residue in chain: if PDB.is_aa(residue, standard=True): aa = UPPER_3TO1.get(residue.get_resname().upper(), "X") chains[cid]["seq"] += aa chains[cid]["resnos"].append(residue.get_id()[1]) break chains = {k: v for k, v in chains.items() if v["seq"]} return pdb_text, chains, None def resolve_chains(chains_dict, ids_str, available): if not ids_str: return "", [], [] ids = [c.strip() for c in str(ids_str).split(",") if c.strip()] seqs, infos, missing = [], [], [] for cid in ids: if cid in chains_dict and chains_dict[cid]["seq"]: seqs.append(chains_dict[cid]["seq"]) infos.append({"id": cid, **chains_dict[cid]}) else: found = next((k for k in chains_dict if k.upper() == cid.upper() and chains_dict[k]["seq"]), None) if found: seqs.append(chains_dict[found]["seq"]) infos.append({"id": found, **chains_dict[found]}) else: missing.append(cid) return "|".join(seqs), infos, missing def clean_multi(s: str) -> str: if not s: return "" out = [] for p in s.split("|"): p = re.sub(r">.*", "", p, flags=re.MULTILINE) p = re.sub(r"[^A-Za-z]", "", p).upper() if p: out.append(p) return "|".join(out) # ═══════════════════════════════════════════════════════════════════════════ # IG HELPERS # ═══════════════════════════════════════════════════════════════════════════ def split_ig_by_chains(ig_flat, chain_seqs, n_sep=2): if not ig_flat: return [[] for _ in chain_seqs] result, offset = [], 0 for i, seq in enumerate(chain_seqs): result.append(list(ig_flat[offset:offset + len(seq)])) offset += len(seq) if i < len(chain_seqs) - 1: offset += n_sep return result def build_ig_chain_map(chain_info_a, chain_info_b, ig_a, ig_b): result = {} for info, ig_flat in [(chain_info_a, ig_a), (chain_info_b, ig_b)]: if not info: continue per_chain = split_ig_by_chains(ig_flat, [c["seq"] for c in info]) for i, ch in enumerate(info): if ch.get("resnos"): result[ch["id"]] = {"start": ch["resnos"][0], "scores": per_chain[i]} return result def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b): seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else "" seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else "" if ig_a and chain_info_a: parts = split_ig_by_chains(ig_a, [c["seq"] for c in chain_info_a]) ig_a_out = [s for sub in parts for s in sub] else: ig_a_out = list(ig_a) if ig_a else [] if ig_b and chain_info_b: parts = split_ig_by_chains(ig_b, [c["seq"] for c in chain_info_b]) ig_b_out = [s for sub in parts for s in sub] else: ig_b_out = list(ig_b) if ig_b else [] return seq_a, ig_a_out, seq_b, ig_b_out # ═══════════════════════════════════════════════════════════════════════════ # MODEL BUILD (CPU-safe — no CUDA touched here) # ═══════════════════════════════════════════════════════════════════════════ def _load_esm_base(): from transformers import EsmModel, EsmTokenizer base = EsmModel.from_pretrained( "facebook/esm2_t33_650M_UR50D", use_safetensors=True, ) tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") return base, tok def _download_weights_hf() -> bytes: from huggingface_hub import hf_hub_download path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True) with open(path, "rb") as f: return f.read() def build_model(weights_bytes, pkd_bounds): import torch.nn as nn import torch.nn.functional as F from peft import LoraConfig, get_peft_model class BALMProjectionHead(nn.Module): def __init__(self, embedding_size, projected_size=256): super().__init__() self.protein_projection = nn.Linear(embedding_size, projected_size) self.proteina_projection = nn.Linear(embedding_size, projected_size) self.dropout = nn.Dropout(0.1) def forward(self, prot_emb, prot_a_emb): p = F.normalize(self.protein_projection(self.dropout(prot_emb)), p=2, dim=1) pa = F.normalize(self.proteina_projection(self.dropout(prot_a_emb)), p=2, dim=1) return torch.clamp(F.cosine_similarity(p, pa), -0.9999, 0.9999) class BALMForLoRAFinetuning(nn.Module): def __init__(self, esm_model, esm_tokenizer, pkd_bounds): super().__init__() self.esm_model = esm_model self.esm_tokenizer = esm_tokenizer self.projection_head = BALMProjectionHead(self.esm_model.config.hidden_size) self.pkd_lower, self.pkd_upper = pkd_bounds self.cls_token = self.esm_tokenizer.cls_token def _get_esm_embeddings(self, sequences): processed = [s.replace("|", f"{self.cls_token}{self.cls_token}") for s in sequences] inputs = self.esm_tokenizer(processed, return_tensors="pt", padding=True, truncation=True, max_length=1024) inputs = {k: v.to(self.esm_model.device) for k, v in inputs.items()} h = self.esm_model(**inputs).last_hidden_state mask = inputs["attention_mask"].unsqueeze(-1).expand(h.size()).float() return (torch.sum(h * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)).float() def forward(self, seq_a, seq_b): ea = self._get_esm_embeddings([seq_a]) eb = self._get_esm_embeddings([seq_b]) cos = self.projection_head(ea, eb) pkd = ((cos + 1) / 2) * (self.pkd_upper - self.pkd_lower) + self.pkd_lower return pkd, cos base_esm, tok = _load_esm_base() lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, target_modules=["key", "query", "value"]) model = BALMForLoRAFinetuning( get_peft_model(base_esm, lora_cfg), tok, pkd_bounds) with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp: tmp.write(weights_bytes) tmp_path = tmp.name model.load_state_dict(torch.load(tmp_path, map_location="cpu"), strict=False) os.unlink(tmp_path) model.eval() return model def ensure_model(pkd_lo: float, pkd_hi: float): """Lazy load (on CPU). Rebuilds head bounds if user changes them.""" bounds = (float(pkd_lo), float(pkd_hi)) if _MODEL_STATE["model"] is None: wb = _download_weights_hf() _MODEL_STATE["model"] = build_model(wb, bounds) _MODEL_STATE["pkd_bounds"] = bounds elif _MODEL_STATE["pkd_bounds"] != bounds: # Same weights, just update the affine output bounds _MODEL_STATE["model"].pkd_lower = bounds[0] _MODEL_STATE["model"].pkd_upper = bounds[1] _MODEL_STATE["pkd_bounds"] = bounds return _MODEL_STATE["model"] # ═══════════════════════════════════════════════════════════════════════════ # INTEGRATED GRADIENTS (float32-safe, 15 Riemann steps) # ═══════════════════════════════════════════════════════════════════════════ def compute_ig(model, seq_a: str, seq_b: str, steps: int = 15): device = next(model.parameters()).device tok = model.esm_tokenizer cls_tok = tok.cls_token esm = model.esm_model def tokenise(seq): proc = seq.replace("|", f"{cls_tok}{cls_tok}") return tok(proc, return_tensors="pt", padding=False, truncation=True, max_length=1024).to(device) word_embed = esm.base_model.model.embeddings.word_embeddings enc_a = tokenise(seq_a); mask_a = enc_a["attention_mask"] enc_b = tokenise(seq_b); mask_b = enc_b["attention_mask"] emb_a = word_embed(enc_a["input_ids"]).detach().float().clone() emb_b = word_embed(enc_b["input_ids"]).detach().float().clone() mask_a = mask_a.float() mask_b = mask_b.float() def encode(embs, mask): int_mask = mask.long() ext = esm.base_model.model.get_extended_attention_mask(int_mask, embs.shape[:2]) h = esm.base_model.model.encoder( embs.float(), attention_mask=ext ).last_hidden_state.float() m = mask.unsqueeze(-1).expand(h.size()).float() return (torch.sum(h * m, 1) / torch.clamp(m.sum(1), min=1e-9)) def fwd(e_a, e_b): return model.projection_head(encode(e_a, mask_a), encode(e_b, mask_b)) def riemann(tgt, baseline, fixed, tgt_is_b): grads = [] for alpha in torch.linspace(0, 1, steps, device=device): interp = (baseline + alpha * (tgt - baseline)).detach().requires_grad_(True) out = fwd(fixed, interp) if tgt_is_b else fwd(interp, fixed) out.sum().backward() grads.append(interp.grad.detach().float().clone()) avg_grad = torch.stack(grads).mean(0) ig = (avg_grad * (tgt - baseline)).abs().sum(-1).squeeze(0) return ig.cpu().numpy() baseline_a = torch.zeros_like(emb_a) baseline_b = torch.zeros_like(emb_b) attr_a = riemann(emb_a, baseline_a, emb_b, False) attr_b = riemann(emb_b, baseline_b, emb_a, True) def norm(a): lo, hi = a.min(), a.max() if hi - lo < 1e-9: return [0.0] * len(a) return ((a - lo) / (hi - lo)).tolist() return norm(attr_a[1:-1]), norm(attr_b[1:-1]) # ═══════════════════════════════════════════════════════════════════════════ # GPU-WRAPPED PREDICTION # This is the only function that touches CUDA. ZeroGPU will allocate an A100 # for the duration of the call. # ═══════════════════════════════════════════════════════════════════════════ @spaces.GPU(duration=120) def gpu_run_prediction(seq_a, seq_b, run_ig, pkd_lo, pkd_hi): model = ensure_model(pkd_lo, pkd_hi) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) with torch.no_grad(): pkd_t, cos_t = model(seq_a, seq_b) pkd = float(pkd_t.item()) cos = float(cos_t.item()) ig_a, ig_b = None, None if run_ig: ig_a, ig_b = compute_ig(model, seq_a, seq_b, steps=15) # Move back to CPU so the model stays small in resident memory between calls model.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() return pkd, cos, ig_a, ig_b, str(device) @spaces.GPU(duration=300) def gpu_run_batch(rows, batch_mode, run_ig, pkd_lo, pkd_hi): """Batch over rows. `rows` is a list of dicts with seq_a/seq_b or heavy/light/antigen.""" model = ensure_model(pkd_lo, pkd_hi) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) results = [] for row in rows: if batch_mode == "Protein-Protein": sa = clean_multi(str(row.get("seq_a", ""))) sb = clean_multi(str(row.get("seq_b", ""))) else: h = clean_multi(str(row.get("heavy_chain", ""))) l = clean_multi(str(row.get("light_chain", ""))) ag = clean_multi(str(row.get("antigen", ""))) sa = ag sb = f"{h}|{l}" if (h or l) else "" if not sa or not sb: results.append({"pKd": None, "Cosine": None}) continue try: with torch.no_grad(): pkd_t, cos_t = model(sa, sb) rr = {"pKd": round(float(pkd_t.item()), 4), "Cosine": round(float(cos_t.item()), 6)} if run_ig: ia, ib = compute_ig(model, sa, sb, steps=15) sap, sbp = sa.replace("|", ""), sb.replace("|", "") rr["top5_Target"] = "|".join( f"{sap[j]}{j+1}:{ia[j]:.3f}" for j in sorted(range(min(len(sap), len(ia))), key=lambda x: -ia[x])[:5]) rr["top5_proteina"] = "|".join( f"{sbp[j]}{j+1}:{ib[j]:.3f}" for j in sorted(range(min(len(sbp), len(ib))), key=lambda x: -ib[x])[:5]) except Exception as e: rr = {"pKd": None, "Cosine": None, "error": str(e)} results.append(rr) model.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() return results # ═══════════════════════════════════════════════════════════════════════════ # NGL VIEWER (theme-reactive, filtered to selected chains) # ═══════════════════════════════════════════════════════════════════════════ def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str: if not pdb_content: return ( '
' '
🧬
' '
Load an example or fetch a PDB
' '
Only selected chains will be shown
' '
' ) ig_json = json.dumps(ig_chain_map) if ig_chain_map else "null" escaped = pdb_content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$") return f"""
🧬 Drag · Scroll · Right-drag pan
""" def wrap_iframe(html_src: str, height: int = 425) -> str: """Sandboxed iframe wrapper so NGL renders cleanly inside Gradio HTML.""" srcdoc = html_src.replace('"', '"') return ( f'' ) # ═══════════════════════════════════════════════════════════════════════════ # PLOTLY # ═══════════════════════════════════════════════════════════════════════════ _PL = dict(paper_bgcolor="rgba(255,255,255,.97)", plot_bgcolor="rgba(241,244,249,1)") _FONT = dict(family="JetBrains Mono, monospace", color="#475569", size=9) _GRID = dict(gridcolor="rgba(37,99,235,.07)", zerolinecolor="rgba(37,99,235,.1)") def make_heatmap(seq, attr, title): W = 50; rows, hover, ylabels = [], [], [] for r in range(0, len(seq), W): rows.append([attr[r + c] if r + c < len(attr) else None for c in range(W)]) hover.append([f"{seq[r+c]}{r+c+1} IG:{attr[r+c]:.3f}" if r + c < len(seq) else "" for c in range(W)]) ylabels.append(f"{r+1}–{min(r+W, len(seq))}") fig = go.Figure(go.Heatmap( z=rows, text=hover, hoverinfo="text", colorscale=[[0, "#e8f0fe"], [.4, "#93c5fd"], [.75, "#3b82f6"], [1, "#1d4ed8"]], zmin=0, zmax=1, xgap=1, ygap=1, colorbar=dict(thickness=10, len=.9, tickfont=dict(**_FONT), title=dict(text="IG", font=dict(**_FONT))))) fig.update_layout( title=dict(text=title, font=dict(family="JetBrains Mono, monospace", size=10, color="#1d4ed8")), **_PL, height=240, margin=dict(t=30, b=24, l=55, r=45), xaxis=dict(title=dict(text="Position in window", font=dict(**_FONT)), tickfont=dict(**_FONT), **_GRID), yaxis=dict(tickvals=list(range(len(ylabels))), ticktext=ylabels, tickfont=dict(**_FONT), **_GRID)) return fig def top10_bar(seq, ig, title): indexed = sorted([{"aa": seq[i], "idx": i + 1, "sc": ig[i]} for i in range(min(len(seq), len(ig)))], key=lambda x: -x["sc"])[:10] fig = go.Figure(go.Bar( x=[x["sc"] for x in indexed], y=[f"{x['aa']}{x['idx']}" for x in indexed], orientation="h", marker=dict(color=[x["sc"] for x in indexed], colorscale=[[0, "#dbeafe"], [.5, "#3b82f6"], [1, "#1e40af"]], showscale=False, line_width=0))) fig.update_layout( title=dict(text=title, font=dict(family="JetBrains Mono, monospace", size=10, color="#1d4ed8")), **_PL, height=260, margin=dict(t=30, b=18, l=52, r=16), xaxis=dict(title=dict(text="IG Score", font=dict(**_FONT)), tickfont=dict(**_FONT), **_GRID), yaxis=dict(tickfont=dict(**_FONT), autorange="reversed")) return fig def residue_strip_html(seq, attr) -> str: cells = [] for i, aa in enumerate(seq): s = attr[i] if i < len(attr) else 0.0 r = int(248 - s * 211); g = int(250 - s * 151); b = int(252 - s * 17) fg = "#1e40af" if s > 0.5 else "#475569" cells.append( f'{aa}') return ('
' 'Residue Attribution (hover for score)
' '
' + "".join(cells) + "
") # ═══════════════════════════════════════════════════════════════════════════ # RESULT FORMATTING # ═══════════════════════════════════════════════════════════════════════════ def result_card_html(pkd: float, cos: float, pkd_lo: float, pkd_hi: float) -> str: pct = max(0.0, min(100.0, ((pkd - pkd_lo) / max(pkd_hi - pkd_lo, 1e-9)) * 100)) strength = "Weak" if pct < 30 else "Moderate" if pct < 55 else "Strong" if pct < 75 else "Very Strong" badge_cls = "badge-weak" if pct < 30 else "badge-moderate" if pct < 55 else "badge-strong" return f"""
Predicted pKd
{pkd:.3f}
Cosine Similarity
{cos:.4f}
Weak ({pkd_lo:.0f}) {strength} Strong ({pkd_hi:.0f})
{strength}
""" def status_badge_html(model_loaded: bool, device: str = "cpu") -> str: if model_loaded: return (f'
' f'READY  ·  {device}
') return '
○  NOT LOADED
' def sidebar_result_html(pkd: float, cos: float) -> str: return ( f'
Predicted pKd
' f'
{pkd:.3f}
' f'
Cosine Similarity
' f'
{cos:.4f}
' ) # ═══════════════════════════════════════════════════════════════════════════ # CSS (adapted for Gradio's DOM; theme vars + custom classes preserved) # ═══════════════════════════════════════════════════════════════════════════ CSS = """ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@300;400;500;600;700&display=swap'); :root { --bg0:#ffffff; --bg1:#f8f9fc; --bg2:#f1f4f9; --bg3:#e4e9f2; --border:rgba(37,99,235,0.13); --shadow:rgba(37,99,235,0.07); --accent:#2563eb; --accent2:#7c3aed; --green:#16a34a; --amber:#d97706; --red:#dc2626; --text0:#0f172a; --text1:#334155; --text2:#64748b; --text3:#94a3b8; --mono:'JetBrains Mono',monospace; --sans:'Inter',sans-serif; } @media (prefers-color-scheme:dark){ :root { --bg0:#06090f; --bg1:#0b1120; --bg2:#101828; --bg3:#172135; --border:rgba(96,165,250,0.13); --shadow:rgba(0,0,0,0.35); --accent:#60a5fa; --accent2:#a78bfa; --green:#34d399; --amber:#fbbf24; --red:#f87171; --text0:#f1f5f9; --text1:#cbd5e1; --text2:#64748b; --text3:#334155; } } .gradio-container { font-family: var(--sans) !important; max-width: 1400px !important; } .gradio-container h1, .gradio-container h2, .gradio-container h3 { font-family: var(--sans) !important; color: var(--text0) !important; } code, pre { font-family: var(--mono) !important; font-size: .82rem !important; background: var(--bg2) !important; border-radius: 4px !important; color: var(--accent) !important; } .app-header{display:flex;align-items:center;gap:14px;padding:4px 0 14px;} .app-logo{width:40px;height:40px;border-radius:10px; background:linear-gradient(135deg,var(--accent),var(--accent2)); display:flex;align-items:center;justify-content:center; font-size:20px;flex-shrink:0;box-shadow:0 2px 12px rgba(37,99,235,.22);} .app-title{font-family:var(--mono)!important;font-size:1.22rem!important; font-weight:700!important;color:var(--text0)!important;letter-spacing:-.02em;} .app-subtitle{font-size:.77rem!important;color:var(--text2)!important;margin-top:1px;letter-spacing:.03em;} .sec-hdr{font-family:var(--mono);font-size:.65rem;font-weight:700;letter-spacing:.2em; text-transform:uppercase;color:var(--accent);margin:14px 0 8px; display:flex;align-items:center;gap:8px;} .sec-hdr::after{content:'';flex:1;height:1px;background:var(--border);} .ex-card{background:var(--bg1);border:1px solid var(--border);border-radius:10px; padding:12px 15px;margin-bottom:8px;position:relative;overflow:hidden; transition:box-shadow .15s,border-color .15s;} .ex-card::before{content:'';position:absolute;left:0;top:0;bottom:0;width:3px; background:linear-gradient(180deg,var(--accent),var(--accent2));border-radius:3px 0 0 3px;} .ex-card:hover{border-color:rgba(37,99,235,.3);box-shadow:0 2px 12px var(--shadow);} .ex-pdb{font-family:var(--mono);font-size:.95rem;font-weight:700;color:var(--accent);} .ex-sub{font-size:.8rem;font-weight:600;color:var(--text1);margin:2px 0;} .ex-desc{font-size:.72rem;line-height:1.5;color:var(--text2);margin-top:3px;} .pkd-card{background:linear-gradient(135deg,rgba(37,99,235,.05),rgba(124,58,237,.04)); border:1px solid var(--border);border-radius:12px;padding:14px 18px; box-shadow:0 1px 6px var(--shadow);} .pkd-lbl{font-family:var(--mono);font-size:.67rem;color:var(--text2); text-transform:uppercase;letter-spacing:.12em;margin-bottom:2px;} .pkd-val{font-family:var(--mono);font-size:2rem;font-weight:700;color:var(--accent);line-height:1.1;} .pkd-badge{display:inline-block;padding:2px 9px;border-radius:20px; font-size:.7rem;font-weight:600;font-family:var(--mono);margin-top:5px;} .badge-weak{background:rgba(220,38,38,.08);color:var(--red);border:1px solid rgba(220,38,38,.22);} .badge-moderate{background:rgba(217,119,6,.08);color:var(--amber);border:1px solid rgba(217,119,6,.22);} .badge-strong{background:rgba(22,163,74,.08);color:var(--green);border:1px solid rgba(22,163,74,.22);} .str-bar{height:5px;border-radius:3px;background:var(--bg3);overflow:hidden;margin:8px 0 3px;} .str-fill{height:100%;border-radius:3px; background:linear-gradient(90deg,var(--accent),var(--accent2)); transition:width .7s cubic-bezier(.4,0,.2,1);} .str-labels{display:flex;justify-content:space-between;font-size:.65rem; color:var(--text2);font-family:var(--mono);} .ready-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px; border-radius:20px;background:rgba(22,163,74,.07); border:1px solid rgba(22,163,74,.22); font-family:var(--mono);font-size:.72rem;color:var(--green);font-weight:600;} .ready-dot{display:inline-block;width:6px;height:6px;border-radius:50%; background:var(--green);animation:pulse 2s infinite;} .idle-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px; border-radius:20px;background:rgba(100,116,139,.07); border:1px solid rgba(100,116,139,.18); font-family:var(--mono);font-size:.72rem;color:var(--text2);font-weight:600;} @keyframes pulse{0%,100%{opacity:1;transform:scale(1)}50%{opacity:.5;transform:scale(.85)}} textarea, input[type="text"], input[type="number"] { font-family: var(--mono) !important; font-size: .82rem !important; } """ # ═══════════════════════════════════════════════════════════════════════════ # GRADIO EVENT HANDLERS # ═══════════════════════════════════════════════════════════════════════════ def handle_mode_change(mode): """Show the right input panel for the selected interaction mode.""" if mode == "Protein-Protein": return ( gr.update(visible=True), # ppi panel gr.update(visible=False), # ab-ag panel gr.update(visible=True), # ppi chain inputs gr.update(visible=False), # ab-ag chain inputs ) return ( gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), ) def handle_load_example(pdb_id, mode, chain_a, chain_b, chain_h=None, chain_l=None, chain_ag=None): """Fetch a PDB, populate sequences + chain info, return updates for the UI.""" pdb_content, chains, err = get_chains_from_pdb(pdb_id) if err or not chains: gr.Warning(err or f"No chains found in {pdb_id}") return ( gr.update(), # ppi_a gr.update(), # ppi_b gr.update(), # ab_h gr.update(), # ab_l gr.update(), # ab_ag [], # state_chain_info_a [], # state_chain_info_b "", # state_pdb_content None, # state_ig_chain_map None, None, # state_ig_a, state_ig_b None, # state_result gr.update(value=ngl_viewer_html(None)), # ngl_html gr.update(value=""), # result_html gr.update(value=""), # strip_a_html None, # bar_a None, # hm_a gr.update(value=""), # strip_b_html None, # bar_b None, # hm_b gr.update(value=""), # pdb_info_md ) available = sorted(chains.keys()) info_msg = f"✅ **{pdb_id.upper()}** — chains: **{', '.join(available)}**" selected_chain_ids: set = set() if mode == "Protein-Protein": id_a = chain_a or available[0] id_b = chain_b or (available[1] if len(available) > 1 else available[0]) seq_a, info_a, miss_a = resolve_chains(chains, id_a, available) seq_b, info_b, miss_b = resolve_chains(chains, id_b, available) for m in miss_a: gr.Warning(f"Chain {m} not found (Side A). Available: {', '.join(available)}") for m in miss_b: gr.Warning(f"Chain {m} not found (Side B). Available: {', '.join(available)}") selected_chain_ids = {c["id"] for c in info_a} | {c["id"] for c in info_b} ppi_a_val = seq_a ppi_b_val = seq_b ah_val = al_val = aag_val = gr.update() chain_info_a, chain_info_b = info_a, info_b else: id_h = chain_h or "H" id_l = chain_l or "L" id_ag_str = chain_ag or next((c for c in available if c not in (id_h, id_l)), available[0]) def _single(cid): if cid in chains and chains[cid]["seq"]: return cid, chains[cid] f = next((k for k in chains if k.upper() == cid.upper() and chains[k]["seq"]), None) return (f, chains[f]) if f else (cid, {"seq": "", "resnos": []}) real_h, ch = _single(id_h) real_l, cl = _single(id_l) seq_ag, info_ag, miss_ag = resolve_chains(chains, id_ag_str, available) if not ch["seq"]: gr.Warning(f"Heavy chain {id_h} not found. Available: {', '.join(available)}") if not cl["seq"]: gr.Warning(f"Light chain {id_l} not found. Available: {', '.join(available)}") for m in miss_ag: gr.Warning(f"Antigen chain {m} not found. Available: {', '.join(available)}") chain_info_a = info_ag if info_ag else [] chain_info_b = [{"id": real_h, **ch}, {"id": real_l, **cl}] selected_chain_ids = ({c["id"] for c in info_ag} if info_ag else set()) if ch["seq"]: selected_chain_ids.add(real_h) if cl["seq"]: selected_chain_ids.add(real_l) ppi_a_val = ppi_b_val = gr.update() ah_val = ch["seq"] al_val = cl["seq"] aag_val = seq_ag filtered_pdb = filter_pdb_to_chains(pdb_content, selected_chain_ids) if selected_chain_ids else pdb_content return ( ppi_a_val, # ppi_a ppi_b_val, # ppi_b ah_val, # ab_h al_val, # ab_l aag_val, # ab_ag chain_info_a, # state_chain_info_a chain_info_b, # state_chain_info_b filtered_pdb, # state_pdb_content None, # state_ig_chain_map (cleared) None, None, # state_ig_a, state_ig_b None, # state_result wrap_iframe(ngl_viewer_html(filtered_pdb, None)), # ngl_html "", # result_html "", # strip_a_html None, # bar_a None, # hm_a "", # strip_b_html None, # bar_b None, # hm_b info_msg, # pdb_info_md ) def handle_fetch_pdb(pdb_id, mode, ca_in, cb_in, ch_in, cl_in, cag_in): if not pdb_id or not pdb_id.strip(): gr.Warning("Enter a PDB ID.") return (gr.update(),) * 20 return handle_load_example( pdb_id.strip(), mode, chain_a=ca_in.strip() or None if ca_in else None, chain_b=cb_in.strip() or None if cb_in else None, chain_h=ch_in.strip() or None if ch_in else None, chain_l=cl_in.strip() or None if cl_in else None, chain_ag=cag_in.strip() or None if cag_in else None, ) def handle_run_prediction( mode, ppi_a, ppi_b, ab_h, ab_l, ab_ag, run_ig, pkd_lo, pkd_hi, chain_info_a, chain_info_b, pdb_content, ): """Main predict callback. Calls the @spaces.GPU function.""" if mode == "Protein-Protein": seq_a = clean_multi(ppi_a or "") seq_b = clean_multi(ppi_b or "") else: seq_a = clean_multi(ab_ag or "") h = clean_multi(ab_h or "") l = clean_multi(ab_l or "") seq_b = f"{h}|{l}" if (h or l) else "" if not seq_a or not seq_b: gr.Warning("Fill all sequence boxes before running.") return (gr.update(),) * 12 # If no chain info exists yet (user pasted sequences directly), synthesise it if not chain_info_a: chain_info_a = [{"id": ch, "seq": s, "resnos": list(range(1, len(s) + 1))} for ch, s in zip("ABCDE", seq_a.split("|"))] if not chain_info_b: if mode == "Antibody-Antigen": h_s = clean_multi(ab_h or "") l_s = clean_multi(ab_l or "") chain_info_b = [ {"id": "H", "seq": h_s, "resnos": list(range(1, len(h_s) + 1))}, {"id": "L", "seq": l_s, "resnos": list(range(1, len(l_s) + 1))}, ] else: chain_info_b = [{"id": ch, "seq": s, "resnos": list(range(1, len(s) + 1))} for ch, s in zip("FGHIJ", seq_b.split("|"))] try: pkd, cos, ig_a, ig_b, device_used = gpu_run_prediction( seq_a, seq_b, bool(run_ig), float(pkd_lo), float(pkd_hi)) except Exception as e: gr.Error(f"Prediction failed: {e}") traceback.print_exc() return (gr.update(),) * 12 result = {"pkd": pkd, "cosine": cos, "sa": seq_a, "sb": seq_b} # Display data seq_a_d, ig_a_d, seq_b_d, ig_b_d = flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b) if not seq_a_d: seq_a_d = seq_a.replace("|", "") if not seq_b_d: seq_b_d = seq_b.replace("|", "") ig_chain_map = None if ig_a is not None and ig_b is not None: ig_chain_map = build_ig_chain_map(chain_info_a, chain_info_b, ig_a, ig_b) lbl_a = "Antigen (Target)" if mode == "Antibody-Antigen" else "Target (Protein A)" lbl_b = "Antibody H+L" if mode == "Antibody-Antigen" else "Binder (proteina)" if ig_a_d and seq_a_d: strip_a = residue_strip_html(seq_a_d, ig_a_d) bar_a = top10_bar(seq_a_d, ig_a_d, f"Top 10 · {lbl_a}") hm_a = make_heatmap(seq_a_d, ig_a_d, f"{lbl_a} · IG Heatmap") else: strip_a, bar_a, hm_a = "", None, None if ig_b_d and seq_b_d: strip_b = residue_strip_html(seq_b_d, ig_b_d) bar_b = top10_bar(seq_b_d, ig_b_d, f"Top 10 · {lbl_b}") hm_b = make_heatmap(seq_b_d, ig_b_d, f"{lbl_b} · IG Heatmap") else: strip_b, bar_b, hm_b = "", None, None ngl_html = wrap_iframe(ngl_viewer_html(pdb_content, ig_chain_map)) if pdb_content else \ wrap_iframe(ngl_viewer_html(None)) result_card = result_card_html(pkd, cos, float(pkd_lo), float(pkd_hi)) sidebar_card = sidebar_result_html(pkd, cos) badge = status_badge_html(True, device_used) return ( result, # state_result ig_a, ig_b, # state_ig_a, state_ig_b ig_chain_map, # state_ig_chain_map result_card, # result_html ngl_html, # ngl_html strip_a, bar_a, hm_a, # vis A strip_b, bar_b, hm_b, # vis B sidebar_card, # sidebar_result_html badge, # status badge ) def handle_download_results(result, state_ig_a, state_ig_b, chain_info_a, chain_info_b): """Build CSV + JSON for download.""" if not result: return None, None seq_a_d, ig_a_d, seq_b_d, ig_b_d = flat_ig_for_display( state_ig_a, state_ig_b, chain_info_a, chain_info_b) if not seq_a_d: seq_a_d = result["sa"].replace("|", "") if not seq_b_d: seq_b_d = result["sb"].replace("|", "") # CSV csv_path = os.path.join(tempfile.gettempdir(), "ig_scores.csv") with open(csv_path, "w", newline="") as f: wc = csv.writer(f) wc.writerow(["chain", "position", "residue", "ig_score"]) for i, (aa, sc) in enumerate(zip(seq_a_d or "", ig_a_d or [])): wc.writerow(["Target", i + 1, aa, f"{sc:.6f}"]) for i, (aa, sc) in enumerate(zip(seq_b_d or "", ig_b_d or [])): wc.writerow(["proteina", i + 1, aa, f"{sc:.6f}"]) # JSON summary = { "pkd": result["pkd"], "cosine": result["cosine"], "Target_length": len(seq_a_d or ""), "proteina_length": len(seq_b_d or ""), "top10_Target": sorted( [{"res": f"{(seq_a_d or '')[i]}{i+1}", "ig": (ig_a_d or [])[i]} for i in range(min(len(seq_a_d or ""), len(ig_a_d or [])))], key=lambda x: -x["ig"])[:10], "top10_proteina": sorted( [{"res": f"{(seq_b_d or '')[i]}{i+1}", "ig": (ig_b_d or [])[i]} for i in range(min(len(seq_b_d or ""), len(ig_b_d or [])))], key=lambda x: -x["ig"])[:10], } json_path = os.path.join(tempfile.gettempdir(), "balm_result.json") with open(json_path, "w") as f: json.dump(summary, f, indent=2) return csv_path, json_path def handle_load_model(pkd_lo, pkd_hi): """Force-load the model on CPU (download weights). GPU is reserved for prediction time.""" try: ensure_model(float(pkd_lo), float(pkd_hi)) gr.Info("Model weights downloaded and ready. GPU will be allocated when you click Run.") return status_badge_html(True, "cpu (GPU on demand)") except Exception as e: traceback.print_exc() gr.Error(f"Load failed: {e}") return status_badge_html(False) def handle_run_batch(file_obj, batch_mode, run_ig_b, pkd_lo, pkd_hi): if file_obj is None: gr.Warning("Upload a CSV file first.") return None, None try: df = pd.read_csv(file_obj.name) except Exception as e: gr.Error(f"CSV read failed: {e}") return None, None rows = df.to_dict(orient="records") try: results = gpu_run_batch(rows, batch_mode, bool(run_ig_b), float(pkd_lo), float(pkd_hi)) except Exception as e: traceback.print_exc() gr.Error(f"Batch failed: {e}") return None, None df_res = pd.concat([df, pd.DataFrame(results)], axis=1) out_path = os.path.join(tempfile.gettempdir(), "balm_batch_results.csv") df_res.to_csv(out_path, index=False) return df_res, out_path # ═══════════════════════════════════════════════════════════════════════════ # GRADIO LAYOUT # ═══════════════════════════════════════════════════════════════════════════ def build_ui(): with gr.Blocks(title="BALM-PPI Pro") as demo: # ── Persistent state ─────────────────────────────────────────────── state_pdb_content = gr.State("") state_chain_info_a = gr.State([]) state_chain_info_b = gr.State([]) state_ig_a = gr.State(None) state_ig_b = gr.State(None) state_ig_chain_map = gr.State(None) state_result = gr.State(None) # ── Header ───────────────────────────────────────────────────────── gr.HTML("""
BALM-PPI Pro
ESM-2 · LoRA · Integrated Gradients  ·  Protein Binding Affinity Prediction
""") with gr.Tabs(): # ═══════════════════════════════════════════════════════════════ # TAB 1 — SINGLE PREDICTION # ═══════════════════════════════════════════════════════════════ with gr.Tab("🎯 Single Prediction"): with gr.Row(): # ─── Sidebar (left column) ───────────────────────────── with gr.Column(scale=1, min_width=260): gr.Markdown("### ⚙️ Model") with gr.Accordion("pKd Bounds", open=False): pkd_lo = gr.Number(value=1.0, label="Min", precision=2) pkd_hi = gr.Number(value=16.0, label="Max", precision=2) load_btn = gr.Button("⚡ Reload Model", variant="secondary") status_html = gr.HTML( status_badge_html(_MODEL_STATE["model"] is not None, "cpu (GPU on demand)")) sidebar_result = gr.HTML("") gr.Markdown( '
' '🤗 Harshit494/BALM-PPI
' '💻 rgorantla04/BALM-PPI' '
' ) # ─── Main column ─────────────────────────────────────── with gr.Column(scale=4): mode = gr.Radio( ["Protein-Protein", "Antibody-Antigen"], value="Protein-Protein", label="Interaction mode", container=False, ) # PPI input panel with gr.Group(visible=True) as ppi_panel: with gr.Row(): ppi_a = gr.Textbox( label="Target Protein — Seq A", placeholder="Paste FASTA or raw sequence… | = chain separator", lines=4, max_lines=8, ) ppi_b = gr.Textbox( label="Binder Protein — Seq B ( | = chain separator)", placeholder="Paste FASTA or raw sequence…", lines=4, max_lines=8, ) # Ab-Ag input panel with gr.Group(visible=False) as abag_panel: with gr.Row(): ab_ag = gr.Textbox( label="Antigen / Target — Seq A", placeholder="Antigen sequence…", lines=4, max_lines=8, scale=2, ) with gr.Column(scale=3): with gr.Row(): ab_h = gr.Textbox( label="Heavy Chain (H)", placeholder="VH sequence…", lines=4, max_lines=8, ) ab_l = gr.Textbox( label="Light Chain (L)", placeholder="VL sequence…", lines=4, max_lines=8, ) with gr.Row(): run_btn = gr.Button("🔬 Run Prediction", variant="primary", scale=3) run_ig = gr.Checkbox(value=True, label="Run Integrated Gradients", info="~45–90s CPU · <5s GPU", scale=2) result_html = gr.HTML("") # ─── Lower section: PDB controls + visualizations ───── with gr.Row(): # ── LEFT: Custom PDB Fetch ─────────────────── with gr.Column(scale=1, min_width=240): gr.HTML('
Custom PDB Fetch
') pdb_in = gr.Textbox(label="PDB ID", placeholder="1YCR, 1BRS, 2VXQ…") # PPI chain selectors with gr.Group(visible=True) as ppi_chains_row: ca_in = gr.Textbox(label="Side A chain(s)", placeholder="A or A,B") cb_in = gr.Textbox(label="Side B chain(s)", placeholder="B or B,C") # Ab-Ag chain selectors with gr.Group(visible=False) as abag_chains_row: ch_in = gr.Textbox(label="Heavy", placeholder="H") cl_in = gr.Textbox(label="Light", placeholder="L") cag_in = gr.Textbox(label="Antigen", placeholder="A or A,B") fetch_btn = gr.Button("🔍 Fetch PDB", variant="secondary", size="sm") pdb_info_md = gr.Markdown("") # ── MIDDLE: Quick Examples ─────────────────── with gr.Column(scale=1, min_width=240): gr.HTML('
Quick Examples
') example_buttons = [] for ex in EXAMPLES: gr.HTML( f'
' f'
{ex["label"]}
' f'
{ex["subtitle"]}
' f'
{ex["desc"]}
' ) b = gr.Button(f"⬇ Load {ex['pdb']}", size="sm") example_buttons.append((b, ex)) # ── RIGHT: Visualizations ──────────────────── with gr.Column(scale=3): with gr.Tabs(): with gr.Tab("🧬 Structure"): ngl_html = gr.HTML(wrap_iframe(ngl_viewer_html(None))) with gr.Tab("📊 Side A Strip"): strip_a_html = gr.HTML("") bar_a_plot = gr.Plot() with gr.Tab("🗺 Side A Heatmap"): hm_a_plot = gr.Plot() with gr.Tab("📊 Side B Strip"): strip_b_html = gr.HTML("") bar_b_plot = gr.Plot() with gr.Tab("🗺 Side B Heatmap"): hm_b_plot = gr.Plot() with gr.Tab("📥 Download"): dl_btn = gr.Button("Build downloads", variant="secondary") dl_csv = gr.File(label="IG Scores CSV") dl_json = gr.File(label="Summary JSON") # ═══════════════════════════════════════════════════════════════ # TAB 2 — BATCH PREDICTION # ═══════════════════════════════════════════════════════════════ with gr.Tab("📂 Batch Prediction"): gr.Markdown("### 📂 Batch Prediction") batch_mode = gr.Radio( ["Protein-Protein", "Antibody-Antigen"], value="Protein-Protein", label="Mode", container=False, ) with gr.Row(): gr.File(label="PPI Template", value=lambda: _write_template(PPI_TEMPLATE, "ppi_template.csv"), interactive=False) gr.File(label="Ab-Ag Template", value=lambda: _write_template(ABAG_TEMPLATE, "abag_template.csv"), interactive=False) gr.Markdown( "Required columns: **seq_a, seq_b** (PPI mode) " "or **heavy_chain, light_chain, antigen** (Ab-Ag mode)." ) batch_file = gr.File(label="CSV file", file_types=[".csv"]) run_ig_b = gr.Checkbox(value=False, label="Compute IG per row (slow on CPU; fine on GPU)") batch_run = gr.Button("🏃 Run Batch", variant="primary") batch_df = gr.Dataframe(label="Results", interactive=False) batch_dl = gr.File(label="Download results CSV") # ═══════════════════════════════════════════════════════════════ # TAB 3 — MODEL INFO # ═══════════════════════════════════════════════════════════════ with gr.Tab("📚 Model Info"): gr.Markdown(""" ### 📚 About BALM-PPI **BALM-PPI** predicts protein–protein and antibody–antigen binding affinity using ESM-2 + LoRA with Integrated Gradients explainability. | Component | Detail | |-----------|--------| | Backbone | ESM-2 650M (`facebook/esm2_t33_650M_UR50D`) | | Fine-tuning | LoRA r=8, α=16, dropout=0.1 · targets: key/query/value | | Column mapping | `seq_a` → **Target** · `seq_b` → **proteina** | | Multi-chain | `\\|` separator → `` double CLS token | | Affinity output | pKd = ((cos+1)/2) × (pKd_max − pKd_min) + pKd_min | | IG | Integrated Gradients, **15-step** Riemann (float32-safe) | | Viewer | PDB filtered to **selected chains only** | | Hardware | HF ZeroGPU (dynamic A100); CPU fallback supported | **Examples** · **1YCR** — MDM2 ↔ p53  ·  **1BRS** — Barnase ↔ Barstar 🤗 [Harshit494/BALM-PPI](https://huggingface.co/Harshit494/BALM-PPI)  ·  💻 [rgorantla04/BALM-PPI](https://github.com/rgorantla04/BALM-PPI) """) # ─────────────────────────────────────────────────────────────────── # EVENT WIRING # ─────────────────────────────────────────────────────────────────── mode.change( handle_mode_change, inputs=[mode], outputs=[ppi_panel, abag_panel, ppi_chains_row, abag_chains_row], ) load_btn.click( handle_load_model, inputs=[pkd_lo, pkd_hi], outputs=[status_html], ) # Outputs that load_example/fetch_pdb returns load_outputs = [ ppi_a, ppi_b, ab_h, ab_l, ab_ag, state_chain_info_a, state_chain_info_b, state_pdb_content, state_ig_chain_map, state_ig_a, state_ig_b, state_result, ngl_html, result_html, strip_a_html, bar_a_plot, hm_a_plot, strip_b_html, bar_b_plot, hm_b_plot, pdb_info_md, ] for btn, ex in example_buttons: btn.click( lambda mode_val=ex["mode"], pdb=ex["pdb"], ca=ex.get("chain_a"), cb=ex.get("chain_b"), ch=ex.get("chain_h"), cl=ex.get("chain_l"), cag=ex.get("chain_ag"): handle_load_example(pdb, mode_val, ca, cb, ch, cl, cag), outputs=load_outputs, ) fetch_btn.click( handle_fetch_pdb, inputs=[pdb_in, mode, ca_in, cb_in, ch_in, cl_in, cag_in], outputs=load_outputs, ) run_btn.click( handle_run_prediction, inputs=[ mode, ppi_a, ppi_b, ab_h, ab_l, ab_ag, run_ig, pkd_lo, pkd_hi, state_chain_info_a, state_chain_info_b, state_pdb_content, ], outputs=[ state_result, state_ig_a, state_ig_b, state_ig_chain_map, result_html, ngl_html, strip_a_html, bar_a_plot, hm_a_plot, strip_b_html, bar_b_plot, hm_b_plot, sidebar_result, status_html, ], ) dl_btn.click( handle_download_results, inputs=[state_result, state_ig_a, state_ig_b, state_chain_info_a, state_chain_info_b], outputs=[dl_csv, dl_json], ) batch_run.click( handle_run_batch, inputs=[batch_file, batch_mode, run_ig_b, pkd_lo, pkd_hi], outputs=[batch_df, batch_dl], ) return demo def _write_template(content: str, name: str) -> str: path = os.path.join(tempfile.gettempdir(), name) with open(path, "w") as f: f.write(content) return path # ═══════════════════════════════════════════════════════════════════════════ if __name__ == "__main__": # Eager model load at startup — happens during Space's "Starting" phase, # so the user opens a ready UI instead of waiting ~5 GB of downloads on # their first click. Progress shows in the Space's runtime logs. print("[BALM-PPI] Pre-loading model on CPU…", flush=True) try: ensure_model(1.0, 16.0) print("[BALM-PPI] Model ready.", flush=True) except Exception as _e: print(f"[BALM-PPI] WARNING — pre-load failed: {_e}", flush=True) traceback.print_exc() demo = build_ui() demo.queue(max_size=20).launch( css=CSS, theme=gr.themes.Soft(), ssr_mode=False, # cleaner logs; SSR is experimental in Gradio 6 )