wangleiofficial commited on
Commit
b3298fd
·
verified ·
1 Parent(s): 4782d51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -45
app.py CHANGED
@@ -6,22 +6,20 @@ import gradio as gr
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  # ==========================
9
- # 🚧 0. 防止 Hugging Face 缓存溢出
10
  # ==========================
11
  os.environ["HF_HOME"] = "/tmp/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
13
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
14
 
15
- # 每次启动时清理旧缓存,防止超过 50G 限制
16
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
17
  shutil.rmtree(path, ignore_errors=True)
18
  os.makedirs(path, exist_ok=True)
19
 
20
  # ==========================
21
- # 1. Model Definition
22
  # ==========================
23
  class AttentionPooling(nn.Module):
24
- """Attention Pooling Layer"""
25
  def __init__(self, d_model):
26
  super().__init__()
27
  self.attention_net = nn.Linear(d_model, 1)
@@ -33,7 +31,6 @@ class AttentionPooling(nn.Module):
33
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
34
 
35
  class ProtDualBranchEnhancedClassifier(nn.Module):
36
- """Enhanced dual-branch model"""
37
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
38
  super().__init__()
39
  self.cls_projector = nn.Linear(d_model, projection_dim)
@@ -68,16 +65,15 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
68
  return self.classifier_head(z_fused_gated)
69
 
70
  # ==========================
71
- # 2. Load Models and Files
72
  # ==========================
73
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
- PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D" # 可改为 esm2_t12_35M_UR50D 减少体积
75
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
76
  LABEL_MAP_PATH = "label_map.json"
77
 
78
- # --- 加载标签映射 ---
79
  if not os.path.exists(LABEL_MAP_PATH):
80
- raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'. Please upload it to your Space.")
81
  with open(LABEL_MAP_PATH, 'r') as f:
82
  label_to_idx = json.load(f)
83
  idx_to_label = {v: k for k, v in label_to_idx.items()}
@@ -85,40 +81,38 @@ with open(LABEL_MAP_PATH, 'r') as f:
85
  NUM_CLASSES = len(idx_to_label)
86
  D_MODEL = 640
87
 
88
- # --- 加载预训练蛋白模型 ---
89
  print("🔹 Loading Protein Language Model...")
90
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
91
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
92
  plm_model.eval()
93
- print("✅ PLM loaded successfully.")
94
 
95
- # --- 加载下游分类器 ---
96
- print("🔹 Loading downstream classifier...")
97
  classifier = ProtDualBranchEnhancedClassifier(
98
  d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
99
  dropout=0.3, kernel_size=3
100
  ).to(DEVICE)
101
 
102
  if not os.path.exists(CLASSIFIER_PATH):
103
- raise FileNotFoundError(f"Error: Could not find '{CLASSIFIER_PATH}'. Please upload your trained .pth file.")
104
 
105
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
106
  classifier.eval()
107
- print("✅ Classifier loaded. Application is ready!")
108
 
109
  # ==========================
110
- # 3. Prediction Function
111
  # ==========================
112
  def predict(sequence_input):
113
  if not sequence_input or sequence_input.isspace():
114
- return {"Error": "Please enter a protein sequence."}
 
115
 
116
- # Clean FASTA header if present
117
  sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
118
  sequence = re.sub(r'[^A-Z]', '', sequence.upper())
119
 
120
  if not sequence:
121
- return {"Error": "Sequence is empty after cleaning. Please enter a valid amino acid sequence."}
122
 
123
  with torch.no_grad():
124
  inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
@@ -135,52 +129,129 @@ def predict(sequence_input):
135
  return confidences
136
 
137
  # ==========================
138
- # 4. Gradio Interface
139
  # ==========================
140
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px; margin: auto;}") as app:
141
- gr.Markdown(
142
- """
143
- # 🧬 Protein Subcellular Localization Prediction
144
- A prediction tool based on **ESM-2 (150M)** and a custom **dual-branch enhanced classifier**.
145
- """
146
- )
147
 
148
- with gr.Row():
149
- with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  sequence_input = gr.Textbox(
151
- lines=10,
152
- label="Protein Sequence",
153
- placeholder="Paste your amino acid sequence here..."
 
154
  )
155
 
156
  with gr.Row():
