Spaces:
Running
Running
| import gradio_client.utils as _gc_utils | |
| _orig_get_type = _gc_utils.get_type | |
| _orig_json2py = _gc_utils._json_schema_to_python_type | |
| def _patched_get_type(schema): | |
| if isinstance(schema, bool): | |
| schema = {} | |
| return _orig_get_type(schema) | |
| def _patched_json_schema_to_python_type(schema, defs=None): | |
| if isinstance(schema, bool): | |
| schema = {} | |
| return _orig_json2py(schema, defs) | |
| _gc_utils.get_type = _patched_get_type | |
| _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type | |
| # βββ Imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import os | |
| import io | |
| import base64 | |
| import argparse | |
| from typing import Optional, List, Tuple | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import selfies | |
| # from rdkit import Chem | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib import cm | |
| from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel, AutoTokenizer | |
| from Bio.PDB import PDBParser, MMCIFParser | |
| from Bio.Data import IUPACData | |
| import gradio as gr | |
| # Project utils (ensure these exist in your repository) | |
| from utils.metric_learning_models_att_maps import Pre_encoded, ExplainBind | |
| from utils.foldseek_util import get_struc_seq | |
| # βββββββββββββββββββββ Paths & Logos βββββββββββββββββββββ | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| ASSET_DIR = os.path.join(ROOT, "utils") | |
| LOSCAZLO_LOGO = os.path.join(ASSET_DIR, "loscalzo.png") | |
| def _load_logo_b64(path): | |
| if not os.path.exists(path): | |
| return "" | |
| with open(path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| LOSCAZLO_B64 = _load_logo_b64(LOSCAZLO_LOGO) | |
| # βββββββββββββββββββββ Configurable constants βββββββββββββββββββββ | |
| # UI-visible names (Halogen bonding removed) | |
| INTERACTION_NAMES = [ | |
| "Hydrogen bonding", | |
| "Salt Bridging", | |
| "ΟβΟ Stacking", | |
| "CationβΟ", | |
| "Hydrophobic", | |
| "Van der Waals", | |
| "Overall Interaction", | |
| ] | |
| # Map visible indices (0..5 = specific, 6 = combined) to underlying channel indices | |
| # Underlying channels originally had Halogen at index=5 (0-based). We skip 5 entirely. | |
| VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cation-Pi, Hydro, VdW | |
| N_VISIBLE_SPEC = len(VISIBLE2UNDERLYING) # 6 | |
| # βββββ Helper utilities βββββββββββββββββββββββββββββββββββββββββββ | |
| three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()} | |
| three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"}) | |
| STANDARD_AA_SET = set("ACDEFGHIKLMNPQRSTVWY") # Uppercase FASTA amino acids | |
| def simple_seq_from_structure(path: str) -> str: | |
| """Extract the longest chain and return standard 1-letter amino acid sequence.""" | |
| parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True) | |
| structure = parser.get_structure("P", path) | |
| chains = list(structure.get_chains()) | |
| if not chains: | |
| return "" | |
| chain = max(chains, key=lambda c: len(list(c.get_residues()))) | |
| seq = [] | |
| for res in chain: | |
| resname = res.get_resname().upper() | |
| if resname in three2one: | |
| seq.append(three2one[resname]) | |
| # else: skip non-standard residues | |
| return "".join(seq) | |
| # def smiles_to_selfies(smiles_text: str) -> Optional[str]: | |
| # """Validate and convert SMILES to SELFIES; return None if invalid.""" | |
| # try: | |
| # mol = Chem.MolFromSmiles(smiles_text) | |
| # if mol is None: | |
| # return None | |
| # return selfies.encoder(smiles_text) | |
| # except Exception: | |
| # return None | |
| def smiles_to_selfies(smiles_text: str) -> Optional[str]: | |
| try: | |
| sf = selfies.encoder(smiles_text) | |
| smiles_back = selfies.decoder(sf) | |
| if not smiles_back: | |
| return None | |
| return sf | |
| except Exception: | |
| return None | |
| def detect_protein_type(seq: str) -> str: | |
| """ | |
| Heuristic for protein input: | |
| - All uppercase and only the standard 20 amino acids β 'fasta' | |
| - Otherwise (contains lowercase or non-standard characters) β 'sa' | |
| """ | |
| s = (seq or "").strip() | |
| if not s: | |
| return "fasta" | |
| up = s.upper() | |
| only_aa = all(ch in STANDARD_AA_SET for ch in up) | |
| all_upper = (s == up) | |
| return "fasta" if (only_aa and all_upper) else "sa" | |
| def detect_ligand_type(text: str) -> str: | |
| """ | |
| Heuristic for ligand input: | |
| - Starts with '[' and contains ']' β 'selfies' | |
| - Otherwise β 'smiles' | |
| """ | |
| t = (text or "").strip() | |
| if not t: | |
| return "smiles" | |
| return "selfies" if (t.startswith("[") and ("]" in t)) else "smiles" | |
| def parse_config(): | |
| """Parse command-line options.""" | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--agg_mode", type=str, default="mean_all_tok") | |
| p.add_argument("--group_size", type=int, default=1) | |
| p.add_argument("--fusion", default="CAN") | |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| p.add_argument("--save_path_prefix", default="save_model_ckp/") # Root folder containing checkpoints | |
| p.add_argument("--dataset", default="Human") | |
| return p.parse_args() | |
| args = parse_config() | |
| DEVICE = args.device | |
| # βββββ Dynamic model registry βββββββββββββββββββββββββββββββββββββ | |
| PROT_MODELS = { | |
| "sa": "westlake-repl/SaProt_650M_AF2", | |
| "fasta": "facebook/esm2_t33_650M_UR50D", | |
| } | |
| DRUG_MODELS = { | |
| "selfies": "HUBioDataLab/SELFormer", | |
| # "smiles": "ibm/MoLFormer-XL-both-10pct", | |
| } | |
| def load_encoders(ptype: str, ltype: str, args): | |
| """ | |
| Dynamically load encoders and tokenisers based on input types. | |
| Returns: (prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding_module) | |
| """ | |
| # Protein encoder | |
| if ptype == "fasta": | |
| prot_path = PROT_MODELS["fasta"] | |
| prot_tokenizer = EsmTokenizer.from_pretrained(prot_path, do_lower_case=False) | |
| prot_model = EsmForMaskedLM.from_pretrained(prot_path) | |
| else: # 'sa' | |
| prot_path = PROT_MODELS["sa"] | |
| prot_tokenizer = EsmTokenizer.from_pretrained(prot_path) | |
| prot_model = EsmForMaskedLM.from_pretrained(prot_path) | |
| drug_path = DRUG_MODELS["selfies"] | |
| drug_tokenizer = AutoTokenizer.from_pretrained(drug_path) | |
| drug_model = AutoModel.from_pretrained(drug_path) | |
| # Ligand encoder | |
| # if ltype == "smiles": | |
| # drug_path = DRUG_MODELS["smiles"] | |
| # drug_tokenizer = AutoTokenizer.from_pretrained(drug_path, trust_remote_code=True) | |
| # drug_model = AutoModel.from_pretrained(drug_path, deterministic_eval=True, trust_remote_code=True) | |
| # else: # 'selfies' | |
| # drug_path = DRUG_MODELS["selfies"] | |
| # drug_tokenizer = AutoTokenizer.from_pretrained(drug_path) | |
| # drug_model = AutoModel.from_pretrained(drug_path) | |
| # Wrap encoders with Pre_encoded module | |
| encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE) | |
| return prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding | |
| def make_collate_fn(prot_tokenizer, drug_tokenizer): | |
| """Create a batch collation function using the given tokenisers.""" | |
| def _collate_fn(batch): | |
| query1, query2, scores = zip(*batch) | |
| query_encodings1 = prot_tokenizer( | |
| list(query1), max_length=512, padding="max_length", truncation=True, | |
| add_special_tokens=True, return_tensors="pt", | |
| ) | |
| query_encodings2 = drug_tokenizer( | |
| list(query2), max_length=512, padding="max_length", truncation=True, | |
| add_special_tokens=True, return_tensors="pt", | |
| ) | |
| scores = torch.tensor(list(scores)) | |
| attention_mask1 = query_encodings1["attention_mask"].bool() | |
| attention_mask2 = query_encodings2["attention_mask"].bool() | |
| return (query_encodings1["input_ids"], attention_mask1, | |
| query_encodings2["input_ids"], attention_mask2, scores) | |
| return _collate_fn | |
| def get_case_feature(model, loader): | |
| """Generate features for one proteinβligand pair using the provided model.""" | |
| model.eval() | |
| with torch.no_grad(): | |
| for p_ids, p_mask, d_ids, d_mask, _ in loader: | |
| p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE) | |
| d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE) | |
| p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask) | |
| return [(p_emb.cpu(), d_emb.cpu(), | |
| p_ids.cpu(), d_ids.cpu(), | |
| p_mask.cpu(), d_mask.cpu(), None)] | |
| # βββββββββββββββ SELFIES grouping by ORIGINAL string βββββββββββββ | |
| def _group_rows_by_selfies_string(n_rows: int, selfies_str: str): | |
| """ | |
| Partition the attention matrix's n_rows along ligand axis into groups per SELFIES token '[ ... ]'. | |
| Each group is a contiguous row span; we assign rows β equally using linspace. | |
| Returns: | |
| groups: List[(start_row, end_row)] inclusive | |
| labels: List['[X]','[=O]', ...] | |
| """ | |
| if n_rows <= 0: | |
| return [], [] | |
| try: | |
| toks = list(selfies.split_selfies((selfies_str or "").strip())) | |
| except Exception: | |
| toks = [] | |
| if not toks: | |
| # Fallback: treat whole ligand as one token | |
| return [(0, n_rows - 1)], [selfies_str or "[?]"] | |
| g = len(toks) | |
| edges = np.linspace(0, n_rows, g + 1, dtype=int) | |
| groups = [] | |
| for i in range(g): | |
| s, e = edges[i], edges[i + 1] - 1 | |
| if e < s: | |
| e = s | |
| groups.append((s, e)) | |
| return groups, toks | |
| def _connected_components_2d(mask: torch.Tensor) -> List[List[Tuple[int, int]]]: | |
| """4-connected components over a 2D boolean mask (rows=ligand tokens, cols=protein residues).""" | |
| h, w = mask.shape | |
| visited = torch.zeros_like(mask, dtype=torch.bool) | |
| comps: List[List[Tuple[int,int]]] = [] | |
| for i in range(h): | |
| for j in range(w): | |
| if mask[i, j] and not visited[i, j]: | |
| stack = [(i, j)] | |
| visited[i, j] = True | |
| comp = [] | |
| while stack: | |
| r, c = stack.pop() | |
| comp.append((r, c)) | |
| for dr, dc in ((1,0), (-1,0), (0,1), (0,-1)): | |
| rr, cc = r + dr, c + dc | |
| if 0 <= rr < h and 0 <= cc < w and mask[rr, cc] and not visited[rr, cc]: | |
| visited[rr, cc] = True | |
| stack.append((rr, cc)) | |
| comps.append(comp) | |
| return comps | |
| def _format_component_table( | |
| components, | |
| p_tokens, | |
| d_tokens, | |
| *, | |
| mode: str = "pair", # "pair" | "residue" | |
| ): | |
| """ | |
| Render HTML table for highlighted interaction components. | |
| Parameters | |
| ---------- | |
| components : List[List[Tuple[int,int]]] | |
| Each component is a list of (ligand_index, protein_index) pairs. | |
| p_tokens : List[str] | |
| Protein token strings. | |
| d_tokens : List[str] | |
| Ligand token strings. | |
| mode : str | |
| "pair" -> show Protein range + Ligand range | |
| "residue" -> show Protein residue(s) only | |
| """ | |
| # ---------------------------- | |
| # Residue-only mode | |
| # ---------------------------- | |
| if mode == "residue": | |
| if not components: | |
| return ( | |
| "<h4 style='margin:12px 0 6px'>Highlighted protein residues</h4>" | |
| "<p>No residues selected.</p>" | |
| ) | |
| rows = [] | |
| for comp in components: | |
| # comp = [(lig_idx, prot_idx), ...] | |
| prot_indices = [j for (_, j) in comp] | |
| p_start, p_end = min(prot_indices), max(prot_indices) | |
| p_s_idx, p_e_idx = p_start + 1, p_end + 1 | |
| p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?" | |
| p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?" | |
| if p_start == p_end: | |
| label = f"{p_s_idx}:{p_s_tok}" | |
| else: | |
| label = f"{p_s_idx}:{p_s_tok} β {p_e_idx}:{p_e_tok}" | |
| rows.append( | |
| f"<tr>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>" | |
| f"<strong>{label}</strong>" | |
| f"</td>" | |
| f"</tr>" | |
| ) | |
| return ( | |
| "<h4 style='margin:12px 0 6px'>Highlighted protein residues</h4>" | |
| "<table style='border-collapse:collapse;margin:6px 0 16px;width:60%'>" | |
| "<thead><tr style='background:#f5f5f5'>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Protein residue(s)</th>" | |
| "</tr></thead>" | |
| f"<tbody>{''.join(rows)}</tbody></table>" | |
| ) | |
| # ---------------------------- | |
| # Pair mode (default behaviour) | |
| # ---------------------------- | |
| if not components: | |
| return ( | |
| "<h4 style='margin:12px 0 6px'>Highlighted interaction segments</h4>" | |
| "<p>No interaction pairs selected.</p>" | |
| ) | |
| rows = [] | |
| for comp in components: | |
| lig_indices = [i for (i, _) in comp] | |
| prot_indices = [j for (_, j) in comp] | |
| d_start, d_end = min(lig_indices), max(lig_indices) | |
| p_start, p_end = min(prot_indices), max(prot_indices) | |
| d_s_idx, d_e_idx = d_start + 1, d_end + 1 | |
| p_s_idx, p_e_idx = p_start + 1, p_end + 1 | |
| d_s_tok = d_tokens[d_start] if d_start < len(d_tokens) else "?" | |
| d_e_tok = d_tokens[d_end] if d_end < len(d_tokens) else "?" | |
| p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?" | |
| p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?" | |
| rows.append( | |
| f"<tr>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>Protein: " | |
| f"<strong>{p_s_idx}:{p_s_tok}</strong>" | |
| f"{' β <strong>'+str(p_e_idx)+':'+p_e_tok+'</strong>' if p_end > p_start else ''}" | |
| f"</td>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>Ligand: " | |
| f"<strong>{d_s_idx}:{d_s_tok}</strong>" | |
| f"{' β <strong>'+str(d_e_idx)+':'+d_e_tok+'</strong>' if d_end > d_start else ''}" | |
| f"</td>" | |
| f"</tr>" | |
| ) | |
| return ( | |
| "<h4 style='margin:12px 0 6px'>Highlighted Binding site</h4>" | |
| "<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>" | |
| "<thead><tr style='background:#f5f5f5'>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Protein range</th>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Ligand range</th>" | |
| "</tr></thead>" | |
| f"<tbody>{''.join(rows)}</tbody></table>" | |
| ) | |
| def visualize_attention_and_ranges( | |
| model, | |
| feats, | |
| head_idx: int, | |
| *, | |
| mode: str = "pair", # "pair" | "residue" | |
| topk_pairs: int = 1, # Top-K interaction pairs (default=1) | |
| topk_residues: int = 1, # Top-K residues (1β20, default=1) | |
| prot_tokenizer=None, | |
| drug_tokenizer=None, | |
| ligand_type: str = "selfies", | |
| raw_selfies: Optional[str] = None, | |
| ) -> Tuple[str, str]: | |
| """ | |
| Visualise interaction attention with two complementary Top-K modes. | |
| Modes | |
| ----- | |
| mode="pair": | |
| - Select Top-K highest-scoring (ligand token, protein residue) pairs | |
| - Project selected pairs onto protein axis (evaluation-aligned) | |
| - Default K = 1 (user-controlled) | |
| mode="residue": | |
| - Aggregate attention over ligand dimension | |
| - Rank residues by aggregated score | |
| - Select Top-K residues (1β100) | |
| - Default K = 1 (binding pocket discovery) | |
| Notes | |
| ----- | |
| - Per-head GLOBAL SUM normalisation (matches test()). | |
| - Specific heads mapped exactly to GT channels. | |
| - Combined head = sum of 6 specific heads (NOT overall=7). | |
| """ | |
| assert mode in {"pair", "residue"} | |
| assert topk_pairs >= 1 | |
| assert 1 <= topk_residues <= 100 | |
| model.eval() | |
| with torch.no_grad(): | |
| # -------------------------------------------------- | |
| # Unpack features | |
| # -------------------------------------------------- | |
| p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0] | |
| p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE) | |
| p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE) | |
| # -------------------------------------------------- | |
| # Forward | |
| # -------------------------------------------------- | |
| prob, att_pd = model(p_emb, d_emb, p_mask, d_mask) | |
| att = att_pd.squeeze(0) | |
| prob = prob.item() | |
| # expected: [Ld, Lp, 8] or [8, Ld, Lp] | |
| # -------------------------------------------------- | |
| # Channel mapping (must match test()) | |
| # -------------------------------------------------- | |
| VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cat-Pi, Hydro, VdW | |
| N_VISIBLE_SPEC = 6 | |
| def select_channel_map(att_): | |
| if att_.dim() == 3 and att_.shape[-1] >= 8: | |
| if head_idx < N_VISIBLE_SPEC: | |
| return att_[:, :, VISIBLE2UNDERLYING[head_idx]].cpu() | |
| return att_[:, :, VISIBLE2UNDERLYING].sum(dim=2).cpu() | |
| if att_.dim() == 3 and att_.shape[0] >= 8: | |
| if head_idx < N_VISIBLE_SPEC: | |
| return att_[VISIBLE2UNDERLYING[head_idx]].cpu() | |
| return att_[VISIBLE2UNDERLYING].sum(dim=0).cpu() | |
| return att_.squeeze().cpu() | |
| att2d_raw = select_channel_map(att) # [Ld, Lp] | |
| # -------------------------------------------------- | |
| # Per-head GLOBAL SUM normalisation (critical) | |
| # -------------------------------------------------- | |
| att2d_raw = att2d_raw / (att2d_raw.sum() + 1e-8) | |
| # -------------------------------------------------- | |
| # Token decoding & trimming | |
| # -------------------------------------------------- | |
| def clean_tokens(ids, tokenizer): | |
| toks = tokenizer.convert_ids_to_tokens(ids.tolist()) | |
| if hasattr(tokenizer, "all_special_tokens"): | |
| toks = [t for t in toks if t not in tokenizer.all_special_tokens] | |
| return toks | |
| p_tokens_full = clean_tokens(p_ids[0], prot_tokenizer) | |
| d_tokens_full = clean_tokens(d_ids[0], drug_tokenizer) | |
| n_d = min(len(d_tokens_full), att2d_raw.size(0)) | |
| n_p = min(len(p_tokens_full), att2d_raw.size(1)) | |
| att2d = att2d_raw[:n_d, :n_p] | |
| p_tokens = p_tokens_full[:n_p] | |
| d_tokens = d_tokens_full[:n_d] | |
| p_indices = list(range(1, n_p + 1)) | |
| d_indices = list(range(1, n_d + 1)) | |
| # -------------------------------------------------- | |
| # SELFIES row merging (for interpretability) | |
| # -------------------------------------------------- | |
| if ligand_type == "selfies" and raw_selfies: | |
| groups, labels = _group_rows_by_selfies_string(att2d.size(0), raw_selfies) | |
| if groups: | |
| merged = [] | |
| for s, e in groups: | |
| merged.append(att2d[s:e + 1].mean(dim=0, keepdim=True)) | |
| att2d = torch.cat(merged, dim=0) | |
| d_tokens = labels | |
| d_indices = list(range(1, len(labels) + 1)) | |
| # -------------------------------------------------- | |
| # Top-K selection (two modes, STRICT RANKING) | |
| # -------------------------------------------------- | |
| if mode == "pair": | |
| flat = att2d.reshape(-1) | |
| k_eff = min(topk_pairs, flat.numel()) | |
| topk_vals, topk_idx = torch.topk(flat, k=k_eff) | |
| mask_top = torch.zeros_like(flat, dtype=torch.bool) | |
| mask_top[topk_idx] = True | |
| mask_top = mask_top.view_as(att2d) | |
| rows = [] | |
| n_cols = att2d.size(1) | |
| for rank, (val, linear_idx) in enumerate(zip(topk_vals, topk_idx), start=1): | |
| i = (linear_idx // n_cols).item() | |
| j = (linear_idx % n_cols).item() | |
| rows.append( | |
| f"<tr>" | |
| f"<td style='border:1px solid #ddd;padding:6px'><strong>Top {rank}</strong></td>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>Protein: <strong>{j+1}:{p_tokens[j]}</strong></td>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>Ligand: <strong>{i+1}:{d_tokens[i]}</strong></td>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>Score: <strong>{val.item():.6f}</strong></td>" | |
| f"</tr>" | |
| ) | |
| ranges_html = ( | |
| "<h4 style='margin:12px 0 6px'>Top-K Interaction Pairs (ranked by attention score)</h4>" | |
| "<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>" | |
| "<thead><tr style='background:#f5f5f5'>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Rank</th>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Protein</th>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Ligand</th>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Attention Score</th>" | |
| "</tr></thead>" | |
| f"<tbody>{''.join(rows)}</tbody></table>" | |
| ) | |
| else: | |
| # --- STRICT Top-K residue ranking --- | |
| residue_score = att2d.sum(dim=0) | |
| k_eff = min(topk_residues, residue_score.numel()) | |
| topk_vals, topk_res_idx = torch.topk(residue_score, k=k_eff) | |
| mask_top = torch.zeros_like(att2d, dtype=torch.bool) | |
| mask_top[:, topk_res_idx] = True | |
| rows = [] | |
| for rank, (val, j) in enumerate(zip(topk_vals, topk_res_idx), start=1): | |
| j = j.item() | |
| rows.append( | |
| f"<tr>" | |
| f"<td style='border:1px solid #ddd;padding:6px'><strong>Top {rank}</strong></td>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>" | |
| f"Protein residue: <strong>{j+1}:{p_tokens[j]}</strong>" | |
| f"</td>" | |
| f"<td style='border:1px solid #ddd;padding:6px'>" | |
| f"Aggregated Score: <strong>{val.item():.6f}</strong>" | |
| f"</td>" | |
| f"</tr>" | |
| ) | |
| ranges_html = ( | |
| "<h4 style='margin:12px 0 6px'>Top-K Residues (ranked by aggregated attention)</h4>" | |
| "<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>" | |
| "<thead><tr style='background:#f5f5f5'>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Rank</th>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Protein Residue</th>" | |
| "<th style='border:1px solid #ddd;padding:6px'>Aggregated Score</th>" | |
| "</tr></thead>" | |
| f"<tbody>{''.join(rows)}</tbody></table>" | |
| ) | |
| # -------------------------------------------------- | |
| # Connected components (visual coherence) | |
| # -------------------------------------------------- | |
| # p_tokens_orig = p_tokens.copy() | |
| # d_tokens_orig = d_tokens.copy() | |
| # components = _connected_components_2d(mask_top) | |
| # ranges_html = _format_component_table( | |
| # components, | |
| # p_tokens_orig, | |
| # d_tokens_orig, | |
| # mode=mode, | |
| # ) | |
| # -------------------------------------------------- | |
| # Crop to union of selected rows / columns | |
| # -------------------------------------------------- | |
| rows_keep = mask_top.any(dim=1) | |
| cols_keep = mask_top.any(dim=0) | |
| if not rows_keep.any(): | |
| rows_keep[:] = True | |
| if not cols_keep.any(): | |
| cols_keep[:] = True | |
| vis = att2d[rows_keep][:, cols_keep] | |
| d_tokens_vis = [t for k, t in zip(rows_keep.tolist(), d_tokens) if k] | |
| p_tokens_vis = [t for k, t in zip(cols_keep.tolist(), p_tokens) if k] | |
| d_indices_vis = [i for k, i in zip(rows_keep.tolist(), d_indices) if k] | |
| p_indices_vis = [i for k, i in zip(cols_keep.tolist(), p_indices) if k] | |
| # Cap columns for readability | |
| if vis.size(1) > 150: | |
| topc = torch.topk(vis.sum(0), k=150).indices | |
| vis = vis[:, topc] | |
| p_tokens_vis = [p_tokens_vis[i] for i in topc] | |
| p_indices_vis = [p_indices_vis[i] for i in topc] | |
| # -------------------------------------------------- | |
| # Plot | |
| # -------------------------------------------------- | |
| x_labels = [f"{i}:{t}" for i, t in zip(p_indices_vis, p_tokens_vis)] | |
| y_labels = [f"{i}:{t}" for i, t in zip(d_indices_vis, d_tokens_vis)] | |
| fig_w = min(22, max(6, len(x_labels) * 0.6)) | |
| fig_h = min(24, max(6, len(y_labels) * 0.8)) | |
| fig, ax = plt.subplots(figsize=(fig_w, fig_h)) | |
| im = ax.imshow(vis.numpy(), aspect="auto", cmap=cm.viridis) | |
| title = INTERACTION_NAMES[head_idx] | |
| suffix = "Top-K pairs" if mode == "pair" else "Top-K residues" | |
| ax.set_title(f"Ligand Γ Protein β {title} ({suffix})", fontsize=10, pad=8) | |
| ax.set_xlabel("Protein residues") | |
| ax.set_ylabel("Ligand tokens") | |
| ax.set_xticks(range(len(x_labels))) | |
| ax.set_xticklabels(x_labels, rotation=90, fontsize=8) | |
| ax.set_yticks(range(len(y_labels))) | |
| ax.set_yticklabels(y_labels, fontsize=7) | |
| ax.xaxis.tick_top() | |
| ax.xaxis.set_label_position("top") | |
| ax.tick_params(axis="x", bottom=False) | |
| fig.colorbar(im, fraction=0.026, pad=0.01) | |
| fig.tight_layout() | |
| # -------------------------------------------------- | |
| # Export | |
| # -------------------------------------------------- | |
| buf_png = io.BytesIO() | |
| buf_pdf = io.BytesIO() | |
| fig.savefig(buf_png, format="png", dpi=140) | |
| fig.savefig(buf_pdf, format="pdf") | |
| plt.close(fig) | |
| png_b64 = base64.b64encode(buf_png.getvalue()).decode() | |
| pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode() | |
| heat_html = f""" | |
| <div style='position:relative'> | |
| <a href='data:application/pdf;base64,{pdf_b64}' download='attention_{head_idx+1}.pdf' | |
| style='position:absolute;top:10px;right:10px; | |
| background:#111;color:#fff;padding:8px 12px; | |
| border-radius:10px;font-size:.85rem;text-decoration:none'> | |
| Download PDF | |
| </a> | |
| <img src='data:image/png;base64,{png_b64}' /> | |
| </div> | |
| """ | |
| # ------------------------------ | |
| # Probability display card | |
| # ------------------------------ | |
| if prob >= 0.8: | |
| bg = "#ecfdf5" | |
| border = "#10b981" | |
| label = "High binding confidence" | |
| elif prob >= 0.4: | |
| bg = "#eff6ff" | |
| border = "#3b82f6" | |
| label = "Moderate binding confidence" | |
| else: | |
| bg = "#fef2f2" | |
| border = "#ef4444" | |
| label = "Low binding confidence" | |
| prob_html = f""" | |
| <div style='margin:10px 0 18px; | |
| padding:14px 16px; | |
| border-left:5px solid {border}; | |
| border-radius:12px; | |
| background:{bg}; | |
| font-size:1rem'> | |
| <div style='font-weight:600;margin-bottom:4px'> | |
| Predicted Binding Probability | |
| </div> | |
| <div style='font-size:1.4rem;font-weight:700'> | |
| {prob:.4f} | |
| </div> | |
| <div style='font-size:0.85rem;color:#64748b;margin-top:4px'> | |
| {label} | |
| </div> | |
| </div> | |
| """ | |
| return prob_html, ranges_html, heat_html | |
| # βββββ Gradio callbacks βββββββββββββββββββββββββββββββββββββββββ | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| FOLDSEEK_BIN = os.path.join(ROOT, "utils", "foldseek") | |
| def extract_aa_seq_cb(structure_file, protein_text): | |
| """ | |
| Extract plain amino acid sequence from uploaded PDB/mmCIF. | |
| """ | |
| prot_seq_out = (protein_text or "").strip() | |
| msgs = [] | |
| if structure_file is None: | |
| return prot_seq_out, "<p style='color:red'>Please upload a structure file.</p>" | |
| try: | |
| seq = simple_seq_from_structure(structure_file.name) | |
| if seq: | |
| prot_seq_out = seq | |
| msgs.append("<li>β Extracted <b>amino acid sequence</b> from structure.</li>") | |
| else: | |
| msgs.append("<li>β No valid amino acid sequence found.</li>") | |
| except Exception as e: | |
| msgs.append(f"<li>β Extraction failed: <b>{e}</b></li>") | |
| status_html = ( | |
| "<div style='margin:10px 0;padding:10px 12px;" | |
| "border:1px solid #e5e7eb;border-radius:10px;" | |
| "background:#f8fafc;color:#0f172a'>" | |
| "<ul style='margin:0 0 0 18px;padding:0'>" | |
| f"{''.join(msgs)}" | |
| "</ul></div>" | |
| ) | |
| return prot_seq_out, status_html | |
| def extract_sa_seq_cb(structure_file, protein_text): | |
| prot_seq_out = (protein_text or "").strip() | |
| msgs = [] | |
| if structure_file is None: | |
| return prot_seq_out, "<p style='color:red'>Please upload a structure file.</p>" | |
| try: | |
| parsed = get_struc_seq( | |
| FOLDSEEK_BIN, | |
| structure_file.name, | |
| None, | |
| plddt_mask=False, | |
| ) | |
| first_chain = next(iter(parsed)) | |
| _, _, struct_seq = parsed[first_chain] | |
| if struct_seq: | |
| prot_seq_out = struct_seq | |
| msgs.append("<li>β Extracted <b>structure-aware sequence</b> (SA).</li>") | |
| else: | |
| msgs.append("<li>β Structure parsed but no SA sequence found.</li>") | |
| except Exception as e: | |
| msgs.append(f"<li>β SA extraction failed: <b>{e}</b></li>") | |
| status_html = ( | |
| "<div style='margin:10px 0;padding:10px 12px;" | |
| "border:1px solid #e5e7eb;border-radius:10px;" | |
| "background:#f8fafc;color:#0f172a'>" | |
| "<ul style='margin:0 0 0 18px;padding:0'>" | |
| f"{''.join(msgs)}" | |
| "</ul></div>" | |
| ) | |
| return prot_seq_out, status_html | |
| def _choose_ckpt_by_types(prot_seq: str, ligand_text: str) -> Tuple[str, str, str]: | |
| """Return (folder_name, protein_type, ligand_type) for checkpoint routing.""" | |
| ptype = detect_protein_type(prot_seq) | |
| ltype = detect_ligand_type(ligand_text) | |
| folder = f"{ptype}_{ltype}" # sa_selfies / fasta_selfies / sa_smiles / fasta_smiles | |
| return folder, ptype, ltype | |
| def inference_cb(prot_seq, drug_seq, head_choice, topk_choice): | |
| """ | |
| Inference callback supporting two Top-K modes: | |
| - Top-K interaction pairs | |
| - Top-K residues | |
| """ | |
| # ------------------------------ | |
| # Input validation | |
| # ------------------------------ | |
| if not prot_seq or not prot_seq.strip(): | |
| return "<p style='color:red'>Please extract or enter a protein sequence first.</p>" | |
| if not drug_seq or not drug_seq.strip(): | |
| return "<p style='color:red'>Please enter a ligand sequence (SELFIES or SMILES).</p>" | |
| prot_seq = prot_seq.strip() | |
| drug_seq_in = drug_seq.strip() | |
| # ------------------------------ | |
| # Detect types & checkpoint routing | |
| # ------------------------------ | |
| folder, ptype, ltype = _choose_ckpt_by_types(prot_seq, drug_seq_in) | |
| # Ligand normalisation: always tokenise as SELFIES | |
| if ltype == "smiles": | |
| conv = smiles_to_selfies(drug_seq_in) | |
| if conv is None: | |
| return ( | |
| "<p style='color:red'>SMILESβSELFIES conversion failed. " | |
| "The SMILES appears invalid.</p>", | |
| "", | |
| ) | |
| drug_seq_for_tokenizer = conv | |
| else: | |
| drug_seq_for_tokenizer = drug_seq_in | |
| # π εΌΊεΆη»δΈη±»ε | |
| ltype = "selfies" | |
| ligand_type_flag = "selfies" | |
| raw_selfies = drug_seq_for_tokenizer | |
| folder = f"{ptype}_selfies" | |
| # # Ligand normalisation: always tokenise as SELFIES | |
| # if ltype == "smiles": | |
| # conv = smiles_to_selfies(drug_seq_in) | |
| # if conv is None: | |
| # return ( | |
| # "<p style='color:red'>SMILESβSELFIES conversion failed. " | |
| # "The SMILES appears invalid.</p>", | |
| # "", | |
| # ) | |
| # drug_seq_for_tokenizer = conv | |
| # ligand_type_flag = "selfies" | |
| # else: | |
| # drug_seq_for_tokenizer = drug_seq_in | |
| # ligand_type_flag = "selfies" | |
| # raw_selfies = drug_seq_for_tokenizer if ligand_type_flag == "selfies" else None | |
| # ------------------------------ | |
| # Load encoders | |
| # ------------------------------ | |
| prot_tok, prot_m, drug_tok, drug_m, encoding = load_encoders(ptype, ltype, args) | |
| loader = DataLoader( | |
| [(prot_seq, drug_seq_for_tokenizer, 1)], | |
| batch_size=1, | |
| collate_fn=make_collate_fn(prot_tok, drug_tok), | |
| ) | |
| feats = get_case_feature(encoding, loader) | |
| # ------------------------------ | |
| # Load trained checkpoint (if exists) | |
| # ------------------------------ | |
| ckpt = os.path.join(args.save_path_prefix, folder, "best_model.ckpt") | |
| model = ExplainBind(1280, 768, args=args).to(DEVICE) | |
| if os.path.isfile(ckpt): | |
| model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) | |
| warn_html = ( | |
| "<div style='margin:8px 0 14px;padding:8px 10px;" | |
| "border-left:4px solid #10b981;background:#ecfdf5'>" | |
| f"<b>Loaded model:</b> <code>{folder}/best_model.ckpt</code></div>" | |
| ) | |
| else: | |
| warn_html = ( | |
| "<div style='margin:8px 0 14px;padding:8px 10px;" | |
| "border-left:4px solid #f59e0b;background:#fffbeb'>" | |
| "<b>Warning:</b> checkpoint not found " | |
| f"<code>{folder}/best_model.ckpt</code>. " | |
| "Using randomly initialised weights for visualisation.</div>" | |
| ) | |
| # ------------------------------ | |
| # Parse interaction head | |
| # ------------------------------ | |
| sel = str(head_choice).strip() | |
| if sel in INTERACTION_NAMES: | |
| head_idx = INTERACTION_NAMES.index(sel) | |
| else: | |
| try: | |
| n = int(sel.split(".", 1)[0]) | |
| head_idx = max(0, min(len(INTERACTION_NAMES) - 1, n - 1)) | |
| except Exception: | |
| head_idx = len(INTERACTION_NAMES) - 1 # Combined Interaction | |
| # ------------------------------ | |
| # Parse Top-K value | |
| # ------------------------------ | |
| try: | |
| topk = int(str(topk_choice).strip()) | |
| except Exception: | |
| topk = 1 | |
| topk = max(1, topk) | |
| mode = "residue" | |
| topk_pairs = 1 | |
| topk_residues = min(100, topk) | |
| # ------------------------------ | |
| # Visualisation | |
| # ------------------------------ | |
| prob_html, table_html, heat_html = visualize_attention_and_ranges( | |
| model, | |
| feats, | |
| head_idx, | |
| mode=mode, | |
| topk_pairs=topk_pairs, | |
| topk_residues=topk_residues, | |
| prot_tokenizer=prot_tok, | |
| drug_tokenizer=drug_tok, | |
| ligand_type=ligand_type_flag, | |
| raw_selfies=raw_selfies, | |
| ) | |
| full_html = prob_html + table_html + heat_html # β εΌΊεΆδΈδΈι‘ΊεΊ | |
| return full_html | |
| def clear_cb(): | |
| return "", "", "", None, "" | |
| # protein_seq, drug_seq, output_full, structure_file, status_box | |
| # βββββ Gradio interface definition βββββββββββββββββββββββββββββββ | |
| css = """ | |
| :root{ | |
| --bg:#f8fafc; --card:#f8fafc; --text:#0f172a; | |
| --muted:#6b7280; --border:#e5e7eb; --shadow:0 6px 24px rgba(2,6,23,.06); | |
| --radius:14px; --icon-size:20px; | |
| } | |
| *{box-sizing:border-box} | |
| html,body{background:#fff!important;color:var(--text)!important} | |
| .gradio-container{max-width:1120px;margin:0 auto} | |
| /* Title and subtitle */ | |
| h1{ | |
| font-family:Inter,ui-sans-serif;letter-spacing:.2px;font-weight:700; | |
| font-size:32px;margin:22px 0 12px;text-align:center | |
| } | |
| .subtle{color:var(--muted);font-size:14px;text-align:center;margin:-6px 0 18px} | |
| /* Card style */ | |
| .card{ | |
| background:var(--card); border:1px solid var(--border); border-radius:var(--radius); | |
| box-shadow:var(--shadow); padding:22px; | |
| } | |
| /* Top links */ | |
| .link-row{display:flex;justify-content:center;gap:14px;margin:0 auto 18px;flex-wrap:wrap} | |
| /* Two-column grid: left=input, right=controls */ | |
| .grid-2{display:grid;grid-template-columns:1.4fr .9fr;gap:16px} | |
| .grid-2 .col{display:flex;flex-direction:column;gap:12px} | |
| /* Buttons */ | |
| .gr-button{border-radius:12px !important;font-weight:700 !important;letter-spacing:.2px} | |
| #extract-btn{background:linear-gradient(90deg,#EFAFB2,#EFAFB2); color:#0f172a} | |
| #inference-btn{background:linear-gradient(90deg,#B2CBDF,#B2CBDF); color:#0f172a} | |
| #clear-btn{background:#FFE2B5; color:#0A0A0A; border:1px solid var(--border)} | |
| /* Result spacing */ | |
| #result-table{margin-bottom:16px} | |
| /* Figure container */ | |
| .figure-wrap{border:1px solid var(--border);border-radius:12px;overflow:hidden;box-shadow:var(--shadow)} | |
| .figure-wrap img{display:block;width:100%;height:auto} | |
| /* Right pane: vertical radio layout and full-width controls (kept for button styling) */ | |
| .right-pane .gr-button{ | |
| width:100% !important; | |
| height:48px !important; | |
| border-radius:12px !important; | |
| font-weight:700 !important; | |
| letter-spacing:.2px; | |
| } | |
| /* βββββββββ Publication links (Bulma-like) βββββββββ */ | |
| .publication-links { | |
| display: flex; | |
| justify-content: center; | |
| gap: 14px; | |
| flex-wrap: wrap; | |
| margin: 6px 0 18px; | |
| } | |
| .link-block a { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 8px; | |
| padding: 10px 18px; | |
| font-size: 14px; | |
| font-weight: 600; | |
| border-radius: 9999px; | |
| text-decoration: none; | |
| transition: all 0.15s ease-in-out; | |
| } | |
| /* colour variants */ | |
| .btn-danger { background:#e2e8f0; color:#0f172a; } | |
| .btn-dark { background:#e2e8f0; color:#0f172a; } | |
| .btn-link { background:#e2e8f0; color:#0f172a; } | |
| .btn-warning { background:#e2e8f0; color:#0f172a; } | |
| .link-block a:hover { | |
| filter: brightness(0.95); | |
| transform: translateY(-1px); | |
| } | |
| .loscalzo-block img { | |
| height: 100px; | |
| width: auto; | |
| object-fit: contain; | |
| } | |
| .loscalzo-block { | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| margin: 0 auto; | |
| justify-content: center; | |
| } | |
| .link-btn{ | |
| display:inline-flex !important; | |
| align-items:center !important; | |
| gap:8px !important; | |
| padding:10px 18px !important; | |
| font-size:14px !important; | |
| font-weight:600 !important; | |
| border-radius:9999px !important; | |
| background:#e2e8f0 !important; | |
| color:#0f172a !important; | |
| text-decoration:none !important; | |
| border:1px solid #e5e7eb !important; | |
| transition:all 0.15s ease-in-out !important; | |
| } | |
| .link-btn:hover{ | |
| filter:brightness(0.95); | |
| transform:translateY(-1px); | |
| } | |
| .project-links{ | |
| display:flex !important; | |
| justify-content:center !important; | |
| gap:28px !important; | |
| flex-wrap:wrap !important; | |
| margin-bottom:32px !important; | |
| } | |
| #example-btn { | |
| background: #979ea8 !important; | |
| color: #1e293b !important; | |
| } | |
| #extract-aa-btn{ | |
| background:#DCE7F3 !important; | |
| color:#0f172a !important; | |
| } | |
| #extract-sa-btn{ | |
| background:#EADCF8 !important; | |
| color:#0f172a !important; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1>ExplainBind: Token-level ProteinβLigand Interaction Visualiser</h1>") | |
| # gr.HTML(f""" | |
| # <div class="loscalzo-block"> | |
| # <img src="data:image/png;base64,{LOSCAZLO_B64}" | |
| # alt="Loscalzo Research Group logo" /> | |
| # <a class="loscalzo-name" | |
| # href="https://ogephd.hms.harvard.edu/people/joseph-loscalzo" | |
| # target="_blank" rel="noopener"> | |
| # </a> | |
| # </div> | |
| # """) | |
| # βββββββββββββββββββββββββββββββ | |
| # Top links | |
| # βββββββββββββββββββββββββββββββ | |
| gr.HTML(""" | |
| <div class="project-links"> | |
| <a class="link-btn" href="https://zhaohanm.github.io/ExplainBind/" target="_blank" rel="noopener noreferrer" aria-label="Project Page"> | |
| <!-- globe icon --> | |
| <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true"> | |
| <path d="M12 2a10 10 0 1 0 10 10A10.012 10.012 0 0 0 12 2Zm7.93 9h-3.18a15.84 15.84 0 0 0-1.19-5.02A8.02 8.02 0 0 1 19.93 11ZM12 4c.86 0 2.25 1.86 3.01 6H8.99C9.75 5.86 11.14 4 12 4ZM4.07 13h3.18c.2 1.79.66 3.47 1.19 5.02A8.02 8.02 0 0 1 4.07 13Zm3.18-2H4.07A8.02 8.02 0 0 1 8.44 5.98 15.84 15.84 0 0 0 7.25 11Zm1.37 2h6.76c-.76 4.14-2.15 6-3.01 6s-2.25-1.86-3.01-6Zm9.05 0h3.18a8.02 8.02 0 0 1-4.37 5.02 15.84 15.84 0 0 0 1.19-5.02Z"/> | |
| </svg> | |
| Project Page | |
| </a> | |
| <a class="link-btn" href="https://doi.org/10.64898/2026.03.03.707476" target="_blank" rel="noopener noreferrer" aria-label="Biorxiv: 2406.01651"> | |
| <!-- arXiv-like paper icon --> | |
| <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true"> | |
| <path d="M6 2h9l5 5v13a2 2 0 0 1-2 2H6a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2Zm8 1.5V8h4.5L14 3.5ZM7 12h10v2H7v-2Zm0 4h10v2H7v-2Zm0-8h6v2H7V8Z"/> | |
| </svg> | |
| biorXiv preprint | |
| </a> | |
| <a class="link-btn" href="https://github.com/ZhaohanM/ExplainBind" target="_blank" rel="noopener noreferrer" aria-label="GitHub Repo"> | |
| <!-- GitHub mark --> | |
| <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true"> | |
| <path d="M12 .5A12 12 0 0 0 0 12.76c0 5.4 3.44 9.98 8.2 11.6.6.12.82-.28.82-.6v-2.3c-3.34.74-4.04-1.44-4.04-1.44-.54-1.38-1.32-1.74-1.32-1.74-1.08-.76.08-.74.08-.74 1.2.08 1.84 1.26 1.84 1.26 1.06 1.86 2.78 1.32 3.46 1.02.1-.8.42-1.32.76-1.62-2.66-.32-5.46-1.36-5.46-6.02 0-1.34.46-2.44 1.22-3.3-.12-.32-.54-1.64.12-3.42 0 0 1-.34 3.32 1.26.96-.28 1.98-.42 3-.42s2.04.14 3 .42c2.32-1.6 3.32-1.26 3.32-1.26.66 1.78.24 3.1.12 3.42.76.86 1.22 1.96 1.22 3.3 0 4.68-2.8 5.68-5.48 6 .44.38.84 1.12.84 2.28v3.38c0 .32.22.74.84.6A12.02 12.02 0 0 0 24 12.76 12 12 0 0 0 12 .5Z"/> | |
| </svg> | |
| Source code | |
| </a> | |
| </div> | |
| """) | |
| # βββββββββββββββββββββββββββββββ | |
| # Guidelines | |
| # βββββββββββββββββββββββββββββββ | |
| with gr.Accordion("Guidelines for Users", open=True, elem_classes=["card"]): | |
| gr.HTML(""" | |
| <ol style="font-size:1rem;line-height:1.6;margin-left:22px;"> | |
| <li> | |
| <strong>Input formats:</strong> | |
| The system supports either <em>structure-aware (SA)</em> sequences derived from | |
| protein structures or conventional <em>FASTA</em> sequences. | |
| For structure-based analysis, users may upload <code>.pdb</code> or | |
| <code>.cif</code> files to extract the corresponding sequence representation. | |
| Ligands can be provided in <em>SMILES</em> or <em>SELFIES</em> format. | |
| </li> | |
| <li> | |
| <strong>Interaction channel selection:</strong> | |
| Users may select a specific non-covalent interaction type | |
| (e.g., hydrogen bonding, hydrophobic interactions) or the | |
| overall interaction channel to visualise the corresponding | |
| token-level binding patterns. | |
| </li> | |
| <li> | |
| <strong>Model outputs:</strong> | |
| The system reports (i) a predicted binding probability for the | |
| proteinβligand pair, (ii) a ranked Top-K residue table, and (iii) a token-level interaction | |
| heat map illustrating spatial interaction patterns. | |
| </li> | |
| </ol> | |
| """) | |
| # βββββββββββββββββββββββββββββββ | |
| # Inputs + Controls | |
| # βββββββββββββββββββββββββββββββ | |
| with gr.Row(): | |
| with gr.Column(elem_classes=["card", "grid-2"]): | |
| # ββββββββββββββββ | |
| # LEFT PANEL | |
| # ββββββββββββββββ | |
| with gr.Column(elem_id="left"): | |
| protein_seq = gr.Textbox( | |
| label="Protein structure-aware / FASTA sequence", | |
| lines=3, | |
| placeholder="Paste SA/FASTA sequence or click Extractβ¦", | |
| elem_id="protein-seq", | |
| render=False, | |
| ) | |
| drug_seq = gr.Textbox( | |
| label="Ligand (SELFIES / SMILES)", | |
| lines=3, | |
| placeholder="Paste SELFIES or SMILES", | |
| elem_id="drug-seq", | |
| render=False, | |
| ) | |
| structure_file = gr.File( | |
| label="Upload protein structure (.pdb / .cif)", | |
| file_types=[".pdb", ".cif"], | |
| elem_id="structure-file", | |
| render=False, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### Example") | |
| gr.Examples( | |
| examples=[[ | |
| "SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA", | |
| "[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]" | |
| ]], | |
| inputs=[protein_seq, drug_seq], | |
| label="Click to load an example", | |
| ) | |
| btn_load_example = gr.Button( | |
| "Load Example", | |
| elem_id="example-btn", | |
| # variant="secondary" | |
| ) | |
| structure_file.render() | |
| with gr.Row(): | |
| btn_extract_aa = gr.Button( | |
| "Extract amino acid sequence", | |
| elem_id="extract-aa-btn" | |
| ) | |
| btn_extract_sa = gr.Button( | |
| "Extract structure-aware sequence", | |
| elem_id="extract-sa-btn" | |
| ) | |
| protein_seq.render() | |
| drug_seq.render() | |
| # ββββββββββββββββ | |
| # RIGHT PANEL | |
| # ββββββββββββββββ | |
| with gr.Column(elem_id="right", elem_classes=["right-pane"]): | |
| head_dd = gr.Dropdown( | |
| label="Non-covalent interaction type/Overall", | |
| choices=INTERACTION_NAMES, | |
| value="Overall Interaction", | |
| interactive=True, | |
| ) | |
| top_k_dd = gr.Dropdown( | |
| label="Top-K residue", | |
| choices=[str(i) for i in range(1, 21)], | |
| value="1", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| btn_infer = gr.Button( | |
| "Inference", | |
| elem_id="inference-btn" | |
| ) | |
| clear_btn = gr.Button( | |
| "Clear", | |
| elem_id="clear-btn" | |
| ) | |
| # βββββββββββββββββββββββββββββββ | |
| # Outputs | |
| # βββββββββββββββββββββββββββββββ | |
| with gr.Column(elem_classes=["card"]): | |
| status_box = gr.HTML(elem_id="status-box") | |
| output_full = gr.HTML(elem_id="result-full") | |
| # βββββββββββββββββββββββββββββββ | |
| # Example Loader Callback | |
| # βββββββββββββββββββββββββββββββ | |
| def load_example_cb(): | |
| return ( | |
| "SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA", | |
| "[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]" | |
| ) | |
| # βββββββββββββββββββββββββββββββ | |
| # Wiring | |
| # βββββββββββββββββββββββββββββββ | |
| btn_load_example.click( | |
| fn=load_example_cb, | |
| inputs=[], | |
| outputs=[protein_seq, drug_seq], | |
| ) | |
| btn_extract_aa.click( | |
| fn=extract_aa_seq_cb, | |
| inputs=[structure_file, protein_seq], | |
| outputs=[protein_seq, status_box], | |
| ) | |
| btn_extract_sa.click( | |
| fn=extract_sa_seq_cb, | |
| inputs=[structure_file, protein_seq], | |
| outputs=[protein_seq, status_box], | |
| ) | |
| btn_infer.click( | |
| fn=inference_cb, | |
| inputs=[protein_seq, drug_seq, head_dd, top_k_dd], | |
| outputs=[output_full], | |
| ) | |
| clear_btn.click( | |
| fn=clear_cb, | |
| inputs=[], | |
| outputs=[ | |
| protein_seq, | |
| drug_seq, | |
| output_full, | |
| structure_file, | |
| status_box, | |
| ], | |
| ) | |
| demo.launch( | |
| theme=gr.themes.Default(), | |
| css=css, | |
| show_error=True | |
| ) |