wangleiofficial commited on
Commit
9539727
·
verified ·
1 Parent(s): 364e78b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -127
app.py CHANGED
@@ -5,14 +5,13 @@ import torch.nn.functional as F
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
  import matplotlib.patches as patches
8
- import matplotlib.path as mpath
9
  import numpy as np
10
  from io import BytesIO
11
  from PIL import Image
12
  from transformers import AutoTokenizer, AutoModel
13
 
14
  # ==========================
15
- # 0. 环境初始化 (保持不变)
16
  # ==========================
17
  plt.switch_backend('Agg')
18
 
@@ -64,13 +63,14 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
64
  return self.classifier_head(z_fused_gated)
65
 
66
  # ==========================
67
- # 2. 加载模型 (保持不变)
68
  # ==========================
69
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
71
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
72
  LABEL_MAP_PATH = "label_map.json"
73
 
 
74
  if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
75
  if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
76
 
@@ -90,126 +90,141 @@ classifier.eval()
90
  print("✅ Ready.")
91
 
92
  # ==========================
93
- # 3. UniProt 风格绘图引擎 (修复高亮逻辑)
94
  # ==========================
95
  def draw_uniprot_style_cell(target_class):
96
  target = target_class.lower() if target_class else ""
97
 
98
- # 配色
99
  c = {
100
- 'stroke': '#263238', 'cyto_bg': '#FFF9C4', 'peri_bg': '#E1F5FE',
101
- 'mem_line': '#546E7A', 'cw_pattern': '#90A4AE', 'dna': '#B0BEC5',
102
- 'ribo': '#CFD8DC', 'active': '#D32F2F', 'active_fill': '#FFCDD2'
 
 
 
 
103
  }
104
 
105
- # 状态判断
106
- is_secreted = "extracellular" in target or "secreted" in target
107
  is_om = "outer membrane" in target
108
  is_peri = "periplasm" in target
109
  is_cw = "cell wall" in target
110
  is_im = "plasma membrane" in target or "inner membrane" in target
111
  is_cyto = "cytoplasm" in target or "cytosol" in target
 
112
 
 
 
113
  fig, ax = plt.subplots(figsize=(8, 4.5), dpi=150)
114
  ax.set_xlim(0, 10)
115
  ax.set_ylim(0, 6)
116
  ax.axis('off')
117
 
118
- def draw_membrane_layer(x, y, w, h, z, is_active):
119
- # 修复:正确使用 is_active 判断颜色
120
- color = c['active'] if is_active else c['stroke']
121
- lw = 2.5 if is_active else 1.2
122
- p1 = patches.FancyBboxPatch((x, y), w-h, h, boxstyle=f"round,pad=0,rounding_size={h/2}",
123
- fc="none", ec=color, lw=lw, zorder=z)
124
- ax.add_patch(p1)
125
- gap = 0.15
126
- p2 = patches.FancyBboxPatch((x+gap, y+gap), w-h-2*gap, h-2*gap, boxstyle=f"round,pad=0,rounding_size={(h-2*gap)/2}",
127
- fc="none", ec=color, lw=lw, zorder=z)
128
- ax.add_patch(p2)
129
-
130
- # === 1. 背景层 ===
131
- # Periplasm Background
132
- base_capsule = patches.FancyBboxPatch((1, 1), 8, 4, boxstyle="round,pad=0,rounding_size=2",
133
- fc=c['peri_bg'], ec="none", zorder=1)
134
- ax.add_patch(base_capsule)
 
 
 
 
 
 
135
 
136
- # Cytoplasm Background
137
- cyto_capsule = patches.FancyBboxPatch((1.6, 1.6), 6.8, 2.8, boxstyle="round,pad=0,rounding_size=1.4",
138
- fc=c['cyto_bg'], ec="none", zorder=2)
139
- ax.add_patch(cyto_capsule)
140
-
141
- # === 2. 结构层 (修复高亮) ===
142
 
