wangleiofficial commited on
Commit
6a80b51
·
verified ·
1 Parent(s): 7d42e13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -66
app.py CHANGED
@@ -85,41 +85,70 @@ classifier.eval()
85
  print("✅ Ready.")
86
 
87
  # ==========================
88
- # 3. 标签映射与 SVG 引擎 (UniProt Font)
89
  # ==========================
90
-
91
- def map_model_label_to_svg_key(model_label):
92
- l = model_label.strip()
 
 
 
 
 
93
  mapping = {
94
- "Extracellular": "sec",
95
- "OuterMembrane": "om",
96
- "Periplasmic": "peri",
97
- "Cellwall": "cw",
98
- "CYtoplasmicMembrane": "im",
99
- "Cytoplasmic": "cyto"
 
 
 
 
 
 
 
 
 
 
 
 
100
  }
101
- if l in mapping: return mapping[l]
102
- l_lower = l.lower()
 
 
 
 
 
103
  for k, v in mapping.items():
104
- if k.lower() == l_lower: return v
105
- return None
 
 
106
 
107
- def infer_gram_type(model_label):
108
- key = map_model_label_to_svg_key(model_label)
109
- if key in ["om", "peri"]: return "negative"
110
- if key == "cw": return "positive"
 
 
 
111
  return "negative"
112
 
113
  def generate_scientific_svg(target_class):
114
- active_key = map_model_label_to_svg_key(target_class)
115
- gram_type = infer_gram_type(target_class)
 
116
 
117
- is_sec = (active_key == "sec")
118
- is_om = (active_key == "om")
119
- is_peri = (active_key == "peri")
120
- is_cw = (active_key == "cw")
121
- is_im = (active_key == "im")
122
- is_cyto = (active_key == "cyto")
 
123
 
124
  c = {
125
  'hl_stroke': '#D32F2F', 'hl_fill': '#FFEBEE', 'hl_text': '#B71C1C', 'hl_dot': '#D32F2F',
@@ -131,19 +160,22 @@ def generate_scientific_svg(target_class):
131
  cx, cy = 300, 210
132
  tx = 620
133
 
134
- # 绘制主体形状
135
  shapes = ""
136
  if gram_type == 'negative':
 
137
  col_om = c['hl_stroke'] if is_om else c['bg_stroke']
138
  fill_om = c['hl_fill'] if is_peri else c['bg_fill']
139
  w_om = "4" if is_om else "2"
140
  shapes += f'<rect x="{cx-200}" y="{cy-120}" width="400" height="240" rx="120" ry="120" fill="{fill_om}" stroke="{col_om}" stroke-width="{w_om}" />'
141
 
 
142
  col_cw = c['hl_stroke'] if is_cw else '#B0BEC5'
143
  w_cw = "3" if is_cw else "1.5"
144
  dash_cw = "0" if is_cw else "6,4"
145
  shapes += f'<rect x="{cx-170}" y="{cy-90}" width="340" height="180" rx="90" ry="90" fill="none" stroke="{col_cw}" stroke-width="{w_cw}" stroke-dasharray="{dash_cw}" />'
146
 
 
147
  col_im = c['hl_stroke'] if is_im else c['bg_stroke']
148
  fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
149
  w_im = "4" if is_im else "2"
@@ -154,6 +186,7 @@ def generate_scientific_svg(target_class):
154
  "cw": (cx+170, cy), "im": (cx+140, cy+30), "cyto": (cx, cy)
155
  }
156
  else:
 
157
  col_cw = c['hl_stroke'] if is_cw else c['bg_stroke']
158
  fill_bg = c['hl_fill'] if is_peri else c['bg_fill']
159
  w_cw = "6" if is_cw else "4"
@@ -178,16 +211,17 @@ def generate_scientific_svg(target_class):
178
  sec_svg = ""
179
  if is_sec:
180
  sec_svg = f"""<g transform="translate({cx}, {cy-170})">
181
- <text x="0" y="0" text-anchor="middle" fill="{c['hl_text']}" font-weight="bold" font-family="'Lato', 'Helvetica Neue', Arial, sans-serif" font-size="14">SECRETED</text>
182
  <path d="M 0 5 L 0 30" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" />
183
  </g>"""
184
 
 
185
  labels_config = [
186
  ("Extracellular", "sec", is_sec),
187
- ("Outer Membrane", "om", is_om),
188
  ("Periplasm", "peri", is_peri),
189
- ("Cell Wall", "cw", is_cw),
190
- ("Inner Membrane", "im", is_im),
191
  ("Cytoplasm", "cyto", is_cyto)
192
  ]
