Zhaohan-Meng commited on
Commit
0d9281c
Β·
verified Β·
1 Parent(s): f947a52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -246
app.py CHANGED
@@ -2,8 +2,8 @@
2
  import gradio_client.utils as _gc_utils
3
 
4
  # back up originals
5
- _orig_get_type = _gc_utils.get_type
6
- _orig_json2py = _gc_utils._json_schema_to_python_type
7
 
8
  def _patched_get_type(schema):
9
  # treat any boolean schema as if it were an empty dict
@@ -22,6 +22,7 @@ _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
22
 
23
  # ─── now it’s safe to import Gradio and build your interface ───────────────────────────
24
  import gradio as gr
 
25
 
26
  import os
27
  import sys
@@ -53,6 +54,7 @@ from utils.foldseek_util import get_struc_seq
53
 
54
  three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
55
  three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
 
56
  def simple_seq_from_structure(path: str) -> str:
57
  parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
58
  structure = parser.get_structure("P", path)
@@ -68,7 +70,7 @@ def smiles_to_selfies(smiles: str) -> Optional[str]:
68
  if mol is None:
69
  return None
70
  return selfies.encoder(smiles)
71
- except:
72
  return None
73
 
74
  def parse_config():
@@ -131,20 +133,23 @@ def get_case_feature(model, loader):
131
  p_ids.cpu(), d_ids.cpu(),
132
  p_mask.cpu(), d_mask.cpu(), None)]
133
 
134
-
135
  # ─────────────── visualisation ───────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
137
  """
138
- Render a Protein β†’ Drug cross-attention heat-map and, optionally, a
139
- Top-30 protein-residue table for a chosen drug-token index.
140
- The token index shown on the x-axis (and accepted via *drug_idx*) is **the
141
- position of that token in the *original* drug sequence**, *after* the
142
- tokeniser but *before* any pruning or truncation (1-based in the labels,
143
- 0-based for the function argument).
144
- Returns
145
- -------
146
- html : str
147
- Base64-embedded PNG heat-map (+ optional HTML table).
148
  """
149
  model.eval()
150
  with torch.no_grad():
@@ -155,154 +160,75 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
155
 
156
  # ── forward pass: Protein β†’ Drug attention (B, n_p, n_d) ───────────────
157
  _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
158
- attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
159
 
160
  # ── decode tokens (skip special symbols) ────────────────────────────────
161
  def clean_ids(ids, tokenizer):
162
  toks = tokenizer.convert_ids_to_tokens(ids.tolist())
163
- return [t for t in toks if t not in tokenizer.all_special_tokens]
164
 
165
- # ── decode full sequences + record 1-based indices ──────────────────
166
  p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
167
  p_indices_full = list(range(1, len(p_tokens_full) + 1))
168
-
169
  d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
170
  d_indices_full = list(range(1, len(d_tokens_full) + 1))
171
 
172
- # ── safety cut-off to match attn mat size ───────────────────────────────
173
- p_tokens = p_tokens_full[: attn.size(0)]
174
- p_indices_full = p_indices_full[: attn.size(0)]
175
- d_tokens_full = d_tokens_full[: attn.size(1)]
176
- d_indices_full = d_indices_full[: attn.size(1)]
177
- attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
178
 
179
  orig_attn = attn.clone()
 
180
  # ── adaptive sparsity pruning ───────────────────────────────────────────
181
- thr = attn.max().item() * 0.05
182
- row_keep = (attn.max(dim=1).values > thr)
183
- col_keep = (attn.max(dim=0).values > thr)
184
 
185
- if row_keep.sum() < 3:
186
- row_keep[:] = True
187
- if col_keep.sum() < 3:
188
- col_keep[:] = True
189
 
190
- attn = attn[row_keep][:, col_keep]
191
- p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
192
- p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
193
- d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
194
- d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
195
 
196
  # ── cap column count at 150 for readability ─────────────────────────────
197
  if attn.size(1) > 150:
198
- topc = torch.topk(attn.sum(0), k=150).indices
199
- attn = attn[:, topc]
200
- d_tokens = [d_tokens [i] for i in topc]
201
- d_indices = [d_indices[i] for i in topc]
202
 
203
- # ── draw heat-map ───────────────────────────────────────────────────────
204
  x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
205
  y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
206
 
207
-
208
- fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6β€³ per column
209
- fig_h = min(24, max(6, len(p_tokens) * 0.8))
210
 
