wangleiofficial commited on
Commit
f7d2100
·
verified ·
1 Parent(s): 9539727

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -181
app.py CHANGED
@@ -1,35 +1,34 @@
1
- import os, shutil, json, re
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import gradio as gr
6
- import matplotlib.pyplot as plt
7
- import matplotlib.patches as patches
8
- import numpy as np
9
- from io import BytesIO
10
- from PIL import Image
11
  from transformers import AutoTokenizer, AutoModel
12
 
13
  # ==========================
14
- # 0. 环境初始化
15
  # ==========================
16
- plt.switch_backend('Agg')
17
-
18
  os.environ["HF_HOME"] = "/tmp/hf_cache"
19
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
20
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
21
 
 
 
22
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
23
  shutil.rmtree(path, ignore_errors=True)
24
  os.makedirs(path, exist_ok=True)
25
 
26
  # ==========================
27
- # 1. 模型定义 (保持不变)
28
  # ==========================
29
  class AttentionPooling(nn.Module):
 
30
  def __init__(self, d_model):
31
  super().__init__()
32
  self.attention_net = nn.Linear(d_model, 1)
 
33
  def forward(self, x, mask):
34
  attn_logits = self.attention_net(x).squeeze(2)
35
  attn_logits.masked_fill_(mask == 0, -float('inf'))
@@ -37,20 +36,29 @@ class AttentionPooling(nn.Module):
37
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
38
 
39
  class ProtDualBranchEnhancedClassifier(nn.Module):
 
40
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
41
  super().__init__()
42
  self.cls_projector = nn.Linear(d_model, projection_dim)
43
  self.token_refiner = nn.Sequential(
44
- nn.Conv1d(d_model, d_model, kernel_size, padding='same'), nn.ReLU()
 
45
  )
46
  self.attention_pooling = AttentionPooling(d_model)
47
  self.tok_projector = nn.Linear(d_model, projection_dim)
48
  fused_dim = projection_dim * 2
49
- self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
 
 
 
50
  self.classifier_head = nn.Sequential(
51
- nn.LayerNorm(fused_dim), nn.Linear(fused_dim, fused_dim * 2),
52
- nn.ReLU(), nn.Dropout(dropout), nn.Linear(fused_dim * 2, num_classes)
 
 
 
53
  )
 
54
  def forward(self, cls_embedding, token_embeddings, mask):
55
  z_cls = self.cls_projector(cls_embedding)
56
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
@@ -63,17 +71,20 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
63
  return self.classifier_head(z_fused_gated)
64
 
65
  # ==========================
66
- # 2. 加载模型
67
  # ==========================
68
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
70
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
71
  LABEL_MAP_PATH = "label_map.json"
72
 
73
- # 简单检查
74
- if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
75
- if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
 
 
76
 
 
77
  with open(LABEL_MAP_PATH, 'r') as f:
78
  label_to_idx = json.load(f)
79
  idx_to_label = {v: k for k, v in label_to_idx.items()}
@@ -81,32 +92,32 @@ with open(LABEL_MAP_PATH, 'r') as f:
81
  NUM_CLASSES = len(idx_to_label)
82
  D_MODEL = 640
83
 
84
- print("🔹 Loading models...")
85
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
86
- plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
87
- classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
 
 
 
 
 
 
 
88
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
89
  classifier.eval()
90
- print("✅ Ready.")
91
 
92
  # ==========================
93
- # 3. 彻底修复的绘图引擎 (Centric Coordinates)
94
  # ==========================
95
- def draw_uniprot_style_cell(target_class):
 
 
 
 
96
  target = target_class.lower() if target_class else ""
97
 
98
- # === 配色定义 ===
99
- c = {
100
- 'stroke': '#37474F', # 默认深灰轮廓
101
- 'bg_peri': '#E1F5FE', # 默认周质背景 (浅蓝)
102
- 'bg_cyto': '#FFF9C4', # 默认胞质背景 (浅黄)
103
- 'highlight_stroke': '#D50000', # 高亮轮廓 (深红)
104
- 'highlight_fill': '#FFCDD2', # 高亮填充 (淡红)
105
- 'dna': '#B0BEC5',
106
- 'ribo': '#90A4AE'
107
- }
108
-
109
- # === 状态判断 ===
110
  is_om = "outer membrane" in target
111
  is_peri = "periplasm" in target
112
  is_cw = "cell wall" in target
