wangleiofficial commited on
Commit
01618b5
·
verified ·
1 Parent(s): ee2bd6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -503
app.py CHANGED
@@ -1,13 +1,7 @@
1
- # app_locpred_prok.py
2
  import os
3
- import re
4
  import json
 
5
  import uuid
6
- import io
7
- import shutil
8
- import base64
9
- from typing import Tuple
10
-
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
@@ -16,32 +10,30 @@ import matplotlib.pyplot as plt
16
  import numpy as np
17
  from transformers import AutoTokenizer, AutoModel
18
 
19
- # Optional server-side PDF export dependency
20
- try:
21
- import cairosvg
22
- CAIROSVG_AVAILABLE = True
23
- except Exception:
24
- CAIROSVG_AVAILABLE = False
25
-
26
- # ========== Environment (same cache handling as before) ==========
27
- plt.switch_backend('Agg')
28
  os.environ["HF_HOME"] = "/tmp/hf_cache"
29
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
30
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
31
 
 
32
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
33
  shutil.rmtree(path, ignore_errors=True)
34
  os.makedirs(path, exist_ok=True)
35
 
36
- # ========== Model architecture (same as you had) ==========
 
 
37
  class AttentionPooling(nn.Module):
38
  def __init__(self, d_model):
39
  super().__init__()
40
  self.attention_net = nn.Linear(d_model, 1)
41
 
42
  def forward(self, x, mask):
43
- attn_logits = self.attention_net(x).squeeze(2)
44
- attn_logits = attn_logits.masked_fill(mask == 0, -1e9)
45
  attn_weights = F.softmax(attn_logits, dim=1)
46
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
47
 
@@ -54,11 +46,7 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
54
  self.tok_projector = nn.Linear(d_model, projection_dim)
55
  fused_dim = projection_dim * 2
56
  self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
57
- self.classifier_head = nn.Sequential(nn.LayerNorm(fused_dim),
58
- nn.Linear(fused_dim, fused_dim * 2),
59
- nn.ReLU(),
60
- nn.Dropout(dropout),
61
- nn.Linear(fused_dim * 2, num_classes))
62
 
63
  def forward(self, cls_embedding, token_embeddings, mask):
64
  z_cls = self.cls_projector(cls_embedding)
@@ -71,527 +59,259 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
71
  z_fused_gated = z_fused_concat * gate_values
72
  return self.classifier_head(z_fused_gated), pooling_weights
73
 
74
- # ========== Load models (keep same variable names so your logic works) ==========
 
 
75
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
77
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
78
  LABEL_MAP_PATH = "label_map.json"
79
 
80
- # If you want to test UI without model files, set MOCK_MODE = True
81
- MOCK_MODE = False
82
-
83
- if not MOCK_MODE:
84
- if not os.path.exists(LABEL_MAP_PATH):
85
- raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
86
- if not os.path.exists(CLASSIFIER_PATH):
87
- raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
88
-
89
- with open(LABEL_MAP_PATH, 'r') as f:
90
- label_to_idx = json.load(f)
91
- idx_to_label = {v: k for k, v in label_to_idx.items()}
92
- NUM_CLASSES = len(idx_to_label)
93
- D_MODEL = 640
94
-
95
- print("🔹 Loading models...")
96
- tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
97
- plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
98
- classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
99
- classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
100
- classifier.eval()
101
- print("✅ Models loaded.")
102
- else:
103
- # Mock objects for UI development
104
- idx_to_label = {0: "Cytoplasm", 1: "Inner Membrane", 2: "Periplasm", 3: "Outer Membrane", 4: "Cell Wall", 5: "Extracellular"}
105
- NUM_CLASSES = len(idx_to_label)
106
- tokenizer = None
107
- plm_model = None
108
- classifier = None
109
-
110
- # ========== SVG generator (UniProt-like) ==========
111
- def generate_uniprot_style_svg(pred_label: str,
112
- gram: str = "negative",
113
- theme: str = "uniprot-blue",
114
- layout: str = "circular",
115
- high_res: bool = False,
116
- uid: str = None) -> str:
117
- """
118
- Create a UniProt-like bacterial localization diagram:
119
- - pred_label: predicted top class name (used to highlight)
120
- - gram: 'negative' or 'positive'
121
- - theme: 'uniprot-blue', 'red-highlight', 'auto' (auto uses prefers-color-scheme)
122
- - layout: 'circular' (capsule) or 'horizontal' (cell left, labels right)
123
- - high_res: True -> larger viewBox for higher-quality PNG/PDF
124
- Returns: HTML string containing responsive SVG and download buttons.
125
- """
126
-
127
- target = pred_label.lower() if pred_label else ""
128
- is_active = {
129
- "sec": ("extracellular" in target) or ("secreted" in target),
130
- "om": ("outer membrane" in target),
131
- "peri": ("periplasm" in target),
132
- "cw": ("cell wall" in target),
133
- "im": ("inner membrane" in target) or ("plasma membrane" in target),
134
- "cyto": ("cytoplasm" in target) or ("cytosol" in target)
135
- }
136
-
137
- # If gram-positive, there is no outer membrane and cell wall is thicker
138
- have_outer_membrane = (gram == "negative")
139
-
140
- # color themes using CSS variables (supports prefers-color-scheme)
141
- css_vars = {
142
- "uniprot-blue": {
143
- "--om-fill": "#F5F7FA", "--im-fill": "#FFFFFF", "--stroke": "#607D8B",
144
- "--muted": "#B0BEC5", "--text": "#263238", "--highlight": "#0288D1"
145
- },
146
- "red-highlight": {
147
- "--om-fill": "#FFEBEE", "--im-fill": "#FFFFFF", "--stroke": "#607D8B",
148
- "--muted": "#B0BEC5", "--text": "#263238", "--highlight": "#D32F2F"
149
- }
150
- }
151
-
152
- selected = css_vars.get(theme if theme in css_vars else "uniprot-blue")
153
-
154
- # Unique IDs for DOM elements to allow multiple diagrams on page
155
- uid = uid or str(uuid.uuid4()).replace("-", "")[:10]
156
- svg_id = f"loc_svg_{uid}"
157
-
158
- # Sizes
159
- if high_res:
160
- W, H = 1600, 800
161
- else:
162
- W, H = 800, 420
163
-
164
- # anchor coords and label positions (tuned for viewBox)
165
- center_x = int(W * 0.38)
166
- center_y = int(H * 0.5)
167
- label_x = int(W * 0.76)
168
- label_x_left = int(W * 0.58)
169
-
170
- label_y_map = {
171
- "sec": int(H * 0.12),
172
- "om": int(H * 0.22),
173
- "peri": int(H * 0.32),
174
- "cw": int(H * 0.42),
175
- "im": int(H * 0.62),
176
- "cyto": int(H * 0.78)
177
  }
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  anchors = {
180
- "sec": (center_x + 160, center_y - 160),
181
- "om": (center_x + 160, center_y - 100),
182
- "peri":(center_x + 140, center_y - 40),
183
- "cw": (center_x + 120, center_y + 10),
184
- "im": (center_x + 80, center_y + 60),
185
- "cyto":(center_x + 20, center_y + 70)
186
  }
