wangleiofficial commited on
Commit
0c3ed7a
·
verified ·
1 Parent(s): f7d2100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -286
app.py CHANGED
@@ -5,119 +5,118 @@ import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import gradio as gr
 
 
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
  # ==========================
11
  # 0. 环境与缓存设置
12
  # ==========================
 
 
 
13
  os.environ["HF_HOME"] = "/tmp/hf_cache"
14
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
15
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
16
 
17
- # 清理旧缓存 (可选)
18
  import shutil
19
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
20
  shutil.rmtree(path, ignore_errors=True)
21
  os.makedirs(path, exist_ok=True)
22
 
23
  # ==========================
24
- # 1. 模型架构定义
25
  # ==========================
26
  class AttentionPooling(nn.Module):
27
- """Attention Pooling Layer"""
28
  def __init__(self, d_model):
29
  super().__init__()
30
  self.attention_net = nn.Linear(d_model, 1)
31
 
32
  def forward(self, x, mask):
33
- attn_logits = self.attention_net(x).squeeze(2)
 
34
  attn_logits.masked_fill_(mask == 0, -float('inf'))
35
  attn_weights = F.softmax(attn_logits, dim=1)
36
- return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
 
 
 
37
 
38
  class ProtDualBranchEnhancedClassifier(nn.Module):
39
- """Enhanced dual-branch model architecture"""
40
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
41
  super().__init__()
42
  self.cls_projector = nn.Linear(d_model, projection_dim)
43
  self.token_refiner = nn.Sequential(
44
- nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
45
  nn.ReLU()
46
  )
47
  self.attention_pooling = AttentionPooling(d_model)
48
  self.tok_projector = nn.Linear(d_model, projection_dim)
49
  fused_dim = projection_dim * 2
50
- self.gate = nn.Sequential(
51
- nn.Linear(fused_dim, fused_dim),
52
- nn.Sigmoid()
53
- )
54
  self.classifier_head = nn.Sequential(
55
- nn.LayerNorm(fused_dim),
56
- nn.Linear(fused_dim, fused_dim * 2),
57
- nn.ReLU(),
58
- nn.Dropout(dropout),
59
  nn.Linear(fused_dim * 2, num_classes)
60
  )
61
 
62
  def forward(self, cls_embedding, token_embeddings, mask):
 
63
  z_cls = self.cls_projector(cls_embedding)
 
 
64
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
65
  refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
66
- z_tok_pooled = self.attention_pooling(refined_tok_emb, mask)
 
 
67
  z_tok = self.tok_projector(z_tok_pooled)
 
 
68
  z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
69
  gate_values = self.gate(z_fused_concat)
70
  z_fused_gated = z_fused_concat * gate_values
71
- return self.classifier_head(z_fused_gated)
 
72
 
73
  # ==========================
74
- # 2. 加载模型与资源
75
  # ==========================
76
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
78
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
79
  LABEL_MAP_PATH = "label_map.json"
80
 
81
- # 文件存在性检查
82
- if not os.path.exists(LABEL_MAP_PATH):
83
- raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'. Please upload it to your Space.")
84
- if not os.path.exists(CLASSIFIER_PATH):
85
- raise FileNotFoundError(f"Error: Missing '{CLASSIFIER_PATH}'. Please upload it to your Space.")
86
 
87
  # 加载 Label Map
88
  with open(LABEL_MAP_PATH, 'r') as f:
89
  label_to_idx = json.load(f)
90
  idx_to_label = {v: k for k, v in label_to_idx.items()}
91
-
92
  NUM_CLASSES = len(idx_to_label)
93
  D_MODEL = 640
94
 
95
- print(f"🔹 Loading ESM-2 Model ({PLM_MODEL_NAME})...")
96
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
97
- plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
98
- plm_model.eval()
99
-
100
- print("🔹 Loading Custom Classifier...")
101
- classifier = ProtDualBranchEnhancedClassifier(
102
- d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
103
- dropout=0.3, kernel_size=3
104
- ).to(DEVICE)
105
 
 
 
106
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
107
  classifier.eval()
108
- print("✅ All Models Loaded Successfully.")
109
 
110
  # ==========================
111
- # 3. SVG 矢量绘图引擎 (完美对齐版)
112
  # ==========================
113
  def generate_bacterial_svg(target_class):
114
- """
115
- Generate a high-quality SVG vector diagram for bacterial localization.
116
- Coordinates are hardcoded to ensure perfect alignment.
117
- """
118
  target = target_class.lower() if target_class else ""
