wangleiofficial commited on
Commit
d03f374
·
verified ·
1 Parent(s): 80c9b83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -120
app.py CHANGED
@@ -9,12 +9,10 @@ import matplotlib.pyplot as plt
9
  import numpy as np
10
  from transformers import AutoTokenizer, AutoModel
11
 
12
- # ==============================================================================
13
- # 0. 环境与缓存设置 (Environment Setup)
14
- # ==============================================================================
15
- # 强制使用非交互式后端,防止 matplotlib 在服务器报错
16
  plt.switch_backend('Agg')
17
-
18
  os.environ["HF_HOME"] = "/tmp/hf_cache"
19
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
20
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
@@ -24,76 +22,53 @@ for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
24
  shutil.rmtree(path, ignore_errors=True)
25
  os.makedirs(path, exist_ok=True)
26
 
27
- # ==============================================================================
28
- # 1. 模型架构定义 (Model Architecture)
29
- # ==============================================================================
30
  class AttentionPooling(nn.Module):
31
  def __init__(self, d_model):
32
  super().__init__()
33
  self.attention_net = nn.Linear(d_model, 1)
34
 
35
  def forward(self, x, mask):
36
- # x shape: (Batch, Seq_Len, Dim)
37
  attn_logits = self.attention_net(x).squeeze(2)
38
  attn_logits.masked_fill_(mask == 0, -float('inf'))
39
  attn_weights = F.softmax(attn_logits, dim=1)
40
-
41
- # 返回: (Pooled_Embedding, Weights)
42
- # Weights 用于 Panel D 的可视化
43
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
44
 
45
  class ProtDualBranchEnhancedClassifier(nn.Module):
46
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
47
  super().__init__()
48
  self.cls_projector = nn.Linear(d_model, projection_dim)
49
- self.token_refiner = nn.Sequential(
50
- nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
51
- nn.ReLU()
52
- )
53
  self.attention_pooling = AttentionPooling(d_model)
54
  self.tok_projector = nn.Linear(d_model, projection_dim)
55
  fused_dim = projection_dim * 2
56
  self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
57
- self.classifier_head = nn.Sequential(
58
- nn.LayerNorm(fused_dim),
59
- nn.Linear(fused_dim, fused_dim * 2),
60
- nn.ReLU(),
61
- nn.Dropout(dropout),
62
- nn.Linear(fused_dim * 2, num_classes)
63
- )
64
 
65
  def forward(self, cls_embedding, token_embeddings, mask):
66
- # Branch 1: Global Semantic
67
  z_cls = self.cls_projector(cls_embedding)
68
-
69
- # Branch 2: Local Structural
70
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
71
  refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
72
-
73
- # ⚠️ 获取 Pooling 权重用于可视化
74
  z_tok_pooled, pooling_weights = self.attention_pooling(refined_tok_emb, mask)
75
  z_tok = self.tok_projector(z_tok_pooled)
76
-
77
- # Fusion Gate
78
  z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
79
  gate_values = self.gate(z_fused_concat)
80
  z_fused_gated = z_fused_concat * gate_values
81
-
82
  return self.classifier_head(z_fused_gated), pooling_weights
83
 
84
- # ==============================================================================
85
- # 2. 加载模型与配置 (Load Resources)
86
- # ==============================================================================
87
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
89
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
90
  LABEL_MAP_PATH = "label_map.json"
91
 
92
- # 检查文件
93
  if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
94
  if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
95
 
96
- # 加载 Label Map
97
  with open(LABEL_MAP_PATH, 'r') as f:
98
  label_to_idx = json.load(f)
99
  idx_to_label = {v: k for k, v in label_to_idx.items()}
@@ -103,41 +78,36 @@ D_MODEL = 640
103
  print("🔹 Loading models...")
104
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
105
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
106
-
107
  classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
108
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
109
  classifier.eval()
110
  print("✅ Ready.")
111
 
112
- # ==============================================================================
113
- # 3. Panel B: SVG 绘图引擎 (Visualization Engine)
114
- # ==============================================================================
115
  def generate_bacterial_svg(target_class):