@@ -114,124 +125,120 @@ def draw_uniprot_style_cell(target_class):
114
  is_cyto = "cytoplasm" in target or "cytosol" in target
115
  is_secreted = "extracellular" in target or "secreted" in target
116
 
117
- # === 画布初始化 ===
118
- # 中心点 (Cx, Cy) = (5, 3)
119
- fig, ax = plt.subplots(figsize=(8, 4.5), dpi=150)
120
- ax.set_xlim(0, 10)
121
- ax.set_ylim(0, 6)
122
- ax.axis('off')
123
-
124
- # === 核心辅助函数:绘制绝对居中的胶囊 ===
125
- def draw_centered_capsule(width, height, fill_color, edge_color, lw, z, linestyle='-'):
126
- # FancyBboxPatch 的 xy 是左下角坐标。
127
- # 要居中,左下角 x = CenterX - Width/2
128
- x = 5.0 - width / 2
129
- y = 3.0 - height / 2
130
- # rounding_size 设为高度的一半,这就变成了标准的胶囊/药丸形状
131
- r = height / 2
132
 
133
- patch = patches.FancyBboxPatch(
134
- (x, y), width, height,
135
- boxstyle=f"round,pad=0,rounding_size={r}",
136
- fc=fill_color, ec=edge_color, lw=lw, linestyle=linestyle, zorder=z
137
- )
138
- ax.add_patch(patch)
139
- return x, y, width, height # 返回坐标供后续标注使用
140
-
141
- # === 1. 绘制 Layer 1: 外膜 (Outer Membrane) ===
142
- # 如果是 Periplasm 高亮,那么底色变红;否则是默认浅蓝
143
- # 如果是 OuterMembrane 高亮,那么边框变红变粗
144
- peri_fill = c['highlight_fill'] if is_peri else c['bg_peri']
145
- om_edge = c['highlight_stroke'] if is_om else c['stroke']
146
- om_lw = 3.5 if is_om else 1.5
147
-
148
- # 绘制最大的胶囊 (代表外膜轮廓 + 周质背景)
149
- # 尺寸: 8.5 x 4.2
150
- draw_centered_capsule(8.5, 4.2, peri_fill, om_edge, om_lw, z=1)
151
-
152
- # === 2. 绘制 Layer 2: 细胞壁 (Cell Wall) ===
153
- # 位于中间层
154
- cw_edge = c['highlight_stroke'] if is_cw else '#78909C'
155
- cw_lw = 2.5 if is_cw else 1.0
156
- cw_ls = '-' if is_cw else '--' # 平时虚线,高亮实线
157
-
158
- # 尺寸: 7.5 x 3.2
159
- draw_centered_capsule(7.5, 3.2, "none", cw_edge, cw_lw, z=2, linestyle=cw_ls)
160
-
161
- # === 3. 绘制 Layer 3: 内膜 (Inner Membrane) + 胞质 (Cytoplasm) ===
162
- # 如果是 Cytoplasm 高亮,填充变红;否则默认浅黄
163
- # 如果是 InnerMembrane 高亮,边框变红变粗
164
- cyto_fill = c['highlight_fill'] if is_cyto else c['bg_cyto']
165
- im_edge = c['highlight_stroke'] if is_im else c['stroke']
166
- im_lw = 3.5 if is_im else 1.5
167
-
168
- # 尺寸: 6.5 x 2.2
169
- draw_centered_capsule(6.5, 2.2, cyto_fill, im_edge, im_lw, z=3)
170
-
171
- # === 4. 内部细节 (DNA & Ribosomes) ===
172
- # 仅装饰,画在最中心
173
- # DNA 线条
174
- t = np.linspace(0, 12, 200)
175
- x_dna = 5 + 2.2 * np.cos(t) * np.sin(t*0.5)
176
- y_dna = 3 + 0.6 * np.sin(t)
177
- ax.plot(x_dna, y_dna, color=c['dna'], lw=1.5, zorder=4, alpha=0.6)
178
-
179
- # 核糖体 (点)
180
- rng = np.random.default_rng(42)
181
- for _ in range(25):
182
- # 在中心区域随机撒点
183
- rx = rng.uniform(3.0, 7.0)
184
- ry = rng.uniform(2.3, 3.7)
185
- circle = patches.Circle((rx, ry), radius=0.05, fc=c['ribo'], zorder=4)
186
- ax.add_patch(circle)
187
-
188
- # === 5. 分泌蛋白 (Secreted) ===
189
- if is_secreted:
190
- ax.text(5, 5.5, "SECRETED / EXTRACELLULAR", ha='center', va='center',
191
- color=c['highlight_stroke'], fontweight='bold')
192
- # 画几个向上的箭头
193
- ax.arrow(5, 5.2, 0, 0.4, head_width=0.2, fc=c['highlight_stroke'], ec=c['highlight_stroke'], width=0.05)
194
-
195
- # === 6. 标注系统 (Labeling) ===
196
- # 使用 annotate 自动画箭头指引
197
-
198
- # 定义各层的指引坐标 (全部取右侧中点)
199
- # CenterY = 3.
200
- # OuterMembrane Edge X ≈ 5 + 8.5/2 = 9.25
201
- # Periplasm X ≈ 5 + 8.0/2 = 9.0 (Inside OM)
202
- # InnerMembrane Edge X ≈ 5 + 6.5/2 = 8.25
203
- # Cytoplasm X ≈ 5
204
-
205
- labels = [
206
- ("Outer Membrane", (9.25, 3.0), (10, 4.5), is_om),
207
- ("Periplasm", (8.0, 3.8), (9.5, 5.2), is_peri), # 指向胶囊上方空隙
208
- ("Cell Wall", (8.75, 3.0), (10, 3.5), is_cw), # 指向中间虚线
209
- ("Inner Membrane", (8.25, 3.0), (10, 2.5), is_im),
210
- ("Cytoplasm", (5.0, 3.0), (5.0, 1.0), is_cyto) # 指向中心,文字在下方
211
- ]
212
-
213
- for txt, xy_target, xy_text, active in labels:
214
- color = c['highlight_stroke'] if active else '#546E7A'
215
- weight = 'bold' if active else 'normal'
216
 