187
 
188
- # Build helper for connector group
189
- def connector_svg(key, text):
190
- ex, ey = anchors[key]
191
- tx, ty = label_x, label_y_map[key]
192
- # styling depends on active
193
- active = is_active.get(key, False)
194
- stroke_color = selected["--highlight"] if active else selected["--muted"]
195
- stroke_w = "2.6" if active else "1.4"
196
- fontw = "700" if active else "400"
197
- dot_r = "6" if active else "4"
198
- path = f"M {tx-20} {ty-6} C {tx-100} {ty-6}, {ex+60} {ey+10}, {ex} {ey}"
199
- return f"""
200
- <g class="connector connector-{key}">
201
- <text x="{tx}" y="{ty}" fill="{selected['--text']}" font-weight="{fontw}" font-size="{14 if not high_res else 22}" font-family="Inter, Arial">{text}</text>
202
- <path d="{path}" fill="none" stroke="{stroke_color}" stroke-width="{stroke_w}" stroke-linecap="round" stroke-linejoin="round" />
203
- <circle cx="{ex}" cy="{ey}" r="{dot_r}" fill="{stroke_color}" stroke="white" stroke-width="{1 if not high_res else 1.6}" />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  </g>
205
  """
206
 
207
- # Build cell shapes: capsule-like curves - simpler parametric shapes
208
- # Outer membrane (or single outer layer for gram-positive), cell wall, inner membrane
209
- # Different shapes when gram-positive: thicker cell wall, no outer membrane ring.
210
- # We'll draw with bezier-like path strings tuned to look UniProt-ish.
211
- if have_outer_membrane:
212
- om_fill = "var(--om-fill)"
213
- om_stroke = "var(--stroke)"
214
- cw_fill = "none"
215
- cw_stroke = "var(--muted)"
216
- im_fill = "var(--im-fill)"
217
- im_stroke = "var(--stroke)"
218
- else:
219
- # gram-positive: cell wall thicker and outer membrane absent
220
- om_fill = "none"
221
- om_stroke = "none"
222
- cw_fill = "var(--om-fill)"
223
- cw_stroke = "var(--stroke)"
224
- im_fill = "var(--im-fill)"
225
- im_stroke = "var(--stroke)"
226
-
227
- # layer highlight override if active
228
- def stroke_override(key, base):
229
- if is_active.get(key, False):
230
- return selected["--highlight"]
231
- return base
232
-
233
- # inline CSS for animations and hover effects
234
- svg_style = f"""
235
- <style>
236
- /* theme vars */
237
- :root {{
238
- --om-fill: {selected['--om-fill']};
239
- --im-fill: {selected['--im-fill']};
240
- --stroke: {selected['--stroke']};
241
- --muted: {selected['--muted']};
242
- --text: {selected['--text']};
243
- --highlight: {selected['--highlight']};
244
- }}
245
- @media (prefers-color-scheme: dark) {{
246
- :root {{
247
- --om-fill: #28343a;
248
- --im-fill: #1f2b30;
249
- --stroke: #90a4ae;
250
- --muted: #546e7a;
251
- --text: #e0f2f1;
252
- }}
253
- }}
254
-
255
- /* connector hover: slightly thicken the path and enlarge dot */
256
- .connector path {{ transition: stroke-width 180ms ease, stroke 180ms ease; opacity:0.95; }}
257
- .connector circle {{ transition: r 160ms ease, transform 160ms ease; transform-origin: center; }}
258
- .connector text {{ transition: fill 160ms ease; }}
259
-
260
- /* on hover of group, emphasize */
261
- .connector:hover path {{ stroke-width: calc(var(--hover-w, 3)); opacity:1; filter: drop-shadow(0 2px 2px rgba(0,0,0,0.06)); }}
262
- .connector:hover circle {{ transform: scale(1.25); }}
263
- /* subtle floating animation for lines */
264
- .connector path {{}}
265
- @keyframes floatx {{ 0% {{ transform: translateX(0px); }} 50% {{ transform: translateX(1px); }} 100% {{ transform: translateX(0px); }} }}
266
- .connector path {{ animation: floatx 4s ease-in-out infinite; animation-delay: calc(var(--i, 0) * 0.12s); opacity:0.95; }}
267
-
268
- /* make the whole svg responsive */
269
- svg {{ max-width: 100%; height: auto; display:block; }}
270
-
271
- /* layer highlight when active: add glow */
272
- .layer-active {{ filter: drop-shadow(0 4px 8px rgba(0,0,0,0.08)); }}
273
- </style>
274
- """
275
-
276
- # Compose SVG core shapes (simplified, but tuned coordinates)
277
- # We use path shapes with translated center for convenience.
278
- cell_shapes = ""
279
- # Outer membrane / envelope
280
- if have_outer_membrane:
281
- cell_shapes += f'''
282
- <g id="outer_membrane" class="layer {'layer-active' if is_active['om'] else ''}">
283
- <ellipse cx="{center_x}" cy="{center_y}" rx="{220 if not high_res else 440}" ry="{170 if not high_res else 340}"
284
- fill="var(--om-fill)" stroke="{stroke_override('om', 'var(--stroke)')}" stroke-width="{3 if is_active['om'] else 2}"/>
285
- </g>
286
- '''
287
- # Cell wall (dashed)
288
- cell_shapes += f'''
289
- <g id="cell_wall">
290
- <ellipse cx="{center_x}" cy="{center_y}" rx="{190 if not high_res else 380}" ry="{150 if not high_res else 300}"
291
- fill="none" stroke="{stroke_override('cw','var(--muted)')}" stroke-width="{4 if is_active['cw'] else 2}" stroke-dasharray="10 6"/>
292
- </g>
293
- '''
294
- # inner membrane
295
- cell_shapes += f'''
296
- <g id="inner_membrane" class="layer {'layer-active' if is_active['im'] else ''}">
297
- <ellipse cx="{center_x}" cy="{center_y}" rx="{140 if not high_res else 280}" ry="{100 if not high_res else 200}"
298
- fill="var(--im-fill)" stroke="{stroke_override('im','var(--stroke)')}" stroke-width="{3 if is_active['im'] else 1.8}"/>
299
- </g>
300
- '''
301
- else:
302
- # Gram positive: thick cell wall as filled ellipse + inner membrane
303
- cell_shapes += f'''
304
- <g id="cell_wall_gp" class="layer {'layer-active' if is_active['cw'] else ''}">
305
- <ellipse cx="{center_x}" cy="{center_y}" rx="{230 if not high_res else 460}" ry="{180 if not high_res else 360}"
306
- fill="{selected['--om-fill']}" stroke="{stroke_override('cw','var(--stroke)')}" stroke-width="{3 if is_active['cw'] else 2}"/>
307
- </g>
308
- <g id="inner_membrane" class="layer {'layer-active' if is_active['im'] else ''}">
309
- <ellipse cx="{center_x}" cy="{center_y}" rx="{150 if not high_res else 300}" ry="{110 if not high_res else 220}"
310
- fill="var(--im-fill)" stroke="{stroke_override('im','var(--stroke)')}" stroke-width="{2 if is_active['im'] else 1.4}"/>
311
- </g>
312
- '''
313
-
314
- # cytoplasm ornament
315
- cell_shapes += f'''
316
- <g id="cytoplasm_wiggles" opacity="0.65">
317
- <path d="M {center_x-60} {center_y+10} q 30 -50 70 0 q 30 50 70 0" stroke="var(--muted)" stroke-width="6" fill="none" stroke-linecap="round"/>
318
- <circle cx="{center_x-40}" cy="{center_y+40}" r="{3 if not high_res else 6}" fill="var(--muted)"/>
319
- <circle cx="{center_x+20}" cy="{center_y+50}" r="{3 if not high_res else 6}" fill="var(--muted)"/>
320
- </g>
321
- '''
322
-
323
- # connectors
324
- connectors = ""
325
- connectors += connector_svg("sec", "Extracellular / Secreted")
326
- # outer membrane only if present
327
- if have_outer_membrane:
328
- connectors += connector_svg("om", "Outer Membrane")
329
- connectors += connector_svg("peri", "Periplasm")
330
- connectors += connector_svg("cw", "Cell Wall")
331
- connectors += connector_svg("im", "Inner Membrane")
332
- connectors += connector_svg("cyto", "Cytoplasm")
333
-
334
- # Build download buttons and client-side JS to download SVG and PNG
335
- # PDF export will call a Gradio server endpoint (provided below)
336
- html = f"""
337
- <div style="width:100%; text-align:center;">
338
- {svg_style}
339
- <svg id="{svg_id}" viewBox="0 0 {W} {H}" xmlns="http://www.w3.org/2000/svg" role="img" aria-label="Bacterial localization diagram">
340
- <defs>
341
- <style><![CDATA[
342
- text {{ font-family: Inter, Arial, sans-serif; }}
343
- ]]></style>
344
- </defs>
345
- {cell_shapes}
346
- {connectors}
347
- </svg>
348
-
349
- <div style="margin-top:8px; display:flex; gap:8px; justify-content:center; align-items:center;">
350
- <button id="download_svg_{uid}" class="download-btn">Download SVG</button>
351
- <button id="download_png_{uid}" class="download-btn">Download PNG</button>
352
- <button id="download_pdf_{uid}" class="download-btn">Download PDF</button>
353
- <div style="font-size:12px; color:var(--text); align-self:center;">{gram.title()} · {layout.title()} {'· High-res' if high_res else ''}</div>
354
- </div>
355
  </div>
356
-
357
  <script>
358
- (function(){{
359
- const svgEl = document.getElementById("{svg_id}");
360
- const btnSvg = document.getElementById("download_svg_{uid}");
361
- const btnPng = document.getElementById("download_png_{uid}");
362
- const btnPdf = document.getElementById("download_pdf_{uid}");
363
-
364
- function downloadFile(filename, blob) {{
365
- const url = URL.createObjectURL(blob);
366
- const a = document.createElement('a');
367
- a.href = url; a.download = filename; document.body.appendChild(a); a.click();
368
- setTimeout(()=>{{ URL.revokeObjectURL(url); a.remove(); }}, 200);
369
- }}
370
-
371
- btnSvg.addEventListener('click', ()=>{{
372
- const serializer = new XMLSerializer();
373
- let source = serializer.serializeToString(svgEl);
374
- if(!source.match(/^<svg[^>]+xmlns="http:\\/\\/www.w3.org\\/2000\\/svg"/)) {{
375
- source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
376
- }}
377
- const blob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
378
- downloadFile('locpred_diagram.svg', blob);
379
- }});
380
-
381
- btnPng.addEventListener('click', ()=>{{
382
- const serializer = new XMLSerializer();
383
- let source = serializer.serializeToString(svgEl);
384
- if(!source.match(/^<svg[^>]+xmlns="http:\\/\\/www.w3.org\\/2000\\/svg"/)) {{
385
- source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
386
- }}
387
- const svgBlob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
388
- const url = URL.createObjectURL(svgBlob);
389
- const img = new Image();
390
- img.onload = function() {{
391
- const canvas = document.createElement('canvas');
392
- // scale 2x for higher quality
393
- const scale = 2;
394
- canvas.width = img.width * scale;
395
- canvas.height = img.height * scale;
396
- const ctx = canvas.getContext('2d');
397
- // optional white background
398
- ctx.fillStyle = "white";
399
- ctx.fillRect(0,0,canvas.width,canvas.height);
400
- ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
401
- canvas.toBlob(function(blob) {{
402
- downloadFile('locpred_diagram.png', blob);
403
- }}, 'image/png');
404
- URL.revokeObjectURL(url);
405
- }};
406
- img.onerror = function(e) {{
407
- alert('Failed to render PNG in your browser.');
408
- URL.revokeObjectURL(url);
409
- }};
410
- img.src = url;
411
- }});
412
-
413
- btnPdf.addEventListener('click', async ()=>{{
414
- // send the SVG string to the server /gradio route for PDF conversion
415
- const serializer = new XMLSerializer();
416
- let source = serializer.serializeToString(svgEl);
417
- if(!source.match(/^<svg[^>]+xmlns="http:\\/\\/www.w3.org\\/2000\\/svg"/)) {{
418
- source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
419
- }}
420
- // call Gradio server function via fetch to /convert_svg_to_pdf (provided below)
421
- try {{
422
- const resp = await fetch('/convert_svg_to_pdf', {{
423
- method: 'POST',
424
- headers: {{ 'Content-Type': 'application/json' }},
425
- body: JSON.stringify({{ svg: source }})
426
- }});
427
- if(!resp.ok) {{
428
- const txt = await resp.text();
429
- alert('PDF conversion failed: ' + txt);
430
- return;
431
- }}
432
- const blob = await resp.blob();
433
- downloadFile('locpred_diagram.pdf', blob);
434
- }} catch (err) {{
435
- alert('PDF conversion failed: ' + err);
436
  }}
437
- }});
438
- }})();
439
- </script>
440
- """
441
-
442
  return html