143
- # A. Outer Membrane (外膜) - 最外层
144
- draw_membrane_layer(1, 1, 8+2, 4, z=5, is_active=is_om)
145
-
146
- # B. Cell Wall (细胞壁)
147
- cw_color = c['active'] if is_cw else c['cw_pattern']
148
- cw_lw = 2 if is_cw else 0.5
149
- cw_patch = patches.FancyBboxPatch((1.3, 1.3), 7.4, 3.4, boxstyle="round,pad=0,rounding_size=1.7",
150
- fc="none", ec=cw_color, lw=cw_lw, linestyle=(0, (5, 5)), zorder=4)
151
- ax.add_patch(cw_patch)
152
 
153
- # Periplasm Highlight Fill
154
- if is_peri:
155
- peri_hl = patches.FancyBboxPatch((1.1, 1.1), 7.8, 3.8, boxstyle="round,pad=0,rounding_size=1.9",
156
- fc=c['active_fill'], ec="none", alpha=0.5, zorder=1.5)
157
- ax.add_patch(peri_hl)
158
-
159
- # C. Inner Membrane (内膜)
160
- draw_membrane_layer(1.6, 1.6, 6.8, 2.8, z=6, is_active=is_im)
161
-
162
- # D. Cytoplasm Highlight Fill (修复:只在胞质激活时填充)
163
- if is_cyto:
164
- cyto_hl = patches.FancyBboxPatch((1.7, 1.7), 6.6, 2.6, boxstyle="round,pad=0,rounding_size=1.3",
165
- fc=c['active_fill'], ec="none", zorder=2.1)
166
- ax.add_patch(cyto_hl)
167
-
168
- # === 3. 内容细节 ===
169
- # DNA
170
- t = np.linspace(0, 15, 300)
171
- x_dna = 5 + 2.0 * np.cos(t) * np.sin(t*0.5)
172
  y_dna = 3 + 0.6 * np.sin(t)
173
- ax.plot(x_dna, y_dna, color=c['dna'], lw=2, zorder=3, solid_capstyle='round')
174
 
175
- # Ribosomes
176
  rng = np.random.default_rng(42)
177
- for _ in range(30):
178
- rx, ry = rng.uniform(2.5, 7.5), rng.uniform(2.2, 3.8)
179
- if not (4 < rx < 6 and 2.5 < ry < 3.5):
180
- circle = patches.Circle((rx, ry), radius=0.06, fc=c['ribo'], ec="none", zorder=3)
181
- ax.add_patch(circle)
182
-
183
- # Secreted Proteins
 
184
  if is_secreted:
185
- ax.text(5, 5.5, "Secreted", ha='center', va='bottom', color=c['active'], fontweight='bold', fontsize=11)
186
- for i in range(3):
187
- sx = 4 + i
188
- ax.arrow(sx, 4.8, 0, 0.4, head_width=0.15, head_length=0.15, fc=c['active'], ec=c['active'], width=0.04, zorder=10)
189
- circle = patches.Circle((sx, 4.7), radius=0.15, fc=c['active'], ec=c['stroke'], zorder=10)
190
- ax.add_patch(circle)
191
-
192
- # === 4. 标注 (修复指示线) ===
 
 
 
 
 
 
 
193
  labels = [
194
- # (Text, y_pos, x_connect, is_active)
195
- ("Outer Membrane", 5, 1.0, is_om),
196
- ("Periplasm", 4.2, 1.2, is_peri),
197
- ("Cell Wall", 3.6, 1.3, is_cw),
198
- ("Inner Membrane", 3.0, 1.6, is_im),
199
- ("Cytoplasm", 2.0, 2.0, is_cyto)
200
  ]