217
- # 如果激活,画红色实线箭头;否则画灰色细箭头
218
- arrow_props = dict(arrowstyle="->", color=color, lw=1.5 if active else 0.8)
 
 
219
 
220
- ax.annotate(txt, xy=xy_target, xytext=xy_text,
221
- arrowprops=arrow_props,
222
- fontsize=10, fontweight=weight, color=color,
223
- ha='center', va='center')
224
-
225
- # 底部标题
226
- ax.text(5, 0.2, f"Prediction: {target_class}", ha='center', va='bottom',
227
- fontsize=12, fontweight='bold', color='#263238')
228
-
229
- buf = BytesIO()
230
- plt.savefig(buf, format='png', bbox_inches='tight', transparent=True, dpi=150)
231
- buf.seek(0)
232
- img = Image.open(buf)
233
- plt.close(fig)
234
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # ==========================
237
  # 4. 预测逻辑
@@ -239,68 +246,173 @@ def draw_uniprot_style_cell(target_class):
239
  def predict(sequence_input):
240
  if not sequence_input or sequence_input.isspace():
241
  raise gr.Error("Please input a protein sequence.")
 
 
242
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
243
- seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
244
- if not seq: raise gr.Error("Invalid Sequence.")
245
 
 
 
 
246
  with torch.no_grad():
247
  inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
248
  outputs = plm_model(**inputs)
249
- logits = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
 
 
 
 
 
 
 
 
250
  probs = F.softmax(logits, dim=1)[0]
251
 
 
252
  top_prob, top_idx = torch.max(probs, dim=0)
253
  top_label = idx_to_label[top_idx.item()]
254
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
255
 
256
- # 生成修正后的图
257
- cell_diagram = draw_uniprot_style_cell(top_label)
258
- return confidences, cell_diagram
 
259
 
260
  # ==========================
261
- # 5. UI 界面
262
  # ==========================
263
  paper_css = """
264
- @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap');
265
  body { font-family: 'Roboto', sans-serif !important; background-color: #ffffff; color: #1a1a1a; }
266
- .header-box { background: #ffffff; padding: 2rem 0; border-bottom: 1px solid #e5e7eb; margin-bottom: 2rem; }
267
- .header-title { font-size: 2.5rem; font-weight: 700; color: #000000; letter-spacing: -1px; }
268
- .badge { display: inline-block; padding: 4px 10px; font-size: 0.8rem; background: #f1f5f9; border: 1px solid #e2e8f0; border-radius: 4px; margin-right: 8px; }
269
- .content-box { background: #ffffff; border: 1px solid #e5e7eb; border-radius: 8px; padding: 1.5rem; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  """
271
 