443
 
444
- # ========== Heatmap (same as before) ==========
 
 
445
  def draw_attention_heatmap_strip(weights, sequence):
 
446
  if weights.max() > 0:
447
  weights = (weights - weights.min()) / (weights.max() - weights.min())
 
 
448
  data = weights.reshape(1, -1)
449
- fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150)
 
450
  im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
451
- ax.set_title('Sequence Attention Heatmap (High Color = Key Feature)', fontsize=10, fontweight='bold', color='#37474F', pad=6)
452
- ax.set_xlabel('Residue Position', fontsize=9)
453
- ax.set_yticks([])
454
- cbar = plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
455
- cbar.ax.tick_params(labelsize=8)
456
- cbar.outline.set_visible(False)
457
- for spine in ax.spines.values():
458
- spine.set_visible(False)
459
  plt.tight_layout()
460
  return fig
461
 
462
- # ========== Prediction function ==========
463
- def predict(sequence_input: str, gram_choice: str, theme_choice: str, layout_choice: str, high_res_flag: bool):
464
- """
465
- Returns:
466
- - confidences dict (for Label)
467
- - svg html string (for HTML output)
468
- - attention heatmap figure (for Plot)
469
- """
470
- if not sequence_input or sequence_input.isspace():
471
- raise gr.Error("Empty Input")
472
-
473
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
474
  seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
