wangleiofficial commited on
Commit
80c9b83
·
verified ·
1 Parent(s): 0c3ed7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -9,9 +9,9 @@ import matplotlib.pyplot as plt
9
  import numpy as np
10
  from transformers import AutoTokenizer, AutoModel
11
 
12
- # ==========================
13
- # 0. 环境与缓存设置
14
- # ==========================
15
  # 强制使用非交互式后端,防止 matplotlib 在服务器报错
16
  plt.switch_backend('Agg')
17
 
@@ -24,9 +24,9 @@ for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
24
  shutil.rmtree(path, ignore_errors=True)
25
  os.makedirs(path, exist_ok=True)
26
 
27
- # ==========================
28
- # 1. 模型架构定义 (支持 Attention 输出)
29
- # ==========================
30
  class AttentionPooling(nn.Module):
31
  def __init__(self, d_model):
32
  super().__init__()
@@ -81,9 +81,9 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
81
 
82
  return self.classifier_head(z_fused_gated), pooling_weights
83
 
84
- # ==========================
85
- # 2. 加载模型与配置
86
- # ==========================
87
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
89
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
@@ -105,14 +105,13 @@ tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
105
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
106
 
107
  classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
108
- # strict=False 允许加载即使权重文件中没有 pooling_weights 相关的特定状态(通常不影响)
109
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
110
  classifier.eval()
111
  print("✅ Ready.")
112
 
113
- # ==========================
114
- # 3. Panel B: SVG 绘图引擎 (贝塞尔曲线 + 锚点)
115
- # ==========================
116
  def generate_bacterial_svg(target_class):
117
  target = target_class.lower() if target_class else ""
118
 
@@ -133,28 +132,29 @@ def generate_bacterial_svg(target_class):
133
  "bg_text": "#78909C", "bg_line": "#CFD8DC", "bg_dot": "#B0BEC5"
134
  }
135
 
136
- # 样式生成器
137
  def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
138
- if active: return c["hl_fill"], c["hl_stroke"], w_act
139
- return base_fill, base_stroke, width_norm
 
 
140
 
141
  om_f, om_s, om_w = style(is_peri, c["bg_fill_om"], c["hl_stroke"] if is_om else c["bg_stroke"])
142
  cw_s = c["hl_stroke"] if is_cw else "#B0BEC5"
143
  cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
144
  im_f, im_s, im_w = style(is_cyto, c["bg_fill_im"], c["hl_stroke"] if is_im else c["bg_stroke"])
145
 
146
- # 标签样式 (文字颜色, 字重, 线条颜色, 线宽, 锚点颜色, 锚点半径)
147
  def label_style(active):
148
  if active: return c["hl_text"], "bold", c["hl_stroke"], "2.5", c["hl_dot"], "5"
149
  return c["bg_text"], "normal", c["bg_line"], "1.5", c["bg_dot"], "3"
150
 
151
  l_om, l_peri, l_cw, l_im, l_cyto = label_style(is_om), label_style(is_peri), label_style(is_cw), label_style(is_im), label_style(is_cyto)
152
 
153
- # 3. 坐标定义
154
  bx, by = 280, 210 # 细菌中心
155
  tx = 600 # 标签文字起始 X 坐标
156
 