116
  target = target_class.lower() if target_class else ""
117
 
118
- # 1. 状态判断
119
- is_om = "outer membrane" in target
 
120
  is_peri = "periplasm" in target
121
- is_cw = "cell wall" in target
122
- is_im = "plasma membrane" in target or "inner membrane" in target
123
  is_cyto = "cytoplasm" in target or "cytosol" in target
124
- is_secreted = "extracellular" in target or "secreted" in target
125
 
126
- # 2. 颜色配置 (高对比度科研风)
127
  c = {
128
- # 激活态: 鲜红
129
  "hl_stroke": "#D32F2F", "hl_fill": "#FFEBEE", "hl_text": "#B71C1C", "hl_dot": "#D32F2F",
130
- # 未激活态: 极淡的灰白 (背景化)
131
  "bg_stroke": "#90A4AE", "bg_fill_om": "#F5F5F5", "bg_fill_im": "#FAFAFA",
132
  "bg_text": "#78909C", "bg_line": "#CFD8DC", "bg_dot": "#B0BEC5"
133
  }
134
 
135
- # 3. 样式生成器 (这里修复了之前的 bug)
136
  def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
137
- if active:
138
- return c["hl_fill"], c["hl_stroke"], w_act
139
- # ✅ 修复点:这里原来写成了 width_norm,现已修正为 w_norm
140
- return base_fill, base_stroke, w_norm
141
 
142
  om_f, om_s, om_w = style(is_peri, c["bg_fill_om"], c["hl_stroke"] if is_om else c["bg_stroke"])
143
  cw_s = c["hl_stroke"] if is_cw else "#B0BEC5"
@@ -149,37 +119,47 @@ def generate_bacterial_svg(target_class):
149
  if active: return c["hl_text"], "bold", c["hl_stroke"], "2.5", c["hl_dot"], "5"
150
  return c["bg_text"], "normal", c["bg_line"], "1.5", c["bg_dot"], "3"
151
 
152
- l_om, l_peri, l_cw, l_im, l_cyto = label_style(is_om), label_style(is_peri), label_style(is_cw), label_style(is_im), label_style(is_cyto)
 
 
 
 
 
153
 
154
- # 4. 坐标定义
155
- bx, by = 280, 210 # 细菌中心
156
- tx = 600 # 标签文字起始 X 坐标
157
 
 
158
  targets = {
159
- "om": (bx + 140, by - 120), # 外膜线
160
- "peri": (bx + 120, by - 90), # 周质间隙
161
- "cw": (bx + 100, by - 70), # 细胞壁线
162
- "im": (bx + 70, by - 50), # 内膜线
163
- "cyto": (bx, by) # 胞质中心
 
164
  }
165
 
166
- text_y = {"om": 90, "peri": 150, "cw": 210, "im": 270, "cyto": 330}
 
 
 
167
 
168
- # 5. 贝塞尔曲线连接器
169
  def draw_connector(key, style_tuple, label_text):
170
  txt_col, weight, line_col, width, dot_col, r = style_tuple
171
  tx_pos, ty_pos = tx, text_y[key]
172
  ex, ey = targets[key]
173
 
174
- # 贝塞尔控制点
175
- c1x, c1y = tx_pos - 100, ty_pos
176
- c2x, c2y = ex + 50, ey
177
 
178
  path = f"M {tx_pos - 10} {ty_pos - 5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
179
 