272
- theme = gr.themes.Base(primary_hue="blue", font=[gr.themes.GoogleFont("Roboto"), "ui-sans-serif", "system-ui"]).set(
273
- body_background_fill="#ffffff", block_background_fill="#ffffff", block_border_width="1px"
 
 
 
 
 
 
274
  )
275
 
276
  with gr.Blocks(theme=theme, css=paper_css, title="LocPred-Prok") as app:
 
 
277
  with gr.Column(elem_classes="header-box"):
278
  gr.HTML("""
279
  <div class="header-title">LocPred-Prok</div>
280
- <div style="font-size: 1.2rem; color: #52525b; margin: 10px 0;">Accurate prokaryotic subcellular localization using dual-branch protein language models</div>
281
- <div><span class="badge">Article</span><span class="badge">ESM-2 Enhanced</span><span class="badge">Gram-negative</span></div>
 
 
 
 
 
 
282
  """)
283
 
 
284
  with gr.Tabs():
285
- with gr.TabItem("Prediction"):
286
  with gr.Row():
 
287
  with gr.Column(scale=4, elem_classes="content-box"):
288
- gr.Markdown("### Sequence Input")
289
- sequence_input = gr.Textbox(lines=10, show_label=False, placeholder=">Sequence...")
 
 
 
 
 
 
 
290
  with gr.Row():
291
- gr.ClearButton(sequence_input, value="Clear")
292
- submit_btn = gr.Button("Analyze", variant="primary")
293
- gr.Examples([
294
- [">Outer Membrane Protein\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAFGGYQVNPYVGFEMGYDWLGRMPYKGSVENGAYKAQGVQLTAKLGYPITDDLDIYTRLGGMVWRADTKSNVYGKNHDTGVSPVFAGGVEYAITPEIATRLEYQWTNNIGDAHTIGTRPDNGMLSLGVSYRFGQGEAAPVVAPAPAPAPEVQTKHFTLKSDVLFNFNKATLKPEGQAALDQLYSQLSNLDPKDGSVVVLGYTDRIGSDAYNQGLSERRAQSVVDYLISKGIPADKISARGMGESNPVTGNTCDNVKQRAALIDCLAPDRRVEIEVKGIKDVVTQPQA"],
295
- [">Cytoplasmic Protein\nARYLGPKLKLSRREGTDLFLKSGVRAIDTKCKIEQAPGQHGARKPRLSDYGVQLREKQKVRRIYGVLERQFRNYYKEAARLKGNTGENLLALLEGRLDNVVYRMGFG"]
296
- ], inputs=sequence_input, label="Examples")
297
-
 
 
 
 
 
 
 
298
  with gr.Column(scale=6, elem_classes="content-box"):
299
- gr.Markdown("### Localization Visualization")
300
- output_image = gr.Image(label="Visualization", show_label=False, show_download_button=True, interactive=False, type="pil", height=400)
 
 
 
301
  gr.Markdown("#### Confidence Scores")
302
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
303
 
304
- submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_image])
305
-
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  app.launch()
 
1
+ import os
2
+ import json
3
+ import re
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import gradio as gr
 
 
 
 
 
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
  # ==========================
11
+ # 0. 环境与缓存设置
12
  # ==========================
 
 
13
  os.environ["HF_HOME"] = "/tmp/hf_cache"
14
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
15
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
16
 
17
+ # 清理旧缓存 (可选)
18
+ import shutil
19
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
20
  shutil.rmtree(path, ignore_errors=True)
21
  os.makedirs(path, exist_ok=True)
22
 
23
  # ==========================
24
+ # 1. 模型架构定义
25
  # ==========================
26
  class AttentionPooling(nn.Module):
27
+ """Attention Pooling Layer"""
28
  def __init__(self, d_model):
29
  super().__init__()
30
  self.attention_net = nn.Linear(d_model, 1)
31
+
32
  def forward(self, x, mask):
33
  attn_logits = self.attention_net(x).squeeze(2)
34
  attn_logits.masked_fill_(mask == 0, -float('inf'))
 
36
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
37
 
38
  class ProtDualBranchEnhancedClassifier(nn.Module):
39
+ """Enhanced dual-branch model architecture"""
40
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
41
  super().__init__()
42
  self.cls_projector = nn.Linear(d_model, projection_dim)
