wangleiofficial commited on
Commit
4782d51
·
verified ·
1 Parent(s): 7b163dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -44
app.py CHANGED
@@ -1,13 +1,25 @@
1
- import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
5
  from transformers import AutoTokenizer, AutoModel
6
- import json
7
- import os
8
- import re
9
 
10
- # --- 1. Model Definition (Must be identical to the one used during training) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class AttentionPooling(nn.Module):
12
  """Attention Pooling Layer"""
13
  def __init__(self, d_model):
@@ -27,17 +39,23 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
27
  self.cls_projector = nn.Linear(d_model, projection_dim)
28
  self.token_refiner = nn.Sequential(
29
  nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
30
- nn.ReLU())
 
31
  self.attention_pooling = AttentionPooling(d_model)
32
  self.tok_projector = nn.Linear(d_model, projection_dim)
33
  fused_dim = projection_dim * 2
34
- self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
 
 
 
35
  self.classifier_head = nn.Sequential(
36
  nn.LayerNorm(fused_dim),
37
  nn.Linear(fused_dim, fused_dim * 2),
38
  nn.ReLU(),
39
  nn.Dropout(dropout),
40
- nn.Linear(fused_dim * 2, num_classes))
 
 
41
  def forward(self, cls_embedding, token_embeddings, mask):
42
  z_cls = self.cls_projector(cls_embedding)
43
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
@@ -49,76 +67,81 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
49
  z_fused_gated = z_fused_concat * gate_values
50
  return self.classifier_head(z_fused_gated)
51
 
52
- # --- 2. Load Models and Auxiliary Files ---
 
 
53
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
55
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
56
  LABEL_MAP_PATH = "label_map.json"
57
 
58
- try:
59
- with open(LABEL_MAP_PATH, 'r') as f:
60
- label_to_idx = json.load(f)
61
- idx_to_label = {v: k for k, v in label_to_idx.items()}
62
- except FileNotFoundError:
63
- raise FileNotFoundError(f"Error: Could not find '{LABEL_MAP_PATH}'. Please make sure this file is uploaded to the Space.")
64
 
65
  NUM_CLASSES = len(idx_to_label)
66
  D_MODEL = 640
67
 
68
- print("Loading Protein Language Model...")
 
69
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
70
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
71
  plm_model.eval()
72
- print("PLM loaded successfully.")
73
 
74
- print("Loading downstream classifier...")
 
75
  classifier = ProtDualBranchEnhancedClassifier(
76
  d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
77
  dropout=0.3, kernel_size=3
78
  ).to(DEVICE)
79
 
80
  if not os.path.exists(CLASSIFIER_PATH):
81
- raise FileNotFoundError(f"Error: Could not find the trained model file '{CLASSIFIER_PATH}'. Please make sure the correct .pth file is uploaded.")
82
 
83
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
84
  classifier.eval()
85
- print("Classifier loaded. Application is ready!")
86
 
87
- # --- 3. Prediction Function ---
 
 
88
  def predict(sequence_input):
89
  if not sequence_input or sequence_input.isspace():
90
  return {"Error": "Please enter a protein sequence."}
91
 
 
92
  sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
93
  sequence = re.sub(r'[^A-Z]', '', sequence.upper())
94
-
95
  if not sequence:
96
  return {"Error": "Sequence is empty after cleaning. Please enter a valid amino acid sequence."}
97
 
98
  with torch.no_grad():
99
  inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
100
  outputs = plm_model(**inputs)
101
-
102
  hidden_states = outputs.last_hidden_state
103
  cls_embedding = hidden_states[:, 0, :]
104
  token_embeddings = hidden_states[:, 1:-1, :]
105
  token_mask = inputs['attention_mask'][:, 1:-1]
106
 
107
- with torch.no_grad():
108
  logits = classifier(cls_embedding, token_embeddings, token_mask)
109
  probabilities = F.softmax(logits, dim=1)[0]