201
-
202
- for txt, y_anchor, x_connect, active in labels:
203
- if active:
204
- # 激活状态:红色加粗 + 指示线连到正确位置
205
- ax.text(0.2, y_anchor, txt, ha='left', va='center', color=c['active'], fontweight='bold', fontsize=10)
206
- # 修复:指示线连接到对应的 x_connect 位置
207
- ax.plot([0.8, x_connect], [y_anchor, y_anchor], color=c['active'], lw=1.5, ls='-', alpha=0.8)
208
- else:
209
- # 非激活状态:灰色小字
210
- ax.text(0.2, y_anchor, txt, ha='left', va='center', color='#90A4AE', fontsize=8)
211
-
212
- ax.text(5, 0.2, f"Localization: {target_class}", ha='center', va='bottom', fontsize=11, fontweight='bold', color=c['stroke'])
 
 
 
 
213
 
214
  buf = BytesIO()
215
  plt.savefig(buf, format='png', bbox_inches='tight', transparent=True, dpi=150)
@@ -219,14 +234,14 @@ def draw_uniprot_style_cell(target_class):
219
  return img
220
 
221
  # ==========================
222
- # 4. 预测逻辑 (保持不变)
223
  # ==========================
224
  def predict(sequence_input):
225
  if not sequence_input or sequence_input.isspace():
226
  raise gr.Error("Please input a protein sequence.")
227
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
228
  seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
229
- if not seq: raise gr.Error("Invalid Amino Acid Sequence.")
230
 
231
  with torch.no_grad():
232
  inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
@@ -238,64 +253,54 @@ def predict(sequence_input):
238
  top_label = idx_to_label[top_idx.item()]
239
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
240
 
 
241
  cell_diagram = draw_uniprot_style_cell(top_label)
242
  return confidences, cell_diagram
243
 
244
  # ==========================
245
- # 5. UI 界面 (Paper-White Style, 保持不变)
246
  # ==========================
247
  paper_css = """
248
  @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap');
249
  body { font-family: 'Roboto', sans-serif !important; background-color: #ffffff; color: #1a1a1a; }
250
- .header-box { background: #ffffff; padding: 2rem 0; border-bottom: 1px solid #e5e7eb; margin-bottom: 2rem; text-align: left; }
251
- .header-title { font-size: 2.5rem; font-weight: 700; color: #000000; letter-spacing: -1px; line-height: 1.1; }
252
- .header-subtitle { font-size: 1.2rem; color: #52525b; font-weight: 300; margin-top: 10px; }
253
- .badge { display: inline-block; padding: 4px 10px; font-size: 0.8rem; font-weight: 500; color: #0f172a; background: #f1f5f9; border-radius: 4px; border: 1px solid #e2e8f0; margin-right: 8px; }
254
  .content-box { background: #ffffff; border: 1px solid #e5e7eb; border-radius: 8px; padding: 1.5rem; }
255
- button.primary { background: #2563eb !important; color: white !important; font-weight: 500; }
256
  """
257
 
258
- theme = gr.themes.Base(
259
- primary_hue="blue",
260
- font=[gr.themes.GoogleFont("Roboto"), "ui-sans-serif", "system-ui"]
261
- ).set(
262
- body_background_fill="#ffffff",
263
- block_background_fill="#ffffff",
264
- block_border_width="1px",
265
- block_label_background_fill="#ffffff"
266
  )
267
 
268
  with gr.Blocks(theme=theme, css=paper_css, title="LocPred-Prok") as app:
269
  with gr.Column(elem_classes="header-box"):
270
  gr.HTML("""
271
  <div class="header-title">LocPred-Prok</div>
272
- <div class="header-subtitle">Accurate prokaryotic subcellular localization using dual-branch protein language models</div>
273
- <div style="margin-top: 15px;"><span class="badge">Article</span><span class="badge">ESM-2 Enhanced</span><span class="badge">Gram-negative</span></div>
274
  """)
275
 
276
  with gr.Tabs():
277
- with gr.TabItem("Prediction Interface"):
278
  with gr.Row():
279
  with gr.Column(scale=4, elem_classes="content-box"):
