import gradio_client.utils as _gc_utils _orig_get_type = _gc_utils.get_type _orig_json2py = _gc_utils._json_schema_to_python_type def _patched_get_type(schema): if isinstance(schema, bool): schema = {} return _orig_get_type(schema) def _patched_json_schema_to_python_type(schema, defs=None): if isinstance(schema, bool): schema = {} return _orig_json2py(schema, defs) _gc_utils.get_type = _patched_get_type _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type # ─── Imports ─────────────────────────────────────────────────────────────────── import os import io import base64 import argparse from typing import Optional, List, Tuple import numpy as np import torch from torch.utils.data import DataLoader import selfies # from rdkit import Chem import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib import cm from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel, AutoTokenizer from Bio.PDB import PDBParser, MMCIFParser from Bio.Data import IUPACData import gradio as gr # Project utils (ensure these exist in your repository) from utils.metric_learning_models_att_maps import Pre_encoded, ExplainBind from utils.foldseek_util import get_struc_seq # ───────────────────── Paths & Logos ───────────────────── ROOT = os.path.dirname(os.path.abspath(__file__)) ASSET_DIR = os.path.join(ROOT, "utils") LOSCAZLO_LOGO = os.path.join(ASSET_DIR, "loscalzo.png") def _load_logo_b64(path): if not os.path.exists(path): return "" with open(path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") LOSCAZLO_B64 = _load_logo_b64(LOSCAZLO_LOGO) # ───────────────────── Configurable constants ───────────────────── # UI-visible names (Halogen bonding removed) INTERACTION_NAMES = [ "Hydrogen bonding", "Salt Bridging", "π–π Stacking", "Cation–π", "Hydrophobic", "Van der Waals", "Overall Interaction", ] # Map visible indices (0..5 = specific, 6 = combined) to underlying channel indices # Underlying channels originally had Halogen at index=5 (0-based). We skip 5 entirely. VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cation-Pi, Hydro, VdW N_VISIBLE_SPEC = len(VISIBLE2UNDERLYING) # 6 # ───── Helper utilities ─────────────────────────────────────────── three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()} three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"}) STANDARD_AA_SET = set("ACDEFGHIKLMNPQRSTVWY") # Uppercase FASTA amino acids def simple_seq_from_structure(path: str) -> str: """Extract the longest chain and return standard 1-letter amino acid sequence.""" parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True) structure = parser.get_structure("P", path) chains = list(structure.get_chains()) if not chains: return "" chain = max(chains, key=lambda c: len(list(c.get_residues()))) seq = [] for res in chain: resname = res.get_resname().upper() if resname in three2one: seq.append(three2one[resname]) # else: skip non-standard residues return "".join(seq) # def smiles_to_selfies(smiles_text: str) -> Optional[str]: # """Validate and convert SMILES to SELFIES; return None if invalid.""" # try: # mol = Chem.MolFromSmiles(smiles_text) # if mol is None: # return None # return selfies.encoder(smiles_text) # except Exception: # return None def smiles_to_selfies(smiles_text: str) -> Optional[str]: try: sf = selfies.encoder(smiles_text) smiles_back = selfies.decoder(sf) if not smiles_back: return None return sf except Exception: return None def detect_protein_type(seq: str) -> str: """ Heuristic for protein input: - All uppercase and only the standard 20 amino acids → 'fasta' - Otherwise (contains lowercase or non-standard characters) → 'sa' """ s = (seq or "").strip() if not s: return "fasta" up = s.upper() only_aa = all(ch in STANDARD_AA_SET for ch in up) all_upper = (s == up) return "fasta" if (only_aa and all_upper) else "sa" def detect_ligand_type(text: str) -> str: """ Heuristic for ligand input: - Starts with '[' and contains ']' → 'selfies' - Otherwise → 'smiles' """ t = (text or "").strip() if not t: return "smiles" return "selfies" if (t.startswith("[") and ("]" in t)) else "smiles" def parse_config(): """Parse command-line options.""" p = argparse.ArgumentParser() p.add_argument("--agg_mode", type=str, default="mean_all_tok") p.add_argument("--group_size", type=int, default=1) p.add_argument("--fusion", default="CAN") p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--save_path_prefix", default="save_model_ckp/") # Root folder containing checkpoints p.add_argument("--dataset", default="Human") return p.parse_args() args = parse_config() DEVICE = args.device # ───── Dynamic model registry ───────────────────────────────────── PROT_MODELS = { "sa": "westlake-repl/SaProt_650M_AF2", "fasta": "facebook/esm2_t33_650M_UR50D", } DRUG_MODELS = { "selfies": "HUBioDataLab/SELFormer", # "smiles": "ibm/MoLFormer-XL-both-10pct", } def load_encoders(ptype: str, ltype: str, args): """ Dynamically load encoders and tokenisers based on input types. Returns: (prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding_module) """ # Protein encoder if ptype == "fasta": prot_path = PROT_MODELS["fasta"] prot_tokenizer = EsmTokenizer.from_pretrained(prot_path, do_lower_case=False) prot_model = EsmForMaskedLM.from_pretrained(prot_path) else: # 'sa' prot_path = PROT_MODELS["sa"] prot_tokenizer = EsmTokenizer.from_pretrained(prot_path) prot_model = EsmForMaskedLM.from_pretrained(prot_path) drug_path = DRUG_MODELS["selfies"] drug_tokenizer = AutoTokenizer.from_pretrained(drug_path) drug_model = AutoModel.from_pretrained(drug_path) # Ligand encoder # if ltype == "smiles": # drug_path = DRUG_MODELS["smiles"] # drug_tokenizer = AutoTokenizer.from_pretrained(drug_path, trust_remote_code=True) # drug_model = AutoModel.from_pretrained(drug_path, deterministic_eval=True, trust_remote_code=True) # else: # 'selfies' # drug_path = DRUG_MODELS["selfies"] # drug_tokenizer = AutoTokenizer.from_pretrained(drug_path) # drug_model = AutoModel.from_pretrained(drug_path) # Wrap encoders with Pre_encoded module encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE) return prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding def make_collate_fn(prot_tokenizer, drug_tokenizer): """Create a batch collation function using the given tokenisers.""" def _collate_fn(batch): query1, query2, scores = zip(*batch) query_encodings1 = prot_tokenizer( list(query1), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) query_encodings2 = drug_tokenizer( list(query2), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) scores = torch.tensor(list(scores)) attention_mask1 = query_encodings1["attention_mask"].bool() attention_mask2 = query_encodings2["attention_mask"].bool() return (query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores) return _collate_fn def get_case_feature(model, loader): """Generate features for one protein–ligand pair using the provided model.""" model.eval() with torch.no_grad(): for p_ids, p_mask, d_ids, d_mask, _ in loader: p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE) d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE) p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask) return [(p_emb.cpu(), d_emb.cpu(), p_ids.cpu(), d_ids.cpu(), p_mask.cpu(), d_mask.cpu(), None)] # ─────────────── SELFIES grouping by ORIGINAL string ───────────── def _group_rows_by_selfies_string(n_rows: int, selfies_str: str): """ Partition the attention matrix's n_rows along ligand axis into groups per SELFIES token '[ ... ]'. Each group is a contiguous row span; we assign rows ≈ equally using linspace. Returns: groups: List[(start_row, end_row)] inclusive labels: List['[X]','[=O]', ...] """ if n_rows <= 0: return [], [] try: toks = list(selfies.split_selfies((selfies_str or "").strip())) except Exception: toks = [] if not toks: # Fallback: treat whole ligand as one token return [(0, n_rows - 1)], [selfies_str or "[?]"] g = len(toks) edges = np.linspace(0, n_rows, g + 1, dtype=int) groups = [] for i in range(g): s, e = edges[i], edges[i + 1] - 1 if e < s: e = s groups.append((s, e)) return groups, toks def _connected_components_2d(mask: torch.Tensor) -> List[List[Tuple[int, int]]]: """4-connected components over a 2D boolean mask (rows=ligand tokens, cols=protein residues).""" h, w = mask.shape visited = torch.zeros_like(mask, dtype=torch.bool) comps: List[List[Tuple[int,int]]] = [] for i in range(h): for j in range(w): if mask[i, j] and not visited[i, j]: stack = [(i, j)] visited[i, j] = True comp = [] while stack: r, c = stack.pop() comp.append((r, c)) for dr, dc in ((1,0), (-1,0), (0,1), (0,-1)): rr, cc = r + dr, c + dc if 0 <= rr < h and 0 <= cc < w and mask[rr, cc] and not visited[rr, cc]: visited[rr, cc] = True stack.append((rr, cc)) comps.append(comp) return comps def _format_component_table( components, p_tokens, d_tokens, *, mode: str = "pair", # "pair" | "residue" ): """ Render HTML table for highlighted interaction components. Parameters ---------- components : List[List[Tuple[int,int]]] Each component is a list of (ligand_index, protein_index) pairs. p_tokens : List[str] Protein token strings. d_tokens : List[str] Ligand token strings. mode : str "pair" -> show Protein range + Ligand range "residue" -> show Protein residue(s) only """ # ---------------------------- # Residue-only mode # ---------------------------- if mode == "residue": if not components: return ( "