110
-
111
  confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
112
  return confidences
113
 
114
- # --- 4. Create Beautified Gradio Interface using Blocks ---
 
 
115
  with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px; margin: auto;}") as app:
116
  gr.Markdown(
117
  """
118
- # Protein Subcellular Localization Prediction
119
- An online prediction tool based on the **ESM-2 (150M)** Protein Language Model and a custom **`dual_branch_enhanced`** classifier.
120
-
121
- Just paste the amino acid sequence of a protein (FASTA format or raw sequence are supported), and the model will predict its location within the cell.
122
  """
123
  )
124
 
@@ -129,7 +152,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px;
129
  label="Protein Sequence",
130
  placeholder="Paste your amino acid sequence here..."
131
  )
132
-
133
  with gr.Row():
134
  clear_btn = gr.ClearButton()
135
  submit_btn = gr.Button("🚀 Predict", variant="primary")
@@ -145,25 +168,19 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px;
145
 
146
  with gr.Column(scale=1):
147
  output_label = gr.Label(num_top_classes=NUM_CLASSES, label="Prediction Results")
148
-
149
  with gr.Accordion("Model Information", open=False):
150
  gr.Markdown(
151
  """
152
- * **Protein Language Model (PLM)**: `facebook/esm2_t30_150M_UR50D`
153
- * **Downstream Classifier**: `ProtDualBranchEnhancedClassifier`
154
- * **GitHub Repository**: github.com/isyslab-hust
155
  """
156
  )
157
 
158
- gr.Markdown(
159
- """
160
- ---
161
- *Built by isyslab*
162
- """
163
- )
164
 
165
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label, api_name="predict")
166
  clear_btn.click(lambda: [None, None], outputs=[sequence_input, output_label])
167
 
168
- app.launch()
169
-
 
1
+ import os, shutil, json, re
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ import gradio as gr
6
  from transformers import AutoTokenizer, AutoModel
 
 
 
7
 
8
+ # ==========================
9
+ # 🚧 0. 防止 Hugging Face 缓存溢出
10
+ # ==========================
11
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
12
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
13
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
14
+
15
+ # 每次启动时清理旧缓存,防止超过 50G 限制
16
+ for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
17
+ shutil.rmtree(path, ignore_errors=True)
18
+ os.makedirs(path, exist_ok=True)
19
+
20
+ # ==========================
21
+ # 1. Model Definition
22
+ # ==========================
23
  class AttentionPooling(nn.Module):
24
  """Attention Pooling Layer"""
25
  def __init__(self, d_model):
 
39
  self.cls_projector = nn.Linear(d_model, projection_dim)
40
  self.token_refiner = nn.Sequential(
41
  nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
42
+ nn.ReLU()
43
+ )
44
  self.attention_pooling = AttentionPooling(d_model)
45
  self.tok_projector = nn.Linear(d_model, projection_dim)
46
  fused_dim = projection_dim * 2
47
+ self.gate = nn.Sequential(
48
+ nn.Linear(fused_dim, fused_dim),
49
+ nn.Sigmoid()
50
+ )
51
  self.classifier_head = nn.Sequential(
52
  nn.LayerNorm(fused_dim),
53
  nn.Linear(fused_dim, fused_dim * 2),
54
  nn.ReLU(),
55
  nn.Dropout(dropout),
56
+ nn.Linear(fused_dim * 2, num_classes)
57
+ )
58
+
59
  def forward(self, cls_embedding, token_embeddings, mask):
60
  z_cls = self.cls_projector(cls_embedding)
61
  tok_emb_permuted = token_embeddings.permute(0, 2, 1)
 
67
  z_fused_gated = z_fused_concat * gate_values
68
  return self.classifier_head(z_fused_gated)
69
 
70
+ # ==========================
71
+ # 2. Load Models and Files
72
+ # ==========================
73
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D" # 可改为 esm2_t12_35M_UR50D 减少体积
75
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
76
  LABEL_MAP_PATH = "label_map.json"