475
- if not seq:
476
- raise gr.Error("Invalid Sequence")
477
-
478
- if MOCK_MODE:
479
- # mock probabilities
480
- probs = torch.softmax(torch.randn(NUM_CLASSES), dim=0)
481
- logits = probs
482
- pooling_weights = np.abs(np.random.randn(len(seq))).astype(float)
483
- pooling_weights = pooling_weights / pooling_weights.sum()
484
- top_label = idx_to_label[int(torch.argmax(probs).item())]
485
- else:
486
- with torch.no_grad():
487
- inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
488
- outputs = plm_model(**inputs)
489
- hidden_states = outputs.last_hidden_state
490
- cls_embedding = hidden_states[:, 0, :]
491
- token_embeddings = hidden_states[:, 1:-1, :]
492
- token_mask = inputs['attention_mask'][:, 1:-1]
493
- logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
494
- probs = F.softmax(logits, dim=1)[0]
495
- top_label = idx_to_label[torch.argmax(probs).item()]
496
- pooling_weights = pooling_weights[0].cpu().numpy()
497
-
498
- confidences = { idx_to_label[i]: float(p) for i,p in enumerate(probs) } if not MOCK_MODE else { idx_to_label[i]: float(p) for i,p in enumerate(probs)}
499
- svg_html = generate_uniprot_style_svg(top_label, gram=gram_choice, theme=theme_choice, layout=layout_choice, high_res=high_res_flag)
500
- heatmap_fig = draw_attention_heatmap_strip(np.array(pooling_weights), seq)
501
-
502
- return confidences, svg_html, heatmap_fig
503
-
504
- # ========== Server-side PDF conversion endpoint for Gradio ==========
505
- # This function will be exposed at /convert_svg_to_pdf when app launches.
506
- def convert_svg_to_pdf_endpoint(svg_str: str):
507
- """
508
- Convert an SVG string to PDF bytes using cairosvg (if available).
509
- Return bytes-like object (PDF) or raise error.
510
- """
511
- if not CAIROSVG_AVAILABLE:
512
- raise RuntimeError("Server-side PDF conversion requires 'cairosvg' package. Install with: pip install cairosvg")
513
-
514
- # cairosvg.svg2pdf can take bytes or string
515
- pdf_bytes = cairosvg.svg2pdf(bytestring=svg_str.encode('utf-8'))
516
- return ("locpred_diagram.pdf", pdf_bytes)
517
-
518
- # ========== UI layout (Gradio) ==========
519
  layout_css = """
520
- /* small customizations and CSS variables fallback */
521
- :root { --panel-bg: white; }
522
- .gradio-container { max-width: 1200px; margin: 0 auto; }
523
- .download-btn { padding:8px 12px; border-radius:6px; border:1px solid #D1D5DB; background:transparent; cursor:pointer; }
524
- .download-btn:hover { box-shadow: 0 4px 14px rgba(16,24,40,0.08); }
525
- .panel-card { border:1px solid #e6eef5; border-radius:8px; padding:12px; background:var(--panel-bg); }
526
- .panel-header { font-weight:700; color:#475569; border-bottom:2px solid #f1f5f9; padding-bottom:8px; margin-bottom:10px; }
 
527
  """