157
- clear_btn = gr.ClearButton()
158
- submit_btn = gr.Button("🚀 Predict", variant="primary")
159
 
 
160
  gr.Examples(
161
  examples=[
162
- [">sp|P27361|PBP2_ECOLI Penicillin-binding protein 2 OS=Escherichia coli (strain K12) OX=83333 GN=mrdA PE=1 SV=2\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
163
  ["MSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
164
  ],
165
  inputs=sequence_input,
166
- label="Examples"
167
  )
168
 
169
- with gr.Column(scale=1):
170
- output_label = gr.Label(num_top_classes=NUM_CLASSES, label="Prediction Results")
 
 
 
 
 
 
 
171
 
172
- with gr.Accordion("Model Information", open=False):
 
173
  gr.Markdown(
174
  """
175
- * **Protein Language Model (PLM)**: `facebook/esm2_t30_150M_UR50D`
176
- * **Downstream Classifier**: `ProtDualBranchEnhancedClassifier`
177
- * **GitHub**: github.com/isyslab-hust
 
 
 
178
  """
179
  )
180
 
181
- gr.Markdown("---\n*Built by isyslab*")
 
 
 
 
 
 
 
182
 
183
- submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label, api_name="predict")
184
- clear_btn.click(lambda: [None, None], outputs=[sequence_input, output_label])
 
185
 
 
186
  app.launch()
 
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  # ==========================
9
+ # 🚧 0. 防止 Hugging Face 缓存溢出 (保持不变)
10
  # ==========================
11
  os.environ["HF_HOME"] = "/tmp/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
13
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
14
 
 
15
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
16
  shutil.rmtree(path, ignore_errors=True)
17
  os.makedirs(path, exist_ok=True)
18
 
19
  # ==========================
20
+ # 1. Model Definition (保持不变)
21
  # ==========================
22
  class AttentionPooling(nn.Module):
 
23
  def __init__(self, d_model):
24
  super().__init__()
25
  self.attention_net = nn.Linear(d_model, 1)
 
31
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
32
 
33
  class ProtDualBranchEnhancedClassifier(nn.Module):
 
34
  def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
35
  super().__init__()
36
  self.cls_projector = nn.Linear(d_model, projection_dim)
 
65
  return self.classifier_head(z_fused_gated)
66
 
67
  # ==========================
68
+ # 2. Load Models and Files (保持不变)
69
  # ==========================
70
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
72
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
73
  LABEL_MAP_PATH = "label_map.json"
74
 
 
75
  if not os.path.exists(LABEL_MAP_PATH):
76
+ raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'.")
77
  with open(LABEL_MAP_PATH, 'r') as f:
78
  label_to_idx = json.load(f)
79
  idx_to_label = {v: k for k, v in label_to_idx.items()}
 
81
  NUM_CLASSES = len(idx_to_label)
82
  D_MODEL = 640
83
 
 
84
  print("🔹 Loading Protein Language Model...")
85
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
86
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
87
  plm_model.eval()
88
+ print("✅ PLM loaded.")
89
 
90
+ print("🔹 Loading classifier...")
 
91
  classifier = ProtDualBranchEnhancedClassifier(
92
  d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
93
  dropout=0.3, kernel_size=3
94
  ).to(DEVICE)
95
 
96
  if not os.path.exists(CLASSIFIER_PATH):
97
+ raise FileNotFoundError(f"Error: Could not find '{CLASSIFIER_PATH}'.")
98
 
99
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
100
  classifier.eval()
101
+ print("✅ System Ready.")
102
 
103
  # ==========================
104
+ # 3. Prediction Function (微调)
105
  # ==========================
106
  def predict(sequence_input):
107
  if not sequence_input or sequence_input.isspace():
108
+ # 返回 None 而不是字典,让 Label 组件显示更干净
109
+ raise gr.Error("Sequence cannot be empty.")
110
 
 
111
  sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
112
  sequence = re.sub(r'[^A-Z]', '', sequence.upper())
113
 
114
  if not sequence:
115
+ raise gr.Error("Invalid sequence format.")
116
 
117
  with torch.no_grad():
118
  inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
 
129
  return confidences
130
 
131
  # ==========================
132
+ # 4. Modernized Gradio Interface
133
  # ==========================
 
 
 
 
 
 
 
134
 
135
+ # 自定义 CSS:增加渐变标题、阴影、圆角
136
+ custom_css = """
137
+ .gradio-container {
138
+ font-family: 'IBM Plex Sans', sans-serif;
139
+ }
140
+ .main-header {
141
+ text-align: center;
142
+ background: linear-gradient(135deg, #3b82f6 0%, #06b6d4 100%);
143
+ color: white;
144
+ padding: 2rem;
145
+ border-radius: 12px;
146
+ margin-bottom: 1.5rem;
147
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
148
+ }
149
+ .main-header h1 {
150
+ color: white;
151
+ margin-bottom: 0.5rem;
152
+ font-size: 2.2rem;
153
+ }
154
+ .main-header p {
155
+ color: #e0f2fe;
156
+ font-size: 1.1rem;
157
+ }
158
+ .input-card, .output-card {
159
+ border: 1px solid #e5e7eb;
160
+ border-radius: 12px;
161
+ padding: 1.5rem;
162
+ background: white;
163
+ box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1);
164
+ }
165
+ """
166
+
167
+ # 使用更清爽的 Teal (青色) 主题,符合生物信息学特征
168
+ theme = gr.themes.Soft(
169
+ primary_hue="teal",
170
+ secondary_hue="blue",
171
+ neutral_hue="slate",
172
+ font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
173
+ ).set(
174
+ button_primary_background_fill="*primary_600",
175
+ button_primary_background_fill_hover="*primary_700",
176
+ block_shadow="*shadow_drop_lg"
177
+ )
178
+
179
+ with gr.Blocks(theme=theme, css=custom_css, title="LocPred-Prok") as app:
180
+
181
+ # --- 顶部 Header ---
182
+ with gr.Column(elem_classes="main-header"):
183
+ gr.Markdown(
184
+ """
185
+ # 🧬 Prokaryotic Subcellular Localization
186
+ ### Dual-Branch Architecture with Protein Language Models
187
+ Identify where your protein functions using State-of-the-Art Deep Learning.
188
+ """
189
+ )
190
+
191
+ # --- 主体内容 ---
192
+ with gr.Row(equal_height=False):
193
+
194
+ # 左侧:输入区
195
+ with gr.Column(scale=5, elem_classes="input-card"):
196
+ gr.Markdown("### 📥 Input Sequence")
197
+ gr.Markdown("Paste your amino acid sequence (FASTA format supported).")
198
+
199
  sequence_input = gr.Textbox(
200
+ lines=8,
201
+ label="",
202
+ placeholder=">Example Header\nMKFKLTAGCLAVAGVLLASSFGADAEIVV...",
203
+ show_label=False
204
  )
205
 
206
  with gr.Row():
207
+ clear_btn = gr.ClearButton(components=[sequence_input], value="Clear")
208
+ submit_btn = gr.Button(" Run Prediction", variant="primary", scale=2)
209
 
210
+ gr.Markdown("#### 💡 Example Sequences")
211
  gr.Examples(
212
  examples=[
213
+ [">sp|P27361|PBP2_ECOLI Penicillin-binding protein 2\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
214
  ["MSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
215
  ],
216
  inputs=sequence_input,
217
+ label=None
218
  )
219
 
220
+ # 右侧:输出区
221
+ with gr.Column(scale=4, elem_classes="output-card"):
222
+ gr.Markdown("### 📊 Prediction Results")
223
+
224
+ output_label = gr.Label(
225
+ num_top_classes=NUM_CLASSES,
226
+ label="Probability Distribution",
227
+ show_label=False
228
+ )
229
 
230
+ # 信息折叠面板
231
+ with gr.Accordion("📘 Model Architecture & Details", open=False):
232
  gr.Markdown(
233
  """
234
+ This model utilizes a **Dual-Branch Architecture**:
235
+ 1. **Semantic Branch**: Extracts global features using `ESM-2 (150M)` CLS token.
236
+ 2. **Structural Branch**: Refines residue-level embeddings via CNN and Attention Pooling.
237
+
238
+ **Citation:**
239
+ *LocPred-Prok: Prokaryotic protein subcellular localization prediction with a dual-branch architecture.*
240
  """
241
  )
242
 
243
+ # --- 底部 Footer ---
244
+ gr.Markdown(
245
+ """
246
+ <div style="text-align: center; margin-top: 2rem; color: #64748b; font-size: 0.9rem;">
247
+ © 2025 iSysLab HUST | Powered by ESM-2 & PyTorch
248
+ </div>
249
+ """
250
+ )
251
 
252
+ # --- ���定事件 ---
253
+ submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label)
254
+ clear_btn.click(lambda: None, outputs=[output_label])
255
 
256
+ # 启动
257
  app.launch()