wangleiofficial commited on
Commit
886c88b
·
verified ·
1 Parent(s): b3298fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -222
app.py CHANGED
@@ -6,252 +6,192 @@ import gradio as gr
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  # ==========================
9
- # 🚧 0. 防止 Hugging Face 缓存溢出 (保持不变)
10
  # ==========================
11
- os.environ["HF_HOME"] = "/tmp/hf_cache"
12
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
13
- os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
14
-
15
- for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
16
- shutil.rmtree(path, ignore_errors=True)
17
- os.makedirs(path, exist_ok=True)
18
-
19
- # ==========================
20
- # 1. Model Definition (保持不变)
21
- # ==========================
22
- class AttentionPooling(nn.Module):
23
- def __init__(self, d_model):
24
- super().__init__()
25
- self.attention_net = nn.Linear(d_model, 1)
26
-
27
- def forward(self, x, mask):
28
- attn_logits = self.attention_net(x).squeeze(2)
29
- attn_logits.masked_fill_(mask == 0, -float('inf'))
30
- attn_weights = F.softmax(attn_logits, dim=1)
31
- return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
32
-
33
- class ProtDualBranchEnhancedClassifier(nn.Module):
34
- def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
35
- super().__init__()
36
- self.cls_projector = nn.Linear(d_model, projection_dim)
37
- self.token_refiner = nn.Sequential(
38
- nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
39
- nn.ReLU()
40
- )
41
- self.attention_pooling = AttentionPooling(d_model)
42
- self.tok_projector = nn.Linear(d_model, projection_dim)
43
- fused_dim = projection_dim * 2
44
- self.gate = nn.Sequential(
45
- nn.Linear(fused_dim, fused_dim),
46
- nn.Sigmoid()
47
- )
48
- self.classifier_head = nn.Sequential(
49
- nn.LayerNorm(fused_dim),
50
- nn.Linear(fused_dim, fused_dim * 2),
51
- nn.ReLU(),
52
- nn.Dropout(dropout),
53
- nn.Linear(fused_dim * 2, num_classes)
54
- )
55
-
56
- def forward(self, cls_embedding, token_embeddings, mask):
57
- z_cls = self.cls_projector(cls_embedding)
58
- tok_emb_permuted = token_embeddings.permute(0, 2, 1)
59
- refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
60
- z_tok_pooled = self.attention_pooling(refined_tok_emb, mask)
61
- z_tok = self.tok_projector(z_tok_pooled)
62
- z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
63
- gate_values = self.gate(z_fused_concat)
64
- z_fused_gated = z_fused_concat * gate_values
65
- return self.classifier_head(z_fused_gated)
66
-
67
- # ==========================
68
- # 2. Load Models and Files (保持不变)
69
- # ==========================
70
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
- PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
72
- CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
73
- LABEL_MAP_PATH = "label_map.json"
74
-
75
- if not os.path.exists(LABEL_MAP_PATH):
76
- raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'.")
77
- with open(LABEL_MAP_PATH, 'r') as f:
78
- label_to_idx = json.load(f)
79
- idx_to_label = {v: k for k, v in label_to_idx.items()}
80
-
81
- NUM_CLASSES = len(idx_to_label)
82
- D_MODEL = 640
83
-
84
- print("🔹 Loading Protein Language Model...")
85
- tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
86
- plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
87
- plm_model.eval()
88
- print("✅ PLM loaded.")
89
-
90
- print("🔹 Loading classifier...")
91
- classifier = ProtDualBranchEnhancedClassifier(
92
- d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
93
- dropout=0.3, kernel_size=3
94
- ).to(DEVICE)
95
-
96
- if not os.path.exists(CLASSIFIER_PATH):
97
- raise FileNotFoundError(f"Error: Could not find '{CLASSIFIER_PATH}'.")
98
-
99
- classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
100
- classifier.eval()
101
- print("✅ System Ready.")
102
-
103
- # ==========================
104
- # 3. Prediction Function (微调)
105
- # ==========================
106
- def predict(sequence_input):
107
- if not sequence_input or sequence_input.isspace():
108
- # 返回 None 而不是字典,让 Label 组件显示更干净
109
- raise gr.Error("Sequence cannot be empty.")
110
-
111
- sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
112
- sequence = re.sub(r'[^A-Z]', '', sequence.upper())
113
-
114
- if not sequence:
115
- raise gr.Error("Invalid sequence format.")
116
-
117
- with torch.no_grad():
118
- inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
119
- outputs = plm_model(**inputs)
120
- hidden_states = outputs.last_hidden_state
121
- cls_embedding = hidden_states[:, 0, :]
122
- token_embeddings = hidden_states[:, 1:-1, :]
123
- token_mask = inputs['attention_mask'][:, 1:-1]
124
-
125
- logits = classifier(cls_embedding, token_embeddings, token_mask)
126
- probabilities = F.softmax(logits, dim=1)[0]
127
-
128
- confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
129
- return confidences
130
 