43
  self.token_refiner = nn.Sequential(
44
+ nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
45
+ nn.ReLU()
46
  )
47
  self.attention_pooling = AttentionPooling(d_model)
48
  self.tok_projector = nn.Linear(d_model, projection_dim)
49
  fused_dim = projection_dim * 2
50
+ self.gate = nn.Sequential(
51
+ nn.Linear(fused_dim, fused_dim),
52
+ nn.Sigmoid()
53
+ )
54
  self.classifier_head = nn.Sequential(
55
+ nn.LayerNorm(fused_dim),
56
+ nn.Linear(fused_dim, fused_dim * 2),
57
+ nn.ReLU(),
58
+ nn.Dropout(dropout),
59
+ nn.Linear(fused_dim * 2, num_classes)
60
  )
61
+
62
  def forward(self, cls_embedding, token_embeddings, mask):
63
  z_cls = self.cls_projector(cls_embedding)
64
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
 
71
  return self.classifier_head(z_fused_gated)
72
 
73
  # ==========================
74
+ # 2. 加载模型与资源
75
  # ==========================
76
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
78
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
79
  LABEL_MAP_PATH = "label_map.json"
80
 
81
+ # 文件存在性检查
82
+ if not os.path.exists(LABEL_MAP_PATH):
83
+ raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'. Please upload it to your Space.")
84
+ if not os.path.exists(CLASSIFIER_PATH):
85
+ raise FileNotFoundError(f"Error: Missing '{CLASSIFIER_PATH}'. Please upload it to your Space.")
86
 
87
+ # 加载 Label Map
88
  with open(LABEL_MAP_PATH, 'r') as f:
89
  label_to_idx = json.load(f)
90
  idx_to_label = {v: k for k, v in label_to_idx.items()}
 
92
  NUM_CLASSES = len(idx_to_label)
93
  D_MODEL = 640
94
 
95
+ print(f"🔹 Loading ESM-2 Model ({PLM_MODEL_NAME})...")
96
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
97
+ plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
98
+ plm_model.eval()
99
+
100
+ print("🔹 Loading Custom Classifier...")
101
+ classifier = ProtDualBranchEnhancedClassifier(
102
+ d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
103
+ dropout=0.3, kernel_size=3
104
+ ).to(DEVICE)
105
+
106
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
107
  classifier.eval()
108
+ print("✅ All Models Loaded Successfully.")
109
 
110
  # ==========================
111
+ # 3. SVG 矢量绘图引擎 (完美对齐版)
112
  # ==========================
113
+ def generate_bacterial_svg(target_class):
114
+ """
115
+ Generate a high-quality SVG vector diagram for bacterial localization.
116
+ Coordinates are hardcoded to ensure perfect alignment.
117
+ """
118
  target = target_class.lower() if target_class else ""
119
 
120
+ # --- 1. 状态判断 ---
 
 
 
 
 
 
 
 
 
 
 
121
  is_om = "outer membrane" in target
122
  is_peri = "periplasm" in target
123
  is_cw = "cell wall" in target
 
125
  is_cyto = "cytoplasm" in target or "cytosol" in target
126
  is_secreted = "extracellular" in target or "secreted" in target
127
 
