FusionDTI / app.py
Zhaohan-Meng's picture
Update app.py
3df029f verified
# ─── monkey-patch gradio_client so bool schemas don’t crash json_schema_to_python_type ───
import gradio_client.utils as _gc_utils
# back up originals
_orig_get_type = _gc_utils.get_type
_orig_json2py = _gc_utils._json_schema_to_python_type
def _patched_get_type(schema):
# treat any boolean schema as if it were an empty dict
if isinstance(schema, bool):
schema = {}
return _orig_get_type(schema)
def _patched_json_schema_to_python_type(schema, defs=None):
# treat any boolean schema as if it were an empty dict
if isinstance(schema, bool):
schema = {}
return _orig_json2py(schema, defs)
_gc_utils.get_type = _patched_get_type
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
# ─── now it’s safe to import Gradio and build your interface ───────────────────────────
import gradio as gr
from gradio.themes import Soft
import os
import sys
import argparse
import tempfile
import shutil
import base64
import io
import torch
import selfies
from rdkit import Chem
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from typing import Optional
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
from torch.utils.data import DataLoader
from Bio.PDB import PDBParser, MMCIFParser
from Bio.Data import IUPACData
from utils.drug_tokenizer import DrugTokenizer
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
from utils.foldseek_util import get_struc_seq
# ───── Helpers ─────────────────────────────────────────────────
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
def simple_seq_from_structure(path: str) -> str:
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
structure = parser.get_structure("P", path)
chains = list(structure.get_chains())
if not chains:
return ""
chain = max(chains, key=lambda c: len(list(c.get_residues())))
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
def smiles_to_selfies(smiles: str) -> Optional[str]:
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return selfies.encoder(smiles)
except Exception:
return None
def parse_config():
p = argparse.ArgumentParser()
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
p.add_argument("--agg_mode", type=str, default="mean_all_tok")
p.add_argument("--group_size", type=int, default=1)
p.add_argument("--fusion", default="CAN")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--save_path_prefix", default="save_model_ckp/")
p.add_argument("--dataset", default="Human")
return p.parse_args()
args = parse_config()
DEVICE = args.device
# ───── Load models & tokenizers ─────────────────────────────────
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
drug_tokenizer = DrugTokenizer()
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
def collate_fn(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer.batch_encode_plus(
list(query1),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
query_encodings2 = drug_tokenizer.batch_encode_plus(
list(query2),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
def get_case_feature(model, loader):
model.eval()
with torch.no_grad():
for p_ids, p_mask, d_ids, d_mask, _ in loader:
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
return [(p_emb.cpu(), d_emb.cpu(),
p_ids.cpu(), d_ids.cpu(),
p_mask.cpu(), d_mask.cpu(), None)]
# ─────────────── visualisation ───────────────────────────────────────────
def _safe_is_special(tokenizer, tok: str) -> bool:
# Some tokenisers expose different special token sets; fall back conservatively.
special_sets = []
if hasattr(tokenizer, "all_special_tokens"):
special_sets.append(set(tokenizer.all_special_tokens))
if hasattr(tokenizer, "special_tokens_map"):
special_sets.extend(set(v) if isinstance(v, list) else {v}
for v in tokenizer.special_tokens_map.values())
for s in special_sets:
if tok in s:
return True
return False
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
"""
Render a Protein β†’ Drug cross-attention heat-map and optional Top-30 residue table.
"""
model.eval()
with torch.no_grad():
# ── unpack single-case tensors ───────────────────────────────────────────
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
# ── forward pass: Protein β†’ Drug attention (B, n_p, n_d) ───────────────
_, att_pd = model(p_emb, d_emb, p_mask, d_mask)
attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
# ── decode tokens (skip special symbols) ────────────────────────────────
def clean_ids(ids, tokenizer):
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
return [t for t in toks if not _safe_is_special(tokenizer, t)]
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
p_indices_full = list(range(1, len(p_tokens_full) + 1))
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
d_indices_full = list(range(1, len(d_tokens_full) + 1))
# ── safety cut-off to match attn mat size ──────────────────────────────
p_tokens = p_tokens_full[: attn.size(0)]
p_indices = p_indices_full[: attn.size(0)]
d_tokens = d_tokens_full[: attn.size(1)]
d_indices = d_indices_full[: attn.size(1)]
attn = attn[: len(p_tokens), : len(d_tokens)]
orig_attn = attn.clone()
# ── adaptive sparsity pruning ───────────────────────────────────────────
thr = attn.max().item() * 0.05 if attn.numel() > 0 else 0.0
row_keep = (attn.max(dim=1).values > thr) if attn.size(0) else torch.tensor([], dtype=torch.bool)
col_keep = (attn.max(dim=0).values > thr) if attn.size(1) else torch.tensor([], dtype=torch.bool)
if row_keep.sum().item() < 3 and attn.size(0) > 0:
row_keep = torch.ones(attn.size(0), dtype=torch.bool)
if col_keep.sum().item() < 3 and attn.size(1) > 0:
col_keep = torch.ones(attn.size(1), dtype=torch.bool)
attn = attn[row_keep][:, col_keep]
p_tokens = [tok for keep, tok in zip(row_keep.tolist(), p_tokens) if keep]
p_indices = [idx for keep, idx in zip(row_keep.tolist(), p_indices) if keep]
d_tokens = [tok for keep, tok in zip(col_keep.tolist(), d_tokens) if keep]
d_indices = [idx for keep, idx in zip(col_keep.tolist(), d_indices) if keep]
# ── cap column count at 150 for readability ─────────────────────────────
if attn.size(1) > 150:
topc = torch.topk(attn.sum(0), k=150).indices
attn = attn[:, topc]
d_tokens = [d_tokens[i] for i in topc]
d_indices = [d_indices[i] for i in topc]
# ── draw heat-map ──────────────────────────────────────────────────────
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
fig_w = min(22, max(8, len(x_labels) * 0.6))
fig_h = min(24, max(6, len(y_labels) * 0.8))
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
im = ax.imshow(attn.numpy(), aspect="auto", cmap=cm.viridis, interpolation="nearest")
ax.set_title("Protein β†’ Drug Attention", pad=8, fontsize=11)
ax.set_xticks(range(len(x_labels)))
ax.set_xticklabels(x_labels, rotation=90, fontsize=8, ha="center", va="center")
ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False, pad=27)
ax.set_yticks(range(len(y_labels)))
ax.set_yticklabels(y_labels, fontsize=7)
ax.tick_params(axis="y", top=True, bottom=False, labeltop=True, labelbottom=False, pad=10)
fig.colorbar(im, fraction=0.026, pad=0.01)
fig.tight_layout()
# build PNG / PDF
buf_png = io.BytesIO()
fig.savefig(buf_png, format="png", dpi=140)
buf_png.seek(0)
buf_pdf = io.BytesIO()
fig.savefig(buf_pdf, format="pdf")
buf_pdf.seek(0)
plt.close(fig)
png_b64 = base64.b64encode(buf_png.getvalue()).decode()
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
html_heat = (
f"<div class='heatmap-card' style='position: relative; width: 100%;'>"
f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
"style='position:absolute; top:12px; right:12px; "
"background: var(--primary); color:#fff; padding:8px 16px; border-radius:8px; "
"font-size:.92rem; font-weight:600; text-decoration:none;'>Download PDF</a>"
f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
f"<img src='data:image/png;base64,{png_b64}' "
"style='display:block; width:100%; height:auto; cursor:zoom-in;'/>"
"</a>"
"</div>"
)
# ───────────────────── Top-30 table (optional) ─────────────────────
table_html = ""
if drug_idx is not None and orig_attn.size(1) > 0 and 0 <= drug_idx < orig_attn.size(1):
# map original 0-based drug_idx β†’ pruned column
col_pos = None
if (drug_idx + 1) in d_indices:
col_pos = d_indices.index(drug_idx + 1)
elif 0 <= drug_idx < len(d_tokens):
col_pos = drug_idx
if col_pos is not None:
col_vec = attn[:, col_pos]
k = min(30, len(col_vec))
if k > 0:
topk = torch.topk(col_vec, k=k).indices.tolist()
# header cells
header_cells = (
"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Rank</th>"
+ "".join(
f"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center'>{r+1}</th>"
for r in range(len(topk))
)
)
residue_cells = (
"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Residue</th>"
+ "".join(
f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_tokens[i]}</td>"
for i in topk
)
)
position_cells = (
"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Position</th>"
+ "".join(
f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_indices[i]}</td>"
for i in topk
)
)
drug_tok_text = d_tokens[col_pos]
orig_idx_disp = d_indices[col_pos]
table_html = (
f"<div class='card' style='margin-top:18px'>"
f"<h4 style='margin:0 0 12px; font-size:1rem;'>"
f"Drug atom #{orig_idx_disp} <code>{drug_tok_text}</code> β†’ Top-30 Protein residues"
f"</h4>"
f"<table style='border-collapse:collapse; margin:0 auto 4px; font-size:.95rem'>"
f"<tr>{header_cells}</tr>"
f"<tr>{residue_cells}</tr>"
f"<tr>{position_cells}</tr>"
f"</table>"
f"</div>"
)
return table_html + html_heat
# ───── Gradio Callbacks ─────────────────────────────────────────
ROOT = os.path.dirname(os.path.abspath(__file__))
FOLDSEEK_BIN = os.path.join(ROOT, "bin", "foldseek")
def extract_sequence_cb(structure_file):
if structure_file is None or not os.path.exists(structure_file.name):
return ""
parsed = get_struc_seq(FOLDSEEK_BIN, structure_file.name, None, plddt_mask=False)
first_chain = next(iter(parsed))
_, _, struct_seq = parsed[first_chain]
return struct_seq
def inference_cb(prot_seq, drug_seq, atom_idx):
if not prot_seq:
return "<p style='color:red'>Please extract or enter a protein sequence first.</p>"
if not drug_seq.strip():
return "<p style='color:red'>Please enter a drug sequence.</p>"
if not drug_seq.strip().startswith("["):
conv = smiles_to_selfies(drug_seq.strip())
if conv is None:
return "<p style='color:red'>SMILES→SELFIES conversion failed.</p>"
drug_seq = conv
loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn)
feats = get_case_feature(encoding, loader)
model = FusionDTI(446, 768, args).to(DEVICE)
ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt")
if os.path.isfile(ckpt):
model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)
def clear_cb():
return "", "", None, "", None
# ───── Theme & CSS ─────────────────────────────────────────────
css = """
:root {
--bg:#f7f7fb;
--card:#ffffff;
--border:#e6e7eb;
--primary:#4f46e5;
--primary-dark:#4338ca;
--text:#0f172a;
--muted:#6b7280;
--radius:14px;
--shadow:0 10px 30px rgba(15,23,42,.06);
}
*{box-sizing:border-box}
html,body{background:var(--bg)!important;color:var(--text)!important;font-family:Inter,system-ui,Arial,sans-serif}
h1{font-weight:700;font-size:32px;margin:22px 0 10px;text-align:center;letter-spacing:.2px}
p,li,button,.gr-button,label,.gr-text{font-size:14px}
/* Cards */
.card{
background:var(--card); border:1px solid var(--border); border-radius:var(--radius);
box-shadow:var(--shadow); padding:24px; max-width:1100px; margin:0 auto 28px;
}
/* Project links */
.link-btn{
display:inline-flex; /* icon + text centred vertically */
align-items:center;
justify-content:center;
margin:0 8px;
padding:10px 18px;
border-radius:10px;
color:#fff;
font-weight:650;
text-decoration:none;
box-shadow:0 6px 18px rgba(79,70,229,.18);
transition:transform .12s ease,filter .12s ease;
}
.link-btn:hover{transform:translateY(-1px);filter:brightness(1.03)}
.link-btn svg{margin-right:6px;vertical-align:middle}
.link-btn.project{background:linear-gradient(135deg,#10b981,#059669)}
.link-btn.arxiv {background:linear-gradient(135deg,#ef4444,#dc2626)}
.link-btn.github {background:linear-gradient(135deg,#3b82f6,#2563eb)}
/* Labels & inputs */
#input-card label{font-weight:650!important;color:var(--text)!important}
textarea, input, .gr-textbox, .gr-number{
border-radius:12px!important; border:1px solid var(--border)!important;
}
#input-card .gr-row, #input-card .gr-cols{gap:16px}
/* Buttons */
.gr-button{min-height:42px!important; padding:0 18px!important; border-radius:12px!important; font-weight:700!important}
.gr-button.primary, .gr-button-primary{
background:var(--primary)!important; border-color:var(--primary)!important; color:#fff!important
}
.gr-button.primary:hover, .gr-button-primary:hover{background:var(--primary-dark)!important;border-color:var(--primary-dark)!important}
/* Action buttons row */
#action-buttons{gap:12px}
#extract-btn, #inference-btn{flex:1 1 260px!important; min-width:180px!important}
#clear-btn{width:100%!important}
/* Output */
#output-card{padding-top:0}
#result-html{padding:0; margin:0}
#result-html .heatmap-card{
background:var(--card); border:1px solid var(--border); border-radius:12px; padding:12px; box-shadow:var(--shadow)
}
/* Guidance */
#guidelines-card h2{font-size:18px;margin-bottom:14px;text-align:center}
#guidelines-card ul{margin-left:18px;line-height:1.6}
/* Small screens */
@media (max-width: 900px){
.card{margin:0 12px 24px}
}
"""
# ───── Gradio Interface Definition ───────────────────────────────
with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) as demo:
# ───────────── Title ─────────────
gr.Markdown("<h1 style='text-align: center;'>Token-level Visualiser for Drug-Target Interaction</h1>")
# ───────────── Project Links (SVG icons) ─────────────
gr.HTML("""
<div style="text-align:center;margin-bottom:32px;">
<a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank" rel="noopener noreferrer" aria-label="Project Page">
<!-- globe icon -->
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
<path d="M12 2a10 10 0 1 0 10 10A10.012 10.012 0 0 0 12 2Zm7.93 9h-3.18a15.84 15.84 0 0 0-1.19-5.02A8.02 8.02 0 0 1 19.93 11ZM12 4c.86 0 2.25 1.86 3.01 6H8.99C9.75 5.86 11.14 4 12 4ZM4.07 13h3.18c.2 1.79.66 3.47 1.19 5.02A8.02 8.02 0 0 1 4.07 13Zm3.18-2H4.07A8.02 8.02 0 0 1 8.44 5.98 15.84 15.84 0 0 0 7.25 11Zm1.37 2h6.76c-.76 4.14-2.15 6-3.01 6s-2.25-1.86-3.01-6Zm9.05 0h3.18a8.02 8.02 0 0 1-4.37 5.02 15.84 15.84 0 0 0 1.19-5.02Z"/>
</svg>
Project Page
</a>
<a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank" rel="noopener noreferrer" aria-label="ArXiv: 2406.01651">
<!-- arXiv-like paper icon -->
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
<path d="M6 2h9l5 5v13a2 2 0 0 1-2 2H6a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2Zm8 1.5V8h4.5L14 3.5ZM7 12h10v2H7v-2Zm0 4h10v2H7v-2Zm0-8h6v2H7V8Z"/>
</svg>
ArXiv: 2406.01651
</a>
<a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank" rel="noopener noreferrer" aria-label="GitHub Repo">
<!-- GitHub mark -->
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
<path d="M12 .5A12 12 0 0 0 0 12.76c0 5.4 3.44 9.98 8.2 11.6.6.12.82-.28.82-.6v-2.3c-3.34.74-4.04-1.44-4.04-1.44-.54-1.38-1.32-1.74-1.32-1.74-1.08-.76.08-.74.08-.74 1.2.08 1.84 1.26 1.84 1.26 1.06 1.86 2.78 1.32 3.46 1.02.1-.8.42-1.32.76-1.62-2.66-.32-5.46-1.36-5.46-6.02 0-1.34.46-2.44 1.22-3.3-.12-.32-.54-1.64.12-3.42 0 0 1-.34 3.32 1.26.96-.28 1.98-.42 3-.42s2.04.14 3 .42c2.32-1.6 3.32-1.26 3.32-1.26.66 1.78.24 3.1.12 3.42.76.86 1.22 1.96 1.22 3.3 0 4.68-2.8 5.68-5.48 6 .44.38.84 1.12.84 2.28v3.38c0 .32.22.74.84.6A12.02 12.02 0 0 0 24 12.76 12 12 0 0 0 12 .5Z"/>
</svg>
GitHub Repo
</a>
</div>
""")
# ───────────── Guidelines Card ─────────────
gr.HTML(
"""
<div class="card" id="guidelines-card" style="margin-bottom:24px">
<h2>Guidelines for Users</h2>
<ul style="list-style:decimal;">
<li><strong>Convert protein structure into a structure-aware sequence:</strong>
Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
sequence will be generated using
<a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
based on 3D structures from
<a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
<li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
please first visit the
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
to download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
<li><strong>Drug input supports both SELFIES and SMILES:</strong>
Enter a SELFIES string directly, or paste a SMILES string. SMILES will
be converted to SELFIES using the
<a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
If conversion fails, a red error message will be displayed.</li>
<li>Optionally enter a <strong>1-based</strong> drug atom/substructure index
to highlight the Top-30 interacting protein residues.</li>
<li>After inference, use β€œDownload PDF” to export a high-resolution vector figure.</li>
</ul>
</div>
"""
)
# ───────────── Input Card ─────────────
with gr.Column(elem_id="input-card", elem_classes="card"):
protein_seq = gr.Textbox(
label="Protein Structure-aware Sequence",
lines=3,
elem_id="protein-seq"
)
drug_seq = gr.Textbox(
label="Drug Sequence (SELFIES/SMILES)",
lines=3,
elem_id="drug-seq"
)
structure_file = gr.File(
label="Upload Protein Structure (.pdb/.cif)",
file_types=[".pdb", ".cif"],
elem_id="structure-file"
)
drug_idx = gr.Number(
label="Drug atom/substructure index (1-based)",
value=None,
precision=0,
elem_id="drug-idx"
)
# ───────────── Action Buttons ─────────────
with gr.Row(elem_id="action-buttons", equal_height=True):
btn_extract = gr.Button("Extract sequence", variant="primary", elem_id="extract-btn")
btn_infer = gr.Button("Inference", variant="primary", elem_id="inference-btn")
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary", elem_id="clear-btn")
# ───────────── Output Visualisation ─────────────
output_html = gr.HTML(elem_id="result-html")
# ───────────── Event Wiring ─────────────
btn_extract.click(
fn=extract_sequence_cb,
inputs=[structure_file],
outputs=[protein_seq]
)
btn_infer.click(
fn=inference_cb,
inputs=[protein_seq, drug_seq, drug_idx],
outputs=[output_html]
)
clear_btn.click(
fn=clear_cb,
inputs=[],
outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)