wangleiofficial commited on
Commit
a2cdbf7
·
verified ·
1 Parent(s): 886c88b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -9
app.py CHANGED
@@ -6,21 +6,144 @@ import gradio as gr
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  # ==========================
9
- # 0-3 部分:保持你的底层逻辑完全不变
10
  # ==========================
11
- # ... (请保持之前的 Imports, Model Definition, Load Models, Predict Function 代码完全一致) ...
12
- # 为了节省篇幅,这里假设你已经保留了之前代码的第0到第3部分 (直到 def predict 为止)
13
- # 务必确保运行前包含之前的 Model 类定义和加载逻辑!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # ==========================
16
- # 4. Academic Research Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ==========================
18
 
19
  # 学术风格 CSS
20
  academic_css = """
21
  body { font-family: 'Roboto', 'Helvetica Neue', Arial, sans-serif; }
22
  .header-container {
23
- background: linear-gradient(to right, #1e3a8a, #3b82f6); /* 深蓝学术风 */
24
  color: white;
25
  padding: 2.5rem;
26
  border-radius: 10px;
@@ -106,6 +229,8 @@ with gr.Blocks(theme=theme, css=academic_css, title="LocPred-Prok Web Server") a
106
  # 右侧输出
107
  with gr.Column(scale=4):
108
  gr.Markdown("### 📊 Prediction Results")
 
 
109
  output_label = gr.Label(num_top_classes=NUM_CLASSES, label="Probabilities")
110
 
111
  # 解释性文字
@@ -164,9 +289,6 @@ with gr.Blocks(theme=theme, css=academic_css, title="LocPred-Prok Web Server") a
164
  </table>
165
  """)
166
 
167
- # 这里可以放架构图,如果你有图片链接的话
168
- # gr.Image("https://your-image-url.com/architecture.png", label="Model Architecture")
169
-
170
  # === Tab 3: Citation (引用) ===
171
  with gr.TabItem("📝 Citation"):
172
  gr.Markdown("If you use LocPred-Prok in your research, please cite our paper:")
@@ -194,4 +316,5 @@ with gr.Blocks(theme=theme, css=academic_css, title="LocPred-Prok Web Server") a
194
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label)
195
  clear_btn.click(lambda: None, outputs=[output_label])
196
 
 
197
  app.launch()
 
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  # ==========================
9
+ # 🚧 0. 防止 Hugging Face 缓存溢出
10
  # ==========================