128
+ # --- 2. 颜色配置 (学术蓝/黄风格) ---
129
+ colors = {
130
+ # 填充色:平时浅色,激活变粉红
131
+ "om_fill": "#FFCDD2" if is_peri else "#E1F5FE",
132
+ "im_fill": "#FFCDD2" if is_cyto else "#FFF9C4",
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # 边框色:平时深灰,激活变鲜红
135
+ "om_stroke": "#D32F2F" if is_om else "#37474F",
136
+ "cw_stroke": "#D32F2F" if is_cw else "#90A4AE",
137
+ "im_stroke": "#D32F2F" if is_im else "#37474F",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ # 线宽
140
+ "om_width": "4" if is_om else "2",
141
+ "cw_width": "3" if is_cw else "1.5",
142
+ "im_width": "4" if is_im else "2",
143
 
144
+ # 细胞壁虚线
145
+ "cw_dash": "0" if is_cw else "6,4",
146
+
147
+ # 标签颜色
148
+ "label_hl": "#D32F2F",
149
+ "label_norm": "#546E7A",
150
+ "arrow_hl": "#D32F2F",
151
+ "arrow_norm": "#90A4AE"
152
+ }
153
+
154
+ # 获取标签样式的辅助函数
155
+ def get_style(active):
156
+ if active:
157
+ return colors["label_hl"], "bold", colors["arrow_hl"], "2.5", "url(#arrowhead_hl)"
158
+ else:
159
+ return colors["label_norm"], "normal", colors["arrow_norm"], "1.0", "url(#arrowhead_norm)"
160
+
161
+ s_om = get_style(is_om)
162
+ s_peri = get_style(is_peri)
163
+ s_cw = get_style(is_cw)
164
+ s_im = get_style(is_im)
165
+ s_cyto = get_style(is_cyto)
166
+
167
+ # --- 3. 生成 SVG 字符串 ---
168
+ svg = f"""
169
+ <svg width="100%" height="100%" viewBox="0 0 800 450" xmlns="http://www.w3.org/2000/svg">
170
+ <defs>
171
+ <marker id="arrowhead_norm" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
172
+ <polygon points="0 0, 10 3.5, 0 7" fill="{colors['arrow_norm']}" />
173
+ </marker>
174
+ <marker id="arrowhead_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
175
+ <polygon points="0 0, 10 3.5, 0 7" fill="{colors['arrow_hl']}" />
176
+ </marker>
177
+ </defs>
178
+
179
+ <rect width="800" height="450" fill="white" />
180
+
181
+ <g transform="translate(50, 50)">
182
+ <rect x="0" y="0" width="500" height="300" rx="150" ry="150"
183
+ fill="{colors['om_fill']}" stroke="{colors['om_stroke']}" stroke-width="{colors['om_width']}" />
184
+
185
+ <rect x="40" y="40" width="420" height="220" rx="110" ry="110"
186
+ fill="none" stroke="{colors['cw_stroke']}" stroke-width="{colors['cw_width']}" stroke-dasharray="{colors['cw_dash']}" />
187
+
188
+ <rect x="80" y="80" width="340" height="140" rx="70" ry="70"
189
+ fill="{colors['im_fill']}" stroke="{colors['im_stroke']}" stroke-width="{colors['im_width']}" />
190
+
191
+ <g opacity="0.6">
192
+ <path d="M 180 150 Q 220 100 250 150 T 320 150" fill="none" stroke="#B0BEC5" stroke-width="3" />
193
+ <path d="M 190 140 Q 230 190 250 140 T 310 160" fill="none" stroke="#B0BEC5" stroke-width="3" />
194
+ <circle cx="150" cy="120" r="3" fill="#90A4AE" />
195
+ <circle cx="350" cy="180" r="3" fill="#90A4AE" />
196
+ <circle cx="250" cy="100" r="3" fill="#90A4AE" />
197
+ <circle cx="200" cy="200" r="3" fill="#90A4AE" />
198
+ </g>
199
+ </g>
200
+
201
+ {f'''
202
+ <g transform="translate(300, 20)">
203
+ <text x="0" y="0" text-anchor="middle" fill="{colors['label_hl']}" font-weight="bold" font-family="Arial" font-size="14">SECRETED / EXTRACELLULAR</text>
204
+ <line x1="0" y1="5" x2="0" y2="25" stroke="{colors['arrow_hl']}" stroke-width="2" marker-end="url(#arrowhead_hl)" />
205
+ </g>
206
+ ''' if is_secreted else ""}
207
+
208
+ <g font-family="Arial, sans-serif">
209
+
210
+ <g transform="translate(580, 80)">
211
+ <text x="0" y="5" fill="{s_om[0]}" font-weight="{s_om[1]}" font-size="14">Outer Membrane</text>
212
+ <line x1="-10" y1="0" x2="-80" y2="0" stroke="{s_om[2]}" stroke-width="{s_om[3]}" marker-end="{s_om[4]}" />
213
+ </g>
214
+
215
+ <g transform="translate(580, 140)">
216
+ <text x="0" y="5" fill="{s_peri[0]}" font-weight="{s_peri[1]}" font-size="14">Periplasm</text>
217
+ <line x1="-10" y1="0" x2="-100" y2="0" stroke="{s_peri[2]}" stroke-width="{s_peri[3]}" marker-end="{s_peri[4]}" />
218
+ </g>
219
+
220
+ <g transform="translate(580, 200)">
221
+ <text x="0" y="5" fill="{s_cw[0]}" font-weight="{s_cw[1]}" font-size="14">Cell Wall</text>
222
+ <line x1="-10" y1="0" x2="-120" y2="0" stroke="{s_cw[2]}" stroke-width="{s_cw[3]}" marker-end="{s_cw[4]}" />
223
+ </g>
224
+
225
+ <g transform="translate(580, 260)">
226
+ <text x="0" y="5" fill="{s_im[0]}" font-weight="{s_im[1]}" font-size="14">Inner Membrane</text>
227
+ <line x1="-10" y1="0" x2="-150" y2="0" stroke="{s_im[2]}" stroke-width="{s_im[3]}" marker-end="{s_im[4]}" />
228
+ </g>
229
+
230
+ <g transform="translate(580, 320)">
231
+ <text x="0" y="5" fill="{s_cyto[0]}" font-weight="{s_cyto[1]}" font-size="14">Cytoplasm</text>
232
+ <line x1="-10" y1="0" x2="-200" y2="0" stroke="{s_cyto[2]}" stroke-width="{s_cyto[3]}" marker-end="{s_cyto[4]}" />
233
+ </g>
234
+ </g>
235
+
236
+ <text x="400" y="420" text-anchor="middle" font-family="Arial" font-size="18" font-weight="bold" fill="#37474F">
237
+ Predicted Localization: {target_class}
238
+ </text>
239
+ </svg>
240
+ """
241
+ return svg
242
 
