Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -425,7 +425,7 @@ def visualize_attention_and_ranges(
|
|
| 425 |
mode="residue":
|
| 426 |
- Aggregate attention over ligand dimension
|
| 427 |
- Rank residues by aggregated score
|
| 428 |
-
- Select Top-K residues (1β
|
| 429 |
- Default K = 1 (binding pocket discovery)
|
| 430 |
|
| 431 |
Notes
|
|
@@ -437,7 +437,7 @@ def visualize_attention_and_ranges(
|
|
| 437 |
|
| 438 |
assert mode in {"pair", "residue"}
|
| 439 |
assert topk_pairs >= 1
|
| 440 |
-
assert 1 <= topk_residues <=
|
| 441 |
|
| 442 |
model.eval()
|
| 443 |
with torch.no_grad():
|
|
@@ -451,8 +451,9 @@ def visualize_attention_and_ranges(
|
|
| 451 |
# --------------------------------------------------
|
| 452 |
# Forward
|
| 453 |
# --------------------------------------------------
|
| 454 |
-
|
| 455 |
att = att_pd.squeeze(0)
|
|
|
|
| 456 |
# expected: [Ld, Lp, 8] or [8, Ld, Lp]
|
| 457 |
|
| 458 |
# --------------------------------------------------
|
|
@@ -514,42 +515,101 @@ def visualize_attention_and_ranges(
|
|
| 514 |
d_tokens = labels
|
| 515 |
d_indices = list(range(1, len(labels) + 1))
|
| 516 |
|
|
|
|
| 517 |
# --------------------------------------------------
|
| 518 |
-
# Top-K selection (two modes)
|
| 519 |
# --------------------------------------------------
|
| 520 |
if mode == "pair":
|
| 521 |
-
|
| 522 |
flat = att2d.reshape(-1)
|
| 523 |
k_eff = min(topk_pairs, flat.numel())
|
| 524 |
-
|
|
|
|
| 525 |
|
| 526 |
mask_top = torch.zeros_like(flat, dtype=torch.bool)
|
| 527 |
-
mask_top[
|
| 528 |
mask_top = mask_top.view_as(att2d)
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
else:
|
| 531 |
-
# --- Top-K
|
| 532 |
-
residue_score = att2d.sum(dim=0)
|
| 533 |
k_eff = min(topk_residues, residue_score.numel())
|
| 534 |
-
|
|
|
|
| 535 |
|
| 536 |
mask_top = torch.zeros_like(att2d, dtype=torch.bool)
|
| 537 |
-
mask_top[:, topk_res_idx] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
# --------------------------------------------------
|
| 540 |
# Connected components (visual coherence)
|
| 541 |
# --------------------------------------------------
|
| 542 |
-
p_tokens_orig = p_tokens.copy()
|
| 543 |
-
d_tokens_orig = d_tokens.copy()
|
| 544 |
|
| 545 |
-
components = _connected_components_2d(mask_top)
|
| 546 |
|
| 547 |
-
ranges_html = _format_component_table(
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
)
|
| 553 |
|
| 554 |
|
| 555 |
# --------------------------------------------------
|
|
@@ -630,8 +690,42 @@ def visualize_attention_and_ranges(
|
|
| 630 |
<img src='data:image/png;base64,{png_b64}' />
|
| 631 |
</div>
|
| 632 |
"""
|
| 633 |
-
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
|
| 637 |
|
|
@@ -775,11 +869,30 @@ def inference_cb(prot_seq, drug_seq, head_choice, topk_choice, mode_choice):
|
|
| 775 |
else:
|
| 776 |
drug_seq_for_tokenizer = drug_seq_in
|
| 777 |
|
|
|
|
| 778 |
ltype = "selfies"
|
| 779 |
ligand_type_flag = "selfies"
|
| 780 |
raw_selfies = drug_seq_for_tokenizer
|
| 781 |
folder = f"{ptype}_selfies"
|
| 782 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
# ------------------------------
|
| 784 |
# Load encoders
|
| 785 |
# ------------------------------
|
|
@@ -846,7 +959,7 @@ def inference_cb(prot_seq, drug_seq, head_choice, topk_choice, mode_choice):
|
|
| 846 |
if mode_choice_clean == "Top-K residues":
|
| 847 |
mode = "residue"
|
| 848 |
topk_pairs = 1
|
| 849 |
-
topk_residues = min(
|
| 850 |
|
| 851 |
elif mode_choice_clean == "Top-K residues-atom pairs":
|
| 852 |
mode = "pair"
|
|
@@ -859,11 +972,10 @@ def inference_cb(prot_seq, drug_seq, head_choice, topk_choice, mode_choice):
|
|
| 859 |
topk_pairs = topk
|
| 860 |
topk_residues = 1
|
| 861 |
|
| 862 |
-
|
| 863 |
# ------------------------------
|
| 864 |
# Visualisation
|
| 865 |
# ------------------------------
|
| 866 |
-
|
| 867 |
model,
|
| 868 |
feats,
|
| 869 |
head_idx,
|
|
@@ -875,12 +987,12 @@ def inference_cb(prot_seq, drug_seq, head_choice, topk_choice, mode_choice):
|
|
| 875 |
ligand_type=ligand_type_flag,
|
| 876 |
raw_selfies=raw_selfies,
|
| 877 |
)
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
|
| 882 |
def clear_cb():
|
| 883 |
-
return "", "", "",
|
| 884 |
# protein, drug, table, heat, file, status
|
| 885 |
|
| 886 |
|
|
@@ -1014,6 +1126,12 @@ h1{
|
|
| 1014 |
margin-bottom:32px !important;
|
| 1015 |
}
|
| 1016 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
"""
|
| 1018 |
with gr.Blocks() as demo:
|
| 1019 |
|
|
@@ -1065,47 +1183,53 @@ with gr.Blocks() as demo:
|
|
| 1065 |
gr.HTML("""
|
| 1066 |
<ol style="font-size:1rem;line-height:1.6;margin-left:22px;">
|
| 1067 |
<li>
|
| 1068 |
-
<strong>Input
|
| 1069 |
-
The model
|
| 1070 |
-
and
|
|
|
|
| 1071 |
</li>
|
|
|
|
| 1072 |
<li>
|
| 1073 |
-
<strong>
|
| 1074 |
-
|
| 1075 |
-
(
|
| 1076 |
-
|
| 1077 |
</li>
|
|
|
|
| 1078 |
<li>
|
| 1079 |
<strong>Top-K mode:</strong>
|
| 1080 |
<ul style="margin-top:6px;">
|
| 1081 |
<li>
|
| 1082 |
-
<em>Top-K
|
| 1083 |
-
ranks individual protein
|
|
|
|
| 1084 |
</li>
|
| 1085 |
<li>
|
| 1086 |
<em>Top-K residues</em>:
|
| 1087 |
-
ranks protein residues by
|
|
|
|
| 1088 |
</li>
|
| 1089 |
</ul>
|
| 1090 |
</li>
|
|
|
|
| 1091 |
<li>
|
| 1092 |
-
<strong>
|
| 1093 |
-
|
| 1094 |
-
and a corresponding results table
|
| 1095 |
-
based on the selected Top-K mode.
|
| 1096 |
</li>
|
| 1097 |
</ol>
|
| 1098 |
""")
|
| 1099 |
|
| 1100 |
|
| 1101 |
# βββββββββββββββββββββββββββββββ
|
| 1102 |
-
# Inputs
|
| 1103 |
# βββββββββββββββββββββββββββββββ
|
| 1104 |
with gr.Row():
|
| 1105 |
with gr.Column(elem_classes=["card", "grid-2"]):
|
| 1106 |
|
|
|
|
|
|
|
|
|
|
| 1107 |
with gr.Column(elem_id="left"):
|
| 1108 |
-
|
| 1109 |
protein_seq = gr.Textbox(
|
| 1110 |
label="Protein structure-aware / FASTA sequence",
|
| 1111 |
lines=3,
|
|
@@ -1135,8 +1259,8 @@ with gr.Blocks() as demo:
|
|
| 1135 |
|
| 1136 |
gr.Examples(
|
| 1137 |
examples=[[
|
| 1138 |
-
"
|
| 1139 |
-
"[C][
|
| 1140 |
]],
|
| 1141 |
inputs=[protein_seq, drug_seq],
|
| 1142 |
label="Click to load an example",
|
|
@@ -1145,16 +1269,15 @@ with gr.Blocks() as demo:
|
|
| 1145 |
btn_load_example = gr.Button(
|
| 1146 |
"Load Example",
|
| 1147 |
elem_id="example-btn",
|
| 1148 |
-
variant="secondary"
|
| 1149 |
)
|
| 1150 |
-
|
| 1151 |
protein_seq.render()
|
| 1152 |
drug_seq.render()
|
|
|
|
| 1153 |
btn_extract = gr.Button(
|
| 1154 |
"Extract sequences",
|
| 1155 |
elem_id="extract-btn"
|
| 1156 |
)
|
| 1157 |
-
structure_file.render()
|
| 1158 |
|
| 1159 |
# ββββββββββββββββ
|
| 1160 |
# RIGHT PANEL
|
|
@@ -1162,7 +1285,7 @@ with gr.Blocks() as demo:
|
|
| 1162 |
with gr.Column(elem_id="right", elem_classes=["right-pane"]):
|
| 1163 |
|
| 1164 |
head_dd = gr.Dropdown(
|
| 1165 |
-
label="
|
| 1166 |
choices=INTERACTION_NAMES,
|
| 1167 |
value="Overall Interaction",
|
| 1168 |
interactive=True,
|
|
@@ -1200,8 +1323,10 @@ with gr.Blocks() as demo:
|
|
| 1200 |
# βββββββββββββββββββββββββββββββ
|
| 1201 |
with gr.Column(elem_classes=["card"]):
|
| 1202 |
status_box = gr.HTML(elem_id="status-box")
|
| 1203 |
-
|
| 1204 |
-
|
|
|
|
|
|
|
| 1205 |
|
| 1206 |
# βββββββββββββββββββββββββββββββ
|
| 1207 |
# Example Loader Callback
|
|
@@ -1209,7 +1334,7 @@ with gr.Blocks() as demo:
|
|
| 1209 |
def load_example_cb():
|
| 1210 |
return (
|
| 1211 |
"MTLSILVAHDLQRVIGFENQLPWHLPNDLKHVKKLSTGHTLVMGRKTFESIGKPLPNRRNVVLTSDTSFNVEGVDVIHSIEDIYQLPGHVFIFGGQTLFEEMIDKVDDMYITVIEGKFRGDTFFPPYTFEDWEVASSVEGKLDEKNTIPHTFLHLIRKK",
|
| 1212 |
-
"[C][
|
| 1213 |
)
|
| 1214 |
|
| 1215 |
# βββββββββββββββββββββββββββββββ
|
|
@@ -1223,22 +1348,14 @@ with gr.Blocks() as demo:
|
|
| 1223 |
|
| 1224 |
btn_extract.click(
|
| 1225 |
fn=extract_sequence_cb,
|
| 1226 |
-
inputs=[
|
| 1227 |
-
|
| 1228 |
-
drug_seq,
|
| 1229 |
-
protein_seq,
|
| 1230 |
-
],
|
| 1231 |
-
outputs=[
|
| 1232 |
-
protein_seq,
|
| 1233 |
-
drug_seq,
|
| 1234 |
-
status_box,
|
| 1235 |
-
],
|
| 1236 |
)
|
| 1237 |
|
| 1238 |
btn_infer.click(
|
| 1239 |
fn=inference_cb,
|
| 1240 |
inputs=[protein_seq, drug_seq, head_dd, top_k_dd, mode_dd],
|
| 1241 |
-
outputs=[
|
| 1242 |
)
|
| 1243 |
|
| 1244 |
clear_btn.click(
|
|
@@ -1247,13 +1364,14 @@ with gr.Blocks() as demo:
|
|
| 1247 |
outputs=[
|
| 1248 |
protein_seq,
|
| 1249 |
drug_seq,
|
| 1250 |
-
|
| 1251 |
-
output_heat,
|
| 1252 |
structure_file,
|
| 1253 |
status_box,
|
| 1254 |
],
|
| 1255 |
)
|
| 1256 |
|
|
|
|
|
|
|
| 1257 |
demo.launch(
|
| 1258 |
theme=gr.themes.Default(),
|
| 1259 |
css=css,
|
|
|
|
| 425 |
mode="residue":
|
| 426 |
- Aggregate attention over ligand dimension
|
| 427 |
- Rank residues by aggregated score
|
| 428 |
+
- Select Top-K residues (1β100)
|
| 429 |
- Default K = 1 (binding pocket discovery)
|
| 430 |
|
| 431 |
Notes
|
|
|
|
| 437 |
|
| 438 |
assert mode in {"pair", "residue"}
|
| 439 |
assert topk_pairs >= 1
|
| 440 |
+
assert 1 <= topk_residues <= 100
|
| 441 |
|
| 442 |
model.eval()
|
| 443 |
with torch.no_grad():
|
|
|
|
| 451 |
# --------------------------------------------------
|
| 452 |
# Forward
|
| 453 |
# --------------------------------------------------
|
| 454 |
+
prob, att_pd = model(p_emb, d_emb, p_mask, d_mask)
|
| 455 |
att = att_pd.squeeze(0)
|
| 456 |
+
prob = prob.item()
|
| 457 |
# expected: [Ld, Lp, 8] or [8, Ld, Lp]
|
| 458 |
|
| 459 |
# --------------------------------------------------
|
|
|
|
| 515 |
d_tokens = labels
|
| 516 |
d_indices = list(range(1, len(labels) + 1))
|
| 517 |
|
| 518 |
+
|
| 519 |
# --------------------------------------------------
|
| 520 |
+
# Top-K selection (two modes, STRICT RANKING)
|
| 521 |
# --------------------------------------------------
|
| 522 |
if mode == "pair":
|
| 523 |
+
|
| 524 |
flat = att2d.reshape(-1)
|
| 525 |
k_eff = min(topk_pairs, flat.numel())
|
| 526 |
+
|
| 527 |
+
topk_vals, topk_idx = torch.topk(flat, k=k_eff)
|
| 528 |
|
| 529 |
mask_top = torch.zeros_like(flat, dtype=torch.bool)
|
| 530 |
+
mask_top[topk_idx] = True
|
| 531 |
mask_top = mask_top.view_as(att2d)
|
| 532 |
|
| 533 |
+
rows = []
|
| 534 |
+
n_cols = att2d.size(1)
|
| 535 |
+
|
| 536 |
+
for rank, (val, linear_idx) in enumerate(zip(topk_vals, topk_idx), start=1):
|
| 537 |
+
i = (linear_idx // n_cols).item()
|
| 538 |
+
j = (linear_idx % n_cols).item()
|
| 539 |
+
|
| 540 |
+
rows.append(
|
| 541 |
+
f"<tr>"
|
| 542 |
+
f"<td style='border:1px solid #ddd;padding:6px'><strong>Top {rank}</strong></td>"
|
| 543 |
+
f"<td style='border:1px solid #ddd;padding:6px'>Protein: <strong>{j+1}:{p_tokens[j]}</strong></td>"
|
| 544 |
+
f"<td style='border:1px solid #ddd;padding:6px'>Ligand: <strong>{i+1}:{d_tokens[i]}</strong></td>"
|
| 545 |
+
f"<td style='border:1px solid #ddd;padding:6px'>Score: <strong>{val.item():.6f}</strong></td>"
|
| 546 |
+
f"</tr>"
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
ranges_html = (
|
| 550 |
+
"<h4 style='margin:12px 0 6px'>Top-K Interaction Pairs (ranked by attention score)</h4>"
|
| 551 |
+
"<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>"
|
| 552 |
+
"<thead><tr style='background:#f5f5f5'>"
|
| 553 |
+
"<th style='border:1px solid #ddd;padding:6px'>Rank</th>"
|
| 554 |
+
"<th style='border:1px solid #ddd;padding:6px'>Protein</th>"
|
| 555 |
+
"<th style='border:1px solid #ddd;padding:6px'>Ligand</th>"
|
| 556 |
+
"<th style='border:1px solid #ddd;padding:6px'>Attention Score</th>"
|
| 557 |
+
"</tr></thead>"
|
| 558 |
+
f"<tbody>{''.join(rows)}</tbody></table>"
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
else:
|
| 562 |
+
# --- STRICT Top-K residue ranking ---
|
| 563 |
+
residue_score = att2d.sum(dim=0)
|
| 564 |
k_eff = min(topk_residues, residue_score.numel())
|
| 565 |
+
|
| 566 |
+
topk_vals, topk_res_idx = torch.topk(residue_score, k=k_eff)
|
| 567 |
|
| 568 |
mask_top = torch.zeros_like(att2d, dtype=torch.bool)
|
| 569 |
+
mask_top[:, topk_res_idx] = True
|
| 570 |
+
|
| 571 |
+
rows = []
|
| 572 |
+
|
| 573 |
+
for rank, (val, j) in enumerate(zip(topk_vals, topk_res_idx), start=1):
|
| 574 |
+
j = j.item()
|
| 575 |
+
|
| 576 |
+
rows.append(
|
| 577 |
+
f"<tr>"
|
| 578 |
+
f"<td style='border:1px solid #ddd;padding:6px'><strong>Top {rank}</strong></td>"
|
| 579 |
+
f"<td style='border:1px solid #ddd;padding:6px'>"
|
| 580 |
+
f"Protein residue: <strong>{j+1}:{p_tokens[j]}</strong>"
|
| 581 |
+
f"</td>"
|
| 582 |
+
f"<td style='border:1px solid #ddd;padding:6px'>"
|
| 583 |
+
f"Aggregated Score: <strong>{val.item():.6f}</strong>"
|
| 584 |
+
f"</td>"
|
| 585 |
+
f"</tr>"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
ranges_html = (
|
| 589 |
+
"<h4 style='margin:12px 0 6px'>Top-K Residues (ranked by aggregated attention)</h4>"
|
| 590 |
+
"<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>"
|
| 591 |
+
"<thead><tr style='background:#f5f5f5'>"
|
| 592 |
+
"<th style='border:1px solid #ddd;padding:6px'>Rank</th>"
|
| 593 |
+
"<th style='border:1px solid #ddd;padding:6px'>Protein Residue</th>"
|
| 594 |
+
"<th style='border:1px solid #ddd;padding:6px'>Aggregated Score</th>"
|
| 595 |
+
"</tr></thead>"
|
| 596 |
+
f"<tbody>{''.join(rows)}</tbody></table>"
|
| 597 |
+
)
|
| 598 |
|
| 599 |
# --------------------------------------------------
|
| 600 |
# Connected components (visual coherence)
|
| 601 |
# --------------------------------------------------
|
| 602 |
+
# p_tokens_orig = p_tokens.copy()
|
| 603 |
+
# d_tokens_orig = d_tokens.copy()
|
| 604 |
|
| 605 |
+
# components = _connected_components_2d(mask_top)
|
| 606 |
|
| 607 |
+
# ranges_html = _format_component_table(
|
| 608 |
+
# components,
|
| 609 |
+
# p_tokens_orig,
|
| 610 |
+
# d_tokens_orig,
|
| 611 |
+
# mode=mode,
|
| 612 |
+
# )
|
| 613 |
|
| 614 |
|
| 615 |
# --------------------------------------------------
|
|
|
|
| 690 |
<img src='data:image/png;base64,{png_b64}' />
|
| 691 |
</div>
|
| 692 |
"""
|
| 693 |
+
# ------------------------------
|
| 694 |
+
# Probability display card
|
| 695 |
+
# ------------------------------
|
| 696 |
+
if prob >= 0.8:
|
| 697 |
+
bg = "#ecfdf5"
|
| 698 |
+
border = "#10b981"
|
| 699 |
+
label = "High binding confidence"
|
| 700 |
+
elif prob >= 0.4:
|
| 701 |
+
bg = "#eff6ff"
|
| 702 |
+
border = "#3b82f6"
|
| 703 |
+
label = "Moderate binding confidence"
|
| 704 |
+
else:
|
| 705 |
+
bg = "#fef2f2"
|
| 706 |
+
border = "#ef4444"
|
| 707 |
+
label = "Low binding confidence"
|
| 708 |
+
|
| 709 |
+
prob_html = f"""
|
| 710 |
+
<div style='margin:10px 0 18px;
|
| 711 |
+
padding:14px 16px;
|
| 712 |
+
border-left:5px solid {border};
|
| 713 |
+
border-radius:12px;
|
| 714 |
+
background:{bg};
|
| 715 |
+
font-size:1rem'>
|
| 716 |
+
<div style='font-weight:600;margin-bottom:4px'>
|
| 717 |
+
Predicted Binding Probability
|
| 718 |
+
</div>
|
| 719 |
+
<div style='font-size:1.4rem;font-weight:700'>
|
| 720 |
+
{prob:.4f}
|
| 721 |
+
</div>
|
| 722 |
+
<div style='font-size:0.85rem;color:#64748b;margin-top:4px'>
|
| 723 |
+
{label}
|
| 724 |
+
</div>
|
| 725 |
+
</div>
|
| 726 |
+
"""
|
| 727 |
+
|
| 728 |
+
return prob_html, ranges_html, heat_html
|
| 729 |
|
| 730 |
|
| 731 |
|
|
|
|
| 869 |
else:
|
| 870 |
drug_seq_for_tokenizer = drug_seq_in
|
| 871 |
|
| 872 |
+
# π εΌΊεΆη»δΈη±»ε
|
| 873 |
ltype = "selfies"
|
| 874 |
ligand_type_flag = "selfies"
|
| 875 |
raw_selfies = drug_seq_for_tokenizer
|
| 876 |
folder = f"{ptype}_selfies"
|
| 877 |
|
| 878 |
+
|
| 879 |
+
# # Ligand normalisation: always tokenise as SELFIES
|
| 880 |
+
# if ltype == "smiles":
|
| 881 |
+
# conv = smiles_to_selfies(drug_seq_in)
|
| 882 |
+
# if conv is None:
|
| 883 |
+
# return (
|
| 884 |
+
# "<p style='color:red'>SMILESβSELFIES conversion failed. "
|
| 885 |
+
# "The SMILES appears invalid.</p>",
|
| 886 |
+
# "",
|
| 887 |
+
# )
|
| 888 |
+
# drug_seq_for_tokenizer = conv
|
| 889 |
+
# ligand_type_flag = "selfies"
|
| 890 |
+
# else:
|
| 891 |
+
# drug_seq_for_tokenizer = drug_seq_in
|
| 892 |
+
# ligand_type_flag = "selfies"
|
| 893 |
+
|
| 894 |
+
# raw_selfies = drug_seq_for_tokenizer if ligand_type_flag == "selfies" else None
|
| 895 |
+
|
| 896 |
# ------------------------------
|
| 897 |
# Load encoders
|
| 898 |
# ------------------------------
|
|
|
|
| 959 |
if mode_choice_clean == "Top-K residues":
|
| 960 |
mode = "residue"
|
| 961 |
topk_pairs = 1
|
| 962 |
+
topk_residues = min(100, topk)
|
| 963 |
|
| 964 |
elif mode_choice_clean == "Top-K residues-atom pairs":
|
| 965 |
mode = "pair"
|
|
|
|
| 972 |
topk_pairs = topk
|
| 973 |
topk_residues = 1
|
| 974 |
|
|
|
|
| 975 |
# ------------------------------
|
| 976 |
# Visualisation
|
| 977 |
# ------------------------------
|
| 978 |
+
prob_html, table_html, heat_html = visualize_attention_and_ranges(
|
| 979 |
model,
|
| 980 |
feats,
|
| 981 |
head_idx,
|
|
|
|
| 987 |
ligand_type=ligand_type_flag,
|
| 988 |
raw_selfies=raw_selfies,
|
| 989 |
)
|
| 990 |
+
|
| 991 |
+
full_html = prob_html + table_html + heat_html # β
εΌΊεΆδΈδΈι‘ΊεΊ
|
| 992 |
+
return full_html
|
| 993 |
|
| 994 |
def clear_cb():
|
| 995 |
+
return "", "", "", None, ""
|
| 996 |
# protein, drug, table, heat, file, status
|
| 997 |
|
| 998 |
|
|
|
|
| 1126 |
margin-bottom:32px !important;
|
| 1127 |
}
|
| 1128 |
|
| 1129 |
+
#example-btn {
|
| 1130 |
+
background: #979ea8 !important;
|
| 1131 |
+
color: #1e293b !important;
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
"""
|
| 1136 |
with gr.Blocks() as demo:
|
| 1137 |
|
|
|
|
| 1183 |
gr.HTML("""
|
| 1184 |
<ol style="font-size:1rem;line-height:1.6;margin-left:22px;">
|
| 1185 |
<li>
|
| 1186 |
+
<strong>Input formats:</strong>
|
| 1187 |
+
The model accepts <em>structure-aware (SA)</em> or <em>FASTA</em> protein sequences,
|
| 1188 |
+
and <em>SMILES</em> or <em>SELFIES</em> representations for ligands.
|
| 1189 |
+
For SA mode, <code>.pdb</code> or <code>.cif</code> files can be uploaded directly.
|
| 1190 |
</li>
|
| 1191 |
+
|
| 1192 |
<li>
|
| 1193 |
+
<strong>Interaction type selection:</strong>
|
| 1194 |
+
Choose the desired non-covalent interaction type
|
| 1195 |
+
(e.g., overall interaction or specific physicochemical channels)
|
| 1196 |
+
to visualise token-level binding patterns.
|
| 1197 |
</li>
|
| 1198 |
+
|
| 1199 |
<li>
|
| 1200 |
<strong>Top-K mode:</strong>
|
| 1201 |
<ul style="margin-top:6px;">
|
| 1202 |
<li>
|
| 1203 |
+
<em>Top-K residueβatom pairs</em>:
|
| 1204 |
+
ranks individual protein residueβligand atom pairs
|
| 1205 |
+
according to their attention scores.
|
| 1206 |
</li>
|
| 1207 |
<li>
|
| 1208 |
<em>Top-K residues</em>:
|
| 1209 |
+
ranks protein residues by aggregating attention
|
| 1210 |
+
across all ligand tokens.
|
| 1211 |
</li>
|
| 1212 |
</ul>
|
| 1213 |
</li>
|
| 1214 |
+
|
| 1215 |
<li>
|
| 1216 |
+
<strong>Output:</strong>
|
| 1217 |
+
The demo system reports a predicted binding probability, a ranked Top-K interaction table, and a token-level attention heat map.
|
|
|
|
|
|
|
| 1218 |
</li>
|
| 1219 |
</ol>
|
| 1220 |
""")
|
| 1221 |
|
| 1222 |
|
| 1223 |
# βββββββββββββββββββββββββββββββ
|
| 1224 |
+
# Inputs + Controls
|
| 1225 |
# βββββββββββββββββββββββββββββββ
|
| 1226 |
with gr.Row():
|
| 1227 |
with gr.Column(elem_classes=["card", "grid-2"]):
|
| 1228 |
|
| 1229 |
+
# ββββββββββββββββ
|
| 1230 |
+
# LEFT PANEL
|
| 1231 |
+
# ββββββββββββββββ
|
| 1232 |
with gr.Column(elem_id="left"):
|
|
|
|
| 1233 |
protein_seq = gr.Textbox(
|
| 1234 |
label="Protein structure-aware / FASTA sequence",
|
| 1235 |
lines=3,
|
|
|
|
| 1259 |
|
| 1260 |
gr.Examples(
|
| 1261 |
examples=[[
|
| 1262 |
+
"SLALSLTADQMVSALLDAEPPILYSEYDPTRPFSEASMMGLLTNLADRELVHMINWAKRVPGFVDLTSHDQVHLLECAWLEILMIGLVWRSMEHPGKLLFAPNLLLDRNQGKCVEGMVEIFDMLLATSSRFRMMNLQGEEFVCLKSIILLNSGVYTFLSSTLKSLEEKDHIHRVLDKITDTLIHLMAKAGLTLQQQHQRLAQLLLILSHIRHMSNKGMEHLYSMKCKNVVPSYDLLLEMLDA",
|
| 1263 |
+
"[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]"
|
| 1264 |
]],
|
| 1265 |
inputs=[protein_seq, drug_seq],
|
| 1266 |
label="Click to load an example",
|
|
|
|
| 1269 |
btn_load_example = gr.Button(
|
| 1270 |
"Load Example",
|
| 1271 |
elem_id="example-btn",
|
| 1272 |
+
# variant="secondary"
|
| 1273 |
)
|
|
|
|
| 1274 |
protein_seq.render()
|
| 1275 |
drug_seq.render()
|
| 1276 |
+
structure_file.render()
|
| 1277 |
btn_extract = gr.Button(
|
| 1278 |
"Extract sequences",
|
| 1279 |
elem_id="extract-btn"
|
| 1280 |
)
|
|
|
|
| 1281 |
|
| 1282 |
# ββββββββββββββββ
|
| 1283 |
# RIGHT PANEL
|
|
|
|
| 1285 |
with gr.Column(elem_id="right", elem_classes=["right-pane"]):
|
| 1286 |
|
| 1287 |
head_dd = gr.Dropdown(
|
| 1288 |
+
label="Non-covalent interaction type/Overall",
|
| 1289 |
choices=INTERACTION_NAMES,
|
| 1290 |
value="Overall Interaction",
|
| 1291 |
interactive=True,
|
|
|
|
| 1323 |
# βββββββββββββββββββββββββββββββ
|
| 1324 |
with gr.Column(elem_classes=["card"]):
|
| 1325 |
status_box = gr.HTML(elem_id="status-box")
|
| 1326 |
+
output_full = gr.HTML(elem_id="result-full")
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
|
| 1330 |
|
| 1331 |
# βββββββββββββββββββββββββββββββ
|
| 1332 |
# Example Loader Callback
|
|
|
|
| 1334 |
def load_example_cb():
|
| 1335 |
return (
|
| 1336 |
"MTLSILVAHDLQRVIGFENQLPWHLPNDLKHVKKLSTGHTLVMGRKTFESIGKPLPNRRNVVLTSDTSFNVEGVDVIHSIEDIYQLPGHVFIFGGQTLFEEMIDKVDDMYITVIEGKFRGDTFFPPYTFEDWEVASSVEGKLDEKNTIPHTFLHLIRKK",
|
| 1337 |
+
"[C][=C][C][=Branch2][Branch1][#C][=C][C][=C][Ring1][=Branch1][C][=C][Branch2][Ring2][#Branch2][C@H1][C@@H1][Branch1][Branch2][C][C@@H1][Ring1][=Branch1][O][Ring1][Branch1][S][=Branch1][C][=O][=Branch1][C][=O][N][Branch1][#Branch2][C][C][Branch1][C][F][Branch1][C][F][F][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][Branch1][C][=C][Ring1][=Branch1][O][O]"
|
| 1338 |
)
|
| 1339 |
|
| 1340 |
# βββββββββββββββββββββββββββββββ
|
|
|
|
| 1348 |
|
| 1349 |
btn_extract.click(
|
| 1350 |
fn=extract_sequence_cb,
|
| 1351 |
+
inputs=[structure_file, drug_seq, protein_seq],
|
| 1352 |
+
outputs=[protein_seq, drug_seq, status_box],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1353 |
)
|
| 1354 |
|
| 1355 |
btn_infer.click(
|
| 1356 |
fn=inference_cb,
|
| 1357 |
inputs=[protein_seq, drug_seq, head_dd, top_k_dd, mode_dd],
|
| 1358 |
+
outputs=[output_full],
|
| 1359 |
)
|
| 1360 |
|
| 1361 |
clear_btn.click(
|
|
|
|
| 1364 |
outputs=[
|
| 1365 |
protein_seq,
|
| 1366 |
drug_seq,
|
| 1367 |
+
output_full,
|
|
|
|
| 1368 |
structure_file,
|
| 1369 |
status_box,
|
| 1370 |
],
|
| 1371 |
)
|
| 1372 |
|
| 1373 |
+
|
| 1374 |
+
|
| 1375 |
demo.launch(
|
| 1376 |
theme=gr.themes.Default(),
|
| 1377 |
css=css,
|