wangleiofficial commited on
Commit
1328b24
·
verified ·
1 Parent(s): 6a80b51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -98
app.py CHANGED
@@ -85,64 +85,36 @@ classifier.eval()
85
  print("✅ Ready.")
86
 
87
  # ==========================
88
- # 3. 标签标准化映射 (核心修正)
89
  # ==========================
90
  def clean_label_name(raw_label):
91
- """
92
- 将模型原始输出的各类写法(如 OuterMembrane, Cytoplasmic)
93
- 统一映射为您要求的 6 个标准显示名称。
94
- """
95
  raw = raw_label.strip()
96
-
97
- # 映射字典:Key 是模型可能的输出,Value 是标准显示名称
98
  mapping = {
99
- # 1. Outer membrane
100
  "OuterMembrane": "Outer membrane", "Outer membrane": "Outer membrane",
101
-
102
- # 2. Periplasm
103
  "Periplasmic": "Periplasm", "Periplasm": "Periplasm",
104
-
105
- # 3. Cell wall
106
  "Cellwall": "Cell wall", "Cell wall": "Cell wall",
107
-
108
- # 4. Cytoplasmic membrane (即 Inner Membrane)
109
- "CYtoplasmicMembrane": "Cytoplasmic membrane", "Cytoplasmic membrane": "Cytoplasmic membrane",
110
- "InnerMembrane": "Cytoplasmic membrane", "Inner membrane": "Cytoplasmic membrane",
111
-
112
- # 5. Cytoplasm
113
  "Cytoplasmic": "Cytoplasm", "Cytoplasm": "Cytoplasm",
114
-
115
- # 6. Extracellular
116
  "Extracellular": "Extracellular", "Secreted": "Extracellular"
117
  }
118
-
119
- # 尝试直接匹配
120
- if raw in mapping:
121
- return mapping[raw]
122
-
123
- # 尝试忽略大小写匹配
124
  raw_lower = raw.lower()
125
  for k, v in mapping.items():
126
- if k.lower() == raw_lower:
127
- return v
128
-
129
- return raw # 如果没匹配上,返回原样
130
 
131
  # ==========================
132
- # 4. SVG 引擎 (适配标准标签)
133
  # ==========================
134
  def infer_gram_type(std_label):
135
- """基于标准标签推断"""
136
  if std_label in ["Outer membrane", "Periplasm"]: return "negative"
137
  if std_label == "Cell wall": return "positive"
138
  return "negative"
139
 
140
  def generate_scientific_svg(target_class):
141
- # 先转为标准标签
142
  std_target = clean_label_name(target_class)
143
  gram_type = infer_gram_type(std_target)
144
 
145
- # 状态判断 (使用标准名称)
146
  is_sec = (std_target == "Extracellular")
147
  is_om = (std_target == "Outer membrane")
148
  is_peri = (std_target == "Periplasm")
@@ -160,22 +132,18 @@ def generate_scientific_svg(target_class):
160
  cx, cy = 300, 210
161
  tx = 620
162
 
163
- # 绘制主体
164
  shapes = ""
165
  if gram_type == 'negative':
166
- # Outer Membrane
167
  col_om = c['hl_stroke'] if is_om else c['bg_stroke']
168
  fill_om = c['hl_fill'] if is_peri else c['bg_fill']
169
  w_om = "4" if is_om else "2"
170
  shapes += f'<rect x="{cx-200}" y="{cy-120}" width="400" height="240" rx="120" ry="120" fill="{fill_om}" stroke="{col_om}" stroke-width="{w_om}" />'
171
 
172
- # Cell Wall
173
  col_cw = c['hl_stroke'] if is_cw else '#B0BEC5'
174
  w_cw = "3" if is_cw else "1.5"
175
  dash_cw = "0" if is_cw else "6,4"
176
  shapes += f'<rect x="{cx-170}" y="{cy-90}" width="340" height="180" rx="90" ry="90" fill="none" stroke="{col_cw}" stroke-width="{w_cw}" stroke-dasharray="{dash_cw}" />'
177
 
178
- # Cytoplasmic Membrane (Inner)
179
  col_im = c['hl_stroke'] if is_im else c['bg_stroke']
180
  fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
181
  w_im = "4" if is_im else "2"