243
  # ==========================
244
  # 4. 预测逻辑
 
246
  def predict(sequence_input):
247
  if not sequence_input or sequence_input.isspace():
248
  raise gr.Error("Please input a protein sequence.")
249
+
250
+ # 清洗输入
251
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
252
+ seq = re.sub(r'[^A-Z]', '', seq.upper())
 
253
 
254
+ if not seq: raise gr.Error("Invalid Amino Acid Sequence.")
255
+ if len(seq) > 1024: seq = seq[:1024] # 截断防止OOM
256
+
257
  with torch.no_grad():
258
  inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
259
  outputs = plm_model(**inputs)
260
+
261
+ # 提取特征
262
+ hidden_states = outputs.last_hidden_state
263
+ cls_embedding = hidden_states[:, 0, :]
264
+ token_embeddings = hidden_states[:, 1:-1, :]
265
+ token_mask = inputs['attention_mask'][:, 1:-1]
266
+
267
+ # 模型推理
268
+ logits = classifier(cls_embedding, token_embeddings, token_mask)
269
  probs = F.softmax(logits, dim=1)[0]
270
 
271
+ # 获取结果
272
  top_prob, top_idx = torch.max(probs, dim=0)
273
  top_label = idx_to_label[top_idx.item()]
274
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
275
 
276
+ # 生成 SVG 可视化
277
+ svg_content = generate_bacterial_svg(top_label)
278
+
279
+ return confidences, svg_content
280
 
281
  # ==========================
282
+ # 5. UI 界面 (学术风格)
283
  # ==========================
284
  paper_css = """
285
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;700&display=swap');
286
  body { font-family: 'Roboto', sans-serif !important; background-color: #ffffff; color: #1a1a1a; }
287
+
288
+ /* Header */
289
+ .header-box {
290
+ background: #ffffff;
291
+ padding: 2rem 0;
292
+ border-bottom: 1px solid #e5e7eb;
293
+ margin-bottom: 2rem;
294
+ }
295
+ .header-title {
296
+ font-size: 2.2rem;
297
+ font-weight: 700;
298
+ color: #0f172a;
299
+ letter-spacing: -0.5px;
300
+ }
301
+ .header-subtitle {
302
+ font-size: 1.1rem;
303
+ color: #64748b;
304
+ font-weight: 300;
305
+ margin-top: 8px;
306
+ }
307
+ .badge {
308
+ display: inline-flex;
309
+ align-items: center;
310
+ padding: 4px 12px;
311
+ font-size: 0.85rem;
312
+ font-weight: 500;
313
+ color: #0f172a;
314
+ background: #f1f5f9;
315
+ border: 1px solid #e2e8f0;
316
+ border-radius: 99px;
317
+ margin-right: 10px;
318
+ }
319
+
320
+ /* Content Box */
321
+ .content-box {
322
+ background: #ffffff;
323
+ border: 1px solid #e2e8f0;
324
+ border-radius: 8px;
325
+ padding: 1.5rem;
326
+ box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
327
+ }
328
+
329
+ /* Button */
330
+ button.primary {
331
+ background-color: #2563eb !important;
332
+ color: white !important;
333
+ border-radius: 6px !important;
334
+ font-weight: 500;
335
+ }
336
  """