Highlighted protein residues

" "

No residues selected.

" ) rows = [] for comp in components: # comp = [(lig_idx, prot_idx), ...] prot_indices = [j for (_, j) in comp] p_start, p_end = min(prot_indices), max(prot_indices) p_s_idx, p_e_idx = p_start + 1, p_end + 1 p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?" p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?" if p_start == p_end: label = f"{p_s_idx}:{p_s_tok}" else: label = f"{p_s_idx}:{p_s_tok} – {p_e_idx}:{p_e_tok}" rows.append( f"" f"" f"{label}" f"" f"" ) return ( "

Highlighted protein residues

" "" "" "" "" f"{''.join(rows)}
Protein residue(s)
" ) # ---------------------------- # Pair mode (default behaviour) # ---------------------------- if not components: return ( "

Highlighted interaction segments

" "

No interaction pairs selected.

" ) rows = [] for comp in components: lig_indices = [i for (i, _) in comp] prot_indices = [j for (_, j) in comp] d_start, d_end = min(lig_indices), max(lig_indices) p_start, p_end = min(prot_indices), max(prot_indices) d_s_idx, d_e_idx = d_start + 1, d_end + 1 p_s_idx, p_e_idx = p_start + 1, p_end + 1 d_s_tok = d_tokens[d_start] if d_start < len(d_tokens) else "?" d_e_tok = d_tokens[d_end] if d_end < len(d_tokens) else "?" p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?" p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?" rows.append( f"" f"Protein: " f"{p_s_idx}:{p_s_tok}" f"{' – '+str(p_e_idx)+':'+p_e_tok+'' if p_end > p_start else ''}" f"" f"Ligand: " f"{d_s_idx}:{d_s_tok}" f"{' – '+str(d_e_idx)+':'+d_e_tok+'' if d_end > d_start else ''}" f"" f"" ) return ( "

