Zhaohan-Meng commited on
Commit
3df029f
Β·
verified Β·
1 Parent(s): fad71d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -258
app.py CHANGED
@@ -2,8 +2,8 @@
2
  import gradio_client.utils as _gc_utils
3
 
4
  # back up originals
5
- _orig_get_type = _gc_utils.get_type
6
- _orig_json2py = _gc_utils._json_schema_to_python_type
7
 
8
  def _patched_get_type(schema):
9
  # treat any boolean schema as if it were an empty dict
@@ -54,6 +54,7 @@ from utils.foldseek_util import get_struc_seq
54
 
55
  three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
56
  three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
 
57
  def simple_seq_from_structure(path: str) -> str:
58
  parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
59
  structure = parser.get_structure("P", path)
@@ -69,7 +70,7 @@ def smiles_to_selfies(smiles: str) -> Optional[str]:
69
  if mol is None:
70
  return None
71
  return selfies.encoder(smiles)
72
- except:
73
  return None
74
 
75
  def parse_config():
@@ -132,20 +133,23 @@ def get_case_feature(model, loader):
132
  p_ids.cpu(), d_ids.cpu(),
133
  p_mask.cpu(), d_mask.cpu(), None)]
134
 
135
-
136
  # ─────────────── visualisation ───────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
138
  """
139
- Render a Protein β†’ Drug cross-attention heat-map and, optionally, a
140
- Top-30 protein-residue table for a chosen drug-token index.
141
- The token index shown on the x-axis (and accepted via *drug_idx*) is **the
142
- position of that token in the *original* drug sequence**, *after* the
143
- tokeniser but *before* any pruning or truncation (1-based in the labels,
144
- 0-based for the function argument).
145
- Returns
146
- -------
147
- html : str
148
- Base64-embedded PNG heat-map (+ optional HTML table).
149
  """
150
  model.eval()
151
  with torch.no_grad():
@@ -156,154 +160,75 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
156
 
157
  # ── forward pass: Protein β†’ Drug attention (B, n_p, n_d) ───────────────
158
  _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
159
- attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
160
 
161
  # ── decode tokens (skip special symbols) ────────────────────────────────
162
  def clean_ids(ids, tokenizer):
163
  toks = tokenizer.convert_ids_to_tokens(ids.tolist())
164
- return [t for t in toks if t not in tokenizer.all_special_tokens]
165
 
166
- # ── decode full sequences + record 1-based indices ──────────────────
167
  p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
168
  p_indices_full = list(range(1, len(p_tokens_full) + 1))
169
-
170
  d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
171
  d_indices_full = list(range(1, len(d_tokens_full) + 1))
172
 
173
- # ── safety cut-off to match attn mat size ───────────────────────────────
174
- p_tokens = p_tokens_full[: attn.size(0)]
175
- p_indices_full = p_indices_full[: attn.size(0)]
176
- d_tokens_full = d_tokens_full[: attn.size(1)]
177
- d_indices_full = d_indices_full[: attn.size(1)]
178
- attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
179
 
180
  orig_attn = attn.clone()
 
181
  # ── adaptive sparsity pruning ───────────────────────────────────────────
182
- thr = attn.max().item() * 0.05
183
- row_keep = (attn.max(dim=1).values > thr)
184
- col_keep = (attn.max(dim=0).values > thr)
185
 
186
- if row_keep.sum() < 3:
187
- row_keep[:] = True
188
- if col_keep.sum() < 3:
189
- col_keep[:] = True
190
 
191
- attn = attn[row_keep][:, col_keep]
192
- p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
193
- p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
194
- d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
195
- d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
196
 
197
  # ── cap column count at 150 for readability ─────────────────────────────
198
  if attn.size(1) > 150:
199
- topc = torch.topk(attn.sum(0), k=150).indices
200
- attn = attn[:, topc]
201
- d_tokens = [d_tokens [i] for i in topc]
202
- d_indices = [d_indices[i] for i in topc]
203
 
204
- # ── draw heat-map ───────────────────────────────────────────────────────
205
  x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
206
  y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
207
 
208
-
209
- fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6β€³ per column
210
- fig_h = min(24, max(6, len(p_tokens) * 0.8))
211
 
212
  fig, ax = plt.subplots(figsize=(fig_w, fig_h))
213
- im = ax.imshow(attn.numpy(), aspect="auto",
214
- cmap=cm.viridis, interpolation="nearest")
215
-
216
- ax.set_title("Protein β†’ Drug Attention", pad=8, fontsize=10)
217
 
 
218
  ax.set_xticks(range(len(x_labels)))