211
  fig, ax = plt.subplots(figsize=(fig_w, fig_h))
212
- im = ax.imshow(attn.numpy(), aspect="auto",
213
- cmap=cm.viridis, interpolation="nearest")
214
-
215
- ax.set_title("Protein β†’ Drug Attention", pad=8, fontsize=10)
216
 
 
217
  ax.set_xticks(range(len(x_labels)))
218
- ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
219
- ha="center", va="center")
220
- ax.tick_params(axis="x", top=True, bottom=False,
221
- labeltop=True, labelbottom=False, pad=27)
222
 
223
  ax.set_yticks(range(len(y_labels)))
224
  ax.set_yticklabels(y_labels, fontsize=7)
225
- ax.tick_params(axis="y", top=True, bottom=False,
226
- labeltop=True, labelbottom=False,
227
- pad=10)
228
 
229
  fig.colorbar(im, fraction=0.026, pad=0.01)
230
  fig.tight_layout()
231
 
232
- buf = io.BytesIO()
233
- fig.savefig(buf, format="png", dpi=140)
234
- plt.close(fig)
235
- html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
236
-
237
- # ───────────────────── Top-30 tabel ─────────────────────
238
- table_html = ""
239
- if drug_idx is not None and 0 <= drug_idx < orig_attn.size(1):
240
- # map original 0-based drug_idx β†’ current column position
241
- if (drug_idx + 1) in d_indices:
242
- col_pos = d_indices.index(drug_idx + 1)
243
- elif 0 <= drug_idx < len(d_tokens):
244
- col_pos = drug_idx
245
- else:
246
- col_pos = None
247
-
248
- if col_pos is not None:
249
- col_vec = attn[:, col_pos]
250
- topk = torch.topk(col_vec, k=min(30, len(col_vec))).indices.tolist()
251
-
252
- rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
253
- res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
254
- pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
255
-
256
- drug_tok_text = d_tokens_full[col_pos]
257
- orig_idx = d_indices_full[col_pos]
258
-
259
- # 1) build the header row: leading β€œRank”, then 1…30
260
- header_cells = (
261
- "<th style='border:1px solid #ccc; padding:6px; "
262
- "background:#f7f7f7; text-align:center;'>Rank</th>"
263
- + "".join(
264
- f"<th style='border:1px solid #ccc; padding:6px; "
265
- f"background:#f7f7f7; text-align:center'>{r+1}</th>"
266
- for r in range(len(topk))
267
- )
268
- )
269
-
270
- # 2) build the residue row: leading β€œResidue”, then the residue tokens
271
- residue_cells = (
272
- "<th style='border:1px solid #ccc; padding:6px; "
273
- "background:#f7f7f7; text-align:center;'>Residue</th>"
274
- + "".join(
275
- f"<td style='border:1px solid #ccc; padding:6px; "
276
- f"text-align:center'>{p_tokens_full[i]}</td>"
277
- for i in topk
278
- )
279
- )
280
-
281
- # 3) build the position row: leading β€œPosition”, then the residue positions
282
- position_cells = (
283
- "<th style='border:1px solid #ccc; padding:6px; "
284
- "background:#f7f7f7; text-align:center;'>Position</th>"
285
- + "".join(
286
- f"<td style='border:1px solid #ccc; padding:6px; "
287
- f"text-align:center'>{p_indices_full[i]}</td>"
288
- for i in topk
289
- )
290
- )
291
-
292
- # 4) assemble your table_html
293
- table_html = (
294
- f"<h4 style='margin-bottom:12px'>"
295
- f"Drug atom #{orig_idx} <code>{drug_tok_text}</code> β†’ Top-30 Protein residues"
296
- f"</h4>"
297
- f"<table style='border-collapse:collapse; margin:0 auto 24px;'>"
298
- f"<tr>{header_cells}</tr>"
299
- f"<tr>{residue_cells}</tr>"
300
- f"<tr>{position_cells}</tr>"
301
- f"</table>"
302
- )
303
-
304
  buf_png = io.BytesIO()
305
- fig.savefig(buf_png, format="png", dpi=140)
306
  buf_png.seek(0)
307
 
308
  buf_pdf = io.BytesIO()
@@ -314,24 +240,73 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
314
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
315
 