131
  # ==========================
132
- # 4. Modernized Gradio Interface
133
  # ==========================
134
 
135
- # 自定义 CSS:增加渐变标题、阴影、圆角
136
- custom_css = """
137
- .gradio-container {
138
- font-family: 'IBM Plex Sans', sans-serif;
139
- }
140
- .main-header {
141
- text-align: center;
142
- background: linear-gradient(135deg, #3b82f6 0%, #06b6d4 100%);
143
- color: white;
144
- padding: 2rem;
145
- border-radius: 12px;
146
- margin-bottom: 1.5rem;
147
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
148
- }
149
- .main-header h1 {
150
  color: white;
151
- margin-bottom: 0.5rem;
152
- font-size: 2.2rem;
 
 
153
  }
154
- .main-header p {
155
- color: #e0f2fe;
156
- font-size: 1.1rem;
 
 
 
 
 
 
157
  }
158
- .input-card, .output-card {
159
- border: 1px solid #e5e7eb;
160
- border-radius: 12px;
161
- padding: 1.5rem;
162
- background: white;
163
- box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1);
164
  }
 
 
 
 
165
  """
166
 
167
- # 使用更清爽的 Teal (青色) 主题,符合生物信息学特征
168
- theme = gr.themes.Soft(
169
- primary_hue="teal",
170
- secondary_hue="blue",
171
  neutral_hue="slate",
172
- font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
173
- ).set(
174
- button_primary_background_fill="*primary_600",
175
- button_primary_background_fill_hover="*primary_700",
176
- block_shadow="*shadow_drop_lg"
177
  )
178
 
179
- with gr.Blocks(theme=theme, css=custom_css, title="LocPred-Prok") as app:
180
 
181
- # --- 顶部 Header ---
182
- with gr.Column(elem_classes="main-header"):
183
- gr.Markdown(
184
- """
185
- # 🧬 Prokaryotic Subcellular Localization
186
- ### Dual-Branch Architecture with Protein Language Models
187
- Identify where your protein functions using State-of-the-Art Deep Learning.
188
- """
189
- )
190
-
191
- # --- 主体内容 ---
192
- with gr.Row(equal_height=False):
 
 
 
 
193
 
194
- # 左侧:输入区
195
- with gr.Column(scale=5, elem_classes="input-card"):
196
- gr.Markdown("### 📥 Input Sequence")
197
- gr.Markdown("Paste your amino acid sequence (FASTA format supported).")
198
-
199
- sequence_input = gr.Textbox(
200
- lines=8,
201
- label="",
202
- placeholder=">Example Header\nMKFKLTAGCLAVAGVLLASSFGADAEIVV...",
203
- show_label=False
204
- )
205
-
206
  with gr.Row():