219
- ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
220
- ha="center", va="center")
221
- ax.tick_params(axis="x", top=True, bottom=False,
222
- labeltop=True, labelbottom=False, pad=27)
223
 
224
  ax.set_yticks(range(len(y_labels)))
225
  ax.set_yticklabels(y_labels, fontsize=7)
226
- ax.tick_params(axis="y", top=True, bottom=False,
227
- labeltop=True, labelbottom=False,
228
- pad=10)
229
 
230
  fig.colorbar(im, fraction=0.026, pad=0.01)
231
  fig.tight_layout()
232
 
233
- buf = io.BytesIO()
234
- fig.savefig(buf, format="png", dpi=140)
235
- plt.close(fig)
236
- html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
237
-
238
- # ───────────────────── Top-30 tabel ─────────────────────
239
- table_html = ""
240
- if drug_idx is not None and 0 <= drug_idx < orig_attn.size(1):
241
- # map original 0-based drug_idx β†’ current column position
242
- if (drug_idx + 1) in d_indices:
243
- col_pos = d_indices.index(drug_idx + 1)
244
- elif 0 <= drug_idx < len(d_tokens):
245
- col_pos = drug_idx
246
- else:
247
- col_pos = None
248
-
249
- if col_pos is not None:
250
- col_vec = attn[:, col_pos]
251
- topk = torch.topk(col_vec, k=min(30, len(col_vec))).indices.tolist()
252
-
253
- rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
254
- res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
255
- pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
256
-
257
- drug_tok_text = d_tokens_full[col_pos]
258
- orig_idx = d_indices_full[col_pos]
259
-
260
- # 1) build the header row: leading β€œRank”, then 1…30
261
- header_cells = (
262
- "<th style='border:1px solid #ccc; padding:6px; "
263
- "background:#f7f7f7; text-align:center;'>Rank</th>"
264
- + "".join(
265
- f"<th style='border:1px solid #ccc; padding:6px; "
266
- f"background:#f7f7f7; text-align:center'>{r+1}</th>"
267
- for r in range(len(topk))
268
- )
269
- )
270
-
271
- # 2) build the residue row: leading β€œResidue”, then the residue tokens
272
- residue_cells = (
273
- "<th style='border:1px solid #ccc; padding:6px; "
274
- "background:#f7f7f7; text-align:center;'>Residue</th>"
275
- + "".join(
276
- f"<td style='border:1px solid #ccc; padding:6px; "
277
- f"text-align:center'>{p_tokens_full[i]}</td>"
278
- for i in topk
279
- )
280
- )
281
-
282
- # 3) build the position row: leading β€œPosition”, then the residue positions
283
- position_cells = (
284
- "<th style='border:1px solid #ccc; padding:6px; "
285
- "background:#f7f7f7; text-align:center;'>Position</th>"
286
- + "".join(
287
- f"<td style='border:1px solid #ccc; padding:6px; "
288
- f"text-align:center'>{p_indices_full[i]}</td>"
289
- for i in topk
290
- )
291
- )
292
-
293
- # 4) assemble your table_html
294
- table_html = (
295
- f"<h4 style='margin-bottom:12px'>"
296
- f"Drug atom #{orig_idx} <code>{drug_tok_text}</code> β†’ Top-30 Protein residues"
297
- f"</h4>"
298
- f"<table style='border-collapse:collapse; margin:0 auto 24px;'>"
299
- f"<tr>{header_cells}</tr>"
300
- f"<tr>{residue_cells}</tr>"
301
- f"<tr>{position_cells}</tr>"
302
- f"</table>"
303
- )
304
-
305
  buf_png = io.BytesIO()
306
- fig.savefig(buf_png, format="png", dpi=140)
307
  buf_png.seek(0)
308
 
309
  buf_pdf = io.BytesIO()
@@ -315,24 +240,73 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
315
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
316
 
317
  html_heat = (
318
- f"<div style='position: relative; width: 100%;'>"
319
- # the PDF button, absolutely positioned
320
  f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
321
- "style='position: absolute; top: 12px; right: 12px; "
322
- "background: var(--primary); color: #fff; "
323
- "padding: 8px 16px; border-radius: 6px; "
324
- "font-size: 0.9rem; font-weight: 500; "
325
- "text-decoration: none;'>"
326
- "Download PDF"
327
- "</a>"
328
- # the clickable heat‐map image
329
  f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
330
  f"<img src='data:image/png;base64,{png_b64}' "
331
- "style='display: block; width: 100%; height: auto; cursor: zoom-in;'/>"
332
  "</a>"
333
  "</div>"
334
  )
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  return table_html + html_heat
337
 