316
  html_heat = (
317
- f"<div style='position: relative; width: 100%;'>"
318
- # the PDF button, absolutely positioned
319
  f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
320
- "style='position: absolute; top: 12px; right: 12px; "
321
- "background: var(--primary); color: #fff; "
322
- "padding: 8px 16px; border-radius: 6px; "
323
- "font-size: 0.9rem; font-weight: 500; "
324
- "text-decoration: none;'>"
325
- "Download PDF"
326
- "</a>"
327
- # the clickable heat‐map image
328
  f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
329
  f"<img src='data:image/png;base64,{png_b64}' "
330
- "style='display: block; width: 100%; height: auto; cursor: zoom-in;'/>"
331
  "</a>"
332
  "</div>"
333
  )
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  return table_html + html_heat
336
 
337
  # ───── Gradio Callbacks ─────────────────────────────────────────
@@ -366,105 +341,127 @@ def inference_cb(prot_seq, drug_seq, atom_idx):
366
  return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)
367
 
368
  def clear_cb():
369
- return None, "", "", None, ""
370
 
371
- # ───── Gradio Interface Definition ───────────────────────────────
372
 
373
  css = """
374
  :root {
375
- --bg: #f3f4f6;
376
- --card: #ffffff;
377
- --border: #e5e7eb;
378
- --primary: #6366f1;
379
- --primary-dark: #4f46e5;
380
- --text: #111827;
 
 
 
381
  }
382
- * { box-sizing: border-box; margin: 0; padding: 0; }
383
- body { background: var(--bg); color: var(--text); font-family: Inter,system-ui,Arial,sans-serif; }
384
- h1 { font-family: Poppins,Inter,sans-serif; font-weight: 600; font-size: 2rem; text-align: center; margin: 24px 0; }
385
- button, .gr-button { font-family: Inter,sans-serif; font-weight: 600; }
386
- #project-links { text-align: center; margin-bottom: 32px; }
387
- #project-links .gr-button { margin: 0 8px; min-width: 160px; }
388
- #project-links .gr-button:nth-child(1) { background: #10b981; }
389
- #project-links .gr-button:nth-child(2) { background: #ef4444; }
390
- #project-links .gr-button:nth-child(3) { background: #3b82f6; }
391
- #project-links .gr-button:hover { opacity: 0.9; }
392
- .link-btn{display:inline-block;margin:0 8px;padding:10px 20px;border-radius:8px;
393
- color:white;font-weight:600;text-decoration:none;box-shadow:0 2px 6px rgba(0,0,0,0.12);
394
- transition:all .2s ease-in-out;}
395
- .link-btn:hover{opacity:.9;}
396
- .link-btn.project{background:linear-gradient(to right,#10b981,#059669);}
397
- .link-btn.arxiv {background:linear-gradient(to right,#ef4444,#dc2626);}
398
- .link-btn.github {background:linear-gradient(to right,#3b82f6,#2563eb);}
399
- /* make *all* gradio buttons a bit taller */
400
- .gr-button { min-height: 10px !important; }
401
- /* now target just our two big action buttons */
402
- #extract-btn, #inference-btn {
403
- width: 5px !important;
404
- min-height: 36px !important;
405
- margin-top: 12px !important;
406
  }
407
- /* and make clear button full width but shorter */
408
- #clear-btn {
409
- width: 10px !important;
410
- min-height: 36px !important;
411
- margin-top: 12px !important;
412
- }
413
- #input-card label {
414
- font-weight: 600 !important; /* make the text bold */
415
- color: var(--text) !important; /* use your standard text color */
416
- }
417
- .card {
418
- background: var(--card);
419
- border: 1px solid var(--border);
420
- border-radius: 12px;
421
- padding: 24px;
422
- max-width: 1000px;
423
- margin: 0 auto 32px;
424
- box-shadow: 0 2px 6px rgba(0,0,0,0.05);
425
- }
426
- #guidelines-card h2 {
427
- font-size: 1.4rem;
428
- margin-bottom: 16px;
429
- text-align: center;
430
  }
431
- #guidelines-card ol {
432
- margin-left: 20px;
433
- line-height: 1.6;
434
- font-size: 1rem;
 
 
 
 
 
 
435
  }
436
- #input-card .gr-row, #input-card .gr-cols {
437
- gap: 16px;
 
 
 
 
438
  }
439
- #input-card .gr-button {
440
- flex: 1;
 
 
 
 
 
 
 
 
 
 
441
  }
442
- #output-card {
443
- padding-top: 0;
 
 
 
 
 
 
444
  }
445
  """
446
 
447
- with gr.Blocks(css=css) as demo:
 
