import gradio_client.utils as _gc_utils
_orig_get_type = _gc_utils.get_type
_orig_json2py = _gc_utils._json_schema_to_python_type
def _patched_get_type(schema):
if isinstance(schema, bool):
schema = {}
return _orig_get_type(schema)
def _patched_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
schema = {}
return _orig_json2py(schema, defs)
_gc_utils.get_type = _patched_get_type
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
# ─── Imports ───────────────────────────────────────────────────────────────────
import os
import io
import base64
import argparse
from typing import Optional, List, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader
import selfies
# from rdkit import Chem
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel, AutoTokenizer
from Bio.PDB import PDBParser, MMCIFParser
from Bio.Data import IUPACData
import gradio as gr
# Project utils (ensure these exist in your repository)
from utils.metric_learning_models_att_maps import Pre_encoded, ExplainBind
from utils.foldseek_util import get_struc_seq
# ───────────────────── Paths & Logos ─────────────────────
ROOT = os.path.dirname(os.path.abspath(__file__))
ASSET_DIR = os.path.join(ROOT, "utils")
LOSCAZLO_LOGO = os.path.join(ASSET_DIR, "loscalzo.png")
def _load_logo_b64(path):
if not os.path.exists(path):
return ""
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
LOSCAZLO_B64 = _load_logo_b64(LOSCAZLO_LOGO)
# ───────────────────── Configurable constants ─────────────────────
# UI-visible names (Halogen bonding removed)
INTERACTION_NAMES = [
"Hydrogen bonding",
"Salt Bridging",
"π–π Stacking",
"Cation–π",
"Hydrophobic",
"Van der Waals",
"Overall Interaction",
]
# Map visible indices (0..5 = specific, 6 = combined) to underlying channel indices
# Underlying channels originally had Halogen at index=5 (0-based). We skip 5 entirely.
VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cation-Pi, Hydro, VdW
N_VISIBLE_SPEC = len(VISIBLE2UNDERLYING) # 6
# ───── Helper utilities ───────────────────────────────────────────
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
STANDARD_AA_SET = set("ACDEFGHIKLMNPQRSTVWY") # Uppercase FASTA amino acids
def simple_seq_from_structure(path: str) -> str:
"""Extract the longest chain and return standard 1-letter amino acid sequence."""
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
structure = parser.get_structure("P", path)
chains = list(structure.get_chains())
if not chains:
return ""
chain = max(chains, key=lambda c: len(list(c.get_residues())))
seq = []
for res in chain:
resname = res.get_resname().upper()
if resname in three2one:
seq.append(three2one[resname])
# else: skip non-standard residues
return "".join(seq)
# def smiles_to_selfies(smiles_text: str) -> Optional[str]:
# """Validate and convert SMILES to SELFIES; return None if invalid."""
# try:
# mol = Chem.MolFromSmiles(smiles_text)
# if mol is None:
# return None
# return selfies.encoder(smiles_text)
# except Exception:
# return None
def smiles_to_selfies(smiles_text: str) -> Optional[str]:
try:
sf = selfies.encoder(smiles_text)
smiles_back = selfies.decoder(sf)
if not smiles_back:
return None
return sf
except Exception:
return None
def detect_protein_type(seq: str) -> str:
"""
Heuristic for protein input:
- All uppercase and only the standard 20 amino acids → 'fasta'
- Otherwise (contains lowercase or non-standard characters) → 'sa'
"""
s = (seq or "").strip()
if not s:
return "fasta"
up = s.upper()
only_aa = all(ch in STANDARD_AA_SET for ch in up)
all_upper = (s == up)
return "fasta" if (only_aa and all_upper) else "sa"
def detect_ligand_type(text: str) -> str:
"""
Heuristic for ligand input:
- Starts with '[' and contains ']' → 'selfies'
- Otherwise → 'smiles'
"""
t = (text or "").strip()
if not t:
return "smiles"
return "selfies" if (t.startswith("[") and ("]" in t)) else "smiles"
def parse_config():
"""Parse command-line options."""
p = argparse.ArgumentParser()
p.add_argument("--agg_mode", type=str, default="mean_all_tok")
p.add_argument("--group_size", type=int, default=1)
p.add_argument("--fusion", default="CAN")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--save_path_prefix", default="save_model_ckp/") # Root folder containing checkpoints
p.add_argument("--dataset", default="Human")
return p.parse_args()
args = parse_config()
DEVICE = args.device
# ───── Dynamic model registry ─────────────────────────────────────
PROT_MODELS = {
"sa": "westlake-repl/SaProt_650M_AF2",
"fasta": "facebook/esm2_t33_650M_UR50D",
}
DRUG_MODELS = {
"selfies": "HUBioDataLab/SELFormer",
# "smiles": "ibm/MoLFormer-XL-both-10pct",
}
def load_encoders(ptype: str, ltype: str, args):
"""
Dynamically load encoders and tokenisers based on input types.
Returns: (prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding_module)
"""
# Protein encoder
if ptype == "fasta":
prot_path = PROT_MODELS["fasta"]
prot_tokenizer = EsmTokenizer.from_pretrained(prot_path, do_lower_case=False)
prot_model = EsmForMaskedLM.from_pretrained(prot_path)
else: # 'sa'
prot_path = PROT_MODELS["sa"]
prot_tokenizer = EsmTokenizer.from_pretrained(prot_path)
prot_model = EsmForMaskedLM.from_pretrained(prot_path)
drug_path = DRUG_MODELS["selfies"]
drug_tokenizer = AutoTokenizer.from_pretrained(drug_path)
drug_model = AutoModel.from_pretrained(drug_path)
# Ligand encoder
# if ltype == "smiles":
# drug_path = DRUG_MODELS["smiles"]
# drug_tokenizer = AutoTokenizer.from_pretrained(drug_path, trust_remote_code=True)
# drug_model = AutoModel.from_pretrained(drug_path, deterministic_eval=True, trust_remote_code=True)
# else: # 'selfies'
# drug_path = DRUG_MODELS["selfies"]
# drug_tokenizer = AutoTokenizer.from_pretrained(drug_path)
# drug_model = AutoModel.from_pretrained(drug_path)
# Wrap encoders with Pre_encoded module
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
return prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding
def make_collate_fn(prot_tokenizer, drug_tokenizer):
"""Create a batch collation function using the given tokenisers."""
def _collate_fn(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer(
list(query1), max_length=512, padding="max_length", truncation=True,
add_special_tokens=True, return_tensors="pt",
)
query_encodings2 = drug_tokenizer(
list(query2), max_length=512, padding="max_length", truncation=True,
add_special_tokens=True, return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return (query_encodings1["input_ids"], attention_mask1,
query_encodings2["input_ids"], attention_mask2, scores)
return _collate_fn
def get_case_feature(model, loader):
"""Generate features for one protein–ligand pair using the provided model."""
model.eval()
with torch.no_grad():
for p_ids, p_mask, d_ids, d_mask, _ in loader:
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
return [(p_emb.cpu(), d_emb.cpu(),
p_ids.cpu(), d_ids.cpu(),
p_mask.cpu(), d_mask.cpu(), None)]
# ─────────────── SELFIES grouping by ORIGINAL string ─────────────
def _group_rows_by_selfies_string(n_rows: int, selfies_str: str):
"""
Partition the attention matrix's n_rows along ligand axis into groups per SELFIES token '[ ... ]'.
Each group is a contiguous row span; we assign rows ≈ equally using linspace.
Returns:
groups: List[(start_row, end_row)] inclusive
labels: List['[X]','[=O]', ...]
"""
if n_rows <= 0:
return [], []
try:
toks = list(selfies.split_selfies((selfies_str or "").strip()))
except Exception:
toks = []
if not toks:
# Fallback: treat whole ligand as one token
return [(0, n_rows - 1)], [selfies_str or "[?]"]
g = len(toks)
edges = np.linspace(0, n_rows, g + 1, dtype=int)
groups = []
for i in range(g):
s, e = edges[i], edges[i + 1] - 1
if e < s:
e = s
groups.append((s, e))
return groups, toks
def _connected_components_2d(mask: torch.Tensor) -> List[List[Tuple[int, int]]]:
"""4-connected components over a 2D boolean mask (rows=ligand tokens, cols=protein residues)."""
h, w = mask.shape
visited = torch.zeros_like(mask, dtype=torch.bool)
comps: List[List[Tuple[int,int]]] = []
for i in range(h):
for j in range(w):
if mask[i, j] and not visited[i, j]:
stack = [(i, j)]
visited[i, j] = True
comp = []
while stack:
r, c = stack.pop()
comp.append((r, c))
for dr, dc in ((1,0), (-1,0), (0,1), (0,-1)):
rr, cc = r + dr, c + dc
if 0 <= rr < h and 0 <= cc < w and mask[rr, cc] and not visited[rr, cc]:
visited[rr, cc] = True
stack.append((rr, cc))
comps.append(comp)
return comps
def _format_component_table(
components,
p_tokens,
d_tokens,
*,
mode: str = "pair", # "pair" | "residue"
):
"""
Render HTML table for highlighted interaction components.
Parameters
----------
components : List[List[Tuple[int,int]]]
Each component is a list of (ligand_index, protein_index) pairs.
p_tokens : List[str]
Protein token strings.
d_tokens : List[str]
Ligand token strings.
mode : str
"pair" -> show Protein range + Ligand range
"residue" -> show Protein residue(s) only
"""
# ----------------------------
# Residue-only mode
# ----------------------------
if mode == "residue":
if not components:
return (
"
Highlighted protein residues
"
"No residues selected.
"
)
rows = []
for comp in components:
# comp = [(lig_idx, prot_idx), ...]
prot_indices = [j for (_, j) in comp]
p_start, p_end = min(prot_indices), max(prot_indices)
p_s_idx, p_e_idx = p_start + 1, p_end + 1
p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?"
p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?"
if p_start == p_end:
label = f"{p_s_idx}:{p_s_tok}"
else:
label = f"{p_s_idx}:{p_s_tok} – {p_e_idx}:{p_e_tok}"
rows.append(
f""
f"| "
f"{label}"
f" | "
f"
"
)
return (
"Highlighted protein residues
"
""
""
"| Protein residue(s) | "
"
"
f"{''.join(rows)}
"
)
# ----------------------------
# Pair mode (default behaviour)
# ----------------------------
if not components:
return (
"Highlighted interaction segments
"
"No interaction pairs selected.
"
)
rows = []
for comp in components:
lig_indices = [i for (i, _) in comp]
prot_indices = [j for (_, j) in comp]
d_start, d_end = min(lig_indices), max(lig_indices)
p_start, p_end = min(prot_indices), max(prot_indices)
d_s_idx, d_e_idx = d_start + 1, d_end + 1
p_s_idx, p_e_idx = p_start + 1, p_end + 1
d_s_tok = d_tokens[d_start] if d_start < len(d_tokens) else "?"
d_e_tok = d_tokens[d_end] if d_end < len(d_tokens) else "?"
p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?"
p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?"
rows.append(
f""
f"| Protein: "
f"{p_s_idx}:{p_s_tok}"
f"{' – '+str(p_e_idx)+':'+p_e_tok+'' if p_end > p_start else ''}"
f" | "
f"Ligand: "
f"{d_s_idx}:{d_s_tok}"
f"{' – '+str(d_e_idx)+':'+d_e_tok+'' if d_end > d_start else ''}"
f" | "
f"
"
)
return (
"Highlighted Binding site
"
""
""
"| Protein range | "
"Ligand range | "
"
"
f"{''.join(rows)}
"
)
def visualize_attention_and_ranges(
model,
feats,
head_idx: int,
*,
mode: str = "pair", # "pair" | "residue"
topk_pairs: int = 1, # Top-K interaction pairs (default=1)
topk_residues: int = 1, # Top-K residues (1–20, default=1)
prot_tokenizer=None,
drug_tokenizer=None,
ligand_type: str = "selfies",
raw_selfies: Optional[str] = None,
) -> Tuple[str, str]:
"""
Visualise interaction attention with two complementary Top-K modes.
Modes
-----
mode="pair":
- Select Top-K highest-scoring (ligand token, protein residue) pairs
- Project selected pairs onto protein axis (evaluation-aligned)
- Default K = 1 (user-controlled)
mode="residue":
- Aggregate attention over ligand dimension
- Rank residues by aggregated score
- Select Top-K residues (1–100)
- Default K = 1 (binding pocket discovery)
Notes
-----
- Per-head GLOBAL SUM normalisation (matches test()).
- Specific heads mapped exactly to GT channels.
- Combined head = sum of 6 specific heads (NOT overall=7).
"""
assert mode in {"pair", "residue"}
assert topk_pairs >= 1
assert 1 <= topk_residues <= 100
model.eval()
with torch.no_grad():
# --------------------------------------------------
# Unpack features
# --------------------------------------------------
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
# --------------------------------------------------
# Forward
# --------------------------------------------------
prob, att_pd = model(p_emb, d_emb, p_mask, d_mask)
att = att_pd.squeeze(0)
prob = prob.item()
# expected: [Ld, Lp, 8] or [8, Ld, Lp]
# --------------------------------------------------
# Channel mapping (must match test())
# --------------------------------------------------
VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cat-Pi, Hydro, VdW
N_VISIBLE_SPEC = 6
def select_channel_map(att_):
if att_.dim() == 3 and att_.shape[-1] >= 8:
if head_idx < N_VISIBLE_SPEC:
return att_[:, :, VISIBLE2UNDERLYING[head_idx]].cpu()
return att_[:, :, VISIBLE2UNDERLYING].sum(dim=2).cpu()
if att_.dim() == 3 and att_.shape[0] >= 8:
if head_idx < N_VISIBLE_SPEC:
return att_[VISIBLE2UNDERLYING[head_idx]].cpu()
return att_[VISIBLE2UNDERLYING].sum(dim=0).cpu()
return att_.squeeze().cpu()
att2d_raw = select_channel_map(att) # [Ld, Lp]
# --------------------------------------------------
# Per-head GLOBAL SUM normalisation (critical)
# --------------------------------------------------
att2d_raw = att2d_raw / (att2d_raw.sum() + 1e-8)
# --------------------------------------------------
# Token decoding & trimming
# --------------------------------------------------
def clean_tokens(ids, tokenizer):
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
if hasattr(tokenizer, "all_special_tokens"):
toks = [t for t in toks if t not in tokenizer.all_special_tokens]
return toks
p_tokens_full = clean_tokens(p_ids[0], prot_tokenizer)
d_tokens_full = clean_tokens(d_ids[0], drug_tokenizer)
n_d = min(len(d_tokens_full), att2d_raw.size(0))
n_p = min(len(p_tokens_full), att2d_raw.size(1))
att2d = att2d_raw[:n_d, :n_p]
p_tokens = p_tokens_full[:n_p]
d_tokens = d_tokens_full[:n_d]
p_indices = list(range(1, n_p + 1))
d_indices = list(range(1, n_d + 1))
# --------------------------------------------------
# SELFIES row merging (for interpretability)
# --------------------------------------------------
if ligand_type == "selfies" and raw_selfies:
groups, labels = _group_rows_by_selfies_string(att2d.size(0), raw_selfies)
if groups:
merged = []
for s, e in groups:
merged.append(att2d[s:e + 1].mean(dim=0, keepdim=True))
att2d = torch.cat(merged, dim=0)
d_tokens = labels
d_indices = list(range(1, len(labels) + 1))
# --------------------------------------------------
# Top-K selection (two modes, STRICT RANKING)
# --------------------------------------------------
if mode == "pair":
flat = att2d.reshape(-1)
k_eff = min(topk_pairs, flat.numel())
topk_vals, topk_idx = torch.topk(flat, k=k_eff)
mask_top = torch.zeros_like(flat, dtype=torch.bool)
mask_top[topk_idx] = True
mask_top = mask_top.view_as(att2d)
rows = []
n_cols = att2d.size(1)
for rank, (val, linear_idx) in enumerate(zip(topk_vals, topk_idx), start=1):
i = (linear_idx // n_cols).item()
j = (linear_idx % n_cols).item()
rows.append(
f""
f"| Top {rank} | "
f"Protein: {j+1}:{p_tokens[j]} | "
f"Ligand: {i+1}:{d_tokens[i]} | "
f"Score: {val.item():.6f} | "
f"
"
)
ranges_html = (
"Top-K Interaction Pairs (ranked by attention score)
"
""
""
"| Rank | "
"Protein | "
"Ligand | "
"Attention Score | "
"
"
f"{''.join(rows)}
"
)
else:
# --- STRICT Top-K residue ranking ---
residue_score = att2d.sum(dim=0)
k_eff = min(topk_residues, residue_score.numel())
topk_vals, topk_res_idx = torch.topk(residue_score, k=k_eff)
mask_top = torch.zeros_like(att2d, dtype=torch.bool)
mask_top[:, topk_res_idx] = True
rows = []
for rank, (val, j) in enumerate(zip(topk_vals, topk_res_idx), start=1):
j = j.item()
rows.append(
f""
f"| Top {rank} | "
f""
f"Protein residue: {j+1}:{p_tokens[j]}"
f" | "
f""
f"Aggregated Score: {val.item():.6f}"
f" | "
f"
"
)
ranges_html = (
"Top-K Residues (ranked by aggregated attention)
"
""
""
"| Rank | "
"Protein Residue | "
"Aggregated Score | "
"
"
f"{''.join(rows)}
"
)
# --------------------------------------------------
# Connected components (visual coherence)
# --------------------------------------------------
# p_tokens_orig = p_tokens.copy()
# d_tokens_orig = d_tokens.copy()
# components = _connected_components_2d(mask_top)
# ranges_html = _format_component_table(
# components,
# p_tokens_orig,
# d_tokens_orig,
# mode=mode,
# )
# --------------------------------------------------
# Crop to union of selected rows / columns
# --------------------------------------------------
rows_keep = mask_top.any(dim=1)
cols_keep = mask_top.any(dim=0)
if not rows_keep.any():
rows_keep[:] = True
if not cols_keep.any():
cols_keep[:] = True
vis = att2d[rows_keep][:, cols_keep]
d_tokens_vis = [t for k, t in zip(rows_keep.tolist(), d_tokens) if k]
p_tokens_vis = [t for k, t in zip(cols_keep.tolist(), p_tokens) if k]
d_indices_vis = [i for k, i in zip(rows_keep.tolist(), d_indices) if k]
p_indices_vis = [i for k, i in zip(cols_keep.tolist(), p_indices) if k]
# Cap columns for readability
if vis.size(1) > 150:
topc = torch.topk(vis.sum(0), k=150).indices
vis = vis[:, topc]
p_tokens_vis = [p_tokens_vis[i] for i in topc]
p_indices_vis = [p_indices_vis[i] for i in topc]
# --------------------------------------------------
# Plot
# --------------------------------------------------
x_labels = [f"{i}:{t}" for i, t in zip(p_indices_vis, p_tokens_vis)]
y_labels = [f"{i}:{t}" for i, t in zip(d_indices_vis, d_tokens_vis)]
fig_w = min(22, max(6, len(x_labels) * 0.6))
fig_h = min(24, max(6, len(y_labels) * 0.8))
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
im = ax.imshow(vis.numpy(), aspect="auto", cmap=cm.viridis)
title = INTERACTION_NAMES[head_idx]
suffix = "Top-K pairs" if mode == "pair" else "Top-K residues"
ax.set_title(f"Ligand × Protein — {title} ({suffix})", fontsize=10, pad=8)
ax.set_xlabel("Protein residues")
ax.set_ylabel("Ligand tokens")
ax.set_xticks(range(len(x_labels)))
ax.set_xticklabels(x_labels, rotation=90, fontsize=8)
ax.set_yticks(range(len(y_labels)))
ax.set_yticklabels(y_labels, fontsize=7)
ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
ax.tick_params(axis="x", bottom=False)
fig.colorbar(im, fraction=0.026, pad=0.01)
fig.tight_layout()
# --------------------------------------------------
# Export
# --------------------------------------------------
buf_png = io.BytesIO()
buf_pdf = io.BytesIO()
fig.savefig(buf_png, format="png", dpi=140)
fig.savefig(buf_pdf, format="pdf")
plt.close(fig)
png_b64 = base64.b64encode(buf_png.getvalue()).decode()
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
heat_html = f"""
"""
# ------------------------------
# Probability display card
# ------------------------------
if prob >= 0.8:
bg = "#ecfdf5"
border = "#10b981"
label = "High binding confidence"
elif prob >= 0.4:
bg = "#eff6ff"
border = "#3b82f6"
label = "Moderate binding confidence"
else:
bg = "#fef2f2"
border = "#ef4444"
label = "Low binding confidence"
prob_html = f"""
Predicted Binding Probability
{prob:.4f}
{label}
"""
return prob_html, ranges_html, heat_html
# ───── Gradio callbacks ─────────────────────────────────────────
ROOT = os.path.dirname(os.path.abspath(__file__))
FOLDSEEK_BIN = os.path.join(ROOT, "utils", "foldseek")
def extract_aa_seq_cb(structure_file, protein_text):
"""
Extract plain amino acid sequence from uploaded PDB/mmCIF.
"""
prot_seq_out = (protein_text or "").strip()
msgs = []
if structure_file is None:
return prot_seq_out, "Please upload a structure file.
"
try:
seq = simple_seq_from_structure(structure_file.name)
if seq:
prot_seq_out = seq
msgs.append("✅ Extracted amino acid sequence from structure.")
else:
msgs.append("❌ No valid amino acid sequence found.")
except Exception as e:
msgs.append(f"❌ Extraction failed: {e}")
status_html = (
""
)
return prot_seq_out, status_html
def extract_sa_seq_cb(structure_file, protein_text):
prot_seq_out = (protein_text or "").strip()
msgs = []
if structure_file is None:
return prot_seq_out, "Please upload a structure file.
"
try:
parsed = get_struc_seq(
FOLDSEEK_BIN,
structure_file.name,
None,
plddt_mask=False,
)
first_chain = next(iter(parsed))
_, _, struct_seq = parsed[first_chain]
if struct_seq:
prot_seq_out = struct_seq
msgs.append("✅ Extracted structure-aware sequence (SA).")
else:
msgs.append("❌ Structure parsed but no SA sequence found.")
except Exception as e:
msgs.append(f"❌ SA extraction failed: {e}")
status_html = (
""
)
return prot_seq_out, status_html
def _choose_ckpt_by_types(prot_seq: str, ligand_text: str) -> Tuple[str, str, str]:
"""Return (folder_name, protein_type, ligand_type) for checkpoint routing."""
ptype = detect_protein_type(prot_seq)
ltype = detect_ligand_type(ligand_text)
folder = f"{ptype}_{ltype}" # sa_selfies / fasta_selfies / sa_smiles / fasta_smiles
return folder, ptype, ltype
def inference_cb(prot_seq, drug_seq, head_choice, topk_choice):
"""
Inference callback supporting two Top-K modes:
- Top-K interaction pairs
- Top-K residues
"""
# ------------------------------
# Input validation
# ------------------------------
if not prot_seq or not prot_seq.strip():
return "Please extract or enter a protein sequence first.
"
if not drug_seq or not drug_seq.strip():
return "Please enter a ligand sequence (SELFIES or SMILES).
"
prot_seq = prot_seq.strip()
drug_seq_in = drug_seq.strip()
# ------------------------------
# Detect types & checkpoint routing
# ------------------------------
folder, ptype, ltype = _choose_ckpt_by_types(prot_seq, drug_seq_in)
# Ligand normalisation: always tokenise as SELFIES
if ltype == "smiles":
conv = smiles_to_selfies(drug_seq_in)
if conv is None:
return (
"SMILES→SELFIES conversion failed. "
"The SMILES appears invalid.
",
"",
)
drug_seq_for_tokenizer = conv
else:
drug_seq_for_tokenizer = drug_seq_in
# 🔒 强制统一类型
ltype = "selfies"
ligand_type_flag = "selfies"
raw_selfies = drug_seq_for_tokenizer
folder = f"{ptype}_selfies"
# # Ligand normalisation: always tokenise as SELFIES
# if ltype == "smiles":
# conv = smiles_to_selfies(drug_seq_in)
# if conv is None:
# return (
# "SMILES→SELFIES conversion failed. "
# "The SMILES appears invalid.
",
# "",
# )
# drug_seq_for_tokenizer = conv
# ligand_type_flag = "selfies"
# else:
# drug_seq_for_tokenizer = drug_seq_in
# ligand_type_flag = "selfies"
# raw_selfies = drug_seq_for_tokenizer if ligand_type_flag == "selfies" else None
# ------------------------------
# Load encoders
# ------------------------------
prot_tok, prot_m, drug_tok, drug_m, encoding = load_encoders(ptype, ltype, args)
loader = DataLoader(
[(prot_seq, drug_seq_for_tokenizer, 1)],
batch_size=1,
collate_fn=make_collate_fn(prot_tok, drug_tok),
)
feats = get_case_feature(encoding, loader)
# ------------------------------
# Load trained checkpoint (if exists)
# ------------------------------
ckpt = os.path.join(args.save_path_prefix, folder, "best_model.ckpt")
model = ExplainBind(1280, 768, args=args).to(DEVICE)
if os.path.isfile(ckpt):
model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
warn_html = (
""
f"Loaded model: {folder}/best_model.ckpt
"
)
else:
warn_html = (
""
"Warning: checkpoint not found "
f"{folder}/best_model.ckpt. "
"Using randomly initialised weights for visualisation.
"
)
# ------------------------------
# Parse interaction head
# ------------------------------
sel = str(head_choice).strip()
if sel in INTERACTION_NAMES:
head_idx = INTERACTION_NAMES.index(sel)
else:
try:
n = int(sel.split(".", 1)[0])
head_idx = max(0, min(len(INTERACTION_NAMES) - 1, n - 1))
except Exception:
head_idx = len(INTERACTION_NAMES) - 1 # Combined Interaction
# ------------------------------
# Parse Top-K value
# ------------------------------
try:
topk = int(str(topk_choice).strip())
except Exception:
topk = 1
topk = max(1, topk)
mode = "residue"
topk_pairs = 1
topk_residues = min(100, topk)
# ------------------------------
# Visualisation
# ------------------------------
prob_html, table_html, heat_html = visualize_attention_and_ranges(
model,
feats,
head_idx,
mode=mode,
topk_pairs=topk_pairs,
topk_residues=topk_residues,
prot_tokenizer=prot_tok,
drug_tokenizer=drug_tok,
ligand_type=ligand_type_flag,
raw_selfies=raw_selfies,
)
full_html = prob_html + table_html + heat_html # ✅ 强制上下顺序
return full_html
def clear_cb():
return "", "", "", None, ""
# protein_seq, drug_seq, output_full, structure_file, status_box
# ───── Gradio interface definition ───────────────────────────────
css = """
:root{
--bg:#f8fafc; --card:#f8fafc; --text:#0f172a;
--muted:#6b7280; --border:#e5e7eb; --shadow:0 6px 24px rgba(2,6,23,.06);
--radius:14px; --icon-size:20px;
}
*{box-sizing:border-box}
html,body{background:#fff!important;color:var(--text)!important}
.gradio-container{max-width:1120px;margin:0 auto}
/* Title and subtitle */
h1{
font-family:Inter,ui-sans-serif;letter-spacing:.2px;font-weight:700;
font-size:32px;margin:22px 0 12px;text-align:center
}
.subtle{color:var(--muted);font-size:14px;text-align:center;margin:-6px 0 18px}
/* Card style */
.card{
background:var(--card); border:1px solid var(--border); border-radius:var(--radius);
box-shadow:var(--shadow); padding:22px;
}
/* Top links */
.link-row{display:flex;justify-content:center;gap:14px;margin:0 auto 18px;flex-wrap:wrap}
/* Two-column grid: left=input, right=controls */
.grid-2{display:grid;grid-template-columns:1.4fr .9fr;gap:16px}
.grid-2 .col{display:flex;flex-direction:column;gap:12px}
/* Buttons */
.gr-button{border-radius:12px !important;font-weight:700 !important;letter-spacing:.2px}
#extract-btn{background:linear-gradient(90deg,#EFAFB2,#EFAFB2); color:#0f172a}
#inference-btn{background:linear-gradient(90deg,#B2CBDF,#B2CBDF); color:#0f172a}
#clear-btn{background:#FFE2B5; color:#0A0A0A; border:1px solid var(--border)}
/* Result spacing */
#result-table{margin-bottom:16px}
/* Figure container */
.figure-wrap{border:1px solid var(--border);border-radius:12px;overflow:hidden;box-shadow:var(--shadow)}
.figure-wrap img{display:block;width:100%;height:auto}
/* Right pane: vertical radio layout and full-width controls (kept for button styling) */
.right-pane .gr-button{
width:100% !important;
height:48px !important;
border-radius:12px !important;
font-weight:700 !important;
letter-spacing:.2px;
}
/* ───────── Publication links (Bulma-like) ───────── */
.publication-links {
display: flex;
justify-content: center;
gap: 14px;
flex-wrap: wrap;
margin: 6px 0 18px;
}
.link-block a {
display: inline-flex;
align-items: center;
gap: 8px;
padding: 10px 18px;
font-size: 14px;
font-weight: 600;
border-radius: 9999px;
text-decoration: none;
transition: all 0.15s ease-in-out;
}
/* colour variants */
.btn-danger { background:#e2e8f0; color:#0f172a; }
.btn-dark { background:#e2e8f0; color:#0f172a; }
.btn-link { background:#e2e8f0; color:#0f172a; }
.btn-warning { background:#e2e8f0; color:#0f172a; }
.link-block a:hover {
filter: brightness(0.95);
transform: translateY(-1px);
}
.loscalzo-block img {
height: 100px;
width: auto;
object-fit: contain;
}
.loscalzo-block {
display: flex;
align-items: center;
gap: 10px;
margin: 0 auto;
justify-content: center;
}
.link-btn{
display:inline-flex !important;
align-items:center !important;
gap:8px !important;
padding:10px 18px !important;
font-size:14px !important;
font-weight:600 !important;
border-radius:9999px !important;
background:#e2e8f0 !important;
color:#0f172a !important;
text-decoration:none !important;
border:1px solid #e5e7eb !important;
transition:all 0.15s ease-in-out !important;
}
.link-btn:hover{
filter:brightness(0.95);
transform:translateY(-1px);
}
.project-links{
display:flex !important;
justify-content:center !important;
gap:28px !important;
flex-wrap:wrap !important;
margin-bottom:32px !important;
}
#example-btn {
background: #979ea8 !important;
color: #1e293b !important;
}
#extract-aa-btn{
background:#DCE7F3 !important;
color:#0f172a !important;
}
#extract-sa-btn{
background:#EADCF8 !important;
color:#0f172a !important;
}
"""
with gr.Blocks() as demo:
gr.Markdown("ExplainBind: Token-level Protein–Ligand Interaction Visualiser
")
# gr.HTML(f"""
#
#

#
#
#
# """)
# ───────────────────────────────
# Top links
# ───────────────────────────────
gr.HTML("""
""")
# ───────────────────────────────
# Guidelines
# ───────────────────────────────
with gr.Accordion("Guidelines for Users", open=True, elem_classes=["card"]):
gr.HTML("""
-
Input formats:
The system supports either structure-aware (SA) sequences derived from
protein structures or conventional FASTA sequences.
For structure-based analysis, users may upload
.pdb or
.cif files to extract the corresponding sequence representation.
Ligands can be provided in SMILES or SELFIES format.
-
Interaction channel selection:
Users may select a specific non-covalent interaction type
(e.g., hydrogen bonding, hydrophobic interactions) or the
overall interaction channel to visualise the corresponding
token-level binding patterns.
-
Model outputs:
The system reports (i) a predicted binding probability for the
protein–ligand pair, (ii) a ranked Top-K residue table, and (iii) a token-level interaction
heat map illustrating spatial interaction patterns.
""")
# ───────────────────────────────
# Inputs + Controls
# ───────────────────────────────
with gr.Row():
with gr.Column(elem_classes=["card", "grid-2"]):
# ────────────────
# LEFT PANEL
# ────────────────
with gr.Column(elem_id="left"):
protein_seq = gr.Textbox(
label="Protein structure-aware / FASTA sequence",
lines=3,
placeholder="Paste SA/FASTA sequence or click Extract…",
elem_id="protein-seq",
render=False,
)
drug_seq = gr.Textbox(
label="Ligand (SELFIES / SMILES)",
lines=3,
placeholder="Paste SELFIES or SMILES",
elem_id="drug-seq",
render=False,
)
structure_file = gr.File(
label="Upload protein structure (.pdb / .cif)",
file_types=[".pdb", ".cif"],
elem_id="structure-file",
render=False,
)
with gr.Group():
gr.Markdown("### Example")
gr.Examples(
examples=[[
"SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA",
"[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]"
]],
inputs=[protein_seq, drug_seq],
label="Click to load an example",
)
btn_load_example = gr.Button(
"Load Example",
elem_id="example-btn",
# variant="secondary"
)
structure_file.render()
with gr.Row():
btn_extract_aa = gr.Button(
"Extract amino acid sequence",
elem_id="extract-aa-btn"
)
btn_extract_sa = gr.Button(
"Extract structure-aware sequence",
elem_id="extract-sa-btn"
)
protein_seq.render()
drug_seq.render()
# ────────────────
# RIGHT PANEL
# ────────────────
with gr.Column(elem_id="right", elem_classes=["right-pane"]):
head_dd = gr.Dropdown(
label="Non-covalent interaction type/Overall",
choices=INTERACTION_NAMES,
value="Overall Interaction",
interactive=True,
)
top_k_dd = gr.Dropdown(
label="Top-K residue",
choices=[str(i) for i in range(1, 21)],
value="1",
interactive=True,
)
with gr.Row():
btn_infer = gr.Button(
"Inference",
elem_id="inference-btn"
)
clear_btn = gr.Button(
"Clear",
elem_id="clear-btn"
)
# ───────────────────────────────
# Outputs
# ───────────────────────────────
with gr.Column(elem_classes=["card"]):
status_box = gr.HTML(elem_id="status-box")
output_full = gr.HTML(elem_id="result-full")
# ───────────────────────────────
# Example Loader Callback
# ───────────────────────────────
def load_example_cb():
return (
"SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA",
"[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]"
)
# ───────────────────────────────
# Wiring
# ───────────────────────────────
btn_load_example.click(
fn=load_example_cb,
inputs=[],
outputs=[protein_seq, drug_seq],
)
btn_extract_aa.click(
fn=extract_aa_seq_cb,
inputs=[structure_file, protein_seq],
outputs=[protein_seq, status_box],
)
btn_extract_sa.click(
fn=extract_sa_seq_cb,
inputs=[structure_file, protein_seq],
outputs=[protein_seq, status_box],
)
btn_infer.click(
fn=inference_cb,
inputs=[protein_seq, drug_seq, head_dd, top_k_dd],
outputs=[output_full],
)
clear_btn.click(
fn=clear_cb,
inputs=[],
outputs=[
protein_seq,
drug_seq,
output_full,
structure_file,
status_box,
],
)
demo.launch(
theme=gr.themes.Default(),
css=css,
show_error=True
)