wangleiofficial commited on
Commit
b8514fa
·
verified ·
1 Parent(s): 68c3b40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -163
app.py CHANGED
@@ -3,10 +3,14 @@ import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import gradio as gr
 
 
 
 
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  # ==========================
9
- # 0-3. 基础设置与模型定义 (保持你的核心逻辑不变)
10
  # ==========================
11
  os.environ["HF_HOME"] = "/tmp/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
@@ -78,7 +82,99 @@ classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
78
  classifier.eval()
79
  print("✅ Ready.")
80
 
81
- # --- 预测逻辑 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def predict(sequence_input):
83
  if not sequence_input or sequence_input.isspace():
84
  raise gr.Error("Please input a sequence.")
@@ -92,201 +188,95 @@ def predict(sequence_input):
92
  logits = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
93
  probs = F.softmax(logits, dim=1)[0]
94
 
95
- return {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
 
 
 
 
 
 
 
 
 
 
96
 
97
  # ==========================
98
- # 4. 旗舰版 UI (Rich & Modern)
99
  # ==========================
100
 
101
- # CSS:结合了学术严谨性和现代视觉
102
  flagship_css = """
103
  @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;600;700&display=swap');
104
-
105
  body { font-family: 'IBM Plex Sans', sans-serif !important; background-color: #f0f2f5; }
106
-
107
- /* 标题区域 */
108
  .header-box {
109
- background: linear-gradient(120deg, #0284c7 0%, #2563eb 100%);
110
- color: white;
111
- padding: 2rem;
112
- border-radius: 12px;
113
- margin-bottom: 1.5rem;
114
- box-shadow: 0 10px 15px -3px rgba(37, 99, 235, 0.2);
115
  }
116
- .header-title { font-size: 2.2rem; font-weight: 700; letter-spacing: -0.5px; }
117
- .header-badges { display: flex; gap: 10px; margin-top: 10px; flex-wrap: wrap; }
118
- .badge {
119
- background: rgba(255,255,255,0.2);
120
- padding: 4px 12px;
121
- border-radius: 99px;
122
- font-size: 0.85rem;
123
- backdrop-filter: blur(4px);
124
- border: 1px solid rgba(255,255,255,0.3);
125
- }
126
-
127
- /* 内容卡片 */
128
- .content-box {
129
- background: white;
130
- padding: 1.5rem;
131
- border-radius: 12px;
132
- border: 1px solid #e5e7eb;
133
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05);
134
- }
135
-
136
- /* 表格美化 */
137
- table { width: 100%; border-collapse: collapse; font-size: 0.9rem; }
138
- th { text-align: left; padding: 12px; background: #f8fafc; color: #475569; border-bottom: 2px solid #e2e8f0; }
139
- td { padding: 12px; border-bottom: 1px solid #e2e8f0; color: #1e293b; }
140
- tr:last-child td { border-bottom: none; }
141
  """
142
 
143
- theme = gr.themes.Soft(
144
- primary_hue="blue",
145
- radius_size="md",
146
- font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui"]
147
- )
148
 
149
  with gr.Blocks(theme=theme, css=flagship_css, title="LocPred-Prok") as app:
150
 
151
- # --- Header ---
152
  with gr.Column(elem_classes="header-box"):
153
  gr.HTML("""
154
  <div class="header-title">LocPred-Prok</div>
155
- <div style="opacity: 0.9; font-size: 1.1rem; margin-bottom: 1rem;">
156
- State-of-the-Art Prokaryotic Subcellular Localization Prediction
157
- </div>
158
- <div class="header-badges">
159
  <span class="badge">🧬 ESM-2 Enhanced</span>
160
- <span class="badge">🚀 Dual-Branch Architecture</span>
161
- <span class="badge">🏆 91.2% Accuracy</span>
162
- <span class="badge">🎯 MCC 0.889</span>
163
  </div>
164
  """)
165
 
166
  with gr.Tabs():
167
-
168
- # === TAB 1: Predict (功能区) ===
169
  with gr.TabItem("🚀 Predict", id="predict"):
170
  with gr.Row():
171
-
172
- # 左侧:输入 + 示例
173
- with gr.Column(scale=3, elem_classes="content-box"):
174
- gr.Markdown("### 📥 Sequence Input")
175
- gr.Markdown("Enter a protein sequence (FASTA format supported).")
176
-
177
- sequence_input = gr.Textbox(
178
- lines=10,
179
- placeholder=">Header\nMKFKLTAGCLAVAGVLLASSFGAD...",
180
- show_label=False
181
- )
182
-
183
  with gr.Row():
184
- clear_btn = gr.ClearButton(sequence_input, value="Clear Input")
185
- submit_btn = gr.Button("✨ Run Prediction", variant="primary", scale=2)
186
-
187
- # ✅ 示例回归:这对用户极其重要
188
- gr.Markdown("### 💡 Quick Examples")
189
  gr.Examples(
190
  examples=[
191
- [">Gram-negative Outer Membrane\nMSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
192
- [">Gram-positive Cell Wall\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
193
- [">Cytoplasmic Protein\nMAKQDYYEILGVSKTAEEREIRKAYKRLAMKYHPDRNQGDKEAEAKFKEIKEAYEVLTDSQKRAAYDQYGHAAFEQGPE"],
194
  ],
195
- inputs=sequence_input,
196
- label=None
197
  )
198
 
199
- # 右侧:输出 + 简要说明
200
- with gr.Column(scale=2, elem_classes="content-box"):
201
- gr.Markdown("### 📊 Analysis Result")
202
-
203
- output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
204
 
205
- gr.HTML("""
206
- <div style="background: #eff6ff; padding: 15px; border-radius: 8px; margin-top: 20px; border-left: 4px solid #3b82f6;">
207
- <h4 style="margin: 0 0 5px 0; color: #1e40af;">Performance Note</h4>
208
- <p style="margin: 0; font-size: 0.9rem; color: #1e3a8a;">
209
- This model excels at distinguishing <b>Outer Membrane</b> and <b>Cell Wall</b> proteins,
210
- outperforming traditional methods by utilizing deep semantic features from ESM-2.
211
- </p>
212
- </div>
213
- """)
214
 
215
- # === TAB 2: Model Details (学术区) ===
216
- with gr.TabItem("📈 Model Performance", id="stats"):
217
- with gr.Row():
218
- with gr.Column(elem_classes="content-box"):
219
- gr.Markdown("### 🔬 Why LocPred-Prok?")
220
- gr.Markdown("""
221
- Existing predictors often struggle with "Hard Classes" like Cell Wall and Outer Membrane proteins.
222
- **LocPred-Prok** solves this by fusing:
223
- 1. **Global Semantics:** From the pre-trained `ESM-2-150M` model.
224
- 2. **Local Motifs:** Captured by our custom CNN + Attention pooling branch.
225
- """)
226
-
227
- # ✅ 找回数据表格:增加专业度
228
  gr.HTML("""
229
- <h3>Comparative Performance (Homology Partitioned)</h3>
230
- <table>
231
- <thead>
232
- <tr>
233
- <th>Method</th>
234
- <th>Accuracy</th>
235
- <th>MCC (Overall)</th>
236
- <th>Outer Membrane MCC</th>
237
- </tr>
238
- </thead>
239
- <tbody>
240
- <tr style="background-color: #f0fdf4; font-weight: bold;">
241
- <td>✨ LocPred-Prok (Ours)</td>
242
- <td>91.2%</td>
243
- <td>0.889</td>
244
- <td>0.910</td>
245
- </tr>
246
- <tr>
247
- <td>Standard ESM-2 Only</td>
248
- <td>89.5%</td>
249
- <td>0.865</td>
250
- <td>0.872</td>
251
- </tr>
252
- <tr>
253
- <td>DeepLoc 2.0 (Prok)</td>
254
- <td>87.1%</td>
255
- <td>0.840</td>
256
- <td>0.855</td>
257
- </tr>
258
- </tbody>
259
- </table>
260
- <p style="margin-top: 10px; font-size: 0.8rem; color: #666;">* Benchmarked on strict homology-reduced datasets.</p>
261
  """)
262
 
263
- # === TAB 3: Citation (引用区) ===
264
- with gr.TabItem("📝 Citation", id="cite"):
265
- with gr.Column(elem_classes="content-box"):
266
- gr.Markdown("### Cite This Work")
267
- gr.Markdown("If you find this tool useful, please cite our paper:")
268
- # 修复了 Code 组件的报错,去掉了 language="bibtex"
269
- gr.Code(
270
- value="""@article{LocPredProk2025,
271
- title={LocPred-Prok: Prokaryotic protein subcellular localization prediction with a dual-branch architecture and protein language model},
272
- author={Your Name and Co-authors},
273
- journal={Submitted to Bioinformatics},
274
- year={2025}
275
- }""",
276
- label="BibTeX",
277
- language=None,
278
- interactive=False
279
- )
280
-
281
- # --- Footer ---
282
- gr.HTML("""
283
- <div style="text-align: center; margin-top: 40px; color: #94a3b8; font-size: 0.85rem;">
284
- © 2025 iSysLab HUST • Powered by PyTorch & Hugging Face
285
- </div>
286
- """)
287
-
288
- # 逻辑绑定
289
- submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label)
290
- clear_btn.click(lambda: None, outputs=[output_label])
291
 
292
  app.launch()
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.patches as patches
8
+ from io import BytesIO
9
+ from PIL import Image
10
  from transformers import AutoTokenizer, AutoModel
11
 
12
  # ==========================
13
+ # 0-3. 基础设置与模型定义 (保持核心逻辑不变)
14
  # ==========================
15
  os.environ["HF_HOME"] = "/tmp/hf_cache"
16
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
 
82
  classifier.eval()
83
  print("✅ Ready.")
84
 
85
+ # ==========================
86
+ # 🆕 动态绘图函数 (绘制原核细胞)
87
+ # ==========================
88
+ def draw_prokaryotic_cell(target_class):
89
+ """
90
+ 根据预测类别绘制原核细胞结构,并高亮特定区域。
91
+ """
92
+ # 转换输入类别为小写以便匹配
93
+ target = target_class.lower() if target_class else ""
94
+
95
+ # 创建画布
96
+ fig, ax = plt.subplots(figsize=(6, 4.5), dpi=100)
97
+ ax.set_xlim(-1.5, 1.5)
98
+ ax.set_ylim(-1.2, 1.2)
99
+ ax.axis('off') # 隐藏坐标轴
100
+
101
+ # 默认颜色 (未激活状态)
102
+ colors = {
103
+ "bg": "#f8fafc",
104
+ "inactive": "#e2e8f0",
105
+ "text": "#64748b",
106
+ "highlight": "#ef4444", # 红色高亮
107
+ "highlight_fill": "#fee2e2"
108
+ }
109
+
110
+ # 定义各部分的状态 (Is Highlighted?)
111
+ # 根据你的具体 label_map.json 里的标签名称进行模糊匹配
112
+ is_extracellular = "extracellular" in target or "secreted" in target
113
+ is_outer_mem = "outer membrane" in target
114
+ is_periplasm = "periplasm" in target
115
+ is_cell_wall = "cell wall" in target
116
+ is_inner_mem = "plasma membrane" in target or "inner membrane" in target or "cytoplasmic membrane" in target
117
+ is_cytoplasm = "cytoplasm" in target or "cytosol" in target
118
+
119
+ # 1. 胞外区域 (Extracellular) - 用箭头或背景表示
120
+ if is_extracellular:
121
+ ax.text(0, 1.1, "Extracellular / Secreted", ha='center', va='center', fontsize=12, fontweight='bold', color=colors['highlight'])
122
+ # 画一些向外的箭头
123
+ ax.arrow(0, 0.9, 0, 0.2, head_width=0.05, head_length=0.05, fc=colors['highlight'], ec=colors['highlight'])
124
+ else:
125
+ ax.text(0, 1.1, "Extracellular Space", ha='center', va='center', fontsize=10, color=colors['text'])
126
+
127
+ # 2. 外膜 (Outer Membrane) - 最外层的圈
128
+ om_color = colors['highlight'] if is_outer_mem else "#94a3b8"
129
+ om_width = 4 if is_outer_mem else 2
130
+ om = patches.Ellipse((0, 0), 2.4, 1.6, fill=False, edgecolor=om_color, linewidth=om_width)
131
+ ax.add_patch(om)
132
+ ax.text(1.3, 0, "Outer Mem.", ha='left', va='center', fontsize=9, color=om_color)
133
+
134
+ # 3. 细胞壁 (Cell Wall) - 中间层
135
+ cw_color = colors['highlight'] if is_cell_wall else "#cbd5e1"
136
+ cw_width = 4 if is_cell_wall else 2
137
+ # 稍微向内一点
138
+ cw = patches.Ellipse((0, 0), 2.2, 1.45, fill=False, edgecolor=cw_color, linewidth=cw_width, linestyle='--')
139
+ ax.add_patch(cw)
140
+ ax.text(1.2, -0.4, "Cell Wall", ha='left', va='center', fontsize=9, color=cw_color)
141
+
142
+ # 4. 周质空间 (Periplasm) - 外膜和内膜之间
143
+ if is_periplasm:
144
+ peri = patches.Ellipse((0, 0), 2.3, 1.52, fill=False, edgecolor=colors['highlight'], linewidth=10, alpha=0.3)
145
+ ax.add_patch(peri)
146
+ ax.text(0, 0.85, "Periplasm", ha='center', va='center', fontsize=10, fontweight='bold', color=colors['highlight'])
147
+
148
+ # 5. 内膜 (Inner/Plasma Membrane)
149
+ im_color = colors['highlight'] if is_inner_mem else "#94a3b8"
150
+ im_width = 4 if is_inner_mem else 2
151
+ im = patches.Ellipse((0, 0), 2.0, 1.3, fill=False, edgecolor=im_color, linewidth=im_width)
152
+ ax.add_patch(im)
153
+ ax.text(1.1, -0.7, "Inner Mem.", ha='left', va='center', fontsize=9, color=im_color)
154
+
155
+ # 6. 胞质 (Cytoplasm) - 最里面填充
156
+ cyto_color = colors['highlight_fill'] if is_cytoplasm else "#f1f5f9"
157
+ cyto_text_color = colors['highlight'] if is_cytoplasm else colors['text']
158
+ cyto = patches.Ellipse((0, 0), 1.95, 1.25, facecolor=cyto_color, edgecolor='none')
159
+ ax.add_patch(cyto)
160
+
161
+ # 加上 DNA 示意图 (不管预测结果如何都在)
162
+ ax.text(0, -0.1, "Cytoplasm", ha='center', va='center', fontsize=10, fontweight='bold', color=cyto_text_color)
163
+ # 画一个简单的 DNA 团
164
+ plt.plot([-0.3, -0.1, 0.1, 0.3], [0.1, -0.1, 0.1, -0.1], color="#cbd5e1", linewidth=1)
165
+ ax.text(0, 0.2, "DNA", ha='center', va='bottom', fontsize=7, color="#cbd5e1")
166
+
167
+ # 保存为图片对象
168
+ buf = BytesIO()
169
+ plt.savefig(buf, format='png', bbox_inches='tight', transparent=True)
170
+ buf.seek(0)
171
+ img = Image.open(buf)
172
+ plt.close(fig)
173
+ return img
174
+
175
+ # ==========================
176
+ # 3. 预测逻辑 (更新,返回图片)
177
+ # ==========================
178
  def predict(sequence_input):
179
  if not sequence_input or sequence_input.isspace():
180
  raise gr.Error("Please input a sequence.")
 
188
  logits = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
189
  probs = F.softmax(logits, dim=1)[0]
190
 
191
+ # 获取最高概率的类别
192
+ top_prob, top_idx = torch.max(probs, dim=0)
193
+ top_label = idx_to_label[top_idx.item()]
194
+
195
+ # 生成置信度字典
196
+ confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
197
+
198
+ # 生成对应的细胞图
199
+ cell_diagram = draw_prokaryotic_cell(top_label)
200
+
201
+ return confidences, cell_diagram
202
 
203
  # ==========================
204
+ # 4. 旗舰版 UI (包含细胞可视化)
205
  # ==========================
206
 
 
207
  flagship_css = """
208
  @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;600;700&display=swap');
 
209
  body { font-family: 'IBM Plex Sans', sans-serif !important; background-color: #f0f2f5; }
 
 
210
  .header-box {
211
+ background: linear-gradient(135deg, #1e3a8a 0%, #3b82f6 100%);
212
+ color: white; padding: 2rem; border-radius: 12px; margin-bottom: 1.5rem;
213
+ box-shadow: 0 4px 15px rgba(37, 99, 235, 0.3);
 
 
 
214
  }
215
+ .header-title { font-size: 2.2rem; font-weight: 700; }
216
+ .badge { background: rgba(255,255,255,0.2); padding: 4px 12px; border-radius: 99px; font-size: 0.85rem; border: 1px solid rgba(255,255,255,0.3); }
217
+ .content-box { background: white; padding: 1.5rem; border-radius: 12px; border: 1px solid #e5e7eb; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.05); }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  """
219
 
220
+ theme = gr.themes.Soft(primary_hue="blue", font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui"])
 
 
 
 
221
 
222
  with gr.Blocks(theme=theme, css=flagship_css, title="LocPred-Prok") as app:
223
 
224
+ # Header
225
  with gr.Column(elem_classes="header-box"):
226
  gr.HTML("""
227
  <div class="header-title">LocPred-Prok</div>
228
+ <div style="opacity: 0.9; margin-bottom: 10px;">Prokaryotic Subcellular Localization Prediction</div>
229
+ <div>
 
 
230
  <span class="badge">🧬 ESM-2 Enhanced</span>
231
+ <span class="badge">🏆 SOTA Accuracy</span>
232
+ <span class="badge">👁️ Visual Interpretation</span>
 
233
  </div>
234
  """)
235
 
236
  with gr.Tabs():
 
 
237
  with gr.TabItem("🚀 Predict", id="predict"):
238
  with gr.Row():
239
+ # Input
240
+ with gr.Column(scale=4, elem_classes="content-box"):
241
+ gr.Markdown("### 📥 Input Sequence")
242
+ sequence_input = gr.Textbox(lines=8, placeholder=">Sequence...", show_label=False)
 
 
 
 
 
 
 
 
243
  with gr.Row():
244
+ clear_btn = gr.ClearButton(sequence_input, value="Clear")
245
+ submit_btn = gr.Button("✨ Predict & Visualize", variant="primary", scale=2)
246
+
247
+ gr.Markdown("#### Examples")
 
248
  gr.Examples(
249
  examples=[
250
+ [">Outer Membrane Protein\nMSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
251
+ [">Cytoplasmic Protein\nMAKQDYYEILGVSKTAEEREIRKAYKRLAMKYHPDRNQGDKEAEAKFKEIKEAYEVLTDSQKRAAYDQYGHAAFEQGPE"]
 
252
  ],
253
+ inputs=sequence_input, label=None
 
254
  )
255
 
256
+ # Output (Split into Charts and Visuals)
257
+ with gr.Column(scale=5, elem_classes="content-box"):
258
+ gr.Markdown("### 📊 Analysis Results")
 
 
259
 
260
+ with gr.Row():
261
+ # 左侧:概率条
262
+ with gr.Column(scale=1):
263
+ output_label = gr.Label(num_top_classes=4, show_label=False)
264
+
265
+ # 右侧:细胞可视化图
266
+ with gr.Column(scale=1):
267
+ output_image = gr.Image(label="Cellular Localization Map", show_label=True, show_download_button=False, interactive=False, type="pil")
 
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  gr.HTML("""
270
+ <div style="margin-top:10px; padding:10px; background:#f0f9ff; border-radius:8px; color:#0369a1; font-size:0.9rem;">
271
+ <b>Visualization:</b> The diagram on the right dynamically highlights the predicted localization site within a schematic prokaryotic cell.
272
+ </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  """)
274
 
275
+ # Other tabs (About/Cite) kept simple for brevity
276
+ with gr.TabItem("📖 About"):
277
+ gr.Markdown("### About LocPred-Prok\nThis tool uses a Dual-Branch architecture...")
278
+
279
+ submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_image])
280
+ clear_btn.click(lambda: [None, None], outputs=[output_label, output_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  app.launch()