528
 
529
  theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
530
 
531
- with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok (UniProt-style)") as app:
532
- gr.Markdown("<div style='font-size:22px; font-weight:800; color:#0288D1;'>LocPred-Prok — UniProt-style visualization</div>")
 
 
 
533
  with gr.Row():
534
- with gr.Column(scale=6):
535
- gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>A</span>Sequence Input</div>")
536
- sequence_input = gr.Textbox(lines=8, placeholder=">Sequence (single-letter amino acids) or paste raw sequence", show_label=False)
537
  with gr.Row():
538
  clear_btn = gr.ClearButton(sequence_input, value="Clear")
539
  submit_btn = gr.Button("Predict Analysis", variant="primary")
540
- with gr.Row():
541
- gram_choice = gr.Radio(choices=["negative", "positive"], value="negative", label="Gram type")
542
- theme_choice = gr.Radio(choices=["uniprot-blue", "red-highlight"], value="uniprot-blue", label="Color Theme")
543
- layout_choice = gr.Radio(choices=["circular", "horizontal"], value="circular", label="Diagram Layout")
544
- high_res_flag = gr.Checkbox(value=False, label="High resolution (bigger SVG/PDF)")
545
- gr.Examples([[">Outer Membrane\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAF..."]], inputs=sequence_input)
546
 
547
- with gr.Column(scale=6):
548
- gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>B</span>Localization Visualization</div>")
549
  output_svg = gr.HTML(label="Visual", show_label=False)
550
 
 
551
  with gr.Row():
552
- with gr.Column(scale=6):
553
- gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>C</span>Prediction Confidence</div>")
554
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
555
- with gr.Column(scale=6):
556
- gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>D</span>Attention Heatmap</div>")
557
- output_plot = gr.Plot(show_label=False)
558
 