207
- clear_btn = gr.ClearButton(components=[sequence_input], value="Clear")
208
- submit_btn = gr.Button("✨ Run Prediction", variant="primary", scale=2)
209
-
210
- gr.Markdown("#### 💡 Example Sequences")
211
- gr.Examples(
212
- examples=[
213
- [">sp|P27361|PBP2_ECOLI Penicillin-binding protein 2\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
214
- ["MSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
215
- ],
216
- inputs=sequence_input,
217
- label=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  )
219
-
220
- # 右侧:输出区
221
- with gr.Column(scale=4, elem_classes="output-card"):
222
- gr.Markdown("### 📊 Prediction Results")
223
 
224
- output_label = gr.Label(
225
- num_top_classes=NUM_CLASSES,
226
- label="Probability Distribution",
227
- show_label=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
 
230
- # 信息折叠面板
231
- with gr.Accordion("📘 Model Architecture & Details", open=False):
232
- gr.Markdown(
233
- """
234
- This model utilizes a **Dual-Branch Architecture**:
235
- 1. **Semantic Branch**: Extracts global features using `ESM-2 (150M)` CLS token.
236
- 2. **Structural Branch**: Refines residue-level embeddings via CNN and Attention Pooling.
237
-
238
- **Citation:**
239
- *LocPred-Prok: Prokaryotic protein subcellular localization prediction with a dual-branch architecture.*
240
- """
241
- )
242
-
243
- # --- 底部 Footer ---
244
- gr.Markdown(
245
- """
246
- <div style="text-align: center; margin-top: 2rem; color: #64748b; font-size: 0.9rem;">
247
- © 2025 iSysLab HUST | Powered by ESM-2 & PyTorch
248
  </div>
249
- """
250
- )
251
 
252
- # --- 绑定事件 ---
253
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label)
254
  clear_btn.click(lambda: None, outputs=[output_label])
255
 
256
- # 启动
257
  app.launch()
 
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  # ==========================
9
+ # 0-3 部分:保持你的底层逻辑完全不变
10
  # ==========================
11
+ # ... (请保持之前的 Imports, Model Definition, Load Models, Predict Function 代码完全一致) ...
12
+ # 为了节省篇幅,这里假设你已经保留了之前代码的第0到第3部分 (直到 def predict 为止)
13
+ # 务必确保运行前包含之前的 Model 类定义和加载逻辑!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # ==========================
16
+ # 4. Academic Research Interface
17
  # ==========================
18
 
19
+ # 学术风格 CSS
20
+ academic_css = """
21
+ body { font-family: 'Roboto', 'Helvetica Neue', Arial, sans-serif; }
22
+ .header-container {
23
+ background: linear-gradient(to right, #1e3a8a, #3b82f6); /* 深蓝学术风 */
 
 
 
 
 
 
 
 
 
 
24
  color: white;
25
+ padding: 2.5rem;
26
+ border-radius: 10px;
27
+ margin-bottom: 20px;
28
+ text-align: center;
29
  }
30
+ .header-title { font-size: 2.5rem; font-weight: 700; margin-bottom: 0.5rem; }
31
+ .header-subtitle { font-size: 1.2rem; opacity: 0.9; font-weight: 300; }
32
+ .badge-container { display: flex; justify-content: center; gap: 15px; margin-top: 15px; }
33
+ .badge {
34
+ background: rgba(255,255,255,0.2);
35
+ padding: 5px 15px;
36
+ border-radius: 20px;
37
+ font-size: 0.9rem;
38
+ border: 1px solid rgba(255,255,255,0.4);
39
  }
40
+ .highlight-box {
41
+ background: #f8fafc;
42
+ border-left: 5px solid #3b82f6;
43
+ padding: 15px;
44
+ margin: 20px 0;
45
+ color: #334155;
46
  }
47
+ .performance-table { width: 100%; border-collapse: collapse; margin-top: 10px; }
48
+ .performance-table th { background: #e2e8f0; padding: 8px; text-align: left; }
49
+ .performance-table td { border-bottom: 1px solid #e2e8f0; padding: 8px; }
50
+ .footer { text-align: center; color: #94a3b8; margin-top: 30px; font-size: 0.85rem; }
51
  """
52
 
53
+ # 定义主题
54
+ theme = gr.themes.Default(
55
+ primary_hue="blue",
56
+ secondary_hue="slate",
57
  neutral_hue="slate",
58
+ font=[gr.themes.GoogleFont("Roboto"), "ui-sans-serif", "system-ui"]
 
 
 
 
59
  )
60
 
61
+ with gr.Blocks(theme=theme, css=academic_css, title="LocPred-Prok Web Server") as app:
62
 
63
+ # --- 1. 学术 Header ---
64
+ with gr.Column(elem_classes="header-container"):
65
+ gr.HTML("""
66
+ <div class="header-title">LocPred-Prok</div>
67
+ <div class="header-subtitle">
68
+ Prokaryotic Protein Subcellular Localization Prediction with Dual-Branch Architecture
69
+ </div>
70
+ <div class="badge-container">
71
+ <span class="badge">🧬 ESM-2 150M Backbone</span>
72
+ <span class="badge">🏆 91.2% Accuracy</span>
73
+ <span class="badge">🎯 MCC 0.889</span>
74
+ </div>
75
+ """)
76
+
77
+ # --- 2. 核心功能区 (Tab结构) ---
78
+ with gr.Tabs():
79
 
80
+ # === Tab 1: Web Server (预测工具) ===
81
+ with gr.TabItem("🚀 Prediction Server"):
 
 
 
 
 
 
 
 
 
 
82
  with gr.Row():
83
+ # 左侧输入
84
+ with gr.Column(scale=5):
85
+ gr.Markdown("### 📥 Input Sequence (FASTA)")
86
+ sequence_input = gr.Textbox(
87
+ lines=8,
88
+ placeholder=">Example_Protein\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL...",
89
+ show_label=False,
90
+ elem_id="seq-input"
91
+ )
92
+ with gr.Row():
93
+ clear_btn = gr.ClearButton(components=[sequence_input], value="Clear")
94
+ submit_btn = gr.Button("Run Prediction", variant="primary", scale=2)
95
+
96
+ gr.Markdown("#### Example Sequences")
97
+ gr.Examples(
98
+ examples=[
99
+ [">Gram-negative Outer Membrane Protein\nMSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
100
+ [">Gram-positive Cell Wall Protein\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
101
+ ],
102
+ inputs=sequence_input,
103
+ label=None
104
+ )
105
+
106
+ # 右侧输出
107
+ with gr.Column(scale=4):
108
+ gr.Markdown("### 📊 Prediction Results")
109
+ output_label = gr.Label(num_top_classes=NUM_CLASSES, label="Probabilities")
110
+
111
+ # 解释性文字
112
+ gr.Markdown(
113
+ """
114
+ <div style="font-size: 0.9rem; color: #64748b; margin-top: 10px;">
115
+ <b>Note:</b> This model is optimized for challenging classes including
116
+ <i>Gram-positive cell wall</i> and <i>Gram-negative outer membrane</i> proteins.
117
+ </div>
118
+ """
119
+ )
120
+
121
+ # === Tab 2: About & Abstract (论文展示) ===
122
+ with gr.TabItem("📖 About & Abstract"):
123
+ gr.Markdown("### Abstract")
124
+ gr.Markdown(
125
+ """
126
+ The precise localization of proteins within prokaryotic cells is fundamental to understanding their function.
127
+ **LocPred-Prok** is a novel deep learning framework that employs a purpose-built **dual-branch architecture**,
128
+ synergistically integrating global and local sequence features extracted from **ESM-2 (150M)** embeddings.
129
+ """
130
  )
 
 
 
 
131
 
132
+ # 高亮核心发现
133
+ gr.HTML("""
134
+ <div class="highlight-box">
135
+ <b>💡 Key Findings:</b><br>
136
+ 1. <b>Bigger ≠ Better:</b> Peak performance is achieved by the mid-sized ESM-2-150M, not the largest models.<br>
137
+ 2. <b>Hard Classes Solved:</b> Exceptional performance on Gram-positive cell wall (MCC=0.84) and Gram-negative outer membrane (MCC=0.91).
138
+ </div>
139
+ """)
140
+
141
+ gr.Markdown("### 📈 Performance Metrics (Homology-Partitioned Benchmark)")
142
+ gr.HTML("""
143
+ <table class="performance-table">
144
+ <tr>
145
+ <th>Metric</th>
146
+ <th>LocPred-Prok Score</th>
147
+ <th>Improvement</th>
148
+ </tr>
149
+ <tr>
150
+ <td><b>Accuracy</b></td>
151
+ <td><b>91.2%</b></td>
152
+ <td>State-of-the-Art</td>
153
+ </tr>
154
+ <tr>
155
+ <td><b>MCC (Overall)</b></td>
156
+ <td><b>0.889</b></td>
157
+ <td>Significant Leap</td>
158
+ </tr>
159
+ <tr>
160
+ <td>MCC (Outer Membrane)</td>
161
+ <td>0.91</td>
162
+ <td>High Precision</td>
163
+ </tr>
164
+ </table>
165
+ """)
166
+
167
+ # 这里可以放架构图,如果你有图片链接的话
168
+ # gr.Image("https://your-image-url.com/architecture.png", label="Model Architecture")
169
+
170
+ # === Tab 3: Citation (引用) ===
171
+ with gr.TabItem("📝 Citation"):
172
+ gr.Markdown("If you use LocPred-Prok in your research, please cite our paper:")
173
+ gr.Code(
174
+ """
175
+ @article{LocPredProk2025,
176
+ title={LocPred-Prok: Prokaryotic protein subcellular localization prediction with a dual-branch architecture and protein language model},
177
+ author={Your Name and Co-authors},
178
+ journal={Submission Journal},
179
+ year={2025}
180
+ }
181
+ """,
182
+ language="bibtex",
183
+ label="BibTeX"
184
  )
185
 
186
+ # --- Footer ---
187
+ gr.HTML("""
188
+ <div class="footer">
189
+ Developed by iSysLab | <a href="https://github.com/isyslab-hust" target="_blank">GitHub</a> | Based on ESM-2 & PyTorch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  </div>
191
+ """)
 
192
 
193
+ # 绑定事件
194
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label)
195
  clear_btn.click(lambda: None, outputs=[output_label])
196
 
 
197
  app.launch()