193
  if gram_type == 'positive':
@@ -214,7 +248,7 @@ def generate_scientific_svg(target_class):
214
 
215
  label_svg += f"""
216
  <g>
217
- <text x="{tx}" y="{ty}" fill="{col_txt}" font-weight="{w_txt}" font-size="14" font-family="'Lato', 'Helvetica Neue', Arial, sans-serif">{text}</text>
218
  <path d="{path_d}" fill="none" stroke="{col_line}" stroke-width="{w_line}" />
219
  <circle cx="{ex}" cy="{ey}" r="{r_dot}" fill="{col_dot}" stroke="white" stroke-width="1" />
220
  </g>
@@ -222,17 +256,14 @@ def generate_scientific_svg(target_class):
222
 
223
  final_svg = f"""<svg id="{svg_id}" width="100%" height="100%" viewBox="0 0 800 420" xmlns="http://www.w3.org/2000/svg">
224
  <defs>
225
- <style>
226
- @import url('https://fonts.googleapis.com/css2?family=Lato:wght@400;700&display=swap');
227
- text {{ font-family: 'Lato', 'Helvetica Neue', Arial, sans-serif; }}
228
- </style>
229
  <marker id="arrow_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"><polygon points="0 0, 10 3.5, 0 7" fill="{c['hl_stroke']}" /></marker>
230
  </defs>
231
  <rect width="800" height="420" fill="white" />
232
  {shapes}
233
  {sec_svg}
234
  {label_svg}
235
- <text x="400" y="400" text-anchor="middle" font-family="'Lato', 'Helvetica Neue', Arial, sans-serif" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {target_class}</text>
236
  </svg>"""
237
 
238
  html = f"""<div>{final_svg}
@@ -252,25 +283,71 @@ def generate_scientific_svg(target_class):
252
  return html
253
 
254
  # ==========================
255
- # 4. Attention Heatmap
256
  # ==========================
257
- def draw_attention_heatmap_strip(weights, sequence):
 
 
 
 
258
  if weights.max() > 0:
259
  weights = (weights - weights.min()) / (weights.max() - weights.min())
260
 
261
- # 字体设置 (Matplotlib)
262
- plt.rcParams['font.family'] = 'sans-serif'
263
- plt.rcParams['font.sans-serif'] = ['Lato', 'Helvetica Neue', 'Arial']
264
 
265
- fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150)
266
- data = weights.reshape(1, -1)
267
- im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
268
 
269
- ax.set_title("Sequence Attention Heatmap (Darker = Higher Attention)", fontsize=10, fontweight='bold', color='#37474F', pad=8)
270
- ax.set_xlabel("Residue Position", fontsize=9)
271
- ax.set_yticks([])
272
- for s in ax.spines.values(): s.set_visible(False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  plt.tight_layout()
 
 
274
  return fig
275
 
276
  # ==========================
@@ -287,38 +364,45 @@ def predict(sequence_input):
287
  logits, pooling_weights = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
288
  probs = F.softmax(logits, dim=1)[0]
289
 
290
- top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
291
- confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
 
 
 
 
 
292
 
293
- svg = generate_scientific_svg(top_label)
294
- heatmap = draw_attention_heatmap_strip(pooling_weights[0].cpu().numpy(), seq)
 
 
 
 
 
 
 
 
 
 
295
 
296
  return confidences, svg, heatmap
297
 
298
  # ==========================
299
- # 6. UI Layout (UniProt Style Fonts)
300
  # ==========================
301
  layout_css = """
302
  @import url('https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap');
303
-
304
- /* 全局字体设置为 Lato */
305
- body, button, input, textarea, .gradio-container {
306
- font-family: 'Lato', 'Helvetica Neue', Arial, sans-serif !important;
307
- }
308
 
309
  .header-div { background: linear-gradient(to right, #E0F7FA, #E1F5FE); padding: 1.5rem; border-radius: 8px; margin-bottom: 20px; text-align: center; border: 1px solid #B3E5FC; }
310
- .header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; letter-spacing: -0.5px; }
311
  .header-sub { font-size: 1.0rem; color: #0277BD; }
312
  .panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
313
  .panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
314
  .panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
315
  """
316
 
317
- # 设置 Gradio 主题字体为 Lato
318
- theme = gr.themes.Soft(
319
- primary_hue="sky",
320
- font=[gr.themes.GoogleFont("Lato"), "Helvetica Neue", "Arial", "sans-serif"]
321
- ).set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
322
 
323
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
324
  gr.HTML("""<div class="header-div"><div class="header-title">LocPred-Prok</div><div class="header-sub">Deep Learning Framework for Prokaryotic Subcellular Localization</div></div>""")
@@ -342,7 +426,7 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
342
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
343
 
344
  with gr.Column(elem_classes="panel-card"):
345
- gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Attention Heatmap</div>")
346
  output_plot = gr.Plot(label="Attention", show_label=False)
347
 
348
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])
 