338
  # ───── Gradio Callbacks ─────────────────────────────────────────
@@ -367,116 +341,127 @@ def inference_cb(prot_seq, drug_seq, atom_idx):
367
  return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)
368
 
369
  def clear_cb():
370
- return None, "", "", None, ""
371
 
372
- # ───── Gradio Interface Definition ───────────────────────────────
373
 
374
  css = """
375
  :root {
376
- --bg: #f3f4f6;
377
- --card: #ffffff;
378
- --border: #e5e7eb;
379
- --primary: #6366f1;
380
- --primary-dark: #4f46e5;
381
- --text: #111827;
 
 
 
382
  }
383
- * { box-sizing: border-box; margin: 0; padding: 0; }
384
- body { background: var(--bg); color: var(--text); font-family: Inter,system-ui,Arial,sans-serif; }
385
- h1 { font-family: Poppins,Inter,sans-serif; font-weight: 600; font-size: 2rem; text-align: center; margin: 24px 0; }
386
- button, .gr-button { font-family: Inter,sans-serif; font-weight: 600; }
387
- #project-links { text-align: center; margin-bottom: 32px; }
388
- #project-links .gr-button { margin: 0 8px; min-width: 160px; }
389
- #project-links .gr-button:nth-child(1) { background: #10b981; }
390
- #project-links .gr-button:nth-child(2) { background: #ef4444; }
391
- #project-links .gr-button:nth-child(3) { background: #3b82f6; }
392
- #project-links .gr-button:hover { opacity: 0.9; }
393
- .link-btn{display:inline-block;margin:0 8px;padding:10px 20px;border-radius:8px;
394
- color:white;font-weight:600;text-decoration:none;box-shadow:0 2px 6px rgba(0,0,0,0.12);
395
- transition:all .2s ease-in-out;}
396
- .link-btn:hover{opacity:.9;}
397
- .link-btn.project{background:linear-gradient(to right,#10b981,#059669);}
398
- .link-btn.arxiv {background:linear-gradient(to right,#ef4444,#dc2626);}
399
- .link-btn.github {background:linear-gradient(to right,#3b82f6,#2563eb);}
400
- /* make *all* gradio buttons a bit taller */
401
- .gr-button { min-height: 10px !important; }
402
- /* now target just our two big action buttons */
403
- #extract-btn, #inference-btn {
404
- width: 5px !important;
405
- min-height: 36px !important;
406
- margin-top: 12px !important;
407
  }
408
- /* and make clear button full width but shorter */
409
- #clear-btn {
410
- width: 10px !important;
411
- min-height: 36px !important;
412
- margin-top: 12px !important;
413
- }
414
- #input-card label {
415
- font-weight: 600 !important; /* make the text bold */
416
- color: var(--text) !important; /* use your standard text color */
417
- }
418
- .card {
419
- background: var(--card);
420
- border: 1px solid var(--border);
421
- border-radius: 12px;
422
- padding: 24px;
423
- max-width: 1000px;
424
- margin: 0 auto 32px;
425
- box-shadow: 0 2px 6px rgba(0,0,0,0.05);
426
- }
427
- #guidelines-card h2 {
428
- font-size: 1.4rem;
429
- margin-bottom: 16px;
430
- text-align: center;
431
  }
432
- #guidelines-card ol {
433
- margin-left: 20px;
434
- line-height: 1.6;
435
- font-size: 1rem;
 
 
 
 
 
 
436
  }
437
- #input-card .gr-row, #input-card .gr-cols {
438
- gap: 16px;
 
 
 
 
439
  }
440
- #input-card .gr-button {
441
- flex: 1;
 
 
 
 
 
 
 
 
 
 
442
  }
443
- #output-card {
444
- padding-top: 0;
 
 
 
 
 
 
445
  }
446
  """
447
 
448
- # ───── Gradio Interface ─────────────────────────────────────
449
  with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) as demo:
450
- gr.Markdown("<h1>Token-level Visualiser for Drug-Target Interaction</h1>")
 
451
 