119
 
120
- # --- 1. 状态判断 ---
121
  is_om = "outer membrane" in target
122
  is_peri = "periplasm" in target
123
  is_cw = "cell wall" in target
@@ -125,294 +124,246 @@ def generate_bacterial_svg(target_class):
125
  is_cyto = "cytoplasm" in target or "cytosol" in target
126
  is_secreted = "extracellular" in target or "secreted" in target
127
 
128
- # --- 2. 颜色配置 (学术蓝/黄风格) ---
129
- colors = {
130
- # 填充色:平时浅色,激活变粉红
131
- "om_fill": "#FFCDD2" if is_peri else "#E1F5FE",
132
- "im_fill": "#FFCDD2" if is_cyto else "#FFF9C4",
133
-
134
- # 边框色:平时深灰,激活变鲜红
135
- "om_stroke": "#D32F2F" if is_om else "#37474F",
136
- "cw_stroke": "#D32F2F" if is_cw else "#90A4AE",
137
- "im_stroke": "#D32F2F" if is_im else "#37474F",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # 线宽
140
- "om_width": "4" if is_om else "2",
141
- "cw_width": "3" if is_cw else "1.5",
142
- "im_width": "4" if is_im else "2",
143
 
144
- # 细胞壁虚线
145
- "cw_dash": "0" if is_cw else "6,4",
146
 
147
- # 标签颜色
148
- "label_hl": "#D32F2F",
149
- "label_norm": "#546E7A",
150
- "arrow_hl": "#D32F2F",
151
- "arrow_norm": "#90A4AE"
152
- }
153
-
154
- # 获取标签样式的辅助函数
155
- def get_style(active):
156
- if active:
157
- return colors["label_hl"], "bold", colors["arrow_hl"], "2.5", "url(#arrowhead_hl)"
158
- else:
159
- return colors["label_norm"], "normal", colors["arrow_norm"], "1.0", "url(#arrowhead_norm)"
160
-
161
- s_om = get_style(is_om)
162
- s_peri = get_style(is_peri)
163
- s_cw = get_style(is_cw)
164
- s_im = get_style(is_im)
165
- s_cyto = get_style(is_cyto)
166
-
167
- # --- 3. 生成 SVG 字符串 ---
168
- svg = f"""
169
- <svg width="100%" height="100%" viewBox="0 0 800 450" xmlns="http://www.w3.org/2000/svg">
170
- <defs>
171
- <marker id="arrowhead_norm" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
172
- <polygon points="0 0, 10 3.5, 0 7" fill="{colors['arrow_norm']}" />
173
- </marker>
174
- <marker id="arrowhead_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
175
- <polygon points="0 0, 10 3.5, 0 7" fill="{colors['arrow_hl']}" />
176
- </marker>
177
- </defs>
178
-
179
- <rect width="800" height="450" fill="white" />
180
-
181
- <g transform="translate(50, 50)">
182
- <rect x="0" y="0" width="500" height="300" rx="150" ry="150"
183
- fill="{colors['om_fill']}" stroke="{colors['om_stroke']}" stroke-width="{colors['om_width']}" />
184
-
185
- <rect x="40" y="40" width="420" height="220" rx="110" ry="110"
186
- fill="none" stroke="{colors['cw_stroke']}" stroke-width="{colors['cw_width']}" stroke-dasharray="{colors['cw_dash']}" />
187
-
188
- <rect x="80" y="80" width="340" height="140" rx="70" ry="70"
189
- fill="{colors['im_fill']}" stroke="{colors['im_stroke']}" stroke-width="{colors['im_width']}" />
190
-
191
- <g opacity="0.6">
192
- <path d="M 180 150 Q 220 100 250 150 T 320 150" fill="none" stroke="#B0BEC5" stroke-width="3" />
193
- <path d="M 190 140 Q 230 190 250 140 T 310 160" fill="none" stroke="#B0BEC5" stroke-width="3" />
194
- <circle cx="150" cy="120" r="3" fill="#90A4AE" />
195
- <circle cx="350" cy="180" r="3" fill="#90A4AE" />
196
- <circle cx="250" cy="100" r="3" fill="#90A4AE" />
197
- <circle cx="200" cy="200" r="3" fill="#90A4AE" />
198
  </g>
199
  </g>
200
 
201
  {f'''
202
- <g transform="translate(300, 20)">
203
- <text x="0" y="0" text-anchor="middle" fill="{colors['label_hl']}" font-weight="bold" font-family="Arial" font-size="14">SECRETED / EXTRACELLULAR</text>
204
- <line x1="0" y1="5" x2="0" y2="25" stroke="{colors['arrow_hl']}" stroke-width="2" marker-end="url(#arrowhead_hl)" />
205
  </g>
206
  ''' if is_secreted else ""}
 
 
 
 
 
 
 
 
 
 
207
 
208
- <g font-family="Arial, sans-serif">
209
-
210
- <g transform="translate(580, 80)">
211
- <text x="0" y="5" fill="{s_om[0]}" font-weight="{s_om[1]}" font-size="14">Outer Membrane</text>
212
- <line x1="-10" y1="0" x2="-80" y2="0" stroke="{s_om[2]}" stroke-width="{s_om[3]}" marker-end="{s_om[4]}" />
213
- </g>
214
-
215
- <g transform="translate(580, 140)">
216
- <text x="0" y="5" fill="{s_peri[0]}" font-weight="{s_peri[1]}" font-size="14">Periplasm</text>
217
- <line x1="-10" y1="0" x2="-100" y2="0" stroke="{s_peri[2]}" stroke-width="{s_peri[3]}" marker-end="{s_peri[4]}" />
218
- </g>
219
-
220
- <g transform="translate(580, 200)">
221
- <text x="0" y="5" fill="{s_cw[0]}" font-weight="{s_cw[1]}" font-size="14">Cell Wall</text>
222
- <line x1="-10" y1="0" x2="-120" y2="0" stroke="{s_cw[2]}" stroke-width="{s_cw[3]}" marker-end="{s_cw[4]}" />
223
- </g>
224
 
225
- <g transform="translate(580, 260)">
226
- <text x="0" y="5" fill="{s_im[0]}" font-weight="{s_im[1]}" font-size="14">Inner Membrane</text>
227
- <line x1="-10" y1="0" x2="-150" y2="0" stroke="{s_im[2]}" stroke-width="{s_im[3]}" marker-end="{s_im[4]}" />
228
- </g>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- <g transform="translate(580, 320)">
231
- <text x="0" y="5" fill="{s_cyto[0]}" font-weight="{s_cyto[1]}" font-size="14">Cytoplasm</text>
232
- <line x1="-10" y1="0" x2="-200" y2="0" stroke="{s_cyto[2]}" stroke-width="{s_cyto[3]}" marker-end="{s_cyto[4]}" />
233
- </g>
234
- </g>
235
-
236
- <text x="400" y="420" text-anchor="middle" font-family="Arial" font-size="18" font-weight="bold" fill="#37474F">
237
- Predicted Localization: {target_class}
238
- </text>
239
- </svg>
240
- """
241
- return svg
242
 