180
  return f"""
181
  <g>
182
- <text x="{tx_pos}" y="{ty_pos}" fill="{txt_col}" font-weight="{weight}" font-size="15" font-family="Arial">{label_text}</text>
183
  <path d="{path}" fill="none" stroke="{line_col}" stroke-width="{width}" />
184
  <circle cx="{ex}" cy="{ey}" r="{r}" fill="{dot_col}" stroke="white" stroke-width="1" />
185
  </g>
@@ -196,15 +176,7 @@ def generate_bacterial_svg(target_class):
196
  </g>
197
  </g>
198
 
199
- {f'''
200
- <g transform="translate(500, 40)">
201
- <text x="0" y="0" text-anchor="middle" fill="{c['hl_stroke']}" font-weight="bold" font-family="Arial" font-size="14">SECRETED</text>
202
- <path d="M 0 10 L 0 40" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" />
203
- </g>
204
- ''' if is_secreted else ""}
205
-
206
- <defs><marker id="arrow_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"><polygon points="0 0, 10 3.5, 0 7" fill="{c['hl_stroke']}" /></marker></defs>
207
-
208
  {draw_connector("om", l_om, "Outer Membrane")}
209
  {draw_connector("peri", l_peri, "Periplasm")}
210
  {draw_connector("cw", l_cw, "Cell Wall")}
@@ -213,46 +185,48 @@ def generate_bacterial_svg(target_class):
213
  </svg>"""
214
  return svg
215
 
216
- # ==============================================================================
217
- # 4. Panel D: Attention 绘图引擎 (Interpretability)
218
- # ==============================================================================
219
- def draw_pooling_weights(weights, sequence):
220
  """
221
- Visualize Attention Pooling Weights (1D Heatmap/Bar).
 
222
  """
223
- # 归一化
224
  if weights.max() > 0:
225
  weights = (weights - weights.min()) / (weights.max() - weights.min())
 
 
 
226
 
227
- fig, ax = plt.subplots(figsize=(6, 3), dpi=120)
228
- x = np.arange(len(weights))
229
 
230
- # 绘制红色条形
231
- ax.bar(x, weights, width=1.0, color='#D32F2F', alpha=0.8, label='Attention')
232
 
233
- # 样式
234
- ax.set_title("Learned Motif Importance (Attention Pooling)", fontsize=10, fontweight='bold', color='#37474F')
235
  ax.set_xlabel("Residue Position", fontsize=9)
236
- ax.set_ylabel("Weight", fontsize=9)
237
- ax.spines['top'].set_visible(False)
238
- ax.spines['right'].set_visible(False)
239
- ax.spines['left'].set_visible(False)
240
- ax.set_yticks([])
241
 
242
- # 标注最高峰 (Key Motif)
243
- threshold = np.percentile(weights, 98)
244
- if weights.max() > threshold:
245
- peak_idx = np.argmax(weights)
246
- ax.annotate('Key Motif', xy=(peak_idx, weights[peak_idx]), xytext=(peak_idx, weights[peak_idx]+0.2),
247
- arrowprops=dict(facecolor='#37474F', shrink=0.05, width=1, headwidth=5),
248
- ha='center', fontsize=8, color='#37474F')
 
 
 
 
249
 
250
  plt.tight_layout()
251
  return fig
252
 
253
- # ==============================================================================
254
- # 5. 预测主逻辑 (Prediction Logic)
255
- # ==============================================================================
256
  def predict(sequence_input):
257
  if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
258
 
@@ -266,35 +240,33 @@ def predict(sequence_input):
266
 
267
  hidden_states = outputs.last_hidden_state
268
  cls_embedding = hidden_states[:, 0, :]
269
- token_embeddings = hidden_states[:, 1:-1, :] # No CLS/EOS
270
  token_mask = inputs['attention_mask'][:, 1:-1]
271
 
272
- # ⚠️ 获取 logits 和 pooling_weights
273
  logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
274
  probs = F.softmax(logits, dim=1)[0]
275
 
276
- # 1. 结果 (Panel C)
277
  top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
278
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
279
 
280
- # 2. Panel B: SVG
281
  svg = generate_bacterial_svg(top_label)
282
 
283
- # 3. Panel D: Attention Plot
284
- # 取 batch 中第一个样本的 weights
285
  w_np = pooling_weights[0].cpu().numpy()
286
- attn_plot = draw_pooling_weights(w_np, seq)
287
 
288
- return confidences, svg, attn_plot
289
 