452
- # Project links with SVG icons
453
  gr.HTML("""
454
- <div style="text-align:center;margin-bottom:32px;">
455
- <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank">
456
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor"
457
- viewBox="0 0 16 16"><path d="M8 0a8 8 0 1 0 8 8A8.009 8.009 0 0 0 8 0ZM4.5 8a3.5 3.5 0 1 1 3.5 3.5A3.504 3.504 0 0 1 4.5 8Z"/></svg>
458
- Project Page
459
- </a>
460
- <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank">
461
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor"
462
- viewBox="0 0 16 16"><path d="M4 1.5a.5.5 0 0 0-.5.5V3h9V2a.5.5 0 0 0-.5-.5h-8ZM2 4v9a2 2 0 0 0 2 2h8a2 2 0 0 0 2-2V4H2Z"/></svg>
463
- ArXiv: 2406.01651
464
- </a>
465
- <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank">
466
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor"
467
- viewBox="0 0 16 16"><path d="M8 .198a8 8 0 0 0-2.53 15.598c.4.074.547-.174.547-.385v-1.352c-2.23.484-2.7-1.073-2.7-1.073a2.132 2.132 0 0 0-.9-1.184c-.735-.503.056-.493.056-.493a1.688 1.688 0 0 1 1.232.83 1.707 1.707 0 0 0 2.34.667 1.706 1.706 0 0 1 .509-1.073c-1.78-.202-3.644-.89-3.644-3.962A3.106 3.106 0 0 1 5.066 5.47a2.882 2.882 0 0 1 .078-2.13s.672-.215 2.2.82a7.634 7.634 0 0 1 4.004 0c1.528-1.035 2.2-.82 2.2-.82a2.882 2.882 0 0 1 .078 2.13 3.106 3.106 0 0 1 .966 2.152c0 3.08-1.866 3.756-3.648 3.954a1.918 1.918 0 0 1 .547 1.482v2.197c0 .214.146.462.55.383A8 8 0 0 0 8 .198Z"/></svg>
468
- GitHub Repo
469
- </a>
470
- </div>
471
- """)
 
 
 
 
 
 
472
 
473
  # ───────────── Guidelines Card ─────────────
474
-
475
  gr.HTML(
476
  """
477
- <div class="card" style="margin-bottom:24px">
478
- <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for User</h2>
479
- <ul style="font-size:1rem; margin-left:18px;line-height:1.55;list-style:decimal;">
480
  <li><strong>Convert protein structure into a structure-aware sequence:</strong>
481
  Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
482
  sequence will be generated using
@@ -485,26 +470,25 @@ with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) a
485
  <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
486
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
487
  <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
488
- you must first visit the
489
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
490
  or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
491
- to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
492
- <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
493
- You can enter a SELFIES string directly, or paste a SMILES string.
494
- SMILES will be automatically converted to SELFIES using
495
  <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
496
  If conversion fails, a red error message will be displayed.</li>
497
- <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
498
  to highlight the Top-30 interacting protein residues.</li>
499
- <li>After inference, you can use the
500
- β€œDownload PDF” link to export a high-resolution vector version.</li>
501
  </ul>
502
  </div>
503
- """)
504
-
 
505
  # ───────────── Input Card ─────────────
506
  with gr.Column(elem_id="input-card", elem_classes="card"):
507
-
508
  protein_seq = gr.Textbox(
509
  label="Protein Structure-aware Sequence",
510
  lines=3,
@@ -529,25 +513,12 @@ with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) a
529
 
530
  # ───────────── Action Buttons ─────────────
531
  with gr.Row(elem_id="action-buttons", equal_height=True):
532
- btn_extract = gr.Button(
533
- "Extract sequence",
534
- variant="primary",
535
- elem_id="extract-btn"
536
- )
537
- btn_infer = gr.Button(
538
- "Inference",
539
- variant="primary",
540
- elem_id="inference-btn"
541
- )
542
  with gr.Row():
543
- clear_btn = gr.Button(
544
- "Clear",
545
- variant="secondary",
546
- elem_classes="full-width",
547
- elem_id="clear-btn"
548
- )
549
 
550
- # ───────────── Output Visualization ─────────────
551
  output_html = gr.HTML(elem_id="result-html")
552
 
553
  # ───────────── Event Wiring ─────────────
@@ -562,7 +533,7 @@ with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) a
562
  outputs=[output_html]
563
  )
564
  clear_btn.click(
565
- fn=lambda: ("", "", None, "", None),
566
  inputs=[],
567
  outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
568
  )
 
2
  import gradio_client.utils as _gc_utils
3
 
4
  # back up originals
5
+ _orig_get_type = _gc_utils.get_type
6
+ _orig_json2py = _gc_utils._json_schema_to_python_type
7
 
8
  def _patched_get_type(schema):
9
  # treat any boolean schema as if it were an empty dict
 
54
 
55
  three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
56
  three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
57
+
58
  def simple_seq_from_structure(path: str) -> str:
59
  parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
60
  structure = parser.get_structure("P", path)
 
70
  if mol is None:
71
  return None