11
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
12
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
13
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
14
+
15
+ # 每次启动时清理旧缓存
16
+ for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
17
+ shutil.rmtree(path, ignore_errors=True)
18
+ os.makedirs(path, exist_ok=True)
19
+
20
+ # ==========================
21
+ # 1. Model Definition (模型架构定义)
22
+ # ==========================
23
+ class AttentionPooling(nn.Module):
24
+ """Attention Pooling Layer"""
25
+ def __init__(self, d_model):
26
+ super().__init__()
27
+ self.attention_net = nn.Linear(d_model, 1)
28
+
29
+ def forward(self, x, mask):
30
+ attn_logits = self.attention_net(x).squeeze(2)
31
+ attn_logits.masked_fill_(mask == 0, -float('inf'))
32
+ attn_weights = F.softmax(attn_logits, dim=1)
33
+ return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
34
+
35
+ class ProtDualBranchEnhancedClassifier(nn.Module):
36
+ """Enhanced dual-branch model"""
37
+ def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
38
+ super().__init__()
39
+ self.cls_projector = nn.Linear(d_model, projection_dim)
40
+ self.token_refiner = nn.Sequential(
41
+ nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
42
+ nn.ReLU()
43
+ )
44
+ self.attention_pooling = AttentionPooling(d_model)
45
+ self.tok_projector = nn.Linear(d_model, projection_dim)
46
+ fused_dim = projection_dim * 2
47
+ self.gate = nn.Sequential(
48
+ nn.Linear(fused_dim, fused_dim),
49
+ nn.Sigmoid()
50
+ )
51
+ self.classifier_head = nn.Sequential(
52
+ nn.LayerNorm(fused_dim),
53
+ nn.Linear(fused_dim, fused_dim * 2),
54
+ nn.ReLU(),
55
+ nn.Dropout(dropout),
56
+ nn.Linear(fused_dim * 2, num_classes)
57
+ )
58
+
59
+ def forward(self, cls_embedding, token_embeddings, mask):
60
+ z_cls = self.cls_projector(cls_embedding)
61
+ tok_emb_permuted = token_embeddings.permute(0, 2, 1)
62
+ refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
63
+ z_tok_pooled = self.attention_pooling(refined_tok_emb, mask)
64
+ z_tok = self.tok_projector(z_tok_pooled)
65
+ z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
66
+ gate_values = self.gate(z_fused_concat)
67
+ z_fused_gated = z_fused_concat * gate_values
68
+ return self.classifier_head(z_fused_gated)
69
+
70
+ # ==========================
71
+ # 2. Load Models and Files (加载模型与配置)
72
+ # ==========================
73
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
75
+ CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
76
+ LABEL_MAP_PATH = "label_map.json"
77
+
78
+ # --- 加载标签映射 (这里定义了 NUM_CLASSES) ---
79
+ if not os.path.exists(LABEL_MAP_PATH):
80
+ raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'. Please upload it to your Space.")
81
+ with open(LABEL_MAP_PATH, 'r') as f:
82
+ label_to_idx = json.load(f)
83
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
84
+
85
+ # ✅ 关键变量定义
86
+ NUM_CLASSES = len(idx_to_label)
87
+ D_MODEL = 640
88
+
89
+ # --- 加载预训练蛋白模型 ---
90
+ print("🔹 Loading Protein Language Model...")
91
+ tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
92
+ plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
93
+ plm_model.eval()
94
+ print("✅ PLM loaded successfully.")
95
+
96
+ # --- 加载下游分类器 ---
97
+ print("🔹 Loading downstream classifier...")
98
+ classifier = ProtDualBranchEnhancedClassifier(
99
+ d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
100
+ dropout=0.3, kernel_size=3
101
+ ).to(DEVICE)
102
+
103
+ if not os.path.exists(CLASSIFIER_PATH):
104
+ raise FileNotFoundError(f"Error: Could not find '{CLASSIFIER_PATH}'. Please upload your trained .pth file.")
105
+
106
+ classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
107
+ classifier.eval()
108
+ print("✅ Classifier loaded. Application is ready!")
109
 
110
  # ==========================
111
+ # 3. Prediction Function (预测函数)
112
+ # ==========================
113
+ def predict(sequence_input):
114
+ if not sequence_input or sequence_input.isspace():
115
+ raise gr.Error("Sequence cannot be empty.")
116
+
117
+ # Clean FASTA header if present
118
+ sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
119
+ sequence = re.sub(r'[^A-Z]', '', sequence.upper())
120
+
121
+ if not sequence:
122
+ raise gr.Error("Invalid sequence format. Please enter amino acids (A-Z).")
123
+
124
+ with torch.no_grad():
125
+ inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
126
+ outputs = plm_model(**inputs)
127
+ hidden_states = outputs.last_hidden_state
128
+ cls_embedding = hidden_states[:, 0, :]
129
+ token_embeddings = hidden_states[:, 1:-1, :]
130
+ token_mask = inputs['attention_mask'][:, 1:-1]
131
+
132
+ logits = classifier(cls_embedding, token_embeddings, token_mask)
133
+ probabilities = F.softmax(logits, dim=1)[0]
134
+
135
+ confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
136
+ return confidences
137
+
138
+ # ==========================
139
+ # 4. Academic Research Interface (UI 界面)
140
  # ==========================
141
 
142
  # 学术风格 CSS
143
  academic_css = """
144
  body { font-family: 'Roboto', 'Helvetica Neue', Arial, sans-serif; }
145
  .header-container {
146
+ background: linear-gradient(to right, #1e3a8a, #3b82f6);
147
  color: white;
148
  padding: 2.5rem;
149
  border-radius: 10px;
 
229
  # 右侧输出
230
  with gr.Column(scale=4):
231
  gr.Markdown("### 📊 Prediction Results")
232
+
233
+ # ✅ 这里使用了 NUM_CLASSES,现在它已经在前面定义过了
234
  output_label = gr.Label(num_top_classes=NUM_CLASSES, label="Probabilities")
235
 
236
  # 解释性文字
 
289
  </table>
290
  """)
291
 
 
 
 
292
  # === Tab 3: Citation (引用) ===
293
  with gr.TabItem("📝 Citation"):
294
  gr.Markdown("If you use LocPred-Prok in your research, please cite our paper:")
 
316
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label)
317
  clear_btn.click(lambda: None, outputs=[output_label])
318
 
319
+ # 启动
320
  app.launch()