157
- # 目标锚点 (Target Anchor Points) - 精确落在结构上
158
  targets = {
159
  "om": (bx + 140, by - 120), # 外膜线
160
  "peri": (bx + 120, by - 90), # 周质间隙
@@ -165,13 +165,13 @@ def generate_bacterial_svg(target_class):
165
 
166
  text_y = {"om": 90, "peri": 150, "cw": 210, "im": 270, "cyto": 330}
167
 
168
- # 4. 贝塞尔曲线连接器
169
  def draw_connector(key, style_tuple, label_text):
170
  txt_col, weight, line_col, width, dot_col, r = style_tuple
171
  tx_pos, ty_pos = tx, text_y[key]
172
  ex, ey = targets[key]
173
 
174
- # 贝塞尔控制点:形成 S 形曲线
175
  c1x, c1y = tx_pos - 100, ty_pos
176
  c2x, c2y = ex + 50, ey
177
 
@@ -213,9 +213,9 @@ def generate_bacterial_svg(target_class):
213
  </svg>"""
214
  return svg
215
 
216
- # ==========================
217
- # 4. Panel D: Attention 绘图引擎
218
- # ==========================
219
  def draw_pooling_weights(weights, sequence):
220
  """
221
  Visualize Attention Pooling Weights (1D Heatmap/Bar).
@@ -239,8 +239,8 @@ def draw_pooling_weights(weights, sequence):
239
  ax.spines['left'].set_visible(False)
240
  ax.set_yticks([])
241
 
242
- # 标注最高峰 (Potential Motif)
243
- threshold = np.percentile(weights, 98) # 更加严格的阈值
244
  if weights.max() > threshold:
245
  peak_idx = np.argmax(weights)
246
  ax.annotate('Key Motif', xy=(peak_idx, weights[peak_idx]), xytext=(peak_idx, weights[peak_idx]+0.2),
@@ -250,9 +250,9 @@ def draw_pooling_weights(weights, sequence):
250
  plt.tight_layout()
251
  return fig
252
 
253
- # ==========================
254
- # 5. 预测主逻辑
255
- # ==========================
256
  def predict(sequence_input):
257
  if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
258
 
@@ -269,11 +269,11 @@ def predict(sequence_input):
269
  token_embeddings = hidden_states[:, 1:-1, :] # No CLS/EOS
270
  token_mask = inputs['attention_mask'][:, 1:-1]
271
 
272
- # ⚠️ 获取 logits 和 weights
273
  logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
274
  probs = F.softmax(logits, dim=1)[0]
275
 
276
- # 1. 结果
277
  top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
278
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
279
 
@@ -287,9 +287,9 @@ def predict(sequence_input):
287
 
288
  return confidences, svg, attn_plot
289
 
290
- # ==========================
291
- # 6. UI Layout (4-Block Paper Style)
292
- # ==========================
293
  layout_css = """
294
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
295
  body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
@@ -330,6 +330,7 @@ theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", bloc
330
 
331
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
332
 
 
333
  gr.HTML("""
334
  <div class="header-div">
335
  <div class="header-title">LocPred-Prok</div>
@@ -337,7 +338,7 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
337
  </div>
338
  """)
339
 
340
- # Row 1: A & B
341
  with gr.Row():
342
  with gr.Column(elem_classes="panel-card"):
343
  gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
@@ -353,7 +354,7 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
353
  gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
354
  output_svg = gr.HTML(label="Visual", show_label=False)
355
 
356
- # Row 2: C & D
357
  with gr.Row():
358
  with gr.Column(elem_classes="panel-card"):
359
  gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
 
9
  import numpy as np
10
  from transformers import AutoTokenizer, AutoModel
11
 
12
+ # ==============================================================================
13
+ # 0. 环境与缓存设置 (Environment Setup)
14
+ # ==============================================================================
15
  # 强制使用非交互式后端,防止 matplotlib 在服务器报错
16
  plt.switch_backend('Agg')
17
 
 
24
  shutil.rmtree(path, ignore_errors=True)
25
  os.makedirs(path, exist_ok=True)
26
 
27
+ # ==============================================================================
28
+ # 1. 模型架构定义 (Model Architecture)
29
+ # ==============================================================================
30
  class AttentionPooling(nn.Module):
31
  def __init__(self, d_model):
32
  super().__init__()
 
81
 
82
  return self.classifier_head(z_fused_gated), pooling_weights
83
 
84
+ # ==============================================================================
85
+ # 2. 加载模型与配置 (Load Resources)
86
+ # ==============================================================================
87
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
89
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
 
105
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
106
 
107
  classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
 
108
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
109
  classifier.eval()
110
  print("✅ Ready.")
111
 
112
+ # ==============================================================================
113
+ # 3. Panel B: SVG 绘图引擎 (Visualization Engine)
114
+ # ==============================================================================
115
  def generate_bacterial_svg(target_class):
116
  target = target_class.lower() if target_class else ""
117
 
 
132
  "bg_text": "#78909C", "bg_line": "#CFD8DC", "bg_dot": "#B0BEC5"
133
  }
134
 
135
+ # 3. 样式生成器 (这里修复了之前的 bug)
136
  def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
137
+ if active:
138
+ return c["hl_fill"], c["hl_stroke"], w_act
139
+ # ✅ 修复点:这里原来写成了 width_norm,现已修正为 w_norm
140
+ return base_fill, base_stroke, w_norm
141
 
142
  om_f, om_s, om_w = style(is_peri, c["bg_fill_om"], c["hl_stroke"] if is_om else c["bg_stroke"])
143
  cw_s = c["hl_stroke"] if is_cw else "#B0BEC5"
144
  cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
145
  im_f, im_s, im_w = style(is_cyto, c["bg_fill_im"], c["hl_stroke"] if is_im else c["bg_stroke"])
146
 
147
+ # 标签样式
148
  def label_style(active):
149
  if active: return c["hl_text"], "bold", c["hl_stroke"], "2.5", c["hl_dot"], "5"
150
  return c["bg_text"], "normal", c["bg_line"], "1.5", c["bg_dot"], "3"
151
 
152
  l_om, l_peri, l_cw, l_im, l_cyto = label_style(is_om), label_style(is_peri), label_style(is_cw), label_style(is_im), label_style(is_cyto)
153
 
154
+ # 4. 坐标定义
155
  bx, by = 280, 210 # 细菌中心
156
  tx = 600 # 标签文字起始 X 坐标
157
 
 
158
  targets = {
159
  "om": (bx + 140, by - 120), # 外膜线
160
  "peri": (bx + 120, by - 90), # 周质间隙
 
165
 
166
  text_y = {"om": 90, "peri": 150, "cw": 210, "im": 270, "cyto": 330}
167
 
168
+ # 5. 贝塞尔曲线连接器
169
  def draw_connector(key, style_tuple, label_text):
170
  txt_col, weight, line_col, width, dot_col, r = style_tuple
171
  tx_pos, ty_pos = tx, text_y[key]
172
  ex, ey = targets[key]
173
 
174
+ # 贝塞尔控制点
175
  c1x, c1y = tx_pos - 100, ty_pos
176
  c2x, c2y = ex + 50, ey
177
 
 
213
  </svg>"""
214
  return svg
215
 
216
+ # ==============================================================================
217
+ # 4. Panel D: Attention 绘图引擎 (Interpretability)
218
+ # ==============================================================================
219
  def draw_pooling_weights(weights, sequence):
220
  """
221
  Visualize Attention Pooling Weights (1D Heatmap/Bar).
 
239
  ax.spines['left'].set_visible(False)
240
  ax.set_yticks([])
241
 
242
+ # 标注最高峰 (Key Motif)
243
+ threshold = np.percentile(weights, 98)
244
  if weights.max() > threshold:
245
  peak_idx = np.argmax(weights)
246
  ax.annotate('Key Motif', xy=(peak_idx, weights[peak_idx]), xytext=(peak_idx, weights[peak_idx]+0.2),
 
250
  plt.tight_layout()
251
  return fig
252
 
253
+ # ==============================================================================
254
+ # 5. 预测主逻辑 (Prediction Logic)
255
+ # ==============================================================================
256
  def predict(sequence_input):
257
  if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
258
 
 
269
  token_embeddings = hidden_states[:, 1:-1, :] # No CLS/EOS
270
  token_mask = inputs['attention_mask'][:, 1:-1]
271
 
272
+ # ⚠️ 获取 logits 和 pooling_weights
273
  logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
274
  probs = F.softmax(logits, dim=1)[0]
275
 
276
+ # 1. 结果 (Panel C)
277
  top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
278
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
279
 
 
287
 
288
  return confidences, svg, attn_plot
289
 
290
+ # ==============================================================================
291
+ # 6. UI 布局 (Four-Block Paper Style)
292
+ # ==============================================================================
293
  layout_css = """
294
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
295
  body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
 
330
 
331
  with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
332
 
333
+ # --- Header ---
334
  gr.HTML("""
335
  <div class="header-div">
336
  <div class="header-title">LocPred-Prok</div>
 
338
  </div>
339
  """)
340
 
341
+ # --- Row 1: Panels A & B ---
342
  with gr.Row():
343
  with gr.Column(elem_classes="panel-card"):
344
  gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
 
354
  gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
355
  output_svg = gr.HTML(label="Visual", show_label=False)
356
 
357
+ # --- Row 2: Panels C & D ---
358
  with gr.Row():
359
  with gr.Column(elem_classes="panel-card"):
360
  gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")