280
- gr.Markdown("### 1. Sequence Input")
281
- gr.Markdown("<small style='color:gray'>Enter protein sequence in FASTA format.</small>")
282
  sequence_input = gr.Textbox(lines=10, show_label=False, placeholder=">Sequence...")
283
  with gr.Row():
284
- clear_btn = gr.ClearButton(sequence_input, value="Clear")
285
  submit_btn = gr.Button("Analyze", variant="primary")
286
- gr.Markdown("#### Example Data")
287
- gr.Examples(examples=[[">Outer Membrane Protein A (OmpA)\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAFGGYQVNPYVGFEMGYDWLGRMPYKGSVENGAYKAQGVQLTAKLGYPITDDLDIYTRLGGMVWRADTKSNVYGKNHDTGVSPVFAGGVEYAITPEIATRLEYQWTNNIGDAHTIGTRPDNGMLSLGVSYRFGQGEAAPVVAPAPAPAPEVQTKHFTLKSDVLFNFNKATLKPEGQAALDQLYSQLSNLDPKDGSVVVLGYTDRIGSDAYNQGLSERRAQSVVDYLISKGIPADKISARGMGESNPVTGNTCDNVKQRAALIDCLAPDRRVEIEVKGIKDVVTQPQA"], [">Cytoplasmic Protein (Ribo)\nARYLGPKLKLSRREGTDLFLKSGVRAIDTKCKIEQAPGQHGARKPRLSDYGVQLREKQKVRRIYGVLERQFRNYYKEAARLKGNTGENLLALLEGRLDNVVYRMGFG"]], inputs=sequence_input, label=None)
 
 
288
 
289
  with gr.Column(scale=6, elem_classes="content-box"):
290
- gr.Markdown("### 2. Localization Analysis")
291
- output_image = gr.Image(label="Schematic Visualization", show_label=False, show_download_button=True, interactive=False, type="pil", height=350)
292
- gr.Markdown("#### Probability Scores")
293
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
294
 
295
- with gr.TabItem("About & Citation"):
296
- gr.Markdown("### Method Description\nLocPred-Prok utilizes...")
297
-
298
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_image])
299
- clear_btn.click(lambda: [None, None], outputs=[output_label, output_image])
300
 
301
  app.launch()
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
  import matplotlib.patches as patches
 
8
  import numpy as np
9
  from io import BytesIO
10
  from PIL import Image
11
  from transformers import AutoTokenizer, AutoModel
12
 
13
  # ==========================
14
+ # 0. 环境初始化
15
  # ==========================
16
  plt.switch_backend('Agg')
17
 
 
63
  return self.classifier_head(z_fused_gated)
64
 
65
  # ==========================
66
+ # 2. 加载模型
67
  # ==========================
68
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
70
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
71
  LABEL_MAP_PATH = "label_map.json"
72
 
73
+ # 简单检查
74
  if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
75
  if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
76
 
 
90
  print("✅ Ready.")
91
 
92
  # ==========================
93
+ # 3. 彻底修复的绘图引擎 (Centric Coordinates)
94
  # ==========================
95
  def draw_uniprot_style_cell(target_class):
96
  target = target_class.lower() if target_class else ""
97
 
98
+ # === 配色定义 ===
99
  c = {
100
+ 'stroke': '#37474F', # 默认深灰轮廓
101
+ 'bg_peri': '#E1F5FE', # 默认周质背景 (浅蓝)
102
+ 'bg_cyto': '#FFF9C4', # 默认胞质背景 (浅黄)
103
+ 'highlight_stroke': '#D50000', # 高亮轮廓 (深红)
104
+ 'highlight_fill': '#FFCDD2', # 高亮填充 (淡红)
105
+ 'dna': '#B0BEC5',
106
+ 'ribo': '#90A4AE'
107
  }
108
 
109
+ # === 状态判断 ===
 
110
  is_om = "outer membrane" in target
111
  is_peri = "periplasm" in target