72
  return selfies.encoder(smiles)
73
+ except Exception:
74
  return None
75
 
76
  def parse_config():
 
133
  p_ids.cpu(), d_ids.cpu(),
134
  p_mask.cpu(), d_mask.cpu(), None)]
135
 
 
136
  # ─────────────── visualisation ───────────────────────────────────────────
137
+ def _safe_is_special(tokenizer, tok: str) -> bool:
138
+ # Some tokenisers expose different special token sets; fall back conservatively.
139
+ special_sets = []
140
+ if hasattr(tokenizer, "all_special_tokens"):
141
+ special_sets.append(set(tokenizer.all_special_tokens))
142
+ if hasattr(tokenizer, "special_tokens_map"):
143
+ special_sets.extend(set(v) if isinstance(v, list) else {v}
144
+ for v in tokenizer.special_tokens_map.values())
145
+ for s in special_sets:
146
+ if tok in s:
147
+ return True
148
+ return False
149
+
150
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
151
  """
152
+ Render a Protein β†’ Drug cross-attention heat-map and optional Top-30 residue table.
 
 
 
 
 
 
 
 
 
153
  """
154
  model.eval()
155
  with torch.no_grad():
 
160
 
161
  # ── forward pass: Protein β†’ Drug attention (B, n_p, n_d) ───────────────
162
  _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
163
+ attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
164
 
165
  # ── decode tokens (skip special symbols) ────────────────────────────────
166
  def clean_ids(ids, tokenizer):
167
  toks = tokenizer.convert_ids_to_tokens(ids.tolist())
168
+ return [t for t in toks if not _safe_is_special(tokenizer, t)]
169
 
 
170
  p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
171
  p_indices_full = list(range(1, len(p_tokens_full) + 1))
 
172
  d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
173
  d_indices_full = list(range(1, len(d_tokens_full) + 1))
174
 
175
+ # ── safety cut-off to match attn mat size ──────────────────────────────
176
+ p_tokens = p_tokens_full[: attn.size(0)]
177
+ p_indices = p_indices_full[: attn.size(0)]
178
+ d_tokens = d_tokens_full[: attn.size(1)]
179
+ d_indices = d_indices_full[: attn.size(1)]
180
+ attn = attn[: len(p_tokens), : len(d_tokens)]
181
 
182
  orig_attn = attn.clone()
183
+
184
  # ── adaptive sparsity pruning ───────────────────────────────────────────
185
+ thr = attn.max().item() * 0.05 if attn.numel() > 0 else 0.0
186
+ row_keep = (attn.max(dim=1).values > thr) if attn.size(0) else torch.tensor([], dtype=torch.bool)
187
+ col_keep = (attn.max(dim=0).values > thr) if attn.size(1) else torch.tensor([], dtype=torch.bool)
188
 
189
+ if row_keep.sum().item() < 3 and attn.size(0) > 0:
190
+ row_keep = torch.ones(attn.size(0), dtype=torch.bool)
191
+ if col_keep.sum().item() < 3 and attn.size(1) > 0:
192
+ col_keep = torch.ones(attn.size(1), dtype=torch.bool)
193
 
194
+ attn = attn[row_keep][:, col_keep]
195
+ p_tokens = [tok for keep, tok in zip(row_keep.tolist(), p_tokens) if keep]
196
+ p_indices = [idx for keep, idx in zip(row_keep.tolist(), p_indices) if keep]
197
+ d_tokens = [tok for keep, tok in zip(col_keep.tolist(), d_tokens) if keep]
198
+ d_indices = [idx for keep, idx in zip(col_keep.tolist(), d_indices) if keep]
199
 
200
  # ── cap column count at 150 for readability ─────────────────────────────
201
  if attn.size(1) > 150:
202
+ topc = torch.topk(attn.sum(0), k=150).indices
203
+ attn = attn[:, topc]
204
+ d_tokens = [d_tokens[i] for i in topc]
205
+ d_indices = [d_indices[i] for i in topc]
206
 
207
+ # ── draw heat-map ──────────────────────────────────────────────────────
208
  x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
209
  y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
210
 
211
+ fig_w = min(22, max(8, len(x_labels) * 0.6))
212
+ fig_h = min(24, max(6, len(y_labels) * 0.8))
 
213
 
214
  fig, ax = plt.subplots(figsize=(fig_w, fig_h))
215
+ im = ax.imshow(attn.numpy(), aspect="auto", cmap=cm.viridis, interpolation="nearest")
 
 
 
216
 
217
+ ax.set_title("Protein β†’ Drug Attention", pad=8, fontsize=11)
218
  ax.set_xticks(range(len(x_labels)))