337
 
338
+ theme = gr.themes.Base(
339
+ primary_hue="blue",
340
+ font=[gr.themes.GoogleFont("Roboto"), "ui-sans-serif", "system-ui"]
341
+ ).set(
342
+ body_background_fill="#ffffff",
343
+ block_background_fill="#ffffff",
344
+ block_border_width="1px",
345
+ block_label_background_fill="#ffffff"
346
  )
347
 
348
  with gr.Blocks(theme=theme, css=paper_css, title="LocPred-Prok") as app:
349
+
350
+ # --- Header ---
351
  with gr.Column(elem_classes="header-box"):
352
  gr.HTML("""
353
  <div class="header-title">LocPred-Prok</div>
354
+ <div class="header-subtitle">
355
+ Deep learning framework for prokaryotic subcellular localization using dual-branch architecture
356
+ </div>
357
+ <div style="margin-top: 15px;">
358
+ <span class="badge">Research Article</span>
359
+ <span class="badge">ESM-2 Enhanced</span>
360
+ <span class="badge">Gram-negative Bacteria</span>
361
+ </div>
362
  """)
363
 
364
+ # --- Main Content ---
365
  with gr.Tabs():
366
+ with gr.TabItem("Prediction Interface"):
367
  with gr.Row():
368
+ # Input Column
369
  with gr.Column(scale=4, elem_classes="content-box"):
370
+ gr.Markdown("### 1. Sequence Input")
371
+ gr.Markdown("<span style='color:#64748b; font-size:0.9rem'>Enter a protein sequence in FASTA format or raw amino acids.</span>")
372
+
373
+ sequence_input = gr.Textbox(
374
+ lines=12,
375
+ show_label=False,
376
+ placeholder=">Sequence_ID\nMKFKLTAGCL..."
377
+ )
378
+
379
  with gr.Row():
380
+ clear_btn = gr.ClearButton(sequence_input, value="Clear")
381
+ submit_btn = gr.Button("Run Analysis", variant="primary")
382
+
383
+ gr.Markdown("#### Test Examples")
384
+ gr.Examples(
385
+ examples=[
386
+ [">Outer Membrane Protein (OmpA)\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAFGGYQVNPYVGFEMGYDWLGRMPYKGSVENGAYKAQGVQLTAKLGYPITDDLDIYTRLGGMVWRADTKSNVYGKNHDTGVSPVFAGGVEYAITPEIATRLEYQWTNNIGDAHTIGTRPDNGMLSLGVSYRFGQGEAAPVVAPAPAPAPEVQTKHFTLKSDVLFNFNKATLKPEGQAALDQLYSQLSNLDPKDGSVVVLGYTDRIGSDAYNQGLSERRAQSVVDYLISKGIPADKISARGMGESNPVTGNTCDNVKQRAALIDCLAPDRRVEIEVKGIKDVVTQPQA"],
387
+ [">Cytoplasmic Protein (Ribosomal)\nARYLGPKLKLSRREGTDLFLKSGVRAIDTKCKIEQAPGQHGARKPRLSDYGVQLREKQKVRRIYGVLERQFRNYYKEAARLKGNTGENLLALLEGRLDNVVYRMGFG"]
388
+ ],
389
+ inputs=sequence_input,
390
+ label=None
391
+ )
392
+
393
+ # Output Column
394
  with gr.Column(scale=6, elem_classes="content-box"):
395
+ gr.Markdown("### 2. Localization Results")
396
+
397
+ # 使用 HTML 组件展示 SVG
398
+ output_svg = gr.HTML(label="Visualization", show_label=False)
399
+
400
  gr.Markdown("#### Confidence Scores")
401
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
402
 
403
+ with gr.TabItem("About & Methodology"):
404
+ gr.Markdown("""
405
+ ### Methodology
406
+ **LocPred-Prok** employs a dual-branch neural network architecture...
407
+ """)
408
+
409
+ # --- Interaction ---
410
+ submit_btn.click(
411
+ fn=predict,
412
+ inputs=sequence_input,
413
+ outputs=[output_label, output_svg]
414
+ )
415
+ clear_btn.click(lambda: [None, None], outputs=[output_label, output_svg])
416
+
417
+ # Launch
418
  app.launch()