85
  print("✅ Ready.")
86
 
87
  # ==========================
88
+ # 3. 标签标准化映射 (核心修正)
89
  # ==========================
90
+ def clean_label_name(raw_label):
91
+ """
92
+ 将模型原始输出的各类写法(如 OuterMembrane, Cytoplasmic)
93
+ 统一映射为您要求的 6 个标准显示名称。
94
+ """
95
+ raw = raw_label.strip()
96
+
97
+ # 映射字典:Key 是模型可能的输出,Value 是标准显示名称
98
  mapping = {
99
+ # 1. Outer membrane
100
+ "OuterMembrane": "Outer membrane", "Outer membrane": "Outer membrane",
101
+
102
+ # 2. Periplasm
103
+ "Periplasmic": "Periplasm", "Periplasm": "Periplasm",
104
+
105
+ # 3. Cell wall
106
+ "Cellwall": "Cell wall", "Cell wall": "Cell wall",
107
+
108
+ # 4. Cytoplasmic membrane (即 Inner Membrane)
109
+ "CYtoplasmicMembrane": "Cytoplasmic membrane", "Cytoplasmic membrane": "Cytoplasmic membrane",
110
+ "InnerMembrane": "Cytoplasmic membrane", "Inner membrane": "Cytoplasmic membrane",
111
+
112
+ # 5. Cytoplasm
113
+ "Cytoplasmic": "Cytoplasm", "Cytoplasm": "Cytoplasm",
114
+
115
+ # 6. Extracellular
116
+ "Extracellular": "Extracellular", "Secreted": "Extracellular"
117
  }
118
+
119
+ # 尝试直接匹配
120
+ if raw in mapping:
121
+ return mapping[raw]
122
+
123
+ # 尝试忽略大小写匹配
124
+ raw_lower = raw.lower()
125
  for k, v in mapping.items():
126
+ if k.lower() == raw_lower:
127
+ return v
128
+
129
+ return raw # 如果没匹配上,返回原样
130
 
131
+ # ==========================
132
+ # 4. SVG 引擎 (适配标准标签)
133
+ # ==========================
134
+ def infer_gram_type(std_label):
135
+ """基于标准标签推断"""
136
+ if std_label in ["Outer membrane", "Periplasm"]: return "negative"
137
+ if std_label == "Cell wall": return "positive"
138
  return "negative"
139
 
140
  def generate_scientific_svg(target_class):
141
+ # 先转为标准标签
142
+ std_target = clean_label_name(target_class)
143
+ gram_type = infer_gram_type(std_target)
144
 
145
+ # 状态判断 (使用标准名称)
146
+ is_sec = (std_target == "Extracellular")
147
+ is_om = (std_target == "Outer membrane")
148
+ is_peri = (std_target == "Periplasm")
149
+ is_cw = (std_target == "Cell wall")
150
+ is_im = (std_target == "Cytoplasmic membrane")
151
+ is_cyto = (std_target == "Cytoplasm")
152
 
153
  c = {
154
  'hl_stroke': '#D32F2F', 'hl_fill': '#FFEBEE', 'hl_text': '#B71C1C', 'hl_dot': '#D32F2F',
 
160
  cx, cy = 300, 210
161
  tx = 620
162
 
163
+ # 绘制主体
164
  shapes = ""
165
  if gram_type == 'negative':
166
+ # Outer Membrane
167
  col_om = c['hl_stroke'] if is_om else c['bg_stroke']
168
  fill_om = c['hl_fill'] if is_peri else c['bg_fill']
169
  w_om = "4" if is_om else "2"
170
  shapes += f'<rect x="{cx-200}" y="{cy-120}" width="400" height="240" rx="120" ry="120" fill="{fill_om}" stroke="{col_om}" stroke-width="{w_om}" />'
171
 
172
+ # Cell Wall
173
  col_cw = c['hl_stroke'] if is_cw else '#B0BEC5'
174
  w_cw = "3" if is_cw else "1.5"
175
  dash_cw = "0" if is_cw else "6,4"
176
  shapes += f'<rect x="{cx-170}" y="{cy-90}" width="340" height="180" rx="90" ry="90" fill="none" stroke="{col_cw}" stroke-width="{w_cw}" stroke-dasharray="{dash_cw}" />'
177
 
178
+ # Cytoplasmic Membrane (Inner)
179
  col_im = c['hl_stroke'] if is_im else c['bg_stroke']
180
  fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
181
  w_im = "4" if is_im else "2"
 
186
  "cw": (cx+170, cy), "im": (cx+140, cy+30), "cyto": (cx, cy)
187
  }