219
+ ax.set_xticklabels(x_labels, rotation=90, fontsize=8, ha="center", va="center")
220
+ ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False, pad=27)
 
 
221
 
222
  ax.set_yticks(range(len(y_labels)))
223
  ax.set_yticklabels(y_labels, fontsize=7)
224
+ ax.tick_params(axis="y", top=True, bottom=False, labeltop=True, labelbottom=False, pad=10)
 
 
225
 
226
  fig.colorbar(im, fraction=0.026, pad=0.01)
227
  fig.tight_layout()
228
 
229
+ # build PNG / PDF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  buf_png = io.BytesIO()
231
+ fig.savefig(buf_png, format="png", dpi=140)
232
  buf_png.seek(0)
233
 
234
  buf_pdf = io.BytesIO()
 
240
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
241
 
242
  html_heat = (
243
+ f"<div class='heatmap-card' style='position: relative; width: 100%;'>"
 
244
  f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
245
+ "style='position:absolute; top:12px; right:12px; "
246
+ "background: var(--primary); color:#fff; padding:8px 16px; border-radius:8px; "
247
+ "font-size:.92rem; font-weight:600; text-decoration:none;'>Download PDF</a>"
 
 
 
 
 
248
  f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
249
  f"<img src='data:image/png;base64,{png_b64}' "
250
+ "style='display:block; width:100%; height:auto; cursor:zoom-in;'/>"
251
  "</a>"
252
  "</div>"
253
  )
254
 
255
+ # ───────────────────── Top-30 table (optional) ─────────────────────
256
+ table_html = ""
257
+ if drug_idx is not None and orig_attn.size(1) > 0 and 0 <= drug_idx < orig_attn.size(1):
258
+ # map original 0-based drug_idx β†’ pruned column
259
+ col_pos = None
260
+ if (drug_idx + 1) in d_indices:
261
+ col_pos = d_indices.index(drug_idx + 1)
262
+ elif 0 <= drug_idx < len(d_tokens):
263
+ col_pos = drug_idx
264
+
265
+ if col_pos is not None:
266
+ col_vec = attn[:, col_pos]
267
+ k = min(30, len(col_vec))
268
+ if k > 0:
269
+ topk = torch.topk(col_vec, k=k).indices.tolist()
270
+
271
+ # header cells
272
+ header_cells = (
273
+ "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Rank</th>"
274
+ + "".join(
275
+ f"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center'>{r+1}</th>"
276
+ for r in range(len(topk))
277
+ )
278
+ )
279
+ residue_cells = (
280
+ "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Residue</th>"
281
+ + "".join(
282
+ f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_tokens[i]}</td>"
283
+ for i in topk
284
+ )
285
+ )
286
+ position_cells = (
287
+ "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Position</th>"
288
+ + "".join(
289
+ f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_indices[i]}</td>"
290
+ for i in topk
291
+ )
292
+ )
293
+
294
+ drug_tok_text = d_tokens[col_pos]
295
+ orig_idx_disp = d_indices[col_pos]
296
+
297
+ table_html = (
298
+ f"<div class='card' style='margin-top:18px'>"
299
+ f"<h4 style='margin:0 0 12px; font-size:1rem;'>"
300
+ f"Drug atom #{orig_idx_disp} <code>{drug_tok_text}</code> β†’ Top-30 Protein residues"
301
+ f"</h4>"
302
+ f"<table style='border-collapse:collapse; margin:0 auto 4px; font-size:.95rem'>"
303
+ f"<tr>{header_cells}</tr>"
304
+ f"<tr>{residue_cells}</tr>"
305
+ f"<tr>{position_cells}</tr>"
306
+ f"</table>"
307
+ f"</div>"
308
+ )
309
+
310
  return table_html + html_heat
311
 
312
  # ───── Gradio Callbacks ─────────────────────────────────────────
 
341
  return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)
342
 
343
  def clear_cb():
344
+ return "", "", None, "", None
345
 
346
+ # ───── Theme & CSS ─────────────────────────────────────────────
347
 