@@ -186,7 +154,6 @@ def generate_scientific_svg(target_class):
186
  "cw": (cx+170, cy), "im": (cx+140, cy+30), "cyto": (cx, cy)
187
  }
188
  else:
189
- # Gram Positive (Thick Wall)
190
  col_cw = c['hl_stroke'] if is_cw else c['bg_stroke']
191
  fill_bg = c['hl_fill'] if is_peri else c['bg_fill']
192
  w_cw = "6" if is_cw else "4"
@@ -215,7 +182,6 @@ def generate_scientific_svg(target_class):
215
  <path d="M 0 5 L 0 30" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" />
216
  </g>"""
217
 
218
- # --- 标签列表 (使用您要求的标准名称) ---
219
  labels_config = [
220
  ("Extracellular", "sec", is_sec),
221
  ("Outer membrane", "om", is_om),
@@ -266,88 +232,49 @@ def generate_scientific_svg(target_class):
266
  <text x="400" y="400" text-anchor="middle" font-family="'Lato', sans-serif" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {std_target}</text>
267
  </svg>"""
268
 
269
- html = f"""<div>{final_svg}
270
- <div style="display:flex; justify-content:center; gap:10px; margin-top:5px;">
271
- <button onclick="downloadSVG('{svg_id}')" style="font-size:11px; padding:4px 8px; border:1px solid #ccc; border-radius:4px; cursor:pointer; font-family:'Lato', sans-serif;">Download SVG</button>
272
- </div>
273
- <script>
274
- function downloadSVG(id) {{
275
- const svg = document.getElementById(id);
276
- const s = new XMLSerializer().serializeToString(svg);
277
- const b = new Blob([s], {{type: "image/svg+xml;charset=utf-8"}});
278
- const u = URL.createObjectURL(b);
279
- const a = document.createElement("a"); a.href = u; a.download = "cell_loc.svg";
280
- document.body.appendChild(a); a.click(); document.body.removeChild(a);
281
- }}
282
- </script></div>"""
283
  return html
284
 
285
  # ==========================
286
- # 4. Panel D: Wrapped Attention Heatmap (折行显示氨基酸)
287
  # ==========================
288
  def draw_wrapped_attention_heatmap(weights, sequence, chars_per_line=60):
289
- """
290
- 绘制折行热图:每行显示固定数量的氨基酸,下方显示字母。
291
- """
292
- # 归一化权重 (0-1)
293
- if weights.max() > 0:
294
- weights = (weights - weights.min()) / (weights.max() - weights.min())
295
-
296
  seq_len = len(sequence)
297
- # 计算行数
298
  num_rows = (seq_len + chars_per_line - 1) // chars_per_line
299
-
300
- # 动态调整画布高度
301
  fig_height = max(2, num_rows * 0.8)
302
  fig, axes = plt.subplots(num_rows, 1, figsize=(10, fig_height), dpi=150)
303
-
304
- # 如果只有一行,axes不是列表,强制转列表
305
- if num_rows == 1:
306
- axes = [axes]
307
-
308
- # 字体设置
309
  plt.rcParams['font.family'] = 'sans-serif'
310
  plt.rcParams['font.sans-serif'] = ['Lato', 'monospace']
 
311
 
312
  for i in range(num_rows):
313
  ax = axes[i]
314
  start_idx = i * chars_per_line
315
  end_idx = min((i + 1) * chars_per_line, seq_len)
316
-
317
- # 截取片段
318
  sub_weights = weights[start_idx:end_idx]
319
  sub_seq = sequence[start_idx:end_idx]
320
  current_len = len(sub_seq)
321
 
322
- # 补全最后一行以便绘图 (保持对齐)
323
  display_weights = np.zeros((1, chars_per_line))
324
  display_weights[0, :current_len] = sub_weights
325
 
326
- # 绘制热图 (Reds)
327
  im = ax.imshow(display_weights, cmap='Reds', aspect='auto', vmin=0, vmax=1)
328
 
329
- # 在每个格子上写氨基酸字母
330
  for j, char in enumerate(sub_seq):
331
- # 文字颜色根据背景深浅调整 (这里简化为黑色,因为权重一般不会全黑)
332
  ax.text(j, 0, char, ha='center', va='center', fontsize=9, color='black', fontweight='bold')
333
 
334
- # 设置 X 轴 (不显示刻度,只显示格子边界)
335
  ax.set_xticks(np.arange(chars_per_line) - 0.5, minor=True)
336
  ax.set_yticks([])
337
  ax.grid(which="minor", color="w", linestyle='-', linewidth=1)
338
  ax.tick_params(which="minor", bottom=False, left=False)
339
  ax.tick_params(which="major", bottom=False, left=False, labelbottom=False)
340
-
341
- # 隐藏边框
342
- for spine in ax.spines.values():
343
- spine.set_visible(False)
344
-
345
- # 左侧添加行号索引 (1, 61, 121...)
346
  ax.set_ylabel(f"{start_idx+1}", rotation=0, ha='right', va='center', fontsize=10, color='#546E7A')
347
 
348
  plt.tight_layout()
349
- # 顶部标题
350
- fig.suptitle(f"Attention Heatmap with Sequence ({seq_len} residues)", fontsize=12, fontweight='bold', color='#37474F', y=1.02)
351
  return fig
352
 
353
  # ==========================
@@ -364,39 +291,50 @@ def predict(sequence_input):
364
  logits, pooling_weights = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
365
  probs = F.softmax(logits, dim=1)[0]
366
 
367
- # 获取原始预测 ID
368
  top_id = torch.max(probs, dim=0)[1].item()
369
- # 获取原始标签文本 (如 OuterMembrane)
370
  raw_label = idx_to_label[top_id]
371
-
372
- # 转换为标准显示文本 (如 Outer membrane)
373
  clean_top_label = clean_label_name(raw_label)
374
 
375
- # 构建置信度字典 (全部转换为标准名称)
376
  confidences = {}
377
  for i, p in enumerate(probs):
378
  orig_name = idx_to_label[i]
379
  std_name = clean_label_name(orig_name)
380
  confidences[std_name] = float(p)
381
 
382
- # 绘图 (传入标准名称)
383
  svg = generate_scientific_svg(clean_top_label)
384
-
385
- # 绘制折行热图
386
  heatmap = draw_wrapped_attention_heatmap(pooling_weights[0].cpu().numpy(), seq, chars_per_line=60)
387
 
388
  return confidences, svg, heatmap
389
 
390
  # ==========================
391
- # 6. UI Layout (UniProt Font Style)
392
  # ==========================
393
  layout_css = """
394
  @import url('https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap');
395
  body, button, input, textarea, .gradio-container { font-family: 'Lato', sans-serif !important; }
396
 
397
- .header-div { background: linear-gradient(to right, #E0F7FA, #E1F5FE); padding: 1.5rem; border-radius: 8px; margin-bottom: 20px; text-align: center; border: 1px solid #B3E5FC; }
 
 
 
 
 
398
  .header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
399
- .header-sub { font-size: 1.0rem; color: #0277BD; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  .panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
401
  .panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
402
  .panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
@@ -405,7 +343,28 @@ body, button, input, textarea, .gradio-container { font-family: 'Lato', sans-ser
405
  theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
406
 
407
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
408
- gr.HTML("""<div class="header-div"><div class="header-title">LocPred-Prok</div><div class="header-sub">Deep Learning Framework for Prokaryotic Subcellular Localization</div></div>""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  with gr.Row():
411
  with gr.Column(elem_classes="panel-card"):
 
85
  print("✅ Ready.")
86
 
87
  # ==========================
88
+ # 3. 标签标准化映射
89
  # ==========================
90
  def clean_label_name(raw_label):
 
 
 
 
91
  raw = raw_label.strip()
 
 
92
  mapping = {
 
93
  "OuterMembrane": "Outer membrane", "Outer membrane": "Outer membrane",
 
 
94
  "Periplasmic": "Periplasm", "Periplasm": "Periplasm",
 
 
95
  "Cellwall": "Cell wall", "Cell wall": "Cell wall",
96
+ "CYtoplasmicMembrane": "Cytoplasmic membrane", "InnerMembrane": "Cytoplasmic membrane",
 
 
 
 
 
97
  "Cytoplasmic": "Cytoplasm", "Cytoplasm": "Cytoplasm",
 
 
98
  "Extracellular": "Extracellular", "Secreted": "Extracellular"
99
  }
100
+ if raw in mapping: return mapping[raw]
 
 
 
 
 
101
  raw_lower = raw.lower()
102
  for k, v in mapping.items():
103
+ if k.lower() == raw_lower: return v
104
+ return raw
 
 
105
 
106
  # ==========================
107
+ # 4. SVG 引擎 (纯净展示版 - 无下载按钮)
108
  # ==========================
109
  def infer_gram_type(std_label):
 
110
  if std_label in ["Outer membrane", "Periplasm"]: return "negative"
111
  if std_label == "Cell wall": return "positive"
112
  return "negative"
113
 
114
  def generate_scientific_svg(target_class):
 
115
  std_target = clean_label_name(target_class)
116
  gram_type = infer_gram_type(std_target)
117
 
 
118
  is_sec = (std_target == "Extracellular")
119
  is_om = (std_target == "Outer membrane")
120
  is_peri = (std_target == "Periplasm")
 
132
  cx, cy = 300, 210
133
  tx = 620
134
 
 
135
  shapes = ""
136
  if gram_type == 'negative':
 
137
  col_om = c['hl_stroke'] if is_om else c['bg_stroke']
138
  fill_om = c['hl_fill'] if is_peri else c['bg_fill']
139
  w_om = "4" if is_om else "2"
140
  shapes += f'<rect x="{cx-200}" y="{cy-120}" width="400" height="240" rx="120" ry="120" fill="{fill_om}" stroke="{col_om}" stroke-width="{w_om}" />'
141
 
 
142
  col_cw = c['hl_stroke'] if is_cw else '#B0BEC5'
143
  w_cw = "3" if is_cw else "1.5"
144
  dash_cw = "0" if is_cw else "6,4"
145
  shapes += f'<rect x="{cx-170}" y="{cy-90}" width="340" height="180" rx="90" ry="90" fill="none" stroke="{col_cw}" stroke-width="{w_cw}" stroke-dasharray="{dash_cw}" />'
146
 
 
147
  col_im = c['hl_stroke'] if is_im else c['bg_stroke']
148
  fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
149
  w_im = "4" if is_im else "2"
 
154
  "cw": (cx+170, cy), "im": (cx+140, cy+30), "cyto": (cx, cy)
155
  }
156
  else:
 
157
  col_cw = c['hl_stroke'] if is_cw else c['bg_stroke']
158
  fill_bg = c['hl_fill'] if is_peri else c['bg_fill']
159
  w_cw = "6" if is_cw else "4"
 
182
  <path d="M 0 5 L 0 30" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" />
183
  </g>"""
184
 
 
185
  labels_config = [
186
  ("Extracellular", "sec", is_sec),
187
  ("Outer membrane", "om", is_om),
 
232
  <text x="400" y="400" text-anchor="middle" font-family="'Lato', sans-serif" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {std_target}</text>
233
  </svg>"""
234
 
235
+ # 纯净 HTML,无按钮
236
+ html = f"<div style='text-align:center;'>{final_svg}</div>"
 
 
 
 
 
 
 
 
 
 
 
 
237
  return html
238
 
239
  # ==========================
240
+ # 4. Wrapped Attention Heatmap
241
  # ==========================
242
  def draw_wrapped_attention_heatmap(weights, sequence, chars_per_line=60):
243
+ if weights.max() > 0: weights = (weights - weights.min()) / (weights.max() - weights.min())
 
 
 
 
 
 
244
  seq_len = len(sequence)
 
245
  num_rows = (seq_len + chars_per_line - 1) // chars_per_line
 
 
246
  fig_height = max(2, num_rows * 0.8)
247
  fig, axes = plt.subplots(num_rows, 1, figsize=(10, fig_height), dpi=150)
 
 
 
 
 
 
248
  plt.rcParams['font.family'] = 'sans-serif'
249
  plt.rcParams['font.sans-serif'] = ['Lato', 'monospace']
250
+ if num_rows == 1: axes = [axes]
251
 
252
  for i in range(num_rows):
253
  ax = axes[i]
254
  start_idx = i * chars_per_line
255
  end_idx = min((i + 1) * chars_per_line, seq_len)
 
 
256
  sub_weights = weights[start_idx:end_idx]
257
  sub_seq = sequence[start_idx:end_idx]
258
  current_len = len(sub_seq)
259
 
 
260
  display_weights = np.zeros((1, chars_per_line))
261
  display_weights[0, :current_len] = sub_weights
262
 
 
263
  im = ax.imshow(display_weights, cmap='Reds', aspect='auto', vmin=0, vmax=1)
264
 
 
265
  for j, char in enumerate(sub_seq):
 
266
  ax.text(j, 0, char, ha='center', va='center', fontsize=9, color='black', fontweight='bold')
267
 
 
268
  ax.set_xticks(np.arange(chars_per_line) - 0.5, minor=True)
269
  ax.set_yticks([])
270
  ax.grid(which="minor", color="w", linestyle='-', linewidth=1)
271
  ax.tick_params(which="minor", bottom=False, left=False)
272
  ax.tick_params(which="major", bottom=False, left=False, labelbottom=False)
273
+ for spine in ax.spines.values(): spine.set_visible(False)
 
 
 
 
 
274
  ax.set_ylabel(f"{start_idx+1}", rotation=0, ha='right', va='center', fontsize=10, color='#546E7A')
275
 
276
  plt.tight_layout()
277
+ fig.suptitle(f"Attention Heatmap (Sequence Length: {seq_len})", fontsize=12, fontweight='bold', color='#37474F', y=1.02)
 
278
  return fig
279
 
280
  # ==========================
 
291
  logits, pooling_weights = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
292
  probs = F.softmax(logits, dim=1)[0]
293
 
 
294
  top_id = torch.max(probs, dim=0)[1].item()
 
295
  raw_label = idx_to_label[top_id]
 
 
296
  clean_top_label = clean_label_name(raw_label)
297
 
 
298
  confidences = {}
299
  for i, p in enumerate(probs):
300
  orig_name = idx_to_label[i]
301
  std_name = clean_label_name(orig_name)
302
  confidences[std_name] = float(p)
303
 
 
304
  svg = generate_scientific_svg(clean_top_label)
 
 
305
  heatmap = draw_wrapped_attention_heatmap(pooling_weights[0].cpu().numpy(), seq, chars_per_line=60)
306
 
307
  return confidences, svg, heatmap
308
 
309
  # ==========================
310
+ # 6. UI Layout (Enhanced Header)
311
  # ==========================
312
  layout_css = """
313
  @import url('https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap');
314
  body, button, input, textarea, .gradio-container { font-family: 'Lato', sans-serif !important; }
315
 
316
+ /* Header 样式 */
317
+ .header-div {
318
+ background: linear-gradient(to right, #E0F7FA, #E1F5FE);
319
+ padding: 1.5rem; border-radius: 8px; margin-bottom: 20px;
320
+ text-align: center; border: 1px solid #B3E5FC;
321
+ }
322
  .header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
323
+ .header-sub { font-size: 1.0rem; color: #0277BD; margin-bottom: 12px; }
324
+
325
+ /* Badge 链接样式 */
326
+ .badge-container { display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; }
327
+ .badge-link {
328
+ text-decoration: none; display: inline-flex; align-items: center;
329
+ background-color: #ffffff; color: #334155;
330
+ padding: 4px 10px; border-radius: 6px;
331
+ font-size: 0.85rem; font-weight: 600;
332
+ border: 1px solid #cbd5e1; transition: all 0.2s;
333
+ }
334
+ .badge-link:hover { background-color: #f1f5f9; border-color: #0288D1; color: #0288D1; }
335
+ .badge-icon { margin-right: 5px; }
336
+
337
+ /* Panel 样式 */
338
  .panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
339
  .panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
340
  .panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
 
343
  theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
344
 
345
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
346
+
347
+ # --- Enhanced Header ---
348
+ gr.HTML("""
349
+ <div class="header-div">
350
+ <div class="header-title">LocPred-Prok</div>
351
+ <div class="header-sub">Dual-Branch Deep Learning for Prokaryotic Subcellular Localization</div>
352
+ <div class="badge-container">
353
+ <a href="https://github.com/isyslab-hust/LocPred-Prok" target="_blank" class="badge-link">
354
+ GitHub
355
+ </a>
356
+ <a href="#" target="_blank" class="badge-link">
357
+ Paper
358
+ </a>
359
+ <span class="badge-link" style="cursor:default">
360
+ 🧬 ESM-2 Enhanced
361
+ </span>
362
+ <span class="badge-link" style="cursor:default">
363
+ ⚖️ MIT License
364
+ </span>
365
+ </div>
366
+ </div>
367
+ """)
368
 
369
  with gr.Row():
370
  with gr.Column(elem_classes="panel-card"):