112
  is_cw = "cell wall" in target
113
  is_im = "plasma membrane" in target or "inner membrane" in target
114
  is_cyto = "cytoplasm" in target or "cytosol" in target
115
+ is_secreted = "extracellular" in target or "secreted" in target
116
 
117
+ # === 画布初始化 ===
118
+ # 中心点 (Cx, Cy) = (5, 3)
119
  fig, ax = plt.subplots(figsize=(8, 4.5), dpi=150)
120
  ax.set_xlim(0, 10)
121
  ax.set_ylim(0, 6)
122
  ax.axis('off')
123
 
124
+ # === 核心辅助函数:绘制绝对居中的胶囊 ===
125
+ def draw_centered_capsule(width, height, fill_color, edge_color, lw, z, linestyle='-'):
126
+ # FancyBboxPatch xy 是左下角坐标。
127
+ # 要居中,左下角 x = CenterX - Width/2
128
+ x = 5.0 - width / 2
129
+ y = 3.0 - height / 2
130
+ # rounding_size 设为高度的一半,这就变成了标准的胶囊/药丸形状
131
+ r = height / 2
132
+
133
+ patch = patches.FancyBboxPatch(
134
+ (x, y), width, height,
135
+ boxstyle=f"round,pad=0,rounding_size={r}",
136
+ fc=fill_color, ec=edge_color, lw=lw, linestyle=linestyle, zorder=z
137
+ )
138
+ ax.add_patch(patch)
139
+ return x, y, width, height # 返回坐标供后续标注使用
140
+
141
+ # === 1. 绘制 Layer 1: 外膜 (Outer Membrane) ===
142
+ # 如果是 Periplasm 高亮,那么底色变红;否则是默认浅蓝
143
+ # 如果是 OuterMembrane 高亮,那么边框变红变粗
144
+ peri_fill = c['highlight_fill'] if is_peri else c['bg_peri']
145
+ om_edge = c['highlight_stroke'] if is_om else c['stroke']
146
+ om_lw = 3.5 if is_om else 1.5
147
 
148
+ # 绘制最大的���囊 (代表外膜轮廓 + 周质背景)
149
+ # 尺寸: 8.5 x 4.2
150
+ draw_centered_capsule(8.5, 4.2, peri_fill, om_edge, om_lw, z=1)
 
 
 
151
 
152
+ # === 2. 绘制 Layer 2: 细胞壁 (Cell Wall) ===
153
+ # 位于中间层
154
+ cw_edge = c['highlight_stroke'] if is_cw else '#78909C'
155
+ cw_lw = 2.5 if is_cw else 1.0
156
+ cw_ls = '-' if is_cw else '--' # 平时虚线,高亮实线
 
 
 
 
157
 
158
+ # 尺寸: 7.5 x 3.2
159
+ draw_centered_capsule(7.5, 3.2, "none", cw_edge, cw_lw, z=2, linestyle=cw_ls)
160
+
161
+ # === 3. 绘制 Layer 3: 内膜 (Inner Membrane) + 胞质 (Cytoplasm) ===
162
+ # 如果是 Cytoplasm 高亮,填充变红;否则默认浅黄
163
+ # 如果是 InnerMembrane 高亮,边框变红变粗
164
+ cyto_fill = c['highlight_fill'] if is_cyto else c['bg_cyto']
165
+ im_edge = c['highlight_stroke'] if is_im else c['stroke']
166
+ im_lw = 3.5 if is_im else 1.5
167
+
168
+ # 尺寸: 6.5 x 2.2
169
+ draw_centered_capsule(6.5, 2.2, cyto_fill, im_edge, im_lw, z=3)
170
+
171
+ # === 4. 内部细节 (DNA & Ribosomes) ===
172
+ # 仅装饰,画在最中心
173
+ # DNA 线条
174
+ t = np.linspace(0, 12, 200)
175
+ x_dna = 5 + 2.2 * np.cos(t) * np.sin(t*0.5)
 