290
- # ==============================================================================
291
- # 6. UI 布局 (Four-Block Paper Style)
292
- # ==============================================================================
293
  layout_css = """
294
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
295
  body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
296
 
297
- /* Header: Sky Blue Theme */
298
  .header-div {
299
  background: linear-gradient(to right, #E0F7FA, #E1F5FE);
300
  padding: 1.5rem;
@@ -330,7 +302,6 @@ theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", bloc
330
 
331
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
332
 
333
- # --- Header ---
334
  gr.HTML("""
335
  <div class="header-div">
336
  <div class="header-title">LocPred-Prok</div>
@@ -338,7 +309,7 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
338
  </div>
339
  """)
340
 
341
- # --- Row 1: Panels A & B ---
342
  with gr.Row():
343
  with gr.Column(elem_classes="panel-card"):
344
  gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
@@ -354,14 +325,14 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
354
  gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
355
  output_svg = gr.HTML(label="Visual", show_label=False)
356
 
357
- # --- Row 2: Panels C & D ---
358
  with gr.Row():
359
  with gr.Column(elem_classes="panel-card"):
360
  gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
361
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
362
 
363
  with gr.Column(elem_classes="panel-card"):
364
- gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Learned Motif Importance (Attention)</div>")
365
  output_plot = gr.Plot(label="Attention", show_label=False)
366
 
367
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])
 
9
  import numpy as np
10
  from transformers import AutoTokenizer, AutoModel
11
 
12
+ # ==========================
13
+ # 0. 环境与缓存
14
+ # ==========================
 
15
  plt.switch_backend('Agg')
 
16
  os.environ["HF_HOME"] = "/tmp/hf_cache"
17
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
18
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
 
22
  shutil.rmtree(path, ignore_errors=True)
23
  os.makedirs(path, exist_ok=True)
24
 
25
+ # ==========================
26
+ # 1. 模型架构 ( Attention 输出)
27
+ # ==========================
28
  class AttentionPooling(nn.Module):
29
  def __init__(self, d_model):
30
  super().__init__()
31
  self.attention_net = nn.Linear(d_model, 1)
32
 
33
  def forward(self, x, mask):
 
34
  attn_logits = self.attention_net(x).squeeze(2)
35
  attn_logits.masked_fill_(mask == 0, -float('inf'))
36
  attn_weights = F.softmax(attn_logits, dim=1)
 
 
 
37
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
38
 
39
  class ProtDualBranchEnhancedClassifier(nn.Module):
40
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
41
  super().__init__()
42
  self.cls_projector = nn.Linear(d_model, projection_dim)
43
+ self.token_refiner = nn.Sequential(nn.Conv1d(d_model, d_model, kernel_size, padding='same'), nn.ReLU())
 
 
 
44
  self.attention_pooling = AttentionPooling(d_model)
45
  self.tok_projector = nn.Linear(d_model, projection_dim)
46
  fused_dim = projection_dim * 2
47
  self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
48
+ self.classifier_head = nn.Sequential(nn.LayerNorm(fused_dim), nn.Linear(fused_dim, fused_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fused_dim * 2, num_classes))
 
 
 
 
 
 
49
 
50
  def forward(self, cls_embedding, token_embeddings, mask):
 
51
  z_cls = self.cls_projector(cls_embedding)
 
 
52
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
53
  refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
 
 
54
  z_tok_pooled, pooling_weights = self.attention_pooling(refined_tok_emb, mask)
55
  z_tok = self.tok_projector(z_tok_pooled)
 
 
56
  z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
57
  gate_values = self.gate(z_fused_concat)
58
  z_fused_gated = z_fused_concat * gate_values
 
59
  return self.classifier_head(z_fused_gated), pooling_weights
60
 
61
+ # ==========================
62
+ # 2. 加载模型
63
+ # ==========================
64
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
66
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
67
  LABEL_MAP_PATH = "label_map.json"
68
 
 
69
  if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
70
  if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
71
 
 
72
  with open(LABEL_MAP_PATH, 'r') as f:
73
  label_to_idx = json.load(f)
74
  idx_to_label = {v: k for k, v in label_to_idx.items()}
 
78
  print("🔹 Loading models...")
79
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
80
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
 
81
  classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
82
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
83
  classifier.eval()
84
  print("✅ Ready.")
85
 
86
+ # ==========================
87
+ # 3. Panel B: SVG 细胞图 (6类完整显示)
88
+ # ==========================
89
  def generate_bacterial_svg(target_class):
90
  target = target_class.lower() if target_class else ""
91
 
92
+ # 状态
93
+ is_sec = "extracellular" in target or "secreted" in target
94
+ is_om = "outer membrane" in target
95
  is_peri = "periplasm" in target
96
+ is_cw = "cell wall" in target
97
+ is_im = "plasma membrane" in target or "inner membrane" in target
98
  is_cyto = "cytoplasm" in target or "cytosol" in target
 
99
 
100
+ # 颜色
101
  c = {
 
102
  "hl_stroke": "#D32F2F", "hl_fill": "#FFEBEE", "hl_text": "#B71C1C", "hl_dot": "#D32F2F",
 
103
  "bg_stroke": "#90A4AE", "bg_fill_om": "#F5F5F5", "bg_fill_im": "#FAFAFA",
104
  "bg_text": "#78909C", "bg_line": "#CFD8DC", "bg_dot": "#B0BEC5"
105
  }
106
 
107
+ # 结构样式
108
  def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
109
+ if active: return c["hl_fill"], c["hl_stroke"], w_act
110
+ return base_fill, base_stroke, width_norm
 
 
111
 
112
  om_f, om_s, om_w = style(is_peri, c["bg_fill_om"], c["hl_stroke"] if is_om else c["bg_stroke"])
113
  cw_s = c["hl_stroke"] if is_cw else "#B0BEC5"
 
119
  if active: return c["hl_text"], "bold", c["hl_stroke"], "2.5", c["hl_dot"], "5"
120
  return c["bg_text"], "normal", c["bg_line"], "1.5", c["bg_dot"], "3"
121
 
122
+ l_sec = label_style(is_sec)
123
+ l_om = label_style(is_om)
124
+ l_peri = label_style(is_peri)
125
+ l_cw = label_style(is_cw)
126
+ l_im = label_style(is_im)
127
+ l_cyto = label_style(is_cyto)
128
 
129
+ # 坐标系统 (中心 280, 210)
130
+ bx, by = 280, 210
131
+ tx = 600 # 标签起始X
132
 
133
+ # 锚点目标 (Target Anchor Points)
134
  targets = {
135
+ "sec": (bx, by - 180), # 胞外 (悬浮在上方)
136
+ "om": (bx + 140, by - 120), # 外膜
137
+ "peri": (bx + 120, by - 90), # 周质
138
+ "cw": (bx + 100, by - 70), # 细胞壁
139
+ "im": (bx + 70, by - 50), # 内膜
140
+ "cyto": (bx, by) # 胞质
141
  }
142
 
143
+ # 标签文字Y坐标 (均匀分布6个)
144
+ text_y = {
145
+ "sec": 50, "om": 110, "peri": 170, "cw": 230, "im": 290, "cyto": 350
146
+ }
147
 
148
+ # 贝塞尔曲线生成器
149
  def draw_connector(key, style_tuple, label_text):
150
  txt_col, weight, line_col, width, dot_col, r = style_tuple
151
  tx_pos, ty_pos = tx, text_y[key]
152
  ex, ey = targets[key]
153
 
154
+ # S形曲线控制点
155
+ c1x, c1y = tx_pos - 80, ty_pos
156
+ c2x, c2y = ex + 60, ey
157
 
158
  path = f"M {tx_pos - 10} {ty_pos - 5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
159
 
160
  return f"""