559
- submit_btn.click(fn=predict, inputs=[sequence_input, gram_choice, theme_choice, layout_choice, high_res_flag], outputs=[output_label, output_svg, output_plot])
560
- clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
561
-
562
- # Expose PDF conversion endpoint: Gradio allows adding a separate route handler via app.launch later.
563
- # We'll attach the endpoint to the FastAPI app used by Gradio when launching.
564
-
565
- # ========== Run server with custom route for PDF conversion ==========
566
- if __name__ == "__main__":
567
- from fastapi import FastAPI, Request, Response
568
- import uvicorn
569
-
570
- # Build Gradio app to get underlying FastAPI instance
571
- demo = app
572
- # Get the underlying FastAPI app (gradio >= 3.0)
573
- # When launching, we'll mount a custom route /convert_svg_to_pdf handled by convert_svg_to_pdf_endpoint
574
- # Gradio's launch will create a FastAPI object; to avoid internal changes, we use the `server_name` arg.
575
 
576
- # Create a lightweight FastAPI for the PDF endpoint and mount the Gradio interface into it
577
- fast_app = FastAPI()
578
-
579
- @fast_app.post("/convert_svg_to_pdf")
580
- async def convert_svg_to_pdf_api(request: Request):
581
- payload = await request.json()
582
- svg = payload.get("svg", None)
583
- if not svg:
584
- return Response(content="No svg provided", status_code=400)
585
- if not CAIROSVG_AVAILABLE:
586
- return Response(content="Server-side PDF conversion unavailable: install 'cairosvg' in the server environment.", status_code=501)
587
- try:
588
- pdf_bytes = cairosvg.svg2pdf(bytestring=svg.encode('utf-8'))
589
- return Response(content=pdf_bytes, media_type="application/pdf")
590
- except Exception as e:
591
- return Response(content=f"PDF conversion error: {e}", status_code=500)
592
-
593
- # Mount the Gradio interface at root
594
- gr.mount_gradio_app(fast_app, demo, path="/")
595
 
596
- # Launch uvicorn with the FastAPI app
597
- uvicorn.run(fast_app, host="0.0.0.0", port=7860, log_level="info")
 
 
1
  import os
 
2
  import json
3
+ import re
4
  import uuid
 
 
 
 
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
 
10
  import numpy as np
11
  from transformers import AutoTokenizer, AutoModel
12
 
13
+ # ==========================
14
+ # 0. 环境初始化
15
+ # ==========================
16
+ plt.switch_backend('Agg')
 
 
 
 
 
17
  os.environ["HF_HOME"] = "/tmp/hf_cache"
18
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
19
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
20
 
21
+ import shutil
22
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
23
  shutil.rmtree(path, ignore_errors=True)
24
  os.makedirs(path, exist_ok=True)
25
 
26
+ # ==========================
27
+ # 1. 模型架构
28
+ # ==========================
29
  class AttentionPooling(nn.Module):
30
  def __init__(self, d_model):
31
  super().__init__()
32
  self.attention_net = nn.Linear(d_model, 1)
33
 
34
  def forward(self, x, mask):
35
+ attn_logits = self.attention_net(x).squeeze(2)
36
+ attn_logits.masked_fill_(mask == 0, -float('inf'))
37
  attn_weights = F.softmax(attn_logits, dim=1)
38
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
39
 
 
46
  self.tok_projector = nn.Linear(d_model, projection_dim)
47
  fused_dim = projection_dim * 2
48
  self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
49
+ self.classifier_head = nn.Sequential(nn.LayerNorm(fused_dim), nn.Linear(fused_dim, fused_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fused_dim * 2, num_classes))
 
 
 
 
50
 
51
  def forward(self, cls_embedding, token_embeddings, mask):
52
  z_cls = self.cls_projector(cls_embedding)
 
59
  z_fused_gated = z_fused_concat * gate_values
60
  return self.classifier_head(z_fused_gated), pooling_weights
61
 
62
+ # ==========================
63
+ # 2. 加载模型
64
+ # ==========================
65
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
67
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
68
  LABEL_MAP_PATH = "label_map.json"
69
 
70
+ if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
71
+ if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
72
+
73
+ with open(LABEL_MAP_PATH, 'r') as f:
74
+ label_to_idx = json.load(f)
75
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
76
+ NUM_CLASSES = len(idx_to_label)
77
+ D_MODEL = 640
78
+
79
+ print("🔹 Loading models...")
80
+ tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
81
+ plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
82
+ classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
83
+ classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
84
+ classifier.eval()
85
+ print(" Ready.")
86
+
87
+ # ==========================
88
+ # 3. Panel B: SVG 绘图引擎 (6标签标准版)
89
+ # ==========================
90
+ def generate_scientific_svg(target_class):
91
+ target = target_class.lower() if target_class else ""
92
+
93
+ # 状态判断 (6类)
94
+ is_sec = "extracellular" in target or "secreted" in target
95
+ is_om = "outer membrane" in target
96
+ is_peri = "periplasm" in target
97
+ is_cw = "cell wall" in target
98
+ is_im = "plasma membrane" in target or "inner membrane" in target
99
+ is_cyto = "cytoplasm" in target or "cytosol" in target
100
+
101
+ # 颜色配置 (Nature Style)
102
+ c = {
103
+ 'hl_stroke': '#D32F2F', 'hl_fill': '#FFEBEE', 'hl_text': '#B71C1C', 'hl_dot': '#D32F2F',
104
+ 'bg_stroke': '#90A4AE', 'bg_fill': '#F9FAFB', # 极淡灰
105
+ 'bg_text': '#78909C', 'bg_line': '#CFD8DC', 'bg_dot': '#B0BEC5'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  }
107
 