188
  else:
189
+ # Gram Positive (Thick Wall)
190
  col_cw = c['hl_stroke'] if is_cw else c['bg_stroke']
191
  fill_bg = c['hl_fill'] if is_peri else c['bg_fill']
192
  w_cw = "6" if is_cw else "4"
 
211
  sec_svg = ""
212
  if is_sec:
213
  sec_svg = f"""<g transform="translate({cx}, {cy-170})">
214
+ <text x="0" y="0" text-anchor="middle" fill="{c['hl_text']}" font-weight="bold" font-family="'Lato', sans-serif" font-size="14">EXTRACELLULAR</text>
215
  <path d="M 0 5 L 0 30" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" />
216
  </g>"""
217
 
218
+ # --- 标签列表 (使用您要求的标准名称) ---
219
  labels_config = [
220
  ("Extracellular", "sec", is_sec),
221
+ ("Outer membrane", "om", is_om),
222
  ("Periplasm", "peri", is_peri),
223
+ ("Cell wall", "cw", is_cw),
224
+ ("Cytoplasmic membrane", "im", is_im),
225
  ("Cytoplasm", "cyto", is_cyto)
226
  ]
227
  if gram_type == 'positive':
 
248
 
249
  label_svg += f"""
250
  <g>
251
+ <text x="{tx}" y="{ty}" fill="{col_txt}" font-weight="{w_txt}" font-size="14" font-family="'Lato', sans-serif">{text}</text>
252
  <path d="{path_d}" fill="none" stroke="{col_line}" stroke-width="{w_line}" />
253
  <circle cx="{ex}" cy="{ey}" r="{r_dot}" fill="{col_dot}" stroke="white" stroke-width="1" />
254
  </g>
 
256
 
257
  final_svg = f"""<svg id="{svg_id}" width="100%" height="100%" viewBox="0 0 800 420" xmlns="http://www.w3.org/2000/svg">
258
  <defs>
259
+ <style>@import url('https://fonts.googleapis.com/css2?family=Lato:wght@400;700&display=swap'); text {{ font-family: 'Lato', sans-serif; }}</style>
 
 
 
260
  <marker id="arrow_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"><polygon points="0 0, 10 3.5, 0 7" fill="{c['hl_stroke']}" /></marker>
261
  </defs>
262
  <rect width="800" height="420" fill="white" />
263
  {shapes}
264
  {sec_svg}
265
  {label_svg}
266
+ <text x="400" y="400" text-anchor="middle" font-family="'Lato', sans-serif" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {std_target}</text>
267
  </svg>"""
268
 
269
  html = f"""<div>{final_svg}
 
283
  return html
284
 
285
  # ==========================
286
+ # 4. Panel D: Wrapped Attention Heatmap (折行显示氨基酸)
287
  # ==========================
288
+ def draw_wrapped_attention_heatmap(weights, sequence, chars_per_line=60):
289
+ """
290
+ 绘制折行热图:每行显示固定数量的氨基酸,下方显示字母。
291
+ """
292
+ # 归一化权重 (0-1)
293
  if weights.max() > 0:
294
  weights = (weights - weights.min()) / (weights.max() - weights.min())
295
 
296
+ seq_len = len(sequence)
297
+ # 计算行数
298
+ num_rows = (seq_len + chars_per_line - 1) // chars_per_line
299
 
300
+ # 动态调整画布高度
301
+ fig_height = max(2, num_rows * 0.8)
302
+ fig, axes = plt.subplots(num_rows, 1, figsize=(10, fig_height), dpi=150)
303
 