348
  css = """
349
  :root {
350
+ --bg:#f7f7fb;
351
+ --card:#ffffff;
352
+ --border:#e6e7eb;
353
+ --primary:#4f46e5;
354
+ --primary-dark:#4338ca;
355
+ --text:#0f172a;
356
+ --muted:#6b7280;
357
+ --radius:14px;
358
+ --shadow:0 10px 30px rgba(15,23,42,.06);
359
  }
360
+ *{box-sizing:border-box}
361
+ html,body{background:var(--bg)!important;color:var(--text)!important;font-family:Inter,system-ui,Arial,sans-serif}
362
+ h1{font-weight:700;font-size:32px;margin:22px 0 10px;text-align:center;letter-spacing:.2px}
363
+ p,li,button,.gr-button,label,.gr-text{font-size:14px}
364
+
365
+ /* Cards */
366
+ .card{
367
+ background:var(--card); border:1px solid var(--border); border-radius:var(--radius);
368
+ box-shadow:var(--shadow); padding:24px; max-width:1100px; margin:0 auto 28px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  }
370
+
371
+ /* Project links */
372
+ .link-btn{
373
+ display:inline-flex; /* icon + text centred vertically */
374
+ align-items:center;
375
+ justify-content:center;
376
+ margin:0 8px;
377
+ padding:10px 18px;
378
+ border-radius:10px;
379
+ color:#fff;
380
+ font-weight:650;
381
+ text-decoration:none;
382
+ box-shadow:0 6px 18px rgba(79,70,229,.18);
383
+ transition:transform .12s ease,filter .12s ease;
 
 
 
 
 
 
 
 
 
384
  }
385
+ .link-btn:hover{transform:translateY(-1px);filter:brightness(1.03)}
386
+ .link-btn svg{margin-right:6px;vertical-align:middle}
387
+ .link-btn.project{background:linear-gradient(135deg,#10b981,#059669)}
388
+ .link-btn.arxiv {background:linear-gradient(135deg,#ef4444,#dc2626)}
389
+ .link-btn.github {background:linear-gradient(135deg,#3b82f6,#2563eb)}
390
+
391
+ /* Labels & inputs */
392
+ #input-card label{font-weight:650!important;color:var(--text)!important}
393
+ textarea, input, .gr-textbox, .gr-number{
394
+ border-radius:12px!important; border:1px solid var(--border)!important;
395
  }
396
+ #input-card .gr-row, #input-card .gr-cols{gap:16px}
397
+
398
+ /* Buttons */
399
+ .gr-button{min-height:42px!important; padding:0 18px!important; border-radius:12px!important; font-weight:700!important}
400
+ .gr-button.primary, .gr-button-primary{
401
+ background:var(--primary)!important; border-color:var(--primary)!important; color:#fff!important
402
  }
403
+ .gr-button.primary:hover, .gr-button-primary:hover{background:var(--primary-dark)!important;border-color:var(--primary-dark)!important}
404
+
405
+ /* Action buttons row */
406
+ #action-buttons{gap:12px}
407
+ #extract-btn, #inference-btn{flex:1 1 260px!important; min-width:180px!important}
408
+ #clear-btn{width:100%!important}
409
+
410
+ /* Output */
411
+ #output-card{padding-top:0}
412
+ #result-html{padding:0; margin:0}
413
+ #result-html .heatmap-card{
414
+ background:var(--card); border:1px solid var(--border); border-radius:12px; padding:12px; box-shadow:var(--shadow)
415
  }
416
+
417
+ /* Guidance */
418
+ #guidelines-card h2{font-size:18px;margin-bottom:14px;text-align:center}
419
+ #guidelines-card ul{margin-left:18px;line-height:1.6}
420
+
421
+ /* Small screens */
422
+ @media (max-width: 900px){
423
+ .card{margin:0 12px 24px}
424
  }
425
  """
426
 
427
+ # ───── Gradio Interface Definition ───────────────────────────────
428
  with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) as demo:
429
+ # ───────────── Title ─────────────
430
+ gr.Markdown("<h1 style='text-align: center;'>Token-level Visualiser for Drug-Target Interaction</h1>")
431
 