176
  y_dna = 3 + 0.6 * np.sin(t)
177
+ ax.plot(x_dna, y_dna, color=c['dna'], lw=1.5, zorder=4, alpha=0.6)
178
 
179
+ # 核糖体 (点)
180
  rng = np.random.default_rng(42)
181
+ for _ in range(25):
182
+ # 在中心区域随机撒点
183
+ rx = rng.uniform(3.0, 7.0)
184
+ ry = rng.uniform(2.3, 3.7)
185
+ circle = patches.Circle((rx, ry), radius=0.05, fc=c['ribo'], zorder=4)
186
+ ax.add_patch(circle)
187
+
188
+ # === 5. 分泌蛋白 (Secreted) ===
189
  if is_secreted:
190
+ ax.text(5, 5.5, "SECRETED / EXTRACELLULAR", ha='center', va='center',
191
+ color=c['highlight_stroke'], fontweight='bold')
192
+ # 画几个向上的箭头
193
+ ax.arrow(5, 5.2, 0, 0.4, head_width=0.2, fc=c['highlight_stroke'], ec=c['highlight_stroke'], width=0.05)
194
+
195
+ # === 6. 标注系统 (Labeling) ===
196
+ # 使用 annotate 自动画箭头指引
197
+
198
+ # 定义各层的指引坐标 (全部取右侧中点)
199
+ # CenterY = 3.
200
+ # OuterMembrane Edge X ≈ 5 + 8.5/2 = 9.25
201
+ # Periplasm X ≈ 5 + 8.0/2 = 9.0 (Inside OM)
202
+ # InnerMembrane Edge X ≈ 5 + 6.5/2 = 8.25
203
+ # Cytoplasm X ≈ 5
204
+
205
  labels = [
206
+ ("Outer Membrane", (9.25, 3.0), (10, 4.5), is_om),
207
+ ("Periplasm", (8.0, 3.8), (9.5, 5.2), is_peri), # 指向胶囊上方空隙
208
+ ("Cell Wall", (8.75, 3.0), (10, 3.5), is_cw), # 指向中间虚线
209
+ ("Inner Membrane", (8.25, 3.0), (10, 2.5), is_im),
210
+ ("Cytoplasm", (5.0, 3.0), (5.0, 1.0), is_cyto) # 指向中心,文字在下方
 
211
  ]
212
+
213
+ for txt, xy_target, xy_text, active in labels:
214
+ color = c['highlight_stroke'] if active else '#546E7A'
215
+ weight = 'bold' if active else 'normal'
216
+
217
+ # 如果激活,画红色实线箭头;否则画灰色细箭头
218
+ arrow_props = dict(arrowstyle="->", color=color, lw=1.5 if active else 0.8)
219
+
220
+ ax.annotate(txt, xy=xy_target, xytext=xy_text,
221
+ arrowprops=arrow_props,
222
+ fontsize=10, fontweight=weight, color=color,
223
+ ha='center', va='center')
224
+
225
+ # 底部标题
226
+ ax.text(5, 0.2, f"Prediction: {target_class}", ha='center', va='bottom',
227
+ fontsize=12, fontweight='bold', color='#263238')
228
 
229
  buf = BytesIO()
230
  plt.savefig(buf, format='png', bbox_inches='tight', transparent=True, dpi=150)
 
234
  return img
235
 
236
  # ==========================
237
+ # 4. 预测逻辑
238
  # ==========================
239
  def predict(sequence_input):
240
  if not sequence_input or sequence_input.isspace():
241
  raise gr.Error("Please input a protein sequence.")
242
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
243
  seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
244
+ if not seq: raise gr.Error("Invalid Sequence.")
245
 
246
  with torch.no_grad():
247
  inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
 
253
  top_label = idx_to_label[top_idx.item()]
254
  confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
255
 
256
+ # 生成修正后的图
257
  cell_diagram = draw_uniprot_style_cell(top_label)