Highlighted Binding site

" "" "" "" "" "" f"{''.join(rows)}
Protein rangeLigand range
" ) def visualize_attention_and_ranges( model, feats, head_idx: int, *, mode: str = "pair", # "pair" | "residue" topk_pairs: int = 1, # Top-K interaction pairs (default=1) topk_residues: int = 1, # Top-K residues (1–20, default=1) prot_tokenizer=None, drug_tokenizer=None, ligand_type: str = "selfies", raw_selfies: Optional[str] = None, ) -> Tuple[str, str]: """ Visualise interaction attention with two complementary Top-K modes. Modes ----- mode="pair": - Select Top-K highest-scoring (ligand token, protein residue) pairs - Project selected pairs onto protein axis (evaluation-aligned) - Default K = 1 (user-controlled) mode="residue": - Aggregate attention over ligand dimension - Rank residues by aggregated score - Select Top-K residues (1–100) - Default K = 1 (binding pocket discovery) Notes ----- - Per-head GLOBAL SUM normalisation (matches test()). - Specific heads mapped exactly to GT channels. - Combined head = sum of 6 specific heads (NOT overall=7). """ assert mode in {"pair", "residue"} assert topk_pairs >= 1 assert 1 <= topk_residues <= 100 model.eval() with torch.no_grad(): # -------------------------------------------------- # Unpack features # -------------------------------------------------- p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0] p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE) p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE) # -------------------------------------------------- # Forward # -------------------------------------------------- prob, att_pd = model(p_emb, d_emb, p_mask, d_mask) att = att_pd.squeeze(0) prob = prob.item() # expected: [Ld, Lp, 8] or [8, Ld, Lp] # -------------------------------------------------- # Channel mapping (must match test()) # -------------------------------------------------- VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cat-Pi, Hydro, VdW N_VISIBLE_SPEC = 6 def select_channel_map(att_): if att_.dim() == 3 and att_.shape[-1] >= 8: if head_idx < N_VISIBLE_SPEC: return att_[:, :, VISIBLE2UNDERLYING[head_idx]].cpu() return att_[:, :, VISIBLE2UNDERLYING].sum(dim=2).cpu() if att_.dim() == 3 and att_.shape[0] >= 8: if head_idx < N_VISIBLE_SPEC: return att_[VISIBLE2UNDERLYING[head_idx]].cpu() return att_[VISIBLE2UNDERLYING].sum(dim=0).cpu() return att_.squeeze().cpu() att2d_raw = select_channel_map(att) # [Ld, Lp] # -------------------------------------------------- # Per-head GLOBAL SUM normalisation (critical) # -------------------------------------------------- att2d_raw = att2d_raw / (att2d_raw.sum() + 1e-8) # -------------------------------------------------- # Token decoding & trimming # -------------------------------------------------- def clean_tokens(ids, tokenizer): toks = tokenizer.convert_ids_to_tokens(ids.tolist()) if hasattr(tokenizer, "all_special_tokens"): toks = [t for t in toks if t not in tokenizer.all_special_tokens] return toks p_tokens_full = clean_tokens(p_ids[0], prot_tokenizer) d_tokens_full = clean_tokens(d_ids[0], drug_tokenizer) n_d = min(len(d_tokens_full), att2d_raw.size(0)) n_p = min(len(p_tokens_full), att2d_raw.size(1)) att2d = att2d_raw[:n_d, :n_p] p_tokens = p_tokens_full[:n_p] d_tokens = d_tokens_full[:n_d] p_indices = list(range(1, n_p + 1)) d_indices = list(range(1, n_d + 1)) # -------------------------------------------------- # SELFIES row merging (for interpretability) # -------------------------------------------------- if ligand_type == "selfies" and raw_selfies: groups, labels = _group_rows_by_selfies_string(att2d.size(0), raw_selfies) if groups: merged = [] for s, e in groups: merged.append(att2d[s:e + 1].mean(dim=0, keepdim=True)) att2d = torch.cat(merged, dim=0) d_tokens = labels d_indices = list(range(1, len(labels) + 1)) # -------------------------------------------------- # Top-K selection (two modes, STRICT RANKING) # -------------------------------------------------- if mode == "pair": flat = att2d.reshape(-1) k_eff = min(topk_pairs, flat.numel()) topk_vals, topk_idx = torch.topk(flat, k=k_eff) mask_top = torch.zeros_like(flat, dtype=torch.bool) mask_top[topk_idx] = True mask_top = mask_top.view_as(att2d) rows = [] n_cols = att2d.size(1) for rank, (val, linear_idx) in enumerate(zip(topk_vals, topk_idx), start=1): i = (linear_idx // n_cols).item() j = (linear_idx % n_cols).item() rows.append( f"" f"Top {rank}" f"Protein: {j+1}:{p_tokens[j]}" f"Ligand: {i+1}:{d_tokens[i]}" f"Score: {val.item():.6f}" f"" ) ranges_html = ( "

Top-K Interaction Pairs (ranked by attention score)

" "" "" "" "" "" "" "" f"{''.join(rows)}
RankProteinLigandAttention Score
" ) else: # --- STRICT Top-K residue ranking --- residue_score = att2d.sum(dim=0) k_eff = min(topk_residues, residue_score.numel()) topk_vals, topk_res_idx = torch.topk(residue_score, k=k_eff) mask_top = torch.zeros_like(att2d, dtype=torch.bool) mask_top[:, topk_res_idx] = True rows = [] for rank, (val, j) in enumerate(zip(topk_vals, topk_res_idx), start=1): j = j.item() rows.append( f"" f"Top {rank}" f"" f"Protein residue: {j+1}:{p_tokens[j]}" f"" f"" f"Aggregated Score: {val.item():.6f}" f"" f"" ) ranges_html = ( "

Top-K Residues (ranked by aggregated attention)

" "" "" "" "" "" "" f"{''.join(rows)}
RankProtein ResidueAggregated Score
" ) # -------------------------------------------------- # Connected components (visual coherence) # -------------------------------------------------- # p_tokens_orig = p_tokens.copy() # d_tokens_orig = d_tokens.copy() # components = _connected_components_2d(mask_top) # ranges_html = _format_component_table( # components, # p_tokens_orig, # d_tokens_orig, # mode=mode, # ) # -------------------------------------------------- # Crop to union of selected rows / columns # -------------------------------------------------- rows_keep = mask_top.any(dim=1) cols_keep = mask_top.any(dim=0) if not rows_keep.any(): rows_keep[:] = True if not cols_keep.any(): cols_keep[:] = True vis = att2d[rows_keep][:, cols_keep] d_tokens_vis = [t for k, t in zip(rows_keep.tolist(), d_tokens) if k] p_tokens_vis = [t for k, t in zip(cols_keep.tolist(), p_tokens) if k] d_indices_vis = [i for k, i in zip(rows_keep.tolist(), d_indices) if k] p_indices_vis = [i for k, i in zip(cols_keep.tolist(), p_indices) if k] # Cap columns for readability if vis.size(1) > 150: topc = torch.topk(vis.sum(0), k=150).indices vis = vis[:, topc] p_tokens_vis = [p_tokens_vis[i] for i in topc] p_indices_vis = [p_indices_vis[i] for i in topc] # -------------------------------------------------- # Plot # -------------------------------------------------- x_labels = [f"{i}:{t}" for i, t in zip(p_indices_vis, p_tokens_vis)] y_labels = [f"{i}:{t}" for i, t in zip(d_indices_vis, d_tokens_vis)] fig_w = min(22, max(6, len(x_labels) * 0.6)) fig_h = min(24, max(6, len(y_labels) * 0.8)) fig, ax = plt.subplots(figsize=(fig_w, fig_h)) im = ax.imshow(vis.numpy(), aspect="auto", cmap=cm.viridis) title = INTERACTION_NAMES[head_idx] suffix = "Top-K pairs" if mode == "pair" else "Top-K residues" ax.set_title(f"Ligand × Protein — {title} ({suffix})", fontsize=10, pad=8) ax.set_xlabel("Protein residues") ax.set_ylabel("Ligand tokens") ax.set_xticks(range(len(x_labels))) ax.set_xticklabels(x_labels, rotation=90, fontsize=8) ax.set_yticks(range(len(y_labels))) ax.set_yticklabels(y_labels, fontsize=7) ax.xaxis.tick_top() ax.xaxis.set_label_position("top") ax.tick_params(axis="x", bottom=False) fig.colorbar(im, fraction=0.026, pad=0.01) fig.tight_layout() # -------------------------------------------------- # Export # -------------------------------------------------- buf_png = io.BytesIO() buf_pdf = io.BytesIO() fig.savefig(buf_png, format="png", dpi=140) fig.savefig(buf_pdf, format="pdf") plt.close(fig) png_b64 = base64.b64encode(buf_png.getvalue()).decode() pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode() heat_html = f"""
Download PDF
""" # ------------------------------ # Probability display card # ------------------------------ if prob >= 0.8: bg = "#ecfdf5" border = "#10b981" label = "High binding confidence" elif prob >= 0.4: bg = "#eff6ff" border = "#3b82f6" label = "Moderate binding confidence" else: bg = "#fef2f2" border = "#ef4444" label = "Low binding confidence" prob_html = f"""
Predicted Binding Probability
{prob:.4f}
{label}
""" return prob_html, ranges_html, heat_html # ───── Gradio callbacks ───────────────────────────────────────── ROOT = os.path.dirname(os.path.abspath(__file__)) FOLDSEEK_BIN = os.path.join(ROOT, "utils", "foldseek") def extract_aa_seq_cb(structure_file, protein_text): """ Extract plain amino acid sequence from uploaded PDB/mmCIF. """ prot_seq_out = (protein_text or "").strip() msgs = [] if structure_file is None: return prot_seq_out, "

Please upload a structure file.

" try: seq = simple_seq_from_structure(structure_file.name) if seq: prot_seq_out = seq msgs.append("
  • ✅ Extracted amino acid sequence from structure.
  • ") else: msgs.append("
  • ❌ No valid amino acid sequence found.
  • ") except Exception as e: msgs.append(f"
  • ❌ Extraction failed: {e}
  • ") status_html = ( "
    " "
    " ) return prot_seq_out, status_html def extract_sa_seq_cb(structure_file, protein_text): prot_seq_out = (protein_text or "").strip() msgs = [] if structure_file is None: return prot_seq_out, "

    Please upload a structure file.

    " try: parsed = get_struc_seq( FOLDSEEK_BIN, structure_file.name, None, plddt_mask=False, ) first_chain = next(iter(parsed)) _, _, struct_seq = parsed[first_chain] if struct_seq: prot_seq_out = struct_seq msgs.append("
  • ✅ Extracted structure-aware sequence (SA).
  • ") else: msgs.append("
  • ❌ Structure parsed but no SA sequence found.
  • ") except Exception as e: msgs.append(f"
  • ❌ SA extraction failed: {e}
  • ") status_html = ( "
    " "
    " ) return prot_seq_out, status_html def _choose_ckpt_by_types(prot_seq: str, ligand_text: str) -> Tuple[str, str, str]: """Return (folder_name, protein_type, ligand_type) for checkpoint routing.""" ptype = detect_protein_type(prot_seq) ltype = detect_ligand_type(ligand_text) folder = f"{ptype}_{ltype}" # sa_selfies / fasta_selfies / sa_smiles / fasta_smiles return folder, ptype, ltype def inference_cb(prot_seq, drug_seq, head_choice, topk_choice): """ Inference callback supporting two Top-K modes: - Top-K interaction pairs - Top-K residues """ # ------------------------------ # Input validation # ------------------------------ if not prot_seq or not prot_seq.strip(): return "

    Please extract or enter a protein sequence first.

    " if not drug_seq or not drug_seq.strip(): return "

    Please enter a ligand sequence (SELFIES or SMILES).

    " prot_seq = prot_seq.strip() drug_seq_in = drug_seq.strip() # ------------------------------ # Detect types & checkpoint routing # ------------------------------ folder, ptype, ltype = _choose_ckpt_by_types(prot_seq, drug_seq_in) # Ligand normalisation: always tokenise as SELFIES if ltype == "smiles": conv = smiles_to_selfies(drug_seq_in) if conv is None: return ( "

    SMILES→SELFIES conversion failed. " "The SMILES appears invalid.

    ", "", ) drug_seq_for_tokenizer = conv else: drug_seq_for_tokenizer = drug_seq_in # 🔒 强制统一类型 ltype = "selfies" ligand_type_flag = "selfies" raw_selfies = drug_seq_for_tokenizer folder = f"{ptype}_selfies" # # Ligand normalisation: always tokenise as SELFIES # if ltype == "smiles": # conv = smiles_to_selfies(drug_seq_in) # if conv is None: # return ( # "

    SMILES→SELFIES conversion failed. " # "The SMILES appears invalid.

    ", # "", # ) # drug_seq_for_tokenizer = conv # ligand_type_flag = "selfies" # else: # drug_seq_for_tokenizer = drug_seq_in # ligand_type_flag = "selfies" # raw_selfies = drug_seq_for_tokenizer if ligand_type_flag == "selfies" else None # ------------------------------ # Load encoders # ------------------------------ prot_tok, prot_m, drug_tok, drug_m, encoding = load_encoders(ptype, ltype, args) loader = DataLoader( [(prot_seq, drug_seq_for_tokenizer, 1)], batch_size=1, collate_fn=make_collate_fn(prot_tok, drug_tok), ) feats = get_case_feature(encoding, loader) # ------------------------------ # Load trained checkpoint (if exists) # ------------------------------ ckpt = os.path.join(args.save_path_prefix, folder, "best_model.ckpt") model = ExplainBind(1280, 768, args=args).to(DEVICE) if os.path.isfile(ckpt): model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) warn_html = ( "
    " f"Loaded model: {folder}/best_model.ckpt
    " ) else: warn_html = ( "
    " "Warning: checkpoint not found " f"{folder}/best_model.ckpt. " "Using randomly initialised weights for visualisation.
    " ) # ------------------------------ # Parse interaction head # ------------------------------ sel = str(head_choice).strip() if sel in INTERACTION_NAMES: head_idx = INTERACTION_NAMES.index(sel) else: try: n = int(sel.split(".", 1)[0]) head_idx = max(0, min(len(INTERACTION_NAMES) - 1, n - 1)) except Exception: head_idx = len(INTERACTION_NAMES) - 1 # Combined Interaction # ------------------------------ # Parse Top-K value # ------------------------------ try: topk = int(str(topk_choice).strip()) except Exception: topk = 1 topk = max(1, topk) mode = "residue" topk_pairs = 1 topk_residues = min(100, topk) # ------------------------------ # Visualisation # ------------------------------ prob_html, table_html, heat_html = visualize_attention_and_ranges( model, feats, head_idx, mode=mode, topk_pairs=topk_pairs, topk_residues=topk_residues, prot_tokenizer=prot_tok, drug_tokenizer=drug_tok, ligand_type=ligand_type_flag, raw_selfies=raw_selfies, ) full_html = prob_html + table_html + heat_html # ✅ 强制上下顺序 return full_html def clear_cb(): return "", "", "", None, "" # protein_seq, drug_seq, output_full, structure_file, status_box # ───── Gradio interface definition ─────────────────────────────── css = """ :root{ --bg:#f8fafc; --card:#f8fafc; --text:#0f172a; --muted:#6b7280; --border:#e5e7eb; --shadow:0 6px 24px rgba(2,6,23,.06); --radius:14px; --icon-size:20px; } *{box-sizing:border-box} html,body{background:#fff!important;color:var(--text)!important} .gradio-container{max-width:1120px;margin:0 auto} /* Title and subtitle */ h1{ font-family:Inter,ui-sans-serif;letter-spacing:.2px;font-weight:700; font-size:32px;margin:22px 0 12px;text-align:center } .subtle{color:var(--muted);font-size:14px;text-align:center;margin:-6px 0 18px} /* Card style */ .card{ background:var(--card); border:1px solid var(--border); border-radius:var(--radius); box-shadow:var(--shadow); padding:22px; } /* Top links */ .link-row{display:flex;justify-content:center;gap:14px;margin:0 auto 18px;flex-wrap:wrap} /* Two-column grid: left=input, right=controls */ .grid-2{display:grid;grid-template-columns:1.4fr .9fr;gap:16px} .grid-2 .col{display:flex;flex-direction:column;gap:12px} /* Buttons */ .gr-button{border-radius:12px !important;font-weight:700 !important;letter-spacing:.2px} #extract-btn{background:linear-gradient(90deg,#EFAFB2,#EFAFB2); color:#0f172a} #inference-btn{background:linear-gradient(90deg,#B2CBDF,#B2CBDF); color:#0f172a} #clear-btn{background:#FFE2B5; color:#0A0A0A; border:1px solid var(--border)} /* Result spacing */ #result-table{margin-bottom:16px} /* Figure container */ .figure-wrap{border:1px solid var(--border);border-radius:12px;overflow:hidden;box-shadow:var(--shadow)} .figure-wrap img{display:block;width:100%;height:auto} /* Right pane: vertical radio layout and full-width controls (kept for button styling) */ .right-pane .gr-button{ width:100% !important; height:48px !important; border-radius:12px !important; font-weight:700 !important; letter-spacing:.2px; } /* ───────── Publication links (Bulma-like) ───────── */ .publication-links { display: flex; justify-content: center; gap: 14px; flex-wrap: wrap; margin: 6px 0 18px; } .link-block a { display: inline-flex; align-items: center; gap: 8px; padding: 10px 18px; font-size: 14px; font-weight: 600; border-radius: 9999px; text-decoration: none; transition: all 0.15s ease-in-out; } /* colour variants */ .btn-danger { background:#e2e8f0; color:#0f172a; } .btn-dark { background:#e2e8f0; color:#0f172a; } .btn-link { background:#e2e8f0; color:#0f172a; } .btn-warning { background:#e2e8f0; color:#0f172a; } .link-block a:hover { filter: brightness(0.95); transform: translateY(-1px); } .loscalzo-block img { height: 100px; width: auto; object-fit: contain; } .loscalzo-block { display: flex; align-items: center; gap: 10px; margin: 0 auto; justify-content: center; } .link-btn{ display:inline-flex !important; align-items:center !important; gap:8px !important; padding:10px 18px !important; font-size:14px !important; font-weight:600 !important; border-radius:9999px !important; background:#e2e8f0 !important; color:#0f172a !important; text-decoration:none !important; border:1px solid #e5e7eb !important; transition:all 0.15s ease-in-out !important; } .link-btn:hover{ filter:brightness(0.95); transform:translateY(-1px); } .project-links{ display:flex !important; justify-content:center !important; gap:28px !important; flex-wrap:wrap !important; margin-bottom:32px !important; } #example-btn { background: #979ea8 !important; color: #1e293b !important; } #extract-aa-btn{ background:#DCE7F3 !important; color:#0f172a !important; } #extract-sa-btn{ background:#EADCF8 !important; color:#0f172a !important; } """ with gr.Blocks() as demo: gr.Markdown("

    ExplainBind: Token-level Protein–Ligand Interaction Visualiser

    ") # gr.HTML(f""" #
    # Loscalzo Research Group logo # # #
    # """) # ─────────────────────────────── # Top links # ─────────────────────────────── gr.HTML(""" """) # ─────────────────────────────── # Guidelines # ─────────────────────────────── with gr.Accordion("Guidelines for Users", open=True, elem_classes=["card"]): gr.HTML("""
    1. Input formats: The system supports either structure-aware (SA) sequences derived from protein structures or conventional FASTA sequences. For structure-based analysis, users may upload .pdb or .cif files to extract the corresponding sequence representation. Ligands can be provided in SMILES or SELFIES format.
    2. Interaction channel selection: Users may select a specific non-covalent interaction type (e.g., hydrogen bonding, hydrophobic interactions) or the overall interaction channel to visualise the corresponding token-level binding patterns.
    3. Model outputs: The system reports (i) a predicted binding probability for the protein–ligand pair, (ii) a ranked Top-K residue table, and (iii) a token-level interaction heat map illustrating spatial interaction patterns.
    """) # ─────────────────────────────── # Inputs + Controls # ─────────────────────────────── with gr.Row(): with gr.Column(elem_classes=["card", "grid-2"]): # ──────────────── # LEFT PANEL # ──────────────── with gr.Column(elem_id="left"): protein_seq = gr.Textbox( label="Protein structure-aware / FASTA sequence", lines=3, placeholder="Paste SA/FASTA sequence or click Extract…", elem_id="protein-seq", render=False, ) drug_seq = gr.Textbox( label="Ligand (SELFIES / SMILES)", lines=3, placeholder="Paste SELFIES or SMILES", elem_id="drug-seq", render=False, ) structure_file = gr.File( label="Upload protein structure (.pdb / .cif)", file_types=[".pdb", ".cif"], elem_id="structure-file", render=False, ) with gr.Group(): gr.Markdown("### Example") gr.Examples( examples=[[ "SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA", "[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]" ]], inputs=[protein_seq, drug_seq], label="Click to load an example", ) btn_load_example = gr.Button( "Load Example", elem_id="example-btn", # variant="secondary" ) structure_file.render() with gr.Row(): btn_extract_aa = gr.Button( "Extract amino acid sequence", elem_id="extract-aa-btn" ) btn_extract_sa = gr.Button( "Extract structure-aware sequence", elem_id="extract-sa-btn" ) protein_seq.render() drug_seq.render() # ──────────────── # RIGHT PANEL # ──────────────── with gr.Column(elem_id="right", elem_classes=["right-pane"]): head_dd = gr.Dropdown( label="Non-covalent interaction type/Overall", choices=INTERACTION_NAMES, value="Overall Interaction", interactive=True, ) top_k_dd = gr.Dropdown( label="Top-K residue", choices=[str(i) for i in range(1, 21)], value="1", interactive=True, ) with gr.Row(): btn_infer = gr.Button( "Inference", elem_id="inference-btn" ) clear_btn = gr.Button( "Clear", elem_id="clear-btn" ) # ─────────────────────────────── # Outputs # ─────────────────────────────── with gr.Column(elem_classes=["card"]): status_box = gr.HTML(elem_id="status-box") output_full = gr.HTML(elem_id="result-full") # ─────────────────────────────── # Example Loader Callback # ─────────────────────────────── def load_example_cb(): return ( "SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA", "[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]" ) # ─────────────────────────────── # Wiring # ─────────────────────────────── btn_load_example.click( fn=load_example_cb, inputs=[], outputs=[protein_seq, drug_seq], ) btn_extract_aa.click( fn=extract_aa_seq_cb, inputs=[structure_file, protein_seq], outputs=[protein_seq, status_box], ) btn_extract_sa.click( fn=extract_sa_seq_cb, inputs=[structure_file, protein_seq], outputs=[protein_seq, status_box], ) btn_infer.click( fn=inference_cb, inputs=[protein_seq, drug_seq, head_dd, top_k_dd], outputs=[output_full], ) clear_btn.click( fn=clear_cb, inputs=[], outputs=[ protein_seq, drug_seq, output_full, structure_file, status_box, ], ) demo.launch( theme=gr.themes.Default(), css=css, show_error=True )