432
+ # ───────────── Project Links (SVG icons) ─────────────
433
  gr.HTML("""
434
+ <div style="text-align:center;margin-bottom:32px;">
435
+ <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank" rel="noopener noreferrer" aria-label="Project Page">
436
+ <!-- globe icon -->
437
+ <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
438
+ <path d="M12 2a10 10 0 1 0 10 10A10.012 10.012 0 0 0 12 2Zm7.93 9h-3.18a15.84 15.84 0 0 0-1.19-5.02A8.02 8.02 0 0 1 19.93 11ZM12 4c.86 0 2.25 1.86 3.01 6H8.99C9.75 5.86 11.14 4 12 4ZM4.07 13h3.18c.2 1.79.66 3.47 1.19 5.02A8.02 8.02 0 0 1 4.07 13Zm3.18-2H4.07A8.02 8.02 0 0 1 8.44 5.98 15.84 15.84 0 0 0 7.25 11Zm1.37 2h6.76c-.76 4.14-2.15 6-3.01 6s-2.25-1.86-3.01-6Zm9.05 0h3.18a8.02 8.02 0 0 1-4.37 5.02 15.84 15.84 0 0 0 1.19-5.02Z"/>
439
+ </svg>
440
+ Project Page
441
+ </a>
442
+ <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank" rel="noopener noreferrer" aria-label="ArXiv: 2406.01651">
443
+ <!-- arXiv-like paper icon -->
444
+ <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
445
+ <path d="M6 2h9l5 5v13a2 2 0 0 1-2 2H6a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2Zm8 1.5V8h4.5L14 3.5ZM7 12h10v2H7v-2Zm0 4h10v2H7v-2Zm0-8h6v2H7V8Z"/>
446
+ </svg>
447
+ ArXiv: 2406.01651
448
+ </a>
449
+ <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank" rel="noopener noreferrer" aria-label="GitHub Repo">
450
+ <!-- GitHub mark -->
451
+ <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
452
+ <path d="M12 .5A12 12 0 0 0 0 12.76c0 5.4 3.44 9.98 8.2 11.6.6.12.82-.28.82-.6v-2.3c-3.34.74-4.04-1.44-4.04-1.44-.54-1.38-1.32-1.74-1.32-1.74-1.08-.76.08-.74.08-.74 1.2.08 1.84 1.26 1.84 1.26 1.06 1.86 2.78 1.32 3.46 1.02.1-.8.42-1.32.76-1.62-2.66-.32-5.46-1.36-5.46-6.02 0-1.34.46-2.44 1.22-3.3-.12-.32-.54-1.64.12-3.42 0 0 1-.34 3.32 1.26.96-.28 1.98-.42 3-.42s2.04.14 3 .42c2.32-1.6 3.32-1.26 3.32-1.26.66 1.78.24 3.1.12 3.42.76.86 1.22 1.96 1.22 3.3 0 4.68-2.8 5.68-5.48 6 .44.38.84 1.12.84 2.28v3.38c0 .32.22.74.84.6A12.02 12.02 0 0 0 24 12.76 12 12 0 0 0 12 .5Z"/>
453
+ </svg>
454
+ GitHub Repo
455
+ </a>
456
+ </div>
457
+ """)
458
 
459
  # ───────────── Guidelines Card ─────────────
 
460
  gr.HTML(
461
  """
462
+ <div class="card" id="guidelines-card" style="margin-bottom:24px">
463
+ <h2>Guidelines for Users</h2>
464
+ <ul style="list-style:decimal;">
465
  <li><strong>Convert protein structure into a structure-aware sequence:</strong>
466
  Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
467
  sequence will be generated using
 
470
  <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
471
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
472
  <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
473
+ please first visit the
474
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
475
  or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
476
+ to download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
477
+ <li><strong>Drug input supports both SELFIES and SMILES:</strong>
478
+ Enter a SELFIES string directly, or paste a SMILES string. SMILES will
479
+ be converted to SELFIES using the
480
  <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
481
  If conversion fails, a red error message will be displayed.</li>
482
+ <li>Optionally enter a <strong>1-based</strong> drug atom/substructure index
483
  to highlight the Top-30 interacting protein residues.</li>
484
+ <li>After inference, use β€œDownload PDF” to export a high-resolution vector figure.</li>
 
485
  </ul>
486
  </div>
487
+ """
488
+ )
489
+
490
  # ───────────── Input Card ─────────────
491
  with gr.Column(elem_id="input-card", elem_classes="card"):
 
492
  protein_seq = gr.Textbox(
493
  label="Protein Structure-aware Sequence",
494
  lines=3,
 
513
 
514
  # ───────────── Action Buttons ─────────────
515
  with gr.Row(elem_id="action-buttons", equal_height=True):
516
+ btn_extract = gr.Button("Extract sequence", variant="primary", elem_id="extract-btn")
517
+ btn_infer = gr.Button("Inference", variant="primary", elem_id="inference-btn")
 
 
 
 
 
 
 
 
518
  with gr.Row():
519
+ clear_btn = gr.Button("Clear", variant="secondary", elem_id="clear-btn")
 
 
 
 
 
520
 
521
+ # ───────────── Output Visualisation ─────────────
522
  output_html = gr.HTML(elem_id="result-html")
523
 
524
  # ───────────── Event Wiring ─────────────
 
533
  outputs=[output_html]
534
  )
535
  clear_btn.click(
536
+ fn=clear_cb,
537
  inputs=[],
538
  outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
539
  )