304
+ # 如果只有一行,axes不是列表,强制转列表
305
+ if num_rows == 1:
306
+ axes = [axes]
307
+
308
+ # 字体设置
309
+ plt.rcParams['font.family'] = 'sans-serif'
310
+ plt.rcParams['font.sans-serif'] = ['Lato', 'monospace']
311
+
312
+ for i in range(num_rows):
313
+ ax = axes[i]
314
+ start_idx = i * chars_per_line
315
+ end_idx = min((i + 1) * chars_per_line, seq_len)
316
+
317
+ # 截取片段
318
+ sub_weights = weights[start_idx:end_idx]
319
+ sub_seq = sequence[start_idx:end_idx]
320
+ current_len = len(sub_seq)
321
+
322
+ # 补全最后一行以便绘图 (保持对齐)
323
+ display_weights = np.zeros((1, chars_per_line))
324
+ display_weights[0, :current_len] = sub_weights
325
+
326
+ # 绘制热图 (Reds)
327
+ im = ax.imshow(display_weights, cmap='Reds', aspect='auto', vmin=0, vmax=1)
328
+
329
+ # 在每个格子上写氨基酸字母
330
+ for j, char in enumerate(sub_seq):
331
+ # 文字颜色根据背景深浅调整 (这里简化为黑色,因为权重一般不会全黑)
332
+ ax.text(j, 0, char, ha='center', va='center', fontsize=9, color='black', fontweight='bold')
333
+
334
+ # 设置 X 轴 (不显示刻度,只显示格子边界)
335
+ ax.set_xticks(np.arange(chars_per_line) - 0.5, minor=True)
336
+ ax.set_yticks([])
337
+ ax.grid(which="minor", color="w", linestyle='-', linewidth=1)
338
+ ax.tick_params(which="minor", bottom=False, left=False)
339
+ ax.tick_params(which="major", bottom=False, left=False, labelbottom=False)
340
+
341
+ # 隐藏边框
342
+ for spine in ax.spines.values():
343
+ spine.set_visible(False)
344
+
345
+ # 左侧添加行号索引 (1, 61, 121...)
346
+ ax.set_ylabel(f"{start_idx+1}", rotation=0, ha='right', va='center', fontsize=10, color='#546E7A')
347
+
348
  plt.tight_layout()
349
+ # 顶部标题
350
+ fig.suptitle(f"Attention Heatmap with Sequence ({seq_len} residues)", fontsize=12, fontweight='bold', color='#37474F', y=1.02)
351
  return fig
352
 
353
  # ==========================
 
364
  logits, pooling_weights = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
365
  probs = F.softmax(logits, dim=1)[0]
366
 
367
+ # 获取原始预测 ID
368
+ top_id = torch.max(probs, dim=0)[1].item()
369
+ # 获取原始标签文本 (如 OuterMembrane)
370
+ raw_label = idx_to_label[top_id]
371
+
372
+ # 转换为标准显示文本 (如 Outer membrane)
373
+ clean_top_label = clean_label_name(raw_label)
374
 
375
+ # 构建置信度字典 (全部转换为标准名称)
376
+ confidences = {}
377
+ for i, p in enumerate(probs):
378
+ orig_name = idx_to_label[i]
379
+ std_name = clean_label_name(orig_name)
380
+ confidences[std_name] = float(p)
381
+
382
+ # 绘图 (传入标准名称)
383
+ svg = generate_scientific_svg(clean_top_label)
384
+
385
+ # 绘制折行热图
386
+ heatmap = draw_wrapped_attention_heatmap(pooling_weights[0].cpu().numpy(), seq, chars_per_line=60)
387
 
388
  return confidences, svg, heatmap
389
 
390
  # ==========================
391
+ # 6. UI Layout (UniProt Font Style)
392
  # ==========================
393
  layout_css = """
394
  @import url('https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap');
395
+ body, button, input, textarea, .gradio-container { font-family: 'Lato', sans-serif !important; }
 
 
 
 
396
 
397
  .header-div { background: linear-gradient(to right, #E0F7FA, #E1F5FE); padding: 1.5rem; border-radius: 8px; margin-bottom: 20px; text-align: center; border: 1px solid #B3E5FC; }
398
+ .header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
399
  .header-sub { font-size: 1.0rem; color: #0277BD; }
400
  .panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
401
  .panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
402
  .panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
403
  """
404
 
405
+ theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
 
 
 
 
406
 
407
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
408
  gr.HTML("""<div class="header-div"><div class="header-title">LocPred-Prok</div><div class="header-sub">Deep Learning Framework for Prokaryotic Subcellular Localization</div></div>""")
 
426
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
427
 
428
  with gr.Column(elem_classes="panel-card"):
429
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Attention Heatmap (Sequence Weights)</div>")
430
  output_plot = gr.Plot(label="Attention", show_label=False)
431
 
432
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])