108
+ # 几何参数
109
+ svg_id = f"svg_{str(uuid.uuid4())[:8]}"
110
+ cx, cy = 300, 210 # 细菌中心
111
+ tx = 620 # 标签 X 坐标
112
+
113
+ # --- 1. 绘制细菌主体 (默认使用 G- 结构以展示全要素) ---
114
+ shapes = ""
115
+
116
+ # Outer Membrane
117
+ col_om = c['hl_stroke'] if is_om else c['bg_stroke']
118
+ fill_om = c['hl_fill'] if is_peri else c['bg_fill']
119
+ w_om = "4" if is_om else "2"
120
+ shapes += f'<rect x="{cx-200}" y="{cy-120}" width="400" height="240" rx="120" ry="120" fill="{fill_om}" stroke="{col_om}" stroke-width="{w_om}" />'
121
+
122
+ # Cell Wall (Dashed)
123
+ col_cw = c['hl_stroke'] if is_cw else '#B0BEC5'
124
+ w_cw = "3" if is_cw else "1.5"
125
+ dash_cw = "0" if is_cw else "6,4"
126
+ shapes += f'<rect x="{cx-170}" y="{cy-90}" width="340" height="180" rx="90" ry="90" fill="none" stroke="{col_cw}" stroke-width="{w_cw}" stroke-dasharray="{dash_cw}" />'
127
+
128
+ # Inner Membrane & Cytoplasm
129
+ col_im = c['hl_stroke'] if is_im else c['bg_stroke']
130
+ fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
131
+ w_im = "4" if is_im else "2"
132
+ shapes += f'<rect x="{cx-140}" y="{cy-60}" width="280" height="120" rx="60" ry="60" fill="{fill_im}" stroke="{col_im}" stroke-width="{w_im}" />'
133
+
134
+ # DNA Decoration
135
+ shapes += f"""<g opacity="0.4">
136
+ <path d="M {cx-30} {cy-10} Q {cx} {cy-50} {cx+30} {cy-10} T {cx+60} {cy}" fill="none" stroke="#CFD8DC" stroke-width="3" />
137
+ <circle cx="{cx-40}" cy="{cy+20}" r="3" fill="#B0BEC5" /> <circle cx="{cx+20}" cy="{cy+30}" r="3" fill="#B0BEC5" />
138
+ </g>"""
139
+
140
+ # --- 2. 标签系统 (6个完整标签 + 贝塞尔曲线) ---
141
+
142
+ # 锚点目标坐标 (Target Anchor Points)
143
  anchors = {
144
+ "sec": (cx, cy - 160), # 胞外 (悬浮)
145
+ "om": (cx + 200, cy - 60), # 外膜边界
146
+ "peri": (cx + 180, cy - 30), # 周质间隙
147
+ "cw": (cx + 170, cy), # 细胞壁
148
+ "im": (cx + 140, cy + 30), # 内膜边界
149
+ "cyto": (cx, cy) # 胞质中心
150
  }
151
 
152
+ # 标签配置
153
+ labels_config = [
154
+ ("Extracellular", "sec", is_sec),
155
+ ("Outer Membrane", "om", is_om),
156
+ ("Periplasm", "peri", is_peri),
157
+ ("Cell Wall", "cw", is_cw),
158
+ ("Inner Membrane", "im", is_im),
159
+ ("Cytoplasm", "cyto", is_cyto)
160
+ ]
161
+
162
+ label_svg = ""
163
+ y_start = 50
164
+ y_step = 60 # 间距
165
+
166
+ for i, (text, key, active) in enumerate(labels_config):
167
+ ty = y_start + i * y_step
168
+ ex, ey = anchors.get(key, (0,0))
169
+
170
+ # 样式
171
+ col_txt = c['hl_text'] if active else c['bg_text']
172
+ w_txt = "bold" if active else "normal"
173
+ col_line = c['hl_stroke'] if active else c['bg_line']
174
+ w_line = "2.5" if active else "1.0"
175
+ col_dot = c['hl_dot'] if active else c['bg_dot']
176
+ r_dot = "5" if active else "3"
177
+
178
+ # 贝塞尔 S 形曲线
179
+ # c1: 从文字左侧水平延伸; c2: 向锚点垂直延伸
180
+ c1x, c1y = tx - 80, ty
181
+ c2x, c2y = ex + 60, ey
182
+ path_d = f"M {tx-10} {ty-5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
183
+
184
+ label_svg += f"""
185
+ <g>
186
+ <text x="{tx}" y="{ty}" fill="{col_txt}" font-weight="{w_txt}" font-size="14" font-family="Arial">{text}</text>
187
+ <path d="{path_d}" fill="none" stroke="{col_line}" stroke-width="{w_line}" />
188
+ <circle cx="{ex}" cy="{ey}" r="{r_dot}" fill="{col_dot}" stroke="white" stroke-width="1" />
189
  </g>
190
  """
191
 