161
  <g>
162
+ <text x="{tx_pos}" y="{ty_pos}" fill="{txt_col}" font-weight="{weight}" font-size="14" font-family="Arial">{label_text}</text>
163
  <path d="{path}" fill="none" stroke="{line_col}" stroke-width="{width}" />
164
  <circle cx="{ex}" cy="{ey}" r="{r}" fill="{dot_col}" stroke="white" stroke-width="1" />
165
  </g>
 
176
  </g>
177
  </g>
178
 
179
+ {draw_connector("sec", l_sec, "Extracellular / Secreted")}
 
 
 
 
 
 
 
 
180
  {draw_connector("om", l_om, "Outer Membrane")}
181
  {draw_connector("peri", l_peri, "Periplasm")}
182
  {draw_connector("cw", l_cw, "Cell Wall")}
 
185
  </svg>"""
186
  return svg
187
 
188
+ # ==========================
189
+ # 4. Panel D: Attention Heatmap (热图版)
190
+ # ==========================
191
+ def draw_attention_heatmap_strip(weights, sequence):
192
  """
193
+ Draws a 1D Heatmap Strip for Attention Weights.
194
+ Standard Bioinformatics visualization style.
195
  """
196
+ # 归一化 (0-1)
197
  if weights.max() > 0:
198
  weights = (weights - weights.min()) / (weights.max() - weights.min())
199
+
200
+ # 准备数据 (Reshape to 2D for imshow: [1, Seq_Len])
201
+ data = weights.reshape(1, -1)
202
 
203
+ fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150) # 长条形
 
204
 
205
+ # 绘制热图 (使用 Reds 色系,颜色越深 Attention 越高)
206
+ im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
207
 
208
+ # 样式美化
209
+ ax.set_title("Sequence Attention Heatmap (High Color = Key Motif)", fontsize=10, fontweight='bold', color='#37474F', pad=10)
210
  ax.set_xlabel("Residue Position", fontsize=9)
 
 
 
 
 
211
 
212
+ # 隐藏 Y 轴刻度
213
+ ax.set_yticks([])
214
+
215
+ # 添加 Colorbar
216
+ cbar = plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
217
+ cbar.ax.tick_params(labelsize=8)
218
+ cbar.outline.set_visible(False)
219
+
220
+ # 隐藏边框
221
+ for spine in ax.spines.values():
222
+ spine.set_visible(False)
223
 
224
  plt.tight_layout()
225
  return fig
226
 
227
+ # ==========================
228
+ # 5. 预测主逻辑
229
+ # ==========================
230
  def predict(sequence_input):
231
  if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
232
 
 
240
 
241
  hidden_states = outputs.last_hidden_state
242
  cls_embedding = hidden_states[:, 0, :]
243
+ token_embeddings = hidden_states[:, 1:-1, :]
244
  token_mask = inputs['attention_mask'][:, 1:-1]
245
 
 
246
  logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
247
  probs = F.softmax(logits, dim=1)[0]
248
 
249
+ # 1. 结果
250
  top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
251
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
252
 
253
+ # 2. SVG (Panel B)
254
  svg = generate_bacterial_svg(top_label)
255
 
256
+ # 3. Heatmap (Panel D)
 
257
  w_np = pooling_weights[0].cpu().numpy()
258
+ heatmap_plot = draw_attention_heatmap_strip(w_np, seq)
259
 
260
+ return confidences, svg, heatmap_plot
261
 
262
+ # ==========================
263
+ # 6. UI Layout (4-Block)
264
+ # ==========================
265
  layout_css = """
266
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
267
  body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
268
 
269
+ /* Header */
270
  .header-div {
271
  background: linear-gradient(to right, #E0F7FA, #E1F5FE);
272
  padding: 1.5rem;
 
302
 
303
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
304
 
 
305
  gr.HTML("""
306
  <div class="header-div">
307
  <div class="header-title">LocPred-Prok</div>
 
309
  </div>
310
  """)
311
 
312
+ # Row 1
313
  with gr.Row():
314
  with gr.Column(elem_classes="panel-card"):
315
  gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
 
325
  gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
326
  output_svg = gr.HTML(label="Visual", show_label=False)
327
 
328
+ # Row 2
329
  with gr.Row():
330
  with gr.Column(elem_classes="panel-card"):
331
  gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
332
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
333
 
334
  with gr.Column(elem_classes="panel-card"):
335
+ gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Attention Heatmap (Motif Discovery)</div>")
336
  output_plot = gr.Plot(label="Attention", show_label=False)
337
 
338
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])