258
  return confidences, cell_diagram
259
 
260
  # ==========================
261
+ # 5. UI 界面
262
  # ==========================
263
  paper_css = """
264
  @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap');
265
  body { font-family: 'Roboto', sans-serif !important; background-color: #ffffff; color: #1a1a1a; }
266
+ .header-box { background: #ffffff; padding: 2rem 0; border-bottom: 1px solid #e5e7eb; margin-bottom: 2rem; }
267
+ .header-title { font-size: 2.5rem; font-weight: 700; color: #000000; letter-spacing: -1px; }
268
+ .badge { display: inline-block; padding: 4px 10px; font-size: 0.8rem; background: #f1f5f9; border: 1px solid #e2e8f0; border-radius: 4px; margin-right: 8px; }
 
269
  .content-box { background: #ffffff; border: 1px solid #e5e7eb; border-radius: 8px; padding: 1.5rem; }
 
270
  """
271
 
272
+ theme = gr.themes.Base(primary_hue="blue", font=[gr.themes.GoogleFont("Roboto"), "ui-sans-serif", "system-ui"]).set(
273
+ body_background_fill="#ffffff", block_background_fill="#ffffff", block_border_width="1px"
 
 
 
 
 
 
274
  )
275
 
276
  with gr.Blocks(theme=theme, css=paper_css, title="LocPred-Prok") as app:
277
  with gr.Column(elem_classes="header-box"):
278
  gr.HTML("""
279
  <div class="header-title">LocPred-Prok</div>
280
+ <div style="font-size: 1.2rem; color: #52525b; margin: 10px 0;">Accurate prokaryotic subcellular localization using dual-branch protein language models</div>
281
+ <div><span class="badge">Article</span><span class="badge">ESM-2 Enhanced</span><span class="badge">Gram-negative</span></div>
282
  """)
283
 
284
  with gr.Tabs():
285
+ with gr.TabItem("Prediction"):
286
  with gr.Row():
287
  with gr.Column(scale=4, elem_classes="content-box"):
288
+ gr.Markdown("### Sequence Input")
 
289
  sequence_input = gr.Textbox(lines=10, show_label=False, placeholder=">Sequence...")
290
  with gr.Row():
291
+ gr.ClearButton(sequence_input, value="Clear")
292
  submit_btn = gr.Button("Analyze", variant="primary")
293
+ gr.Examples([
294
+ [">Outer Membrane Protein\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAFGGYQVNPYVGFEMGYDWLGRMPYKGSVENGAYKAQGVQLTAKLGYPITDDLDIYTRLGGMVWRADTKSNVYGKNHDTGVSPVFAGGVEYAITPEIATRLEYQWTNNIGDAHTIGTRPDNGMLSLGVSYRFGQGEAAPVVAPAPAPAPEVQTKHFTLKSDVLFNFNKATLKPEGQAALDQLYSQLSNLDPKDGSVVVLGYTDRIGSDAYNQGLSERRAQSVVDYLISKGIPADKISARGMGESNPVTGNTCDNVKQRAALIDCLAPDRRVEIEVKGIKDVVTQPQA"],
295
+ [">Cytoplasmic Protein\nARYLGPKLKLSRREGTDLFLKSGVRAIDTKCKIEQAPGQHGARKPRLSDYGVQLREKQKVRRIYGVLERQFRNYYKEAARLKGNTGENLLALLEGRLDNVVYRMGFG"]
296
+ ], inputs=sequence_input, label="Examples")
297
 
298
  with gr.Column(scale=6, elem_classes="content-box"):
299
+ gr.Markdown("### Localization Visualization")
300
+ output_image = gr.Image(label="Visualization", show_label=False, show_download_button=True, interactive=False, type="pil", height=400)
301
+ gr.Markdown("#### Confidence Scores")
302
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
303
 
 
 
 
304
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_image])
 
305
 
306
  app.launch()