77
 
78
+ # --- 加载标签映射 ---
79
+ if not os.path.exists(LABEL_MAP_PATH):
80
+ raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'. Please upload it to your Space.")
81
+ with open(LABEL_MAP_PATH, 'r') as f:
82
+ label_to_idx = json.load(f)
83
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
84
 
85
  NUM_CLASSES = len(idx_to_label)
86
  D_MODEL = 640
87
 
88
+ # --- 加载预训练蛋白模型 ---
89
+ print("🔹 Loading Protein Language Model...")
90
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
91
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
92
  plm_model.eval()
93
+ print("PLM loaded successfully.")
94
 
95
+ # --- 加载下游分类器 ---
96
+ print("🔹 Loading downstream classifier...")
97
  classifier = ProtDualBranchEnhancedClassifier(
98
  d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
99
  dropout=0.3, kernel_size=3
100
  ).to(DEVICE)
101
 
102
  if not os.path.exists(CLASSIFIER_PATH):
103
+ raise FileNotFoundError(f"Error: Could not find '{CLASSIFIER_PATH}'. Please upload your trained .pth file.")
104
 
105
  classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
106
  classifier.eval()
107
+ print("Classifier loaded. Application is ready!")
108
 
109
+ # ==========================
110
+ # 3. Prediction Function
111
+ # ==========================
112
  def predict(sequence_input):
113
  if not sequence_input or sequence_input.isspace():
114
  return {"Error": "Please enter a protein sequence."}
115
 
116
+ # Clean FASTA header if present
117
  sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
118
  sequence = re.sub(r'[^A-Z]', '', sequence.upper())
119
+
120
  if not sequence:
121
  return {"Error": "Sequence is empty after cleaning. Please enter a valid amino acid sequence."}
122
 
123
  with torch.no_grad():
124
  inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
125
  outputs = plm_model(**inputs)
 
126
  hidden_states = outputs.last_hidden_state
127
  cls_embedding = hidden_states[:, 0, :]
128
  token_embeddings = hidden_states[:, 1:-1, :]
129
  token_mask = inputs['attention_mask'][:, 1:-1]
130
 
 
131
  logits = classifier(cls_embedding, token_embeddings, token_mask)
132
  probabilities = F.softmax(logits, dim=1)[0]
133
+
134
  confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
135
  return confidences
136
 
137
+ # ==========================
138
+ # 4. Gradio Interface
139
+ # ==========================
140
  with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px; margin: auto;}") as app:
141
  gr.Markdown(
142
  """
143
+ # 🧬 Protein Subcellular Localization Prediction
144
+ A prediction tool based on **ESM-2 (150M)** and a custom **dual-branch enhanced classifier**.
 
 
145
  """
146
  )
147
 
 
152
  label="Protein Sequence",
153
  placeholder="Paste your amino acid sequence here..."
154
  )
155
+
156
  with gr.Row():
157
  clear_btn = gr.ClearButton()
158
  submit_btn = gr.Button("🚀 Predict", variant="primary")
 
168
 
169
  with gr.Column(scale=1):
170
  output_label = gr.Label(num_top_classes=NUM_CLASSES, label="Prediction Results")
171
+
172
  with gr.Accordion("Model Information", open=False):
173
  gr.Markdown(
174
  """
175
+ * **Protein Language Model (PLM)**: `facebook/esm2_t30_150M_UR50D`
176
+ * **Downstream Classifier**: `ProtDualBranchEnhancedClassifier`
177
+ * **GitHub**: github.com/isyslab-hust
178
  """
179
  )
180
 
181
+ gr.Markdown("---\n*Built by isyslab*")
 
 
 
 
 
182
 
183
  submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label, api_name="predict")
184
  clear_btn.click(lambda: [None, None], outputs=[sequence_input, output_label])
185
 
186
+ app.launch()