192
+ final_svg = f"""<svg id="{svg_id}" width="100%" height="100%" viewBox="0 0 800 420" xmlns="http://www.w3.org/2000/svg">
193
+ <rect width="800" height="420" fill="white" />
194
+ {shapes}
195
+ {label_svg}
196
+ <text x="400" y="400" text-anchor="middle" font-family="Arial" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {target_class}</text>
197
+ </svg>"""
198
+
199
+ # 嵌入 JS 下载
200
+ html = f"""<div>{final_svg}
201
+ <div style="display:flex; justify-content:center; gap:10px; margin-top:5px;">
202
+ <button onclick="downloadSVG('{svg_id}')" style="font-size:11px; padding:4px 8px; border:1px solid #ccc; border-radius:4px; cursor:pointer;">Download SVG</button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  </div>
 
204
  <script>
205
+ function downloadSVG(id) {{
206
+ const svg = document.getElementById(id);
207
+ const s = new XMLSerializer().serializeToString(svg);
208
+ const b = new Blob([s], {{type: "image/svg+xml;charset=utf-8"}});
209
+ const u = URL.createObjectURL(b);
210
+ const a = document.createElement("a"); a.href = u; a.download = "cell_loc.svg";
211
+ document.body.appendChild(a); a.click(); document.body.removeChild(a);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  }}
213
+ </script></div>"""
 
 
 
 
214
  return html
215
 
216
+ # ==========================
217
+ # 4. Panel D: Attention Heatmap (纯净热图)
218
+ # ==========================
219
  def draw_attention_heatmap_strip(weights, sequence):
220
+ # 归一化
221
  if weights.max() > 0:
222
  weights = (weights - weights.min()) / (weights.max() - weights.min())
223
+
224
+ fig, ax = plt.subplots(figsize=(8, 2), dpi=150) # 稍微加宽
225
  data = weights.reshape(1, -1)
226
+
227
+ # 绘制热图 (Reds)
228
  im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
229
+
230
+ ax.set_title("Sequence Attention Heatmap (Darker = Higher Attention)", fontsize=10, fontweight='bold', color='#37474F', pad=10)
231
+ ax.set_xlabel("Residue Position", fontsize=9)
232
+ ax.set_yticks([]) # 不显示 Y
233
+
234
+ # 隐藏四周边框
235
+ for spine in ax.spines.values(): spine.set_visible(False)
236
+
237
  plt.tight_layout()
238
  return fig
239
 
240
+ # ==========================
241
+ # 5. 预测主逻辑
242
+ # ==========================
243
+ def predict(sequence_input):
244
+ if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
 
 
 
 
 
 
245
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
246
  seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
247
+
248
+ with torch.no_grad():
249
+ inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
250
+ outputs = plm_model(**inputs)
251
+
252
+ logits, pooling_weights = classifier(
253
+ outputs.last_hidden_state[:, 0, :],
254
+ outputs.last_hidden_state[:, 1:-1, :],
255
+ inputs['attention_mask'][:, 1:-1]
256
+ )
257
+ probs = F.softmax(logits, dim=1)[0]
258
+
259
+ top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
260
+ confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
261
+
262
+ # Panel B: SVG
263
+ svg = generate_scientific_svg(top_label)
264
+
265
+ # Panel D: Heatmap (纯净版)
266
+ heatmap = draw_attention_heatmap_strip(pooling_weights[0].cpu().numpy(), seq)
267
+
268
+ return confidences, svg, heatmap
269
+
270
+ # ==========================
271
+ # 6. UI Layout (4-Block Paper Style)
272
+ # ==========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  layout_css = """
274
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
275
+ body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
276
+ .header-div { background: linear-gradient(to right, #E0F7FA, #E1F5FE); padding: 1.5rem; border-radius: 8px; margin-bottom: 20px; text-align: center; border: 1px solid #B3E5FC; }
277
+ .header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
278
+ .header-sub { font-size: 1.0rem; color: #0277BD; }
279
+ .panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
280
+ .panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
281
+ .panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
282
  """
283
 
284
  theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
285
 
286
+ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
287
+
288
+ gr.HTML("""<div class="header-div"><div class="header-title">LocPred-Prok</div><div class="header-sub">Deep Learning Framework for Prokaryotic Subcellular Localization</div></div>""")
289
+
290
+ # Row 1: A & B
291
  with gr.Row():
292
+ with gr.Column(elem_classes="panel-card"):
293
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
294
+ sequence_input = gr.Textbox(lines=8, show_label=False, placeholder=">Sequence...")
295
  with gr.Row():
296
  clear_btn = gr.ClearButton(sequence_input, value="Clear")
297
  submit_btn = gr.Button("Predict Analysis", variant="primary")
298
+ gr.Examples([[">Outer Membrane\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAFGGYQVNPYVGFEMGYDWLGRMPYKGSVENGAYKAQGVQLTAKLGYPITDDLDIYTRLGGMVWRADTKSNVYGKNHDTGVSPVFAGGVEYAITPEIATRLEYQWTNNIGDAHTIGTRPDNGMLSLGVSYRFGQGEAAPVVAPAPAPAPEVQTKHFTLKSDVLFNFNKATLKPEGQAALDQLYSQLSNLDPKDGSVVVLGYTDRIGSDAYNQGLSERRAQSVVDYLISKGIPADKISARGMGESNPVTGNTCDNVKQRAALIDCLAPDRRVEIEVKGIKDVVTQPQA"]], inputs=sequence_input, label=None)
 
 
 
 
 
299
 
300
+ with gr.Column(elem_classes="panel-card"):
301
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
302
  output_svg = gr.HTML(label="Visual", show_label=False)
303
 
304
+ # Row 2: C & D
305
  with gr.Row():
306
+ with gr.Column(elem_classes="panel-card"):
307
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
308
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
 
 
 
309
 
310
+ with gr.Column(elem_classes="panel-card"):
311
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Attention Heatmap</div>")
312
+ output_plot = gr.Plot(label="Attention", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])
315
+ clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ app.launch()