448
  # ───────────── Title ─────────────
449
- gr.Markdown(
450
- "<h1 style='text-align: center;'>Token-level Visualiser for Drug-Target Interaction</h1>"
451
- )
452
 
453
- # ───────────── Project Links ─────────────
454
- gr.Markdown("""
455
  <div style="text-align:center;margin-bottom:32px;">
456
- <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank">🌐 Project Page</a>
457
- <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank">πŸ“„ ArXiv: 2406.01651</a>
458
- <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank">πŸ’» GitHub Repo</a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  </div>
460
  """)
 
461
  # ───────────── Guidelines Card ─────────────
462
-
463
  gr.HTML(
464
  """
465
- <div class="card" style="margin-bottom:24px">
466
- <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for User</h2>
467
- <ul style="font-size:1rem; margin-left:18px;line-height:1.55;list-style:decimal;">
468
  <li><strong>Convert protein structure into a structure-aware sequence:</strong>
469
  Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
470
  sequence will be generated using
@@ -473,26 +470,25 @@ with gr.Blocks(css=css) as demo:
473
  <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
474
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
475
  <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
476
- you must first visit the
477
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
478
  or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
479
- to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
480
- <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
481
- You can enter a SELFIES string directly, or paste a SMILES string.
482
- SMILES will be automatically converted to SELFIES using
483
  <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
484
  If conversion fails, a red error message will be displayed.</li>
485
- <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
486
  to highlight the Top-30 interacting protein residues.</li>
487
- <li>After inference, you can use the
488
- β€œDownload PDF” link to export a high-resolution vector version.</li>
489
  </ul>
490
  </div>
491
- """)
492
-
 
493
  # ───────────── Input Card ─────────────
494
  with gr.Column(elem_id="input-card", elem_classes="card"):
495
-
496
  protein_seq = gr.Textbox(
497
  label="Protein Structure-aware Sequence",
498
  lines=3,
@@ -517,25 +513,12 @@ with gr.Blocks(css=css) as demo:
517
 
518
  # ───────────── Action Buttons ─────────────
519
  with gr.Row(elem_id="action-buttons", equal_height=True):
520
- btn_extract = gr.Button(
521
- "Extract sequence",
522
- variant="primary",
523
- elem_id="extract-btn"
524
- )
525
- btn_infer = gr.Button(
526
- "Inference",
527
- variant="primary",
528
- elem_id="inference-btn"
529
- )
530
  with gr.Row():
531
- clear_btn = gr.Button(
532
- "Clear",
533
- variant="secondary",
534
- elem_classes="full-width",
535
- elem_id="clear-btn"
536
- )
537
 
538
- # ───────────── Output Visualization ─────────────
539
  output_html = gr.HTML(elem_id="result-html")
540
 
541
  # ───────────── Event Wiring ─────────────
@@ -550,7 +533,7 @@ with gr.Blocks(css=css) as demo:
550
  outputs=[output_html]
551
  )
552
  clear_btn.click(
553
- fn=lambda: ("", "", None, "", None),
554
  inputs=[],
555
  outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
556
  )
 
2
  import gradio_client.utils as _gc_utils
3
 
4
  # back up originals
5
+ _orig_get_type = _gc_utils.get_type
6
+ _orig_json2py = _gc_utils._json_schema_to_python_type
7
 
8
  def _patched_get_type(schema):
9
  # treat any boolean schema as if it were an empty dict
 
22
 
23
  # ─── now it’s safe to import Gradio and build your interface ───────────────────────────
24
  import gradio as gr
25
+ from gradio.themes import Soft
26
 
27
  import os
28
  import sys
 
54
 
55
  three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
56
  three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
57
+
58
  def simple_seq_from_structure(path: str) -> str:
59
  parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
60
  structure = parser.get_structure("P", path)
 
70
  if mol is None:
71
  return None
72
  return selfies.encoder(smiles)
73
+ except Exception:
74
  return None
75
 
76
  def parse_config():
 
133
  p_ids.cpu(), d_ids.cpu(),
134
  p_mask.cpu(), d_mask.cpu(), None)]
135
 
 
136
  # ─────────────── visualisation ───────────────────────────────────────────
137
+ def _safe_is_special(tokenizer, tok: str) -> bool:
138
+ # Some tokenisers expose different special token sets; fall back conservatively.
139
+ special_sets = []
140
+ if hasattr(tokenizer, "all_special_tokens"):
141
+ special_sets.append(set(tokenizer.all_special_tokens))
142
+ if hasattr(tokenizer, "special_tokens_map"):
143
+ special_sets.extend(set(v) if isinstance(v, list) else {v}
144
+ for v in tokenizer.special_tokens_map.values())
145
+ for s in special_sets:
146
+ if tok in s:
147
+ return True
148
+ return False
149
+
150
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
151
  """
152
+ Render a Protein β†’ Drug cross-attention heat-map and optional Top-30 residue table.
 
 
 
 
 
 
 
 
 
153
  """
154
  model.eval()
155
  with torch.no_grad():
 
160
 
161
  # ── forward pass: Protein β†’ Drug attention (B, n_p, n_d) ───────────────
162
  _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
163
+ attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
164
 
165
  # ── decode tokens (skip special symbols) ────────────────────────────────
166
  def clean_ids(ids, tokenizer):
167
  toks = tokenizer.convert_ids_to_tokens(ids.tolist())
168
+ return [t for t in toks if not _safe_is_special(tokenizer, t)]
169
 
 
170
  p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
171
  p_indices_full = list(range(1, len(p_tokens_full) + 1))
 
172
  d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
173
  d_indices_full = list(range(1, len(d_tokens_full) + 1))
174
 
175
+ # ── safety cut-off to match attn mat size ──────────────────────────────
176
+ p_tokens = p_tokens_full[: attn.size(0)]
177
+ p_indices = p_indices_full[: attn.size(0)]
178
+ d_tokens = d_tokens_full[: attn.size(1)]
179
+ d_indices = d_indices_full[: attn.size(1)]
180
+ attn = attn[: len(p_tokens), : len(d_tokens)]
181
 
182
  orig_attn = attn.clone()
183
+
184
  # ── adaptive sparsity pruning ───────────────────────────────────────────
185
+ thr = attn.max().item() * 0.05 if attn.numel() > 0 else 0.0
186
+ row_keep = (attn.max(dim=1).values > thr) if attn.size(0) else torch.tensor([], dtype=torch.bool)
187
+ col_keep = (attn.max(dim=0).values > thr) if attn.size(1) else torch.tensor([], dtype=torch.bool)
188
 
189
+ if row_keep.sum().item() < 3 and attn.size(0) > 0:
190
+ row_keep = torch.ones(attn.size(0), dtype=torch.bool)
191
+ if col_keep.sum().item() < 3 and attn.size(1) > 0:
192
+ col_keep = torch.ones(attn.size(1), dtype=torch.bool)
193
 
194
+ attn = attn[row_keep][:, col_keep]
195
+ p_tokens = [tok for keep, tok in zip(row_keep.tolist(), p_tokens) if keep]
196
+ p_indices = [idx for keep, idx in zip(row_keep.tolist(), p_indices) if keep]
197
+ d_tokens = [tok for keep, tok in zip(col_keep.tolist(), d_tokens) if keep]
198
+ d_indices = [idx for keep, idx in zip(col_keep.tolist(), d_indices) if keep]
199
 
200
  # ── cap column count at 150 for readability ─────────────────────────────
201
  if attn.size(1) > 150:
202
+ topc = torch.topk(attn.sum(0), k=150).indices
203
+ attn = attn[:, topc]
204
+ d_tokens = [d_tokens[i] for i in topc]
205
+ d_indices = [d_indices[i] for i in topc]
206
 
207
+ # ── draw heat-map ──────────────────────────────────────────────────────
208
  x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
209
  y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
210
 
211
+ fig_w = min(22, max(8, len(x_labels) * 0.6))
212
+ fig_h = min(24, max(6, len(y_labels) * 0.8))
 
213
 
214
  fig, ax = plt.subplots(figsize=(fig_w, fig_h))
215
+ im = ax.imshow(attn.numpy(), aspect="auto", cmap=cm.viridis, interpolation="nearest")
 
 
 
216
 
217
+ ax.set_title("Protein β†’ Drug Attention", pad=8, fontsize=11)
218
  ax.set_xticks(range(len(x_labels)))
219
+ ax.set_xticklabels(x_labels, rotation=90, fontsize=8, ha="center", va="center")
220
+ ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False, pad=27)
 
 
221
 
222
  ax.set_yticks(range(len(y_labels)))
223
  ax.set_yticklabels(y_labels, fontsize=7)
224
+ ax.tick_params(axis="y", top=True, bottom=False, labeltop=True, labelbottom=False, pad=10)
 
 
225
 
226
  fig.colorbar(im, fraction=0.026, pad=0.01)
227
  fig.tight_layout()
228
 
229
+ # build PNG / PDF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  buf_png = io.BytesIO()
231
+ fig.savefig(buf_png, format="png", dpi=140)
232
  buf_png.seek(0)
233
 
234
  buf_pdf = io.BytesIO()
 
240
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
241
 
242
  html_heat = (
243
+ f"<div class='heatmap-card' style='position: relative; width: 100%;'>"
 
244
  f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
245
+ "style='position:absolute; top:12px; right:12px; "
246
+ "background: var(--primary); color:#fff; padding:8px 16px; border-radius:8px; "
247
+ "font-size:.92rem; font-weight:600; text-decoration:none;'>Download PDF</a>"
 
 
 
 
 
248
  f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
249
  f"<img src='data:image/png;base64,{png_b64}' "
250
+ "style='display:block; width:100%; height:auto; cursor:zoom-in;'/>"
251
  "</a>"
252
  "</div>"
253
  )
254
 
255
+ # ───────────────────── Top-30 table (optional) ─────────────────────
256
+ table_html = ""
257
+ if drug_idx is not None and orig_attn.size(1) > 0 and 0 <= drug_idx < orig_attn.size(1):
258
+ # map original 0-based drug_idx β†’ pruned column
259
+ col_pos = None
260
+ if (drug_idx + 1) in d_indices:
261
+ col_pos = d_indices.index(drug_idx + 1)
262
+ elif 0 <= drug_idx < len(d_tokens):
263
+ col_pos = drug_idx
264
+
265
+ if col_pos is not None:
266
+ col_vec = attn[:, col_pos]
267
+ k = min(30, len(col_vec))
268
+ if k > 0:
269
+ topk = torch.topk(col_vec, k=k).indices.tolist()
270
+
271
+ # header cells
272
+ header_cells = (
273
+ "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Rank</th>"
274
+ + "".join(
275
+ f"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center'>{r+1}</th>"
276
+ for r in range(len(topk))
277
+ )
278
+ )
279
+ residue_cells = (
280
+ "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Residue</th>"
281
+ + "".join(
282
+ f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_tokens[i]}</td>"
283
+ for i in topk
284
+ )
285
+ )
286
+ position_cells = (
287
+ "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Position</th>"
288
+ + "".join(
289
+ f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_indices[i]}</td>"
290
+ for i in topk
291
+ )
292
+ )
293
+
294
+ drug_tok_text = d_tokens[col_pos]
295
+ orig_idx_disp = d_indices[col_pos]
296
+
297
+ table_html = (
298
+ f"<div class='card' style='margin-top:18px'>"
299
+ f"<h4 style='margin:0 0 12px; font-size:1rem;'>"
300
+ f"Drug atom #{orig_idx_disp} <code>{drug_tok_text}</code> β†’ Top-30 Protein residues"
301
+ f"</h4>"
302
+ f"<table style='border-collapse:collapse; margin:0 auto 4px; font-size:.95rem'>"
303
+ f"<tr>{header_cells}</tr>"
304
+ f"<tr>{residue_cells}</tr>"
305
+ f"<tr>{position_cells}</tr>"
306
+ f"</table>"
307
+ f"</div>"
308
+ )
309
+
310
  return table_html + html_heat
311
 
312
  # ───── Gradio Callbacks ─────────────────────────────────────────
 
341
  return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)
342
 
343
  def clear_cb():
344
+ return "", "", None, "", None
345
 
346
+ # ───── Theme & CSS ─────────────────────────────────────────────
347
 
348
  css = """
349
  :root {
350
+ --bg:#f7f7fb;
351
+ --card:#ffffff;
352
+ --border:#e6e7eb;
353
+ --primary:#4f46e5;
354
+ --primary-dark:#4338ca;
355
+ --text:#0f172a;
356
+ --muted:#6b7280;
357
+ --radius:14px;
358
+ --shadow:0 10px 30px rgba(15,23,42,.06);
359
  }
360
+ *{box-sizing:border-box}
361
+ html,body{background:var(--bg)!important;color:var(--text)!important;font-family:Inter,system-ui,Arial,sans-serif}
362
+ h1{font-weight:700;font-size:32px;margin:22px 0 10px;text-align:center;letter-spacing:.2px}
363
+ p,li,button,.gr-button,label,.gr-text{font-size:14px}
364
+
365
+ /* Cards */
366
+ .card{
367
+ background:var(--card); border:1px solid var(--border); border-radius:var(--radius);
368
+ box-shadow:var(--shadow); padding:24px; max-width:1100px; margin:0 auto 28px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  }
370
+
371
+ /* Project links */
372
+ .link-btn{
373
+ display:inline-flex; /* icon + text centred vertically */
374
+ align-items:center;
375
+ justify-content:center;
376
+ margin:0 8px;
377
+ padding:10px 18px;
378
+ border-radius:10px;
379
+ color:#fff;
380
+ font-weight:650;
381
+ text-decoration:none;
382
+ box-shadow:0 6px 18px rgba(79,70,229,.18);
383
+ transition:transform .12s ease,filter .12s ease;
 
 
 
 
 
 
 
 
 
384
  }
385
+ .link-btn:hover{transform:translateY(-1px);filter:brightness(1.03)}
386
+ .link-btn svg{margin-right:6px;vertical-align:middle}
387
+ .link-btn.project{background:linear-gradient(135deg,#10b981,#059669)}
388
+ .link-btn.arxiv {background:linear-gradient(135deg,#ef4444,#dc2626)}
389
+ .link-btn.github {background:linear-gradient(135deg,#3b82f6,#2563eb)}
390
+
391
+ /* Labels & inputs */
392
+ #input-card label{font-weight:650!important;color:var(--text)!important}
393
+ textarea, input, .gr-textbox, .gr-number{
394
+ border-radius:12px!important; border:1px solid var(--border)!important;
395
  }
396
+ #input-card .gr-row, #input-card .gr-cols{gap:16px}
397
+
398
+ /* Buttons */
399
+ .gr-button{min-height:42px!important; padding:0 18px!important; border-radius:12px!important; font-weight:700!important}
400
+ .gr-button.primary, .gr-button-primary{
401
+ background:var(--primary)!important; border-color:var(--primary)!important; color:#fff!important
402
  }
403
+ .gr-button.primary:hover, .gr-button-primary:hover{background:var(--primary-dark)!important;border-color:var(--primary-dark)!important}
404
+
405
+ /* Action buttons row */
406
+ #action-buttons{gap:12px}
407
+ #extract-btn, #inference-btn{flex:1 1 260px!important; min-width:180px!important}
408
+ #clear-btn{width:100%!important}
409
+
410
+ /* Output */
411
+ #output-card{padding-top:0}
412
+ #result-html{padding:0; margin:0}
413
+ #result-html .heatmap-card{
414
+ background:var(--card); border:1px solid var(--border); border-radius:12px; padding:12px; box-shadow:var(--shadow)
415
  }
416
+
417
+ /* Guidance */
418
+ #guidelines-card h2{font-size:18px;margin-bottom:14px;text-align:center}
419
+ #guidelines-card ul{margin-left:18px;line-height:1.6}
420
+
421
+ /* Small screens */
422
+ @media (max-width: 900px){
423
+ .card{margin:0 12px 24px}
424
  }
425
  """
426
 
427
+ # ───── Gradio Interface Definition ───────────────────────────────
428
+ with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) as demo:
429
  # ───────────── Title ─────────────
430
+ gr.Markdown("<h1 style='text-align: center;'>Token-level Visualiser for Drug-Target Interaction</h1>")
 
 
431
 
432
+ # ───────────── Project Links (SVG icons) ─────────────
433
+ gr.HTML("""
434
  <div style="text-align:center;margin-bottom:32px;">
435
+ <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank" rel="noopener noreferrer" aria-label="Project Page">
436
+ <!-- globe icon -->
437
+ <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
438
+ <path d="M12 2a10 10 0 1 0 10 10A10.012 10.012 0 0 0 12 2Zm7.93 9h-3.18a15.84 15.84 0 0 0-1.19-5.02A8.02 8.02 0 0 1 19.93 11ZM12 4c.86 0 2.25 1.86 3.01 6H8.99C9.75 5.86 11.14 4 12 4ZM4.07 13h3.18c.2 1.79.66 3.47 1.19 5.02A8.02 8.02 0 0 1 4.07 13Zm3.18-2H4.07A8.02 8.02 0 0 1 8.44 5.98 15.84 15.84 0 0 0 7.25 11Zm1.37 2h6.76c-.76 4.14-2.15 6-3.01 6s-2.25-1.86-3.01-6Zm9.05 0h3.18a8.02 8.02 0 0 1-4.37 5.02 15.84 15.84 0 0 0 1.19-5.02Z"/>
439
+ </svg>
440
+ Project Page
441
+ </a>
442
+ <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank" rel="noopener noreferrer" aria-label="ArXiv: 2406.01651">
443
+ <!-- arXiv-like paper icon -->
444
+ <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
445
+ <path d="M6 2h9l5 5v13a2 2 0 0 1-2 2H6a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2Zm8 1.5V8h4.5L14 3.5ZM7 12h10v2H7v-2Zm0 4h10v2H7v-2Zm0-8h6v2H7V8Z"/>
446
+ </svg>
447
+ ArXiv: 2406.01651
448
+ </a>
449
+ <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank" rel="noopener noreferrer" aria-label="GitHub Repo">
450
+ <!-- GitHub mark -->
451
+ <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
452
+ <path d="M12 .5A12 12 0 0 0 0 12.76c0 5.4 3.44 9.98 8.2 11.6.6.12.82-.28.82-.6v-2.3c-3.34.74-4.04-1.44-4.04-1.44-.54-1.38-1.32-1.74-1.32-1.74-1.08-.76.08-.74.08-.74 1.2.08 1.84 1.26 1.84 1.26 1.06 1.86 2.78 1.32 3.46 1.02.1-.8.42-1.32.76-1.62-2.66-.32-5.46-1.36-5.46-6.02 0-1.34.46-2.44 1.22-3.3-.12-.32-.54-1.64.12-3.42 0 0 1-.34 3.32 1.26.96-.28 1.98-.42 3-.42s2.04.14 3 .42c2.32-1.6 3.32-1.26 3.32-1.26.66 1.78.24 3.1.12 3.42.76.86 1.22 1.96 1.22 3.3 0 4.68-2.8 5.68-5.48 6 .44.38.84 1.12.84 2.28v3.38c0 .32.22.74.84.6A12.02 12.02 0 0 0 24 12.76 12 12 0 0 0 12 .5Z"/>
453
+ </svg>
454
+ GitHub Repo
455
+ </a>
456
  </div>
457
  """)
458
+
459
  # ───────────── Guidelines Card ─────────────
 
460
  gr.HTML(
461
  """
462
+ <div class="card" id="guidelines-card" style="margin-bottom:24px">
463
+ <h2>Guidelines for Users</h2>
464
+ <ul style="list-style:decimal;">
465
  <li><strong>Convert protein structure into a structure-aware sequence:</strong>
466
  Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
467
  sequence will be generated using
 
470
  <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
471
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
472
  <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
473
+ please first visit the
474
  <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
475
  or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
476
+ to download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
477
+ <li><strong>Drug input supports both SELFIES and SMILES:</strong>
478
+ Enter a SELFIES string directly, or paste a SMILES string. SMILES will
479
+ be converted to SELFIES using the
480
  <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
481
  If conversion fails, a red error message will be displayed.</li>
482
+ <li>Optionally enter a <strong>1-based</strong> drug atom/substructure index
483
  to highlight the Top-30 interacting protein residues.</li>
484
+ <li>After inference, use β€œDownload PDF” to export a high-resolution vector figure.</li>
 
485
  </ul>
486
  </div>
487
+ """
488
+ )
489
+
490
  # ───────────── Input Card ─────────────
491
  with gr.Column(elem_id="input-card", elem_classes="card"):
 
492
  protein_seq = gr.Textbox(
493
  label="Protein Structure-aware Sequence",
494
  lines=3,
 
513
 
514
  # ───────────── Action Buttons ─────────────
515
  with gr.Row(elem_id="action-buttons", equal_height=True):
516
+ btn_extract = gr.Button("Extract sequence", variant="primary", elem_id="extract-btn")
517
+ btn_infer = gr.Button("Inference", variant="primary", elem_id="inference-btn")
 
 
 
 
 
 
 
 
518
  with gr.Row():
519
+ clear_btn = gr.Button("Clear", variant="secondary", elem_id="clear-btn")
 
 
 
 
 
520
 
521
+ # ───────────── Output Visualisation ─────────────
522
  output_html = gr.HTML(elem_id="result-html")
523
 
524
  # ───────────── Event Wiring ─────────────
 
533
  outputs=[output_html]
534
  )
535
  clear_btn.click(
536
+ fn=clear_cb,
537
  inputs=[],
538
  outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
539
  )