243
  # ==========================
244
- # 4. 预测逻辑
245
  # ==========================
246
  def predict(sequence_input):
247
- if not sequence_input or sequence_input.isspace():
248
- raise gr.Error("Please input a protein sequence.")
249
 
250
- # 清洗输入
251
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
252
- seq = re.sub(r'[^A-Z]', '', seq.upper())
 
253
 
254
- if not seq: raise gr.Error("Invalid Amino Acid Sequence.")
255
- if len(seq) > 1024: seq = seq[:1024] # 截断防止OOM
256
-
257
  with torch.no_grad():
258
  inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
259
  outputs = plm_model(**inputs)
260
 
261
- # 提取特征
262
  hidden_states = outputs.last_hidden_state
263
  cls_embedding = hidden_states[:, 0, :]
264
- token_embeddings = hidden_states[:, 1:-1, :]
265
  token_mask = inputs['attention_mask'][:, 1:-1]
266
 
267
- # 模型推理
268
- logits = classifier(cls_embedding, token_embeddings, token_mask)
269
  probs = F.softmax(logits, dim=1)[0]
270
 
271
- # 获取结果
272
- top_prob, top_idx = torch.max(probs, dim=0)
273
- top_label = idx_to_label[top_idx.item()]
274
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
275
 
276
- # 生成 SVG 可视化
277
- svg_content = generate_bacterial_svg(top_label)
278
 
279
- return confidences, svg_content
 
 
 
 
 
280
 
281
  # ==========================
282
- # 5. UI 界面 (学术风格)
283
  # ==========================
284
- paper_css = """
285
- @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;700&display=swap');
286
- body { font-family: 'Roboto', sans-serif !important; background-color: #ffffff; color: #1a1a1a; }
287
-
288
- /* Header */
289
- .header-box {
290
- background: #ffffff;
291
- padding: 2rem 0;
292
- border-bottom: 1px solid #e5e7eb;
293
- margin-bottom: 2rem;
294
- }
295
- .header-title {
296
- font-size: 2.2rem;
297
- font-weight: 700;
298
- color: #0f172a;
299
- letter-spacing: -0.5px;
300
  }
301
- .header-subtitle {
302
- font-size: 1.1rem;
303
- color: #64748b;
304
- font-weight: 300;
305
- margin-top: 8px;
 
 
 
 
 
 
 
306
  }
307
- .badge {
308
- display: inline-flex;
309
- align-items: center;
310
- padding: 4px 12px;
311
- font-size: 0.85rem;
312
- font-weight: 500;
313
- color: #0f172a;
314
- background: #f1f5f9;
315
- border: 1px solid #e2e8f0;
316
- border-radius: 99px;
317
- margin-right: 10px;
318
  }
319
-
320
- /* Content Box */
321
- .content-box {
322
- background: #ffffff;
323
- border: 1px solid #e2e8f0;
324
- border-radius: 8px;
325
- padding: 1.5rem;
326
- box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
327
- }
328
-
329
- /* Button */
330
- button.primary {
331
- background-color: #2563eb !important;
332
- color: white !important;
333
- border-radius: 6px !important;
334
- font-weight: 500;
335
  }
336
  """
337
 
338
- theme = gr.themes.Base(
339
- primary_hue="blue",
340
- font=[gr.themes.GoogleFont("Roboto"), "ui-sans-serif", "system-ui"]
341
- ).set(
342
- body_background_fill="#ffffff",
343
- block_background_fill="#ffffff",
344
- block_border_width="1px",
345
- block_label_background_fill="#ffffff"
346
- )
347
-
348
- with gr.Blocks(theme=theme, css=paper_css, title="LocPred-Prok") as app:
349
 
350
- # --- Header ---
351
- with gr.Column(elem_classes="header-box"):
352
- gr.HTML("""
353
  <div class="header-title">LocPred-Prok</div>
354
- <div class="header-subtitle">
355
- Deep learning framework for prokaryotic subcellular localization using dual-branch architecture
356
- </div>
357
- <div style="margin-top: 15px;">
358
- <span class="badge">Research Article</span>
359
- <span class="badge">ESM-2 Enhanced</span>
360
- <span class="badge">Gram-negative Bacteria</span>
361
- </div>
362
- """)
363
-
364
- # --- Main Content ---
365
- with gr.Tabs():
366
- with gr.TabItem("Prediction Interface"):
367
  with gr.Row():
368
- # Input Column
369
- with gr.Column(scale=4, elem_classes="content-box"):
370
- gr.Markdown("### 1. Sequence Input")
371
- gr.Markdown("<span style='color:#64748b; font-size:0.9rem'>Enter a protein sequence in FASTA format or raw amino acids.</span>")
372
-
373
- sequence_input = gr.Textbox(
374
- lines=12,
375
- show_label=False,
376
- placeholder=">Sequence_ID\nMKFKLTAGCL..."
377
- )
378
-
379
- with gr.Row():
380
- clear_btn = gr.ClearButton(sequence_input, value="Clear")
381
- submit_btn = gr.Button("Run Analysis", variant="primary")
382
-
383
- gr.Markdown("#### Test Examples")
384
- gr.Examples(
385
- examples=[
386
- [">Outer Membrane Protein (OmpA)\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAFGGYQVNPYVGFEMGYDWLGRMPYKGSVENGAYKAQGVQLTAKLGYPITDDLDIYTRLGGMVWRADTKSNVYGKNHDTGVSPVFAGGVEYAITPEIATRLEYQWTNNIGDAHTIGTRPDNGMLSLGVSYRFGQGEAAPVVAPAPAPAPEVQTKHFTLKSDVLFNFNKATLKPEGQAALDQLYSQLSNLDPKDGSVVVLGYTDRIGSDAYNQGLSERRAQSVVDYLISKGIPADKISARGMGESNPVTGNTCDNVKQRAALIDCLAPDRRVEIEVKGIKDVVTQPQA"],
387
- [">Cytoplasmic Protein (Ribosomal)\nARYLGPKLKLSRREGTDLFLKSGVRAIDTKCKIEQAPGQHGARKPRLSDYGVQLREKQKVRRIYGVLERQFRNYYKEAARLKGNTGENLLALLEGRLDNVVYRMGFG"]
388
- ],
389
- inputs=sequence_input,
390
- label=None
391
- )
392
-
393
- # Output Column
394
- with gr.Column(scale=6, elem_classes="content-box"):
395
- gr.Markdown("### 2. Localization Results")
396
-
397
- # 使用 HTML 组件展示 SVG
398
- output_svg = gr.HTML(label="Visualization", show_label=False)
399
-
400
- gr.Markdown("#### Confidence Scores")
401
- output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
402
-
403
- with gr.TabItem("About & Methodology"):
404
- gr.Markdown("""
405
- ### Methodology
406
- **LocPred-Prok** employs a dual-branch neural network architecture...
407
- """)
408
-
409
- # --- Interaction ---
410
- submit_btn.click(
411
- fn=predict,
412
- inputs=sequence_input,
413
- outputs=[output_label, output_svg]
414
- )
415
- clear_btn.click(lambda: [None, None], outputs=[output_label, output_svg])
416
-
417
- # Launch
418
  app.launch()
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
  from transformers import AutoTokenizer, AutoModel
11
 
12
  # ==========================
13
  # 0. 环境与缓存设置
14
  # ==========================
15
+ # 强制使用非交互式后端,防止 matplotlib 在服务器报错
16
+ plt.switch_backend('Agg')
17
+
18
  os.environ["HF_HOME"] = "/tmp/hf_cache"
19
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
20
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
21
 
 
22
  import shutil
23
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
24
  shutil.rmtree(path, ignore_errors=True)
25
  os.makedirs(path, exist_ok=True)
26
 
27
  # ==========================
28
+ # 1. 模型架构定义 (支持 Attention 输出)
29
  # ==========================
30
  class AttentionPooling(nn.Module):
 
31
  def __init__(self, d_model):
32
  super().__init__()
33
  self.attention_net = nn.Linear(d_model, 1)
34
 
35
  def forward(self, x, mask):
36
+ # x shape: (Batch, Seq_Len, Dim)
37
+ attn_logits = self.attention_net(x).squeeze(2)
38
  attn_logits.masked_fill_(mask == 0, -float('inf'))
39
  attn_weights = F.softmax(attn_logits, dim=1)
40
+
41
+ # 返回: (Pooled_Embedding, Weights)
42
+ # Weights 用于 Panel D 的可视化
43
+ return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
44
 
45
  class ProtDualBranchEnhancedClassifier(nn.Module):
 
46
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
47
  super().__init__()
48
  self.cls_projector = nn.Linear(d_model, projection_dim)
49
  self.token_refiner = nn.Sequential(
50
+ nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
51
  nn.ReLU()
52
  )
53
  self.attention_pooling = AttentionPooling(d_model)
54
  self.tok_projector = nn.Linear(d_model, projection_dim)
55
  fused_dim = projection_dim * 2
56
+ self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
 
 
 
57
  self.classifier_head = nn.Sequential(
58
+ nn.LayerNorm(fused_dim),
59
+ nn.Linear(fused_dim, fused_dim * 2),
60
+ nn.ReLU(),
61
+ nn.Dropout(dropout),
62
  nn.Linear(fused_dim * 2, num_classes)
63
  )
64
 
65
  def forward(self, cls_embedding, token_embeddings, mask):
66
+ # Branch 1: Global Semantic
67
  z_cls = self.cls_projector(cls_embedding)
68
+
69
+ # Branch 2: Local Structural
70
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
71
  refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
72
+
73
+ # ⚠️ 获取 Pooling 权重用于可视化
74
+ z_tok_pooled, pooling_weights = self.attention_pooling(refined_tok_emb, mask)
75
  z_tok = self.tok_projector(z_tok_pooled)
76
+
77
+ # Fusion Gate
78
  z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
79
  gate_values = self.gate(z_fused_concat)
80
  z_fused_gated = z_fused_concat * gate_values
81
+
82
+ return self.classifier_head(z_fused_gated), pooling_weights
83
 
84
  # ==========================
85
+ # 2. 加载模型与配置
86
  # ==========================
87
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
89
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
90
  LABEL_MAP_PATH = "label_map.json"
91
 
92
+ # 检查文件
93
+ if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
94
+ if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
 
 
95
 
96
  # 加载 Label Map
97
  with open(LABEL_MAP_PATH, 'r') as f:
98
  label_to_idx = json.load(f)
99
  idx_to_label = {v: k for k, v in label_to_idx.items()}
 
100
  NUM_CLASSES = len(idx_to_label)
101
  D_MODEL = 640
102
 
103
+ print("🔹 Loading models...")
104
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
105
+ plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
 
 
 
 
 
 
 
106
 
107
+ classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
108
+ # strict=False 允许加载即使权重文件中没有 pooling_weights 相关的特定状态(通常不影响)
109
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
110
  classifier.eval()
111
+ print("✅ Ready.")
112
 
113
  # ==========================
114
+ # 3. Panel B: SVG 绘图引擎 (贝塞尔曲线 + 锚点)
115
  # ==========================
116
  def generate_bacterial_svg(target_class):
 
 
 
 
117
  target = target_class.lower() if target_class else ""
118
 
119
+ # 1. 状态判断
120
  is_om = "outer membrane" in target
121
  is_peri = "periplasm" in target
122
  is_cw = "cell wall" in target
 
124
  is_cyto = "cytoplasm" in target or "cytosol" in target
125
  is_secreted = "extracellular" in target or "secreted" in target
126
 
127
+ # 2. 颜色配置 (高对比度科研风)
128
+ c = {
129
+ # 激活态: 鲜红
130
+ "hl_stroke": "#D32F2F", "hl_fill": "#FFEBEE", "hl_text": "#B71C1C", "hl_dot": "#D32F2F",
131
+ # 未激活态: 极淡的灰白 (背景化)
132
+ "bg_stroke": "#90A4AE", "bg_fill_om": "#F5F5F5", "bg_fill_im": "#FAFAFA",
133
+ "bg_text": "#78909C", "bg_line": "#CFD8DC", "bg_dot": "#B0BEC5"
134
+ }
135
+
136
+ # 样式生成器
137
+ def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
138
+ if active: return c["hl_fill"], c["hl_stroke"], w_act
139
+ return base_fill, base_stroke, width_norm
140
+
141
+ om_f, om_s, om_w = style(is_peri, c["bg_fill_om"], c["hl_stroke"] if is_om else c["bg_stroke"])
142
+ cw_s = c["hl_stroke"] if is_cw else "#B0BEC5"
143
+ cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
144
+ im_f, im_s, im_w = style(is_cyto, c["bg_fill_im"], c["hl_stroke"] if is_im else c["bg_stroke"])
145
+
146
+ # 标签样式 (文字颜色, 字重, 线条颜色, 线宽, 锚点颜色, 锚点半径)
147
+ def label_style(active):
148
+ if active: return c["hl_text"], "bold", c["hl_stroke"], "2.5", c["hl_dot"], "5"
149
+ return c["bg_text"], "normal", c["bg_line"], "1.5", c["bg_dot"], "3"
150
+
151
+ l_om, l_peri, l_cw, l_im, l_cyto = label_style(is_om), label_style(is_peri), label_style(is_cw), label_style(is_im), label_style(is_cyto)
152
+
153
+ # 3. 坐标定义
154
+ bx, by = 280, 210 # 细菌中心
155
+ tx = 600 # 标签文字起始 X 坐标
156
+
157
+ # 目标锚点 (Target Anchor Points) - 精确落在结构上
158
+ targets = {
159
+ "om": (bx + 140, by - 120), # 外膜线
160
+ "peri": (bx + 120, by - 90), # 周质间隙
161
+ "cw": (bx + 100, by - 70), # 细胞壁线
162
+ "im": (bx + 70, by - 50), # 内膜线
163
+ "cyto": (bx, by) # 胞质中心
164
+ }
165
+
166
+ text_y = {"om": 90, "peri": 150, "cw": 210, "im": 270, "cyto": 330}
167
+
168
+ # 4. 贝塞尔曲线连接器
169
+ def draw_connector(key, style_tuple, label_text):
170
+ txt_col, weight, line_col, width, dot_col, r = style_tuple
171
+ tx_pos, ty_pos = tx, text_y[key]
172
+ ex, ey = targets[key]
173
 
174
+ # 贝塞尔控制点:形成 S 形曲线
175
+ c1x, c1y = tx_pos - 100, ty_pos
176
+ c2x, c2y = ex + 50, ey
 
177
 
178
+ path = f"M {tx_pos - 10} {ty_pos - 5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
 
179
 
180
+ return f"""
181
+ <g>
182
+ <text x="{tx_pos}" y="{ty_pos}" fill="{txt_col}" font-weight="{weight}" font-size="15" font-family="Arial">{label_text}</text>
183
+ <path d="{path}" fill="none" stroke="{line_col}" stroke-width="{width}" />
184
+ <circle cx="{ex}" cy="{ey}" r="{r}" fill="{dot_col}" stroke="white" stroke-width="1" />
185
+ </g>
186
+ """
187
+
188
+ svg = f"""<svg width="100%" height="100%" viewBox="0 0 800 420" xmlns="http://www.w3.org/2000/svg">
189
+ <g transform="translate(280, 210)">
190
+ <rect x="-150" y="-150" width="300" height="300" rx="150" ry="150" fill="{om_f}" stroke="{om_s}" stroke-width="{om_w}" />
191
+ <rect x="-110" y="-110" width="220" height="220" rx="110" ry="110" fill="none" stroke="{cw_s}" stroke-width="{cw_w}" stroke-dasharray="{cw_d}" />
192
+ <rect x="-70" y="-70" width="140" height="140" rx="70" ry="70" fill="{im_f}" stroke="{im_s}" stroke-width="{im_w}" />
193
+ <g opacity="0.4">
194
+ <path d="M -30 -20 Q 0 -60 30 -20 T 60 -10" fill="none" stroke="#CFD8DC" stroke-width="3" />
195
+ <circle cx="-40" cy="30" r="3" fill="#B0BEC5" /> <circle cx="20" cy="40" r="3" fill="#B0BEC5" />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  </g>
197
  </g>
198
 
199
  {f'''
200
+ <g transform="translate(500, 40)">
201
+ <text x="0" y="0" text-anchor="middle" fill="{c['hl_stroke']}" font-weight="bold" font-family="Arial" font-size="14">SECRETED</text>
202
+ <path d="M 0 10 L 0 40" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" />
203
  </g>
204
  ''' if is_secreted else ""}
205
+
206
+ <defs><marker id="arrow_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"><polygon points="0 0, 10 3.5, 0 7" fill="{c['hl_stroke']}" /></marker></defs>
207
+
208
+ {draw_connector("om", l_om, "Outer Membrane")}
209
+ {draw_connector("peri", l_peri, "Periplasm")}
210
+ {draw_connector("cw", l_cw, "Cell Wall")}
211
+ {draw_connector("im", l_im, "Inner Membrane")}
212
+ {draw_connector("cyto", l_cyto, "Cytoplasm")}
213
+ </svg>"""
214
+ return svg
215
 
216
+ # ==========================
217
+ # 4. Panel D: Attention 绘图引擎
218
+ # ==========================
219
+ def draw_pooling_weights(weights, sequence):
220
+ """
221
+ Visualize Attention Pooling Weights (1D Heatmap/Bar).
222
+ """
223
+ # 归一化
224
+ if weights.max() > 0:
225
+ weights = (weights - weights.min()) / (weights.max() - weights.min())
 
 
 
 
 
 
226
 
227
+ fig, ax = plt.subplots(figsize=(6, 3), dpi=120)
228
+ x = np.arange(len(weights))
229
+
230
+ # 绘制红色条形
231
+ ax.bar(x, weights, width=1.0, color='#D32F2F', alpha=0.8, label='Attention')
232
+
233
+ # 样式
234
+ ax.set_title("Learned Motif Importance (Attention Pooling)", fontsize=10, fontweight='bold', color='#37474F')
235
+ ax.set_xlabel("Residue Position", fontsize=9)
236
+ ax.set_ylabel("Weight", fontsize=9)
237
+ ax.spines['top'].set_visible(False)
238
+ ax.spines['right'].set_visible(False)
239
+ ax.spines['left'].set_visible(False)
240
+ ax.set_yticks([])
241
+
242
+ # 标注最高峰 (Potential Motif)
243
+ threshold = np.percentile(weights, 98) # 更加严格的阈值
244
+ if weights.max() > threshold:
245
+ peak_idx = np.argmax(weights)
246
+ ax.annotate('Key Motif', xy=(peak_idx, weights[peak_idx]), xytext=(peak_idx, weights[peak_idx]+0.2),
247
+ arrowprops=dict(facecolor='#37474F', shrink=0.05, width=1, headwidth=5),
248
+ ha='center', fontsize=8, color='#37474F')
249
 
250
+ plt.tight_layout()
251
+ return fig
 
 
 
 
 
 
 
 
 
 
252
 
253
  # ==========================
254
+ # 5. 预测主逻辑
255
  # ==========================
256
  def predict(sequence_input):
257
+ if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
 
258
 
 
259
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
260
+ seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
261
+ if not seq: raise gr.Error("Invalid Sequence")
262
 
 
 
 
263
  with torch.no_grad():
264
  inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
265
  outputs = plm_model(**inputs)
266
 
 
267
  hidden_states = outputs.last_hidden_state
268
  cls_embedding = hidden_states[:, 0, :]
269
+ token_embeddings = hidden_states[:, 1:-1, :] # No CLS/EOS
270
  token_mask = inputs['attention_mask'][:, 1:-1]
271
 
272
+ # ⚠️ 获取 logits 和 weights
273
+ logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
274
  probs = F.softmax(logits, dim=1)[0]
275
 
276
+ # 1. 结果
277
+ top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
 
278
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
279
 
280
+ # 2. Panel B: SVG
281
+ svg = generate_bacterial_svg(top_label)
282
 
283
+ # 3. Panel D: Attention Plot
284
+ # 取 batch 中第一个样本的 weights
285
+ w_np = pooling_weights[0].cpu().numpy()
286
+ attn_plot = draw_pooling_weights(w_np, seq)
287
+
288
+ return confidences, svg, attn_plot
289
 
290
  # ==========================
291
+ # 6. UI Layout (4-Block Paper Style)
292
  # ==========================
293
+ layout_css = """
294
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
295
+ body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
296
+
297
+ /* Header: Sky Blue Theme */
298
+ .header-div {
299
+ background: linear-gradient(to right, #E0F7FA, #E1F5FE);
300
+ padding: 1.5rem;
301
+ border-radius: 8px;
302
+ margin-bottom: 20px;
303
+ text-align: center;
304
+ border: 1px solid #B3E5FC;
 
 
 
 
305
  }
306
+ .header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
307
+ .header-sub { font-size: 1.0rem; color: #0277BD; }
308
+
309
+ /* Panel Cards */
310
+ .panel-card {
311
+ border: 1px solid #e2e8f0;
312
+ border-radius: 8px;
313
+ padding: 15px;
314
+ background: white;
315
+ height: 100%;
316
+ display: flex;
317
+ flex-direction: column;
318
  }
319
+ .panel-header {
320
+ font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9;
321
+ padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem;
 
 
 
 
 
 
 
 
322
  }
323
+ .panel-label {
324
+ display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2;
325
+ padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800;
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  }
327
  """
328
 
329
+ theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
330
+
331
+ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
 
 
 
 
 
 
 
 
332
 
333
+ gr.HTML("""
334
+ <div class="header-div">
 
335
  <div class="header-title">LocPred-Prok</div>
336
+ <div class="header-sub">Deep Learning Framework for Prokaryotic Subcellular Localization</div>
337
+ </div>
338
+ """)
339
+
340
+ # Row 1: A & B
341
+ with gr.Row():
342
+ with gr.Column(elem_classes="panel-card"):
343
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
344
+ sequence_input = gr.Textbox(lines=8, show_label=False, placeholder=">Sequence...")
 
 
 
 
345
  with gr.Row():
346
+ clear_btn = gr.ClearButton(sequence_input, value="Clear")
347
+ submit_btn = gr.Button("Predict Analysis", variant="primary")
348
+ gr.Examples([
349
+ [">Outer Membrane\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAFGGYQVNPYVGFEMGYDWLGRMPYKGSVENGAYKAQGVQLTAKLGYPITDDLDIYTRLGGMVWRADTKSNVYGKNHDTGVSPVFAGGVEYAITPEIATRLEYQWTNNIGDAHTIGTRPDNGMLSLGVSYRFGQGEAAPVVAPAPAPAPEVQTKHFTLKSDVLFNFNKATLKPEGQAALDQLYSQLSNLDPKDGSVVVLGYTDRIGSDAYNQGLSERRAQSVVDYLISKGIPADKISARGMGESNPVTGNTCDNVKQRAALIDCLAPDRRVEIEVKGIKDVVTQPQA"]
350
+ ], inputs=sequence_input, label=None)
351
+
352
+ with gr.Column(elem_classes="panel-card"):
353
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
354
+ output_svg = gr.HTML(label="Visual", show_label=False)
355
+
356
+ # Row 2: C & D
357
+ with gr.Row():
358
+ with gr.Column(elem_classes="panel-card"):
359
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
360
+ output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
361
+
362
+ with gr.Column(elem_classes="panel-card"):
363
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Learned Motif Importance (Attention)</div>")
364
+ output_plot = gr.Plot(label="Attention", show_label=False)
365
+
366
+ submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])
367
+ clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
368
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  app.launch()