Spaces:
Running
Running
| # βββ monkey-patch gradio_client so bool schemas donβt crash json_schema_to_python_type βββ | |
| import gradio_client.utils as _gc_utils | |
| # back up originals | |
| _orig_get_type = _gc_utils.get_type | |
| _orig_json2py = _gc_utils._json_schema_to_python_type | |
| def _patched_get_type(schema): | |
| # treat any boolean schema as if it were an empty dict | |
| if isinstance(schema, bool): | |
| schema = {} | |
| return _orig_get_type(schema) | |
| def _patched_json_schema_to_python_type(schema, defs=None): | |
| # treat any boolean schema as if it were an empty dict | |
| if isinstance(schema, bool): | |
| schema = {} | |
| return _orig_json2py(schema, defs) | |
| _gc_utils.get_type = _patched_get_type | |
| _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type | |
| # βββ now itβs safe to import Gradio and build your interface βββββββββββββββββββββββββββ | |
| import gradio as gr | |
| from gradio.themes import Soft | |
| import os | |
| import sys | |
| import argparse | |
| import tempfile | |
| import shutil | |
| import base64 | |
| import io | |
| import torch | |
| import selfies | |
| from rdkit import Chem | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib import cm | |
| from typing import Optional | |
| from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel | |
| from torch.utils.data import DataLoader | |
| from Bio.PDB import PDBParser, MMCIFParser | |
| from Bio.Data import IUPACData | |
| from utils.drug_tokenizer import DrugTokenizer | |
| from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI | |
| from utils.foldseek_util import get_struc_seq | |
| # βββββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()} | |
| three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"}) | |
| def simple_seq_from_structure(path: str) -> str: | |
| parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True) | |
| structure = parser.get_structure("P", path) | |
| chains = list(structure.get_chains()) | |
| if not chains: | |
| return "" | |
| chain = max(chains, key=lambda c: len(list(c.get_residues()))) | |
| return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain) | |
| def smiles_to_selfies(smiles: str) -> Optional[str]: | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| return selfies.encoder(smiles) | |
| except Exception: | |
| return None | |
| def parse_config(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2") | |
| p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer") | |
| p.add_argument("--agg_mode", type=str, default="mean_all_tok") | |
| p.add_argument("--group_size", type=int, default=1) | |
| p.add_argument("--fusion", default="CAN") | |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| p.add_argument("--save_path_prefix", default="save_model_ckp/") | |
| p.add_argument("--dataset", default="Human") | |
| return p.parse_args() | |
| args = parse_config() | |
| DEVICE = args.device | |
| # βββββ Load models & tokenizers βββββββββββββββββββββββββββββββββ | |
| prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path) | |
| prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path) | |
| drug_tokenizer = DrugTokenizer() | |
| drug_model = AutoModel.from_pretrained(args.drug_encoder_path) | |
| encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE) | |
| def collate_fn(batch): | |
| query1, query2, scores = zip(*batch) | |
| query_encodings1 = prot_tokenizer.batch_encode_plus( | |
| list(query1), | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| query_encodings2 = drug_tokenizer.batch_encode_plus( | |
| list(query2), | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| scores = torch.tensor(list(scores)) | |
| attention_mask1 = query_encodings1["attention_mask"].bool() | |
| attention_mask2 = query_encodings2["attention_mask"].bool() | |
| return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores | |
| def get_case_feature(model, loader): | |
| model.eval() | |
| with torch.no_grad(): | |
| for p_ids, p_mask, d_ids, d_mask, _ in loader: | |
| p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE) | |
| d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE) | |
| p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask) | |
| return [(p_emb.cpu(), d_emb.cpu(), | |
| p_ids.cpu(), d_ids.cpu(), | |
| p_mask.cpu(), d_mask.cpu(), None)] | |
| # βββββββββββββββ visualisation βββββββββββββββββββββββββββββββββββββββββββ | |
| def _safe_is_special(tokenizer, tok: str) -> bool: | |
| # Some tokenisers expose different special token sets; fall back conservatively. | |
| special_sets = [] | |
| if hasattr(tokenizer, "all_special_tokens"): | |
| special_sets.append(set(tokenizer.all_special_tokens)) | |
| if hasattr(tokenizer, "special_tokens_map"): | |
| special_sets.extend(set(v) if isinstance(v, list) else {v} | |
| for v in tokenizer.special_tokens_map.values()) | |
| for s in special_sets: | |
| if tok in s: | |
| return True | |
| return False | |
| def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str: | |
| """ | |
| Render a Protein β Drug cross-attention heat-map and optional Top-30 residue table. | |
| """ | |
| model.eval() | |
| with torch.no_grad(): | |
| # ββ unpack single-case tensors βββββββββββββββββββββββββββββββββββββββββββ | |
| p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0] | |
| p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE) | |
| p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE) | |
| # ββ forward pass: Protein β Drug attention (B, n_p, n_d) βββββββββββββββ | |
| _, att_pd = model(p_emb, d_emb, p_mask, d_mask) | |
| attn = att_pd.squeeze(0).cpu() # (n_p, n_d) | |
| # ββ decode tokens (skip special symbols) ββββββββββββββββββββββββββββββββ | |
| def clean_ids(ids, tokenizer): | |
| toks = tokenizer.convert_ids_to_tokens(ids.tolist()) | |
| return [t for t in toks if not _safe_is_special(tokenizer, t)] | |
| p_tokens_full = clean_ids(p_ids[0], prot_tokenizer) | |
| p_indices_full = list(range(1, len(p_tokens_full) + 1)) | |
| d_tokens_full = clean_ids(d_ids[0], drug_tokenizer) | |
| d_indices_full = list(range(1, len(d_tokens_full) + 1)) | |
| # ββ safety cut-off to match attn mat size ββββββββββββββββββββββββββββββ | |
| p_tokens = p_tokens_full[: attn.size(0)] | |
| p_indices = p_indices_full[: attn.size(0)] | |
| d_tokens = d_tokens_full[: attn.size(1)] | |
| d_indices = d_indices_full[: attn.size(1)] | |
| attn = attn[: len(p_tokens), : len(d_tokens)] | |
| orig_attn = attn.clone() | |
| # ββ adaptive sparsity pruning βββββββββββββββββββββββββββββββββββββββββββ | |
| thr = attn.max().item() * 0.05 if attn.numel() > 0 else 0.0 | |
| row_keep = (attn.max(dim=1).values > thr) if attn.size(0) else torch.tensor([], dtype=torch.bool) | |
| col_keep = (attn.max(dim=0).values > thr) if attn.size(1) else torch.tensor([], dtype=torch.bool) | |
| if row_keep.sum().item() < 3 and attn.size(0) > 0: | |
| row_keep = torch.ones(attn.size(0), dtype=torch.bool) | |
| if col_keep.sum().item() < 3 and attn.size(1) > 0: | |
| col_keep = torch.ones(attn.size(1), dtype=torch.bool) | |
| attn = attn[row_keep][:, col_keep] | |
| p_tokens = [tok for keep, tok in zip(row_keep.tolist(), p_tokens) if keep] | |
| p_indices = [idx for keep, idx in zip(row_keep.tolist(), p_indices) if keep] | |
| d_tokens = [tok for keep, tok in zip(col_keep.tolist(), d_tokens) if keep] | |
| d_indices = [idx for keep, idx in zip(col_keep.tolist(), d_indices) if keep] | |
| # ββ cap column count at 150 for readability βββββββββββββββββββββββββββββ | |
| if attn.size(1) > 150: | |
| topc = torch.topk(attn.sum(0), k=150).indices | |
| attn = attn[:, topc] | |
| d_tokens = [d_tokens[i] for i in topc] | |
| d_indices = [d_indices[i] for i in topc] | |
| # ββ draw heat-map ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)] | |
| y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)] | |
| fig_w = min(22, max(8, len(x_labels) * 0.6)) | |
| fig_h = min(24, max(6, len(y_labels) * 0.8)) | |
| fig, ax = plt.subplots(figsize=(fig_w, fig_h)) | |
| im = ax.imshow(attn.numpy(), aspect="auto", cmap=cm.viridis, interpolation="nearest") | |
| ax.set_title("Protein β Drug Attention", pad=8, fontsize=11) | |
| ax.set_xticks(range(len(x_labels))) | |
| ax.set_xticklabels(x_labels, rotation=90, fontsize=8, ha="center", va="center") | |
| ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False, pad=27) | |
| ax.set_yticks(range(len(y_labels))) | |
| ax.set_yticklabels(y_labels, fontsize=7) | |
| ax.tick_params(axis="y", top=True, bottom=False, labeltop=True, labelbottom=False, pad=10) | |
| fig.colorbar(im, fraction=0.026, pad=0.01) | |
| fig.tight_layout() | |
| # build PNG / PDF | |
| buf_png = io.BytesIO() | |
| fig.savefig(buf_png, format="png", dpi=140) | |
| buf_png.seek(0) | |
| buf_pdf = io.BytesIO() | |
| fig.savefig(buf_pdf, format="pdf") | |
| buf_pdf.seek(0) | |
| plt.close(fig) | |
| png_b64 = base64.b64encode(buf_png.getvalue()).decode() | |
| pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode() | |
| html_heat = ( | |
| f"<div class='heatmap-card' style='position: relative; width: 100%;'>" | |
| f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' " | |
| "style='position:absolute; top:12px; right:12px; " | |
| "background: var(--primary); color:#fff; padding:8px 16px; border-radius:8px; " | |
| "font-size:.92rem; font-weight:600; text-decoration:none;'>Download PDF</a>" | |
| f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>" | |
| f"<img src='data:image/png;base64,{png_b64}' " | |
| "style='display:block; width:100%; height:auto; cursor:zoom-in;'/>" | |
| "</a>" | |
| "</div>" | |
| ) | |
| # βββββββββββββββββββββ Top-30 table (optional) βββββββββββββββββββββ | |
| table_html = "" | |
| if drug_idx is not None and orig_attn.size(1) > 0 and 0 <= drug_idx < orig_attn.size(1): | |
| # map original 0-based drug_idx β pruned column | |
| col_pos = None | |
| if (drug_idx + 1) in d_indices: | |
| col_pos = d_indices.index(drug_idx + 1) | |
| elif 0 <= drug_idx < len(d_tokens): | |
| col_pos = drug_idx | |
| if col_pos is not None: | |
| col_vec = attn[:, col_pos] | |
| k = min(30, len(col_vec)) | |
| if k > 0: | |
| topk = torch.topk(col_vec, k=k).indices.tolist() | |
| # header cells | |
| header_cells = ( | |
| "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Rank</th>" | |
| + "".join( | |
| f"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center'>{r+1}</th>" | |
| for r in range(len(topk)) | |
| ) | |
| ) | |
| residue_cells = ( | |
| "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Residue</th>" | |
| + "".join( | |
| f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_tokens[i]}</td>" | |
| for i in topk | |
| ) | |
| ) | |
| position_cells = ( | |
| "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Position</th>" | |
| + "".join( | |
| f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_indices[i]}</td>" | |
| for i in topk | |
| ) | |
| ) | |
| drug_tok_text = d_tokens[col_pos] | |
| orig_idx_disp = d_indices[col_pos] | |
| table_html = ( | |
| f"<div class='card' style='margin-top:18px'>" | |
| f"<h4 style='margin:0 0 12px; font-size:1rem;'>" | |
| f"Drug atom #{orig_idx_disp} <code>{drug_tok_text}</code> β Top-30 Protein residues" | |
| f"</h4>" | |
| f"<table style='border-collapse:collapse; margin:0 auto 4px; font-size:.95rem'>" | |
| f"<tr>{header_cells}</tr>" | |
| f"<tr>{residue_cells}</tr>" | |
| f"<tr>{position_cells}</tr>" | |
| f"</table>" | |
| f"</div>" | |
| ) | |
| return table_html + html_heat | |
| # βββββ Gradio Callbacks βββββββββββββββββββββββββββββββββββββββββ | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| FOLDSEEK_BIN = os.path.join(ROOT, "bin", "foldseek") | |
| def extract_sequence_cb(structure_file): | |
| if structure_file is None or not os.path.exists(structure_file.name): | |
| return "" | |
| parsed = get_struc_seq(FOLDSEEK_BIN, structure_file.name, None, plddt_mask=False) | |
| first_chain = next(iter(parsed)) | |
| _, _, struct_seq = parsed[first_chain] | |
| return struct_seq | |
| def inference_cb(prot_seq, drug_seq, atom_idx): | |
| if not prot_seq: | |
| return "<p style='color:red'>Please extract or enter a protein sequence first.</p>" | |
| if not drug_seq.strip(): | |
| return "<p style='color:red'>Please enter a drug sequence.</p>" | |
| if not drug_seq.strip().startswith("["): | |
| conv = smiles_to_selfies(drug_seq.strip()) | |
| if conv is None: | |
| return "<p style='color:red'>SMILESβSELFIES conversion failed.</p>" | |
| drug_seq = conv | |
| loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn) | |
| feats = get_case_feature(encoding, loader) | |
| model = FusionDTI(446, 768, args).to(DEVICE) | |
| ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt") | |
| if os.path.isfile(ckpt): | |
| model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) | |
| return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None) | |
| def clear_cb(): | |
| return "", "", None, "", None | |
| # βββββ Theme & CSS βββββββββββββββββββββββββββββββββββββββββββββ | |
| css = """ | |
| :root { | |
| --bg:#f7f7fb; | |
| --card:#ffffff; | |
| --border:#e6e7eb; | |
| --primary:#4f46e5; | |
| --primary-dark:#4338ca; | |
| --text:#0f172a; | |
| --muted:#6b7280; | |
| --radius:14px; | |
| --shadow:0 10px 30px rgba(15,23,42,.06); | |
| } | |
| *{box-sizing:border-box} | |
| html,body{background:var(--bg)!important;color:var(--text)!important;font-family:Inter,system-ui,Arial,sans-serif} | |
| h1{font-weight:700;font-size:32px;margin:22px 0 10px;text-align:center;letter-spacing:.2px} | |
| p,li,button,.gr-button,label,.gr-text{font-size:14px} | |
| /* Cards */ | |
| .card{ | |
| background:var(--card); border:1px solid var(--border); border-radius:var(--radius); | |
| box-shadow:var(--shadow); padding:24px; max-width:1100px; margin:0 auto 28px; | |
| } | |
| /* Project links */ | |
| .link-btn{ | |
| display:inline-flex; /* icon + text centred vertically */ | |
| align-items:center; | |
| justify-content:center; | |
| margin:0 8px; | |
| padding:10px 18px; | |
| border-radius:10px; | |
| color:#fff; | |
| font-weight:650; | |
| text-decoration:none; | |
| box-shadow:0 6px 18px rgba(79,70,229,.18); | |
| transition:transform .12s ease,filter .12s ease; | |
| } | |
| .link-btn:hover{transform:translateY(-1px);filter:brightness(1.03)} | |
| .link-btn svg{margin-right:6px;vertical-align:middle} | |
| .link-btn.project{background:linear-gradient(135deg,#10b981,#059669)} | |
| .link-btn.arxiv {background:linear-gradient(135deg,#ef4444,#dc2626)} | |
| .link-btn.github {background:linear-gradient(135deg,#3b82f6,#2563eb)} | |
| /* Labels & inputs */ | |
| #input-card label{font-weight:650!important;color:var(--text)!important} | |
| textarea, input, .gr-textbox, .gr-number{ | |
| border-radius:12px!important; border:1px solid var(--border)!important; | |
| } | |
| #input-card .gr-row, #input-card .gr-cols{gap:16px} | |
| /* Buttons */ | |
| .gr-button{min-height:42px!important; padding:0 18px!important; border-radius:12px!important; font-weight:700!important} | |
| .gr-button.primary, .gr-button-primary{ | |
| background:var(--primary)!important; border-color:var(--primary)!important; color:#fff!important | |
| } | |
| .gr-button.primary:hover, .gr-button-primary:hover{background:var(--primary-dark)!important;border-color:var(--primary-dark)!important} | |
| /* Action buttons row */ | |
| #action-buttons{gap:12px} | |
| #extract-btn, #inference-btn{flex:1 1 260px!important; min-width:180px!important} | |
| #clear-btn{width:100%!important} | |
| /* Output */ | |
| #output-card{padding-top:0} | |
| #result-html{padding:0; margin:0} | |
| #result-html .heatmap-card{ | |
| background:var(--card); border:1px solid var(--border); border-radius:12px; padding:12px; box-shadow:var(--shadow) | |
| } | |
| /* Guidance */ | |
| #guidelines-card h2{font-size:18px;margin-bottom:14px;text-align:center} | |
| #guidelines-card ul{margin-left:18px;line-height:1.6} | |
| /* Small screens */ | |
| @media (max-width: 900px){ | |
| .card{margin:0 12px 24px} | |
| } | |
| """ | |
| # βββββ Gradio Interface Definition βββββββββββββββββββββββββββββββ | |
| with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) as demo: | |
| # βββββββββββββ Title βββββββββββββ | |
| gr.Markdown("<h1 style='text-align: center;'>Token-level Visualiser for Drug-Target Interaction</h1>") | |
| # βββββββββββββ Project Links (SVG icons) βββββββββββββ | |
| gr.HTML(""" | |
| <div style="text-align:center;margin-bottom:32px;"> | |
| <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank" rel="noopener noreferrer" aria-label="Project Page"> | |
| <!-- globe icon --> | |
| <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true"> | |
| <path d="M12 2a10 10 0 1 0 10 10A10.012 10.012 0 0 0 12 2Zm7.93 9h-3.18a15.84 15.84 0 0 0-1.19-5.02A8.02 8.02 0 0 1 19.93 11ZM12 4c.86 0 2.25 1.86 3.01 6H8.99C9.75 5.86 11.14 4 12 4ZM4.07 13h3.18c.2 1.79.66 3.47 1.19 5.02A8.02 8.02 0 0 1 4.07 13Zm3.18-2H4.07A8.02 8.02 0 0 1 8.44 5.98 15.84 15.84 0 0 0 7.25 11Zm1.37 2h6.76c-.76 4.14-2.15 6-3.01 6s-2.25-1.86-3.01-6Zm9.05 0h3.18a8.02 8.02 0 0 1-4.37 5.02 15.84 15.84 0 0 0 1.19-5.02Z"/> | |
| </svg> | |
| Project Page | |
| </a> | |
| <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank" rel="noopener noreferrer" aria-label="ArXiv: 2406.01651"> | |
| <!-- arXiv-like paper icon --> | |
| <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true"> | |
| <path d="M6 2h9l5 5v13a2 2 0 0 1-2 2H6a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2Zm8 1.5V8h4.5L14 3.5ZM7 12h10v2H7v-2Zm0 4h10v2H7v-2Zm0-8h6v2H7V8Z"/> | |
| </svg> | |
| ArXiv: 2406.01651 | |
| </a> | |
| <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank" rel="noopener noreferrer" aria-label="GitHub Repo"> | |
| <!-- GitHub mark --> | |
| <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true"> | |
| <path d="M12 .5A12 12 0 0 0 0 12.76c0 5.4 3.44 9.98 8.2 11.6.6.12.82-.28.82-.6v-2.3c-3.34.74-4.04-1.44-4.04-1.44-.54-1.38-1.32-1.74-1.32-1.74-1.08-.76.08-.74.08-.74 1.2.08 1.84 1.26 1.84 1.26 1.06 1.86 2.78 1.32 3.46 1.02.1-.8.42-1.32.76-1.62-2.66-.32-5.46-1.36-5.46-6.02 0-1.34.46-2.44 1.22-3.3-.12-.32-.54-1.64.12-3.42 0 0 1-.34 3.32 1.26.96-.28 1.98-.42 3-.42s2.04.14 3 .42c2.32-1.6 3.32-1.26 3.32-1.26.66 1.78.24 3.1.12 3.42.76.86 1.22 1.96 1.22 3.3 0 4.68-2.8 5.68-5.48 6 .44.38.84 1.12.84 2.28v3.38c0 .32.22.74.84.6A12.02 12.02 0 0 0 24 12.76 12 12 0 0 0 12 .5Z"/> | |
| </svg> | |
| GitHub Repo | |
| </a> | |
| </div> | |
| """) | |
| # βββββββββββββ Guidelines Card βββββββββββββ | |
| gr.HTML( | |
| """ | |
| <div class="card" id="guidelines-card" style="margin-bottom:24px"> | |
| <h2>Guidelines for Users</h2> | |
| <ul style="list-style:decimal;"> | |
| <li><strong>Convert protein structure into a structure-aware sequence:</strong> | |
| Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware | |
| sequence will be generated using | |
| <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>, | |
| based on 3D structures from | |
| <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> or the | |
| <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li> | |
| <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong> | |
| please first visit the | |
| <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a> | |
| or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> | |
| to download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li> | |
| <li><strong>Drug input supports both SELFIES and SMILES:</strong> | |
| Enter a SELFIES string directly, or paste a SMILES string. SMILES will | |
| be converted to SELFIES using the | |
| <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>. | |
| If conversion fails, a red error message will be displayed.</li> | |
| <li>Optionally enter a <strong>1-based</strong> drug atom/substructure index | |
| to highlight the Top-30 interacting protein residues.</li> | |
| <li>After inference, use βDownload PDFβ to export a high-resolution vector figure.</li> | |
| </ul> | |
| </div> | |
| """ | |
| ) | |
| # βββββββββββββ Input Card βββββββββββββ | |
| with gr.Column(elem_id="input-card", elem_classes="card"): | |
| protein_seq = gr.Textbox( | |
| label="Protein Structure-aware Sequence", | |
| lines=3, | |
| elem_id="protein-seq" | |
| ) | |
| drug_seq = gr.Textbox( | |
| label="Drug Sequence (SELFIES/SMILES)", | |
| lines=3, | |
| elem_id="drug-seq" | |
| ) | |
| structure_file = gr.File( | |
| label="Upload Protein Structure (.pdb/.cif)", | |
| file_types=[".pdb", ".cif"], | |
| elem_id="structure-file" | |
| ) | |
| drug_idx = gr.Number( | |
| label="Drug atom/substructure index (1-based)", | |
| value=None, | |
| precision=0, | |
| elem_id="drug-idx" | |
| ) | |
| # βββββββββββββ Action Buttons βββββββββββββ | |
| with gr.Row(elem_id="action-buttons", equal_height=True): | |
| btn_extract = gr.Button("Extract sequence", variant="primary", elem_id="extract-btn") | |
| btn_infer = gr.Button("Inference", variant="primary", elem_id="inference-btn") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", variant="secondary", elem_id="clear-btn") | |
| # βββββββββββββ Output Visualisation βββββββββββββ | |
| output_html = gr.HTML(elem_id="result-html") | |
| # βββββββββββββ Event Wiring βββββββββββββ | |
| btn_extract.click( | |
| fn=extract_sequence_cb, | |
| inputs=[structure_file], | |
| outputs=[protein_seq] | |
| ) | |
| btn_infer.click( | |
| fn=inference_cb, | |
| inputs=[protein_seq, drug_seq, drug_idx], | |
| outputs=[output_html] | |
| ) | |
| clear_btn.click( | |
| fn=clear_cb, | |
| inputs=[], | |
| outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |