ExplainBind / app.py
Zhaohan-Meng's picture
Update app.py
01d9e59 verified
import gradio_client.utils as _gc_utils
_orig_get_type = _gc_utils.get_type
_orig_json2py = _gc_utils._json_schema_to_python_type
def _patched_get_type(schema):
if isinstance(schema, bool):
schema = {}
return _orig_get_type(schema)
def _patched_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
schema = {}
return _orig_json2py(schema, defs)
_gc_utils.get_type = _patched_get_type
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
# ─── Imports ───────────────────────────────────────────────────────────────────
import os
import io
import base64
import argparse
from typing import Optional, List, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader
import selfies
# from rdkit import Chem
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel, AutoTokenizer
from Bio.PDB import PDBParser, MMCIFParser
from Bio.Data import IUPACData
import gradio as gr
# Project utils (ensure these exist in your repository)
from utils.metric_learning_models_att_maps import Pre_encoded, ExplainBind
from utils.foldseek_util import get_struc_seq
# ───────────────────── Paths & Logos ─────────────────────
ROOT = os.path.dirname(os.path.abspath(__file__))
ASSET_DIR = os.path.join(ROOT, "utils")
LOSCAZLO_LOGO = os.path.join(ASSET_DIR, "loscalzo.png")
def _load_logo_b64(path):
if not os.path.exists(path):
return ""
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
LOSCAZLO_B64 = _load_logo_b64(LOSCAZLO_LOGO)
# ───────────────────── Configurable constants ─────────────────────
# UI-visible names (Halogen bonding removed)
INTERACTION_NAMES = [
"Hydrogen bonding",
"Salt Bridging",
"π–π Stacking",
"Cation–π",
"Hydrophobic",
"Van der Waals",
"Overall Interaction",
]
# Map visible indices (0..5 = specific, 6 = combined) to underlying channel indices
# Underlying channels originally had Halogen at index=5 (0-based). We skip 5 entirely.
VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cation-Pi, Hydro, VdW
N_VISIBLE_SPEC = len(VISIBLE2UNDERLYING) # 6
# ───── Helper utilities ───────────────────────────────────────────
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
STANDARD_AA_SET = set("ACDEFGHIKLMNPQRSTVWY") # Uppercase FASTA amino acids
def simple_seq_from_structure(path: str) -> str:
"""Extract the longest chain and return standard 1-letter amino acid sequence."""
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
structure = parser.get_structure("P", path)
chains = list(structure.get_chains())
if not chains:
return ""
chain = max(chains, key=lambda c: len(list(c.get_residues())))
seq = []
for res in chain:
resname = res.get_resname().upper()
if resname in three2one:
seq.append(three2one[resname])
# else: skip non-standard residues
return "".join(seq)
# def smiles_to_selfies(smiles_text: str) -> Optional[str]:
# """Validate and convert SMILES to SELFIES; return None if invalid."""
# try:
# mol = Chem.MolFromSmiles(smiles_text)
# if mol is None:
# return None
# return selfies.encoder(smiles_text)
# except Exception:
# return None
def smiles_to_selfies(smiles_text: str) -> Optional[str]:
try:
sf = selfies.encoder(smiles_text)
smiles_back = selfies.decoder(sf)
if not smiles_back:
return None
return sf
except Exception:
return None
def detect_protein_type(seq: str) -> str:
"""
Heuristic for protein input:
- All uppercase and only the standard 20 amino acids β†’ 'fasta'
- Otherwise (contains lowercase or non-standard characters) β†’ 'sa'
"""
s = (seq or "").strip()
if not s:
return "fasta"
up = s.upper()
only_aa = all(ch in STANDARD_AA_SET for ch in up)
all_upper = (s == up)
return "fasta" if (only_aa and all_upper) else "sa"
def detect_ligand_type(text: str) -> str:
"""
Heuristic for ligand input:
- Starts with '[' and contains ']' β†’ 'selfies'
- Otherwise β†’ 'smiles'
"""
t = (text or "").strip()
if not t:
return "smiles"
return "selfies" if (t.startswith("[") and ("]" in t)) else "smiles"
def parse_config():
"""Parse command-line options."""
p = argparse.ArgumentParser()
p.add_argument("--agg_mode", type=str, default="mean_all_tok")
p.add_argument("--group_size", type=int, default=1)
p.add_argument("--fusion", default="CAN")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--save_path_prefix", default="save_model_ckp/") # Root folder containing checkpoints
p.add_argument("--dataset", default="Human")
return p.parse_args()
args = parse_config()
DEVICE = args.device
# ───── Dynamic model registry ─────────────────────────────────────
PROT_MODELS = {
"sa": "westlake-repl/SaProt_650M_AF2",
"fasta": "facebook/esm2_t33_650M_UR50D",
}
DRUG_MODELS = {
"selfies": "HUBioDataLab/SELFormer",
# "smiles": "ibm/MoLFormer-XL-both-10pct",
}
def load_encoders(ptype: str, ltype: str, args):
"""
Dynamically load encoders and tokenisers based on input types.
Returns: (prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding_module)
"""
# Protein encoder
if ptype == "fasta":
prot_path = PROT_MODELS["fasta"]
prot_tokenizer = EsmTokenizer.from_pretrained(prot_path, do_lower_case=False)
prot_model = EsmForMaskedLM.from_pretrained(prot_path)
else: # 'sa'
prot_path = PROT_MODELS["sa"]
prot_tokenizer = EsmTokenizer.from_pretrained(prot_path)
prot_model = EsmForMaskedLM.from_pretrained(prot_path)
drug_path = DRUG_MODELS["selfies"]
drug_tokenizer = AutoTokenizer.from_pretrained(drug_path)
drug_model = AutoModel.from_pretrained(drug_path)
# Ligand encoder
# if ltype == "smiles":
# drug_path = DRUG_MODELS["smiles"]
# drug_tokenizer = AutoTokenizer.from_pretrained(drug_path, trust_remote_code=True)
# drug_model = AutoModel.from_pretrained(drug_path, deterministic_eval=True, trust_remote_code=True)
# else: # 'selfies'
# drug_path = DRUG_MODELS["selfies"]
# drug_tokenizer = AutoTokenizer.from_pretrained(drug_path)
# drug_model = AutoModel.from_pretrained(drug_path)
# Wrap encoders with Pre_encoded module
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
return prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding
def make_collate_fn(prot_tokenizer, drug_tokenizer):
"""Create a batch collation function using the given tokenisers."""
def _collate_fn(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer(
list(query1), max_length=512, padding="max_length", truncation=True,
add_special_tokens=True, return_tensors="pt",
)
query_encodings2 = drug_tokenizer(
list(query2), max_length=512, padding="max_length", truncation=True,
add_special_tokens=True, return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return (query_encodings1["input_ids"], attention_mask1,
query_encodings2["input_ids"], attention_mask2, scores)
return _collate_fn
def get_case_feature(model, loader):
"""Generate features for one protein–ligand pair using the provided model."""
model.eval()
with torch.no_grad():
for p_ids, p_mask, d_ids, d_mask, _ in loader:
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
return [(p_emb.cpu(), d_emb.cpu(),
p_ids.cpu(), d_ids.cpu(),
p_mask.cpu(), d_mask.cpu(), None)]
# ─────────────── SELFIES grouping by ORIGINAL string ─────────────
def _group_rows_by_selfies_string(n_rows: int, selfies_str: str):
"""
Partition the attention matrix's n_rows along ligand axis into groups per SELFIES token '[ ... ]'.
Each group is a contiguous row span; we assign rows β‰ˆ equally using linspace.
Returns:
groups: List[(start_row, end_row)] inclusive
labels: List['[X]','[=O]', ...]
"""
if n_rows <= 0:
return [], []
try:
toks = list(selfies.split_selfies((selfies_str or "").strip()))
except Exception:
toks = []
if not toks:
# Fallback: treat whole ligand as one token
return [(0, n_rows - 1)], [selfies_str or "[?]"]
g = len(toks)
edges = np.linspace(0, n_rows, g + 1, dtype=int)
groups = []
for i in range(g):
s, e = edges[i], edges[i + 1] - 1
if e < s:
e = s
groups.append((s, e))
return groups, toks
def _connected_components_2d(mask: torch.Tensor) -> List[List[Tuple[int, int]]]:
"""4-connected components over a 2D boolean mask (rows=ligand tokens, cols=protein residues)."""
h, w = mask.shape
visited = torch.zeros_like(mask, dtype=torch.bool)
comps: List[List[Tuple[int,int]]] = []
for i in range(h):
for j in range(w):
if mask[i, j] and not visited[i, j]:
stack = [(i, j)]
visited[i, j] = True
comp = []
while stack:
r, c = stack.pop()
comp.append((r, c))
for dr, dc in ((1,0), (-1,0), (0,1), (0,-1)):
rr, cc = r + dr, c + dc
if 0 <= rr < h and 0 <= cc < w and mask[rr, cc] and not visited[rr, cc]:
visited[rr, cc] = True
stack.append((rr, cc))
comps.append(comp)
return comps
def _format_component_table(
components,
p_tokens,
d_tokens,
*,
mode: str = "pair", # "pair" | "residue"
):
"""
Render HTML table for highlighted interaction components.
Parameters
----------
components : List[List[Tuple[int,int]]]
Each component is a list of (ligand_index, protein_index) pairs.
p_tokens : List[str]
Protein token strings.
d_tokens : List[str]
Ligand token strings.
mode : str
"pair" -> show Protein range + Ligand range
"residue" -> show Protein residue(s) only
"""
# ----------------------------
# Residue-only mode
# ----------------------------
if mode == "residue":
if not components:
return (
"<h4 style='margin:12px 0 6px'>Highlighted protein residues</h4>"
"<p>No residues selected.</p>"
)
rows = []
for comp in components:
# comp = [(lig_idx, prot_idx), ...]
prot_indices = [j for (_, j) in comp]
p_start, p_end = min(prot_indices), max(prot_indices)
p_s_idx, p_e_idx = p_start + 1, p_end + 1
p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?"
p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?"
if p_start == p_end:
label = f"{p_s_idx}:{p_s_tok}"
else:
label = f"{p_s_idx}:{p_s_tok} – {p_e_idx}:{p_e_tok}"
rows.append(
f"<tr>"
f"<td style='border:1px solid #ddd;padding:6px'>"
f"<strong>{label}</strong>"
f"</td>"
f"</tr>"
)
return (
"<h4 style='margin:12px 0 6px'>Highlighted protein residues</h4>"
"<table style='border-collapse:collapse;margin:6px 0 16px;width:60%'>"
"<thead><tr style='background:#f5f5f5'>"
"<th style='border:1px solid #ddd;padding:6px'>Protein residue(s)</th>"
"</tr></thead>"
f"<tbody>{''.join(rows)}</tbody></table>"
)
# ----------------------------
# Pair mode (default behaviour)
# ----------------------------
if not components:
return (
"<h4 style='margin:12px 0 6px'>Highlighted interaction segments</h4>"
"<p>No interaction pairs selected.</p>"
)
rows = []
for comp in components:
lig_indices = [i for (i, _) in comp]
prot_indices = [j for (_, j) in comp]
d_start, d_end = min(lig_indices), max(lig_indices)
p_start, p_end = min(prot_indices), max(prot_indices)
d_s_idx, d_e_idx = d_start + 1, d_end + 1
p_s_idx, p_e_idx = p_start + 1, p_end + 1
d_s_tok = d_tokens[d_start] if d_start < len(d_tokens) else "?"
d_e_tok = d_tokens[d_end] if d_end < len(d_tokens) else "?"
p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?"
p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?"
rows.append(
f"<tr>"
f"<td style='border:1px solid #ddd;padding:6px'>Protein: "
f"<strong>{p_s_idx}:{p_s_tok}</strong>"
f"{' – <strong>'+str(p_e_idx)+':'+p_e_tok+'</strong>' if p_end > p_start else ''}"
f"</td>"
f"<td style='border:1px solid #ddd;padding:6px'>Ligand: "
f"<strong>{d_s_idx}:{d_s_tok}</strong>"
f"{' – <strong>'+str(d_e_idx)+':'+d_e_tok+'</strong>' if d_end > d_start else ''}"
f"</td>"
f"</tr>"
)
return (
"<h4 style='margin:12px 0 6px'>Highlighted Binding site</h4>"
"<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>"
"<thead><tr style='background:#f5f5f5'>"
"<th style='border:1px solid #ddd;padding:6px'>Protein range</th>"
"<th style='border:1px solid #ddd;padding:6px'>Ligand range</th>"
"</tr></thead>"
f"<tbody>{''.join(rows)}</tbody></table>"
)
def visualize_attention_and_ranges(
model,
feats,
head_idx: int,
*,
mode: str = "pair", # "pair" | "residue"
topk_pairs: int = 1, # Top-K interaction pairs (default=1)
topk_residues: int = 1, # Top-K residues (1–20, default=1)
prot_tokenizer=None,
drug_tokenizer=None,
ligand_type: str = "selfies",
raw_selfies: Optional[str] = None,
) -> Tuple[str, str]:
"""
Visualise interaction attention with two complementary Top-K modes.
Modes
-----
mode="pair":
- Select Top-K highest-scoring (ligand token, protein residue) pairs
- Project selected pairs onto protein axis (evaluation-aligned)
- Default K = 1 (user-controlled)
mode="residue":
- Aggregate attention over ligand dimension
- Rank residues by aggregated score
- Select Top-K residues (1–100)
- Default K = 1 (binding pocket discovery)
Notes
-----
- Per-head GLOBAL SUM normalisation (matches test()).
- Specific heads mapped exactly to GT channels.
- Combined head = sum of 6 specific heads (NOT overall=7).
"""
assert mode in {"pair", "residue"}
assert topk_pairs >= 1
assert 1 <= topk_residues <= 100
model.eval()
with torch.no_grad():
# --------------------------------------------------
# Unpack features
# --------------------------------------------------
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
# --------------------------------------------------
# Forward
# --------------------------------------------------
prob, att_pd = model(p_emb, d_emb, p_mask, d_mask)
att = att_pd.squeeze(0)
prob = prob.item()
# expected: [Ld, Lp, 8] or [8, Ld, Lp]
# --------------------------------------------------
# Channel mapping (must match test())
# --------------------------------------------------
VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cat-Pi, Hydro, VdW
N_VISIBLE_SPEC = 6
def select_channel_map(att_):
if att_.dim() == 3 and att_.shape[-1] >= 8:
if head_idx < N_VISIBLE_SPEC:
return att_[:, :, VISIBLE2UNDERLYING[head_idx]].cpu()
return att_[:, :, VISIBLE2UNDERLYING].sum(dim=2).cpu()
if att_.dim() == 3 and att_.shape[0] >= 8:
if head_idx < N_VISIBLE_SPEC:
return att_[VISIBLE2UNDERLYING[head_idx]].cpu()
return att_[VISIBLE2UNDERLYING].sum(dim=0).cpu()
return att_.squeeze().cpu()
att2d_raw = select_channel_map(att) # [Ld, Lp]
# --------------------------------------------------
# Per-head GLOBAL SUM normalisation (critical)
# --------------------------------------------------
att2d_raw = att2d_raw / (att2d_raw.sum() + 1e-8)
# --------------------------------------------------
# Token decoding & trimming
# --------------------------------------------------
def clean_tokens(ids, tokenizer):
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
if hasattr(tokenizer, "all_special_tokens"):
toks = [t for t in toks if t not in tokenizer.all_special_tokens]
return toks
p_tokens_full = clean_tokens(p_ids[0], prot_tokenizer)
d_tokens_full = clean_tokens(d_ids[0], drug_tokenizer)
n_d = min(len(d_tokens_full), att2d_raw.size(0))
n_p = min(len(p_tokens_full), att2d_raw.size(1))
att2d = att2d_raw[:n_d, :n_p]
p_tokens = p_tokens_full[:n_p]
d_tokens = d_tokens_full[:n_d]
p_indices = list(range(1, n_p + 1))
d_indices = list(range(1, n_d + 1))
# --------------------------------------------------
# SELFIES row merging (for interpretability)
# --------------------------------------------------
if ligand_type == "selfies" and raw_selfies:
groups, labels = _group_rows_by_selfies_string(att2d.size(0), raw_selfies)
if groups:
merged = []
for s, e in groups:
merged.append(att2d[s:e + 1].mean(dim=0, keepdim=True))
att2d = torch.cat(merged, dim=0)
d_tokens = labels
d_indices = list(range(1, len(labels) + 1))
# --------------------------------------------------
# Top-K selection (two modes, STRICT RANKING)
# --------------------------------------------------
if mode == "pair":
flat = att2d.reshape(-1)
k_eff = min(topk_pairs, flat.numel())
topk_vals, topk_idx = torch.topk(flat, k=k_eff)
mask_top = torch.zeros_like(flat, dtype=torch.bool)
mask_top[topk_idx] = True
mask_top = mask_top.view_as(att2d)
rows = []
n_cols = att2d.size(1)
for rank, (val, linear_idx) in enumerate(zip(topk_vals, topk_idx), start=1):
i = (linear_idx // n_cols).item()
j = (linear_idx % n_cols).item()
rows.append(
f"<tr>"
f"<td style='border:1px solid #ddd;padding:6px'><strong>Top {rank}</strong></td>"
f"<td style='border:1px solid #ddd;padding:6px'>Protein: <strong>{j+1}:{p_tokens[j]}</strong></td>"
f"<td style='border:1px solid #ddd;padding:6px'>Ligand: <strong>{i+1}:{d_tokens[i]}</strong></td>"
f"<td style='border:1px solid #ddd;padding:6px'>Score: <strong>{val.item():.6f}</strong></td>"
f"</tr>"
)
ranges_html = (
"<h4 style='margin:12px 0 6px'>Top-K Interaction Pairs (ranked by attention score)</h4>"
"<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>"
"<thead><tr style='background:#f5f5f5'>"
"<th style='border:1px solid #ddd;padding:6px'>Rank</th>"
"<th style='border:1px solid #ddd;padding:6px'>Protein</th>"
"<th style='border:1px solid #ddd;padding:6px'>Ligand</th>"
"<th style='border:1px solid #ddd;padding:6px'>Attention Score</th>"
"</tr></thead>"
f"<tbody>{''.join(rows)}</tbody></table>"
)
else:
# --- STRICT Top-K residue ranking ---
residue_score = att2d.sum(dim=0)
k_eff = min(topk_residues, residue_score.numel())
topk_vals, topk_res_idx = torch.topk(residue_score, k=k_eff)
mask_top = torch.zeros_like(att2d, dtype=torch.bool)
mask_top[:, topk_res_idx] = True
rows = []
for rank, (val, j) in enumerate(zip(topk_vals, topk_res_idx), start=1):
j = j.item()
rows.append(
f"<tr>"
f"<td style='border:1px solid #ddd;padding:6px'><strong>Top {rank}</strong></td>"
f"<td style='border:1px solid #ddd;padding:6px'>"
f"Protein residue: <strong>{j+1}:{p_tokens[j]}</strong>"
f"</td>"
f"<td style='border:1px solid #ddd;padding:6px'>"
f"Aggregated Score: <strong>{val.item():.6f}</strong>"
f"</td>"
f"</tr>"
)
ranges_html = (
"<h4 style='margin:12px 0 6px'>Top-K Residues (ranked by aggregated attention)</h4>"
"<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>"
"<thead><tr style='background:#f5f5f5'>"
"<th style='border:1px solid #ddd;padding:6px'>Rank</th>"
"<th style='border:1px solid #ddd;padding:6px'>Protein Residue</th>"
"<th style='border:1px solid #ddd;padding:6px'>Aggregated Score</th>"
"</tr></thead>"
f"<tbody>{''.join(rows)}</tbody></table>"
)
# --------------------------------------------------
# Connected components (visual coherence)
# --------------------------------------------------
# p_tokens_orig = p_tokens.copy()
# d_tokens_orig = d_tokens.copy()
# components = _connected_components_2d(mask_top)
# ranges_html = _format_component_table(
# components,
# p_tokens_orig,
# d_tokens_orig,
# mode=mode,
# )
# --------------------------------------------------
# Crop to union of selected rows / columns
# --------------------------------------------------
rows_keep = mask_top.any(dim=1)
cols_keep = mask_top.any(dim=0)
if not rows_keep.any():
rows_keep[:] = True
if not cols_keep.any():
cols_keep[:] = True
vis = att2d[rows_keep][:, cols_keep]
d_tokens_vis = [t for k, t in zip(rows_keep.tolist(), d_tokens) if k]
p_tokens_vis = [t for k, t in zip(cols_keep.tolist(), p_tokens) if k]
d_indices_vis = [i for k, i in zip(rows_keep.tolist(), d_indices) if k]
p_indices_vis = [i for k, i in zip(cols_keep.tolist(), p_indices) if k]
# Cap columns for readability
if vis.size(1) > 150:
topc = torch.topk(vis.sum(0), k=150).indices
vis = vis[:, topc]
p_tokens_vis = [p_tokens_vis[i] for i in topc]
p_indices_vis = [p_indices_vis[i] for i in topc]
# --------------------------------------------------
# Plot
# --------------------------------------------------
x_labels = [f"{i}:{t}" for i, t in zip(p_indices_vis, p_tokens_vis)]
y_labels = [f"{i}:{t}" for i, t in zip(d_indices_vis, d_tokens_vis)]
fig_w = min(22, max(6, len(x_labels) * 0.6))
fig_h = min(24, max(6, len(y_labels) * 0.8))
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
im = ax.imshow(vis.numpy(), aspect="auto", cmap=cm.viridis)
title = INTERACTION_NAMES[head_idx]
suffix = "Top-K pairs" if mode == "pair" else "Top-K residues"
ax.set_title(f"Ligand Γ— Protein β€” {title} ({suffix})", fontsize=10, pad=8)
ax.set_xlabel("Protein residues")
ax.set_ylabel("Ligand tokens")
ax.set_xticks(range(len(x_labels)))
ax.set_xticklabels(x_labels, rotation=90, fontsize=8)
ax.set_yticks(range(len(y_labels)))
ax.set_yticklabels(y_labels, fontsize=7)
ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
ax.tick_params(axis="x", bottom=False)
fig.colorbar(im, fraction=0.026, pad=0.01)
fig.tight_layout()
# --------------------------------------------------
# Export
# --------------------------------------------------
buf_png = io.BytesIO()
buf_pdf = io.BytesIO()
fig.savefig(buf_png, format="png", dpi=140)
fig.savefig(buf_pdf, format="pdf")
plt.close(fig)
png_b64 = base64.b64encode(buf_png.getvalue()).decode()
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
heat_html = f"""
<div style='position:relative'>
<a href='data:application/pdf;base64,{pdf_b64}' download='attention_{head_idx+1}.pdf'
style='position:absolute;top:10px;right:10px;
background:#111;color:#fff;padding:8px 12px;
border-radius:10px;font-size:.85rem;text-decoration:none'>
Download PDF
</a>
<img src='data:image/png;base64,{png_b64}' />
</div>
"""
# ------------------------------
# Probability display card
# ------------------------------
if prob >= 0.8:
bg = "#ecfdf5"
border = "#10b981"
label = "High binding confidence"
elif prob >= 0.4:
bg = "#eff6ff"
border = "#3b82f6"
label = "Moderate binding confidence"
else:
bg = "#fef2f2"
border = "#ef4444"
label = "Low binding confidence"
prob_html = f"""
<div style='margin:10px 0 18px;
padding:14px 16px;
border-left:5px solid {border};
border-radius:12px;
background:{bg};
font-size:1rem'>
<div style='font-weight:600;margin-bottom:4px'>
Predicted Binding Probability
</div>
<div style='font-size:1.4rem;font-weight:700'>
{prob:.4f}
</div>
<div style='font-size:0.85rem;color:#64748b;margin-top:4px'>
{label}
</div>
</div>
"""
return prob_html, ranges_html, heat_html
# ───── Gradio callbacks ─────────────────────────────────────────
ROOT = os.path.dirname(os.path.abspath(__file__))
FOLDSEEK_BIN = os.path.join(ROOT, "utils", "foldseek")
def extract_aa_seq_cb(structure_file, protein_text):
"""
Extract plain amino acid sequence from uploaded PDB/mmCIF.
"""
prot_seq_out = (protein_text or "").strip()
msgs = []
if structure_file is None:
return prot_seq_out, "<p style='color:red'>Please upload a structure file.</p>"
try:
seq = simple_seq_from_structure(structure_file.name)
if seq:
prot_seq_out = seq
msgs.append("<li>βœ… Extracted <b>amino acid sequence</b> from structure.</li>")
else:
msgs.append("<li>❌ No valid amino acid sequence found.</li>")
except Exception as e:
msgs.append(f"<li>❌ Extraction failed: <b>{e}</b></li>")
status_html = (
"<div style='margin:10px 0;padding:10px 12px;"
"border:1px solid #e5e7eb;border-radius:10px;"
"background:#f8fafc;color:#0f172a'>"
"<ul style='margin:0 0 0 18px;padding:0'>"
f"{''.join(msgs)}"
"</ul></div>"
)
return prot_seq_out, status_html
def extract_sa_seq_cb(structure_file, protein_text):
prot_seq_out = (protein_text or "").strip()
msgs = []
if structure_file is None:
return prot_seq_out, "<p style='color:red'>Please upload a structure file.</p>"
try:
parsed = get_struc_seq(
FOLDSEEK_BIN,
structure_file.name,
None,
plddt_mask=False,
)
first_chain = next(iter(parsed))
_, _, struct_seq = parsed[first_chain]
if struct_seq:
prot_seq_out = struct_seq
msgs.append("<li>βœ… Extracted <b>structure-aware sequence</b> (SA).</li>")
else:
msgs.append("<li>❌ Structure parsed but no SA sequence found.</li>")
except Exception as e:
msgs.append(f"<li>❌ SA extraction failed: <b>{e}</b></li>")
status_html = (
"<div style='margin:10px 0;padding:10px 12px;"
"border:1px solid #e5e7eb;border-radius:10px;"
"background:#f8fafc;color:#0f172a'>"
"<ul style='margin:0 0 0 18px;padding:0'>"
f"{''.join(msgs)}"
"</ul></div>"
)
return prot_seq_out, status_html
def _choose_ckpt_by_types(prot_seq: str, ligand_text: str) -> Tuple[str, str, str]:
"""Return (folder_name, protein_type, ligand_type) for checkpoint routing."""
ptype = detect_protein_type(prot_seq)
ltype = detect_ligand_type(ligand_text)
folder = f"{ptype}_{ltype}" # sa_selfies / fasta_selfies / sa_smiles / fasta_smiles
return folder, ptype, ltype
def inference_cb(prot_seq, drug_seq, head_choice, topk_choice):
"""
Inference callback supporting two Top-K modes:
- Top-K interaction pairs
- Top-K residues
"""
# ------------------------------
# Input validation
# ------------------------------
if not prot_seq or not prot_seq.strip():
return "<p style='color:red'>Please extract or enter a protein sequence first.</p>"
if not drug_seq or not drug_seq.strip():
return "<p style='color:red'>Please enter a ligand sequence (SELFIES or SMILES).</p>"
prot_seq = prot_seq.strip()
drug_seq_in = drug_seq.strip()
# ------------------------------
# Detect types & checkpoint routing
# ------------------------------
folder, ptype, ltype = _choose_ckpt_by_types(prot_seq, drug_seq_in)
# Ligand normalisation: always tokenise as SELFIES
if ltype == "smiles":
conv = smiles_to_selfies(drug_seq_in)
if conv is None:
return (
"<p style='color:red'>SMILES→SELFIES conversion failed. "
"The SMILES appears invalid.</p>",
"",
)
drug_seq_for_tokenizer = conv
else:
drug_seq_for_tokenizer = drug_seq_in
# πŸ”’ εΌΊεˆΆη»ŸδΈ€η±»εž‹
ltype = "selfies"
ligand_type_flag = "selfies"
raw_selfies = drug_seq_for_tokenizer
folder = f"{ptype}_selfies"
# # Ligand normalisation: always tokenise as SELFIES
# if ltype == "smiles":
# conv = smiles_to_selfies(drug_seq_in)
# if conv is None:
# return (
# "<p style='color:red'>SMILES→SELFIES conversion failed. "
# "The SMILES appears invalid.</p>",
# "",
# )
# drug_seq_for_tokenizer = conv
# ligand_type_flag = "selfies"
# else:
# drug_seq_for_tokenizer = drug_seq_in
# ligand_type_flag = "selfies"
# raw_selfies = drug_seq_for_tokenizer if ligand_type_flag == "selfies" else None
# ------------------------------
# Load encoders
# ------------------------------
prot_tok, prot_m, drug_tok, drug_m, encoding = load_encoders(ptype, ltype, args)
loader = DataLoader(
[(prot_seq, drug_seq_for_tokenizer, 1)],
batch_size=1,
collate_fn=make_collate_fn(prot_tok, drug_tok),
)
feats = get_case_feature(encoding, loader)
# ------------------------------
# Load trained checkpoint (if exists)
# ------------------------------
ckpt = os.path.join(args.save_path_prefix, folder, "best_model.ckpt")
model = ExplainBind(1280, 768, args=args).to(DEVICE)
if os.path.isfile(ckpt):
model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
warn_html = (
"<div style='margin:8px 0 14px;padding:8px 10px;"
"border-left:4px solid #10b981;background:#ecfdf5'>"
f"<b>Loaded model:</b> <code>{folder}/best_model.ckpt</code></div>"
)
else:
warn_html = (
"<div style='margin:8px 0 14px;padding:8px 10px;"
"border-left:4px solid #f59e0b;background:#fffbeb'>"
"<b>Warning:</b> checkpoint not found "
f"<code>{folder}/best_model.ckpt</code>. "
"Using randomly initialised weights for visualisation.</div>"
)
# ------------------------------
# Parse interaction head
# ------------------------------
sel = str(head_choice).strip()
if sel in INTERACTION_NAMES:
head_idx = INTERACTION_NAMES.index(sel)
else:
try:
n = int(sel.split(".", 1)[0])
head_idx = max(0, min(len(INTERACTION_NAMES) - 1, n - 1))
except Exception:
head_idx = len(INTERACTION_NAMES) - 1 # Combined Interaction
# ------------------------------
# Parse Top-K value
# ------------------------------
try:
topk = int(str(topk_choice).strip())
except Exception:
topk = 1
topk = max(1, topk)
mode = "residue"
topk_pairs = 1
topk_residues = min(100, topk)
# ------------------------------
# Visualisation
# ------------------------------
prob_html, table_html, heat_html = visualize_attention_and_ranges(
model,
feats,
head_idx,
mode=mode,
topk_pairs=topk_pairs,
topk_residues=topk_residues,
prot_tokenizer=prot_tok,
drug_tokenizer=drug_tok,
ligand_type=ligand_type_flag,
raw_selfies=raw_selfies,
)
full_html = prob_html + table_html + heat_html # βœ… εΌΊεˆΆδΈŠδΈ‹ι‘ΊεΊ
return full_html
def clear_cb():
return "", "", "", None, ""
# protein_seq, drug_seq, output_full, structure_file, status_box
# ───── Gradio interface definition ───────────────────────────────
css = """
:root{
--bg:#f8fafc; --card:#f8fafc; --text:#0f172a;
--muted:#6b7280; --border:#e5e7eb; --shadow:0 6px 24px rgba(2,6,23,.06);
--radius:14px; --icon-size:20px;
}
*{box-sizing:border-box}
html,body{background:#fff!important;color:var(--text)!important}
.gradio-container{max-width:1120px;margin:0 auto}
/* Title and subtitle */
h1{
font-family:Inter,ui-sans-serif;letter-spacing:.2px;font-weight:700;
font-size:32px;margin:22px 0 12px;text-align:center
}
.subtle{color:var(--muted);font-size:14px;text-align:center;margin:-6px 0 18px}
/* Card style */
.card{
background:var(--card); border:1px solid var(--border); border-radius:var(--radius);
box-shadow:var(--shadow); padding:22px;
}
/* Top links */
.link-row{display:flex;justify-content:center;gap:14px;margin:0 auto 18px;flex-wrap:wrap}
/* Two-column grid: left=input, right=controls */
.grid-2{display:grid;grid-template-columns:1.4fr .9fr;gap:16px}
.grid-2 .col{display:flex;flex-direction:column;gap:12px}
/* Buttons */
.gr-button{border-radius:12px !important;font-weight:700 !important;letter-spacing:.2px}
#extract-btn{background:linear-gradient(90deg,#EFAFB2,#EFAFB2); color:#0f172a}
#inference-btn{background:linear-gradient(90deg,#B2CBDF,#B2CBDF); color:#0f172a}
#clear-btn{background:#FFE2B5; color:#0A0A0A; border:1px solid var(--border)}
/* Result spacing */
#result-table{margin-bottom:16px}
/* Figure container */
.figure-wrap{border:1px solid var(--border);border-radius:12px;overflow:hidden;box-shadow:var(--shadow)}
.figure-wrap img{display:block;width:100%;height:auto}
/* Right pane: vertical radio layout and full-width controls (kept for button styling) */
.right-pane .gr-button{
width:100% !important;
height:48px !important;
border-radius:12px !important;
font-weight:700 !important;
letter-spacing:.2px;
}
/* ───────── Publication links (Bulma-like) ───────── */
.publication-links {
display: flex;
justify-content: center;
gap: 14px;
flex-wrap: wrap;
margin: 6px 0 18px;
}
.link-block a {
display: inline-flex;
align-items: center;
gap: 8px;
padding: 10px 18px;
font-size: 14px;
font-weight: 600;
border-radius: 9999px;
text-decoration: none;
transition: all 0.15s ease-in-out;
}
/* colour variants */
.btn-danger { background:#e2e8f0; color:#0f172a; }
.btn-dark { background:#e2e8f0; color:#0f172a; }
.btn-link { background:#e2e8f0; color:#0f172a; }
.btn-warning { background:#e2e8f0; color:#0f172a; }
.link-block a:hover {
filter: brightness(0.95);
transform: translateY(-1px);
}
.loscalzo-block img {
height: 100px;
width: auto;
object-fit: contain;
}
.loscalzo-block {
display: flex;
align-items: center;
gap: 10px;
margin: 0 auto;
justify-content: center;
}
.link-btn{
display:inline-flex !important;
align-items:center !important;
gap:8px !important;
padding:10px 18px !important;
font-size:14px !important;
font-weight:600 !important;
border-radius:9999px !important;
background:#e2e8f0 !important;
color:#0f172a !important;
text-decoration:none !important;
border:1px solid #e5e7eb !important;
transition:all 0.15s ease-in-out !important;
}
.link-btn:hover{
filter:brightness(0.95);
transform:translateY(-1px);
}
.project-links{
display:flex !important;
justify-content:center !important;
gap:28px !important;
flex-wrap:wrap !important;
margin-bottom:32px !important;
}
#example-btn {
background: #979ea8 !important;
color: #1e293b !important;
}
#extract-aa-btn{
background:#DCE7F3 !important;
color:#0f172a !important;
}
#extract-sa-btn{
background:#EADCF8 !important;
color:#0f172a !important;
}
"""
with gr.Blocks() as demo:
gr.Markdown("<h1>ExplainBind: Token-level Protein–Ligand Interaction Visualiser</h1>")
# gr.HTML(f"""
# <div class="loscalzo-block">
# <img src="data:image/png;base64,{LOSCAZLO_B64}"
# alt="Loscalzo Research Group logo" />
# <a class="loscalzo-name"
# href="https://ogephd.hms.harvard.edu/people/joseph-loscalzo"
# target="_blank" rel="noopener">
# </a>
# </div>
# """)
# ───────────────────────────────
# Top links
# ───────────────────────────────
gr.HTML("""
<div class="project-links">
<a class="link-btn" href="https://zhaohanm.github.io/ExplainBind/" target="_blank" rel="noopener noreferrer" aria-label="Project Page">
<!-- globe icon -->
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
<path d="M12 2a10 10 0 1 0 10 10A10.012 10.012 0 0 0 12 2Zm7.93 9h-3.18a15.84 15.84 0 0 0-1.19-5.02A8.02 8.02 0 0 1 19.93 11ZM12 4c.86 0 2.25 1.86 3.01 6H8.99C9.75 5.86 11.14 4 12 4ZM4.07 13h3.18c.2 1.79.66 3.47 1.19 5.02A8.02 8.02 0 0 1 4.07 13Zm3.18-2H4.07A8.02 8.02 0 0 1 8.44 5.98 15.84 15.84 0 0 0 7.25 11Zm1.37 2h6.76c-.76 4.14-2.15 6-3.01 6s-2.25-1.86-3.01-6Zm9.05 0h3.18a8.02 8.02 0 0 1-4.37 5.02 15.84 15.84 0 0 0 1.19-5.02Z"/>
</svg>
Project Page
</a>
<a class="link-btn" href="https://doi.org/10.64898/2026.03.03.707476" target="_blank" rel="noopener noreferrer" aria-label="Biorxiv: 2406.01651">
<!-- arXiv-like paper icon -->
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
<path d="M6 2h9l5 5v13a2 2 0 0 1-2 2H6a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2Zm8 1.5V8h4.5L14 3.5ZM7 12h10v2H7v-2Zm0 4h10v2H7v-2Zm0-8h6v2H7V8Z"/>
</svg>
biorXiv preprint
</a>
<a class="link-btn" href="https://github.com/ZhaohanM/ExplainBind" target="_blank" rel="noopener noreferrer" aria-label="GitHub Repo">
<!-- GitHub mark -->
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
<path d="M12 .5A12 12 0 0 0 0 12.76c0 5.4 3.44 9.98 8.2 11.6.6.12.82-.28.82-.6v-2.3c-3.34.74-4.04-1.44-4.04-1.44-.54-1.38-1.32-1.74-1.32-1.74-1.08-.76.08-.74.08-.74 1.2.08 1.84 1.26 1.84 1.26 1.06 1.86 2.78 1.32 3.46 1.02.1-.8.42-1.32.76-1.62-2.66-.32-5.46-1.36-5.46-6.02 0-1.34.46-2.44 1.22-3.3-.12-.32-.54-1.64.12-3.42 0 0 1-.34 3.32 1.26.96-.28 1.98-.42 3-.42s2.04.14 3 .42c2.32-1.6 3.32-1.26 3.32-1.26.66 1.78.24 3.1.12 3.42.76.86 1.22 1.96 1.22 3.3 0 4.68-2.8 5.68-5.48 6 .44.38.84 1.12.84 2.28v3.38c0 .32.22.74.84.6A12.02 12.02 0 0 0 24 12.76 12 12 0 0 0 12 .5Z"/>
</svg>
Source code
</a>
</div>
""")
# ───────────────────────────────
# Guidelines
# ───────────────────────────────
with gr.Accordion("Guidelines for Users", open=True, elem_classes=["card"]):
gr.HTML("""
<ol style="font-size:1rem;line-height:1.6;margin-left:22px;">
<li>
<strong>Input formats:</strong>
The system supports either <em>structure-aware (SA)</em> sequences derived from
protein structures or conventional <em>FASTA</em> sequences.
For structure-based analysis, users may upload <code>.pdb</code> or
<code>.cif</code> files to extract the corresponding sequence representation.
Ligands can be provided in <em>SMILES</em> or <em>SELFIES</em> format.
</li>
<li>
<strong>Interaction channel selection:</strong>
Users may select a specific non-covalent interaction type
(e.g., hydrogen bonding, hydrophobic interactions) or the
overall interaction channel to visualise the corresponding
token-level binding patterns.
</li>
<li>
<strong>Model outputs:</strong>
The system reports (i) a predicted binding probability for the
protein–ligand pair, (ii) a ranked Top-K residue table, and (iii) a token-level interaction
heat map illustrating spatial interaction patterns.
</li>
</ol>
""")
# ───────────────────────────────
# Inputs + Controls
# ───────────────────────────────
with gr.Row():
with gr.Column(elem_classes=["card", "grid-2"]):
# ────────────────
# LEFT PANEL
# ────────────────
with gr.Column(elem_id="left"):
protein_seq = gr.Textbox(
label="Protein structure-aware / FASTA sequence",
lines=3,
placeholder="Paste SA/FASTA sequence or click Extract…",
elem_id="protein-seq",
render=False,
)
drug_seq = gr.Textbox(
label="Ligand (SELFIES / SMILES)",
lines=3,
placeholder="Paste SELFIES or SMILES",
elem_id="drug-seq",
render=False,
)
structure_file = gr.File(
label="Upload protein structure (.pdb / .cif)",
file_types=[".pdb", ".cif"],
elem_id="structure-file",
render=False,
)
with gr.Group():
gr.Markdown("### Example")
gr.Examples(
examples=[[
"SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA",
"[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]"
]],
inputs=[protein_seq, drug_seq],
label="Click to load an example",
)
btn_load_example = gr.Button(
"Load Example",
elem_id="example-btn",
# variant="secondary"
)
structure_file.render()
with gr.Row():
btn_extract_aa = gr.Button(
"Extract amino acid sequence",
elem_id="extract-aa-btn"
)
btn_extract_sa = gr.Button(
"Extract structure-aware sequence",
elem_id="extract-sa-btn"
)
protein_seq.render()
drug_seq.render()
# ────────────────
# RIGHT PANEL
# ────────────────
with gr.Column(elem_id="right", elem_classes=["right-pane"]):
head_dd = gr.Dropdown(
label="Non-covalent interaction type/Overall",
choices=INTERACTION_NAMES,
value="Overall Interaction",
interactive=True,
)
top_k_dd = gr.Dropdown(
label="Top-K residue",
choices=[str(i) for i in range(1, 21)],
value="1",
interactive=True,
)
with gr.Row():
btn_infer = gr.Button(
"Inference",
elem_id="inference-btn"
)
clear_btn = gr.Button(
"Clear",
elem_id="clear-btn"
)
# ───────────────────────────────
# Outputs
# ───────────────────────────────
with gr.Column(elem_classes=["card"]):
status_box = gr.HTML(elem_id="status-box")
output_full = gr.HTML(elem_id="result-full")
# ───────────────────────────────
# Example Loader Callback
# ───────────────────────────────
def load_example_cb():
return (
"SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA",
"[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]"
)
# ───────────────────────────────
# Wiring
# ───────────────────────────────
btn_load_example.click(
fn=load_example_cb,
inputs=[],
outputs=[protein_seq, drug_seq],
)
btn_extract_aa.click(
fn=extract_aa_seq_cb,
inputs=[structure_file, protein_seq],
outputs=[protein_seq, status_box],
)
btn_extract_sa.click(
fn=extract_sa_seq_cb,
inputs=[structure_file, protein_seq],
outputs=[protein_seq, status_box],
)
btn_infer.click(
fn=inference_cb,
inputs=[protein_seq, drug_seq, head_dd, top_k_dd],
outputs=[output_full],
)
clear_btn.click(
fn=clear_cb,
inputs=[],
outputs=[
protein_seq,
drug_seq,
output_full,
structure_file,
status_box,
],
)
demo.launch(
theme=gr.themes.Default(),
css=css,
show_error=True
)