wangleiofficial commited on
Commit
7b163dc
·
verified ·
1 Parent(s): ab6d96f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -49
app.py CHANGED
@@ -55,7 +55,6 @@ PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
55
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
56
  LABEL_MAP_PATH = "label_map.json"
57
 
58
- # Load the label map file
59
  try:
60
  with open(LABEL_MAP_PATH, 'r') as f:
61
  label_to_idx = json.load(f)
@@ -64,23 +63,18 @@ except FileNotFoundError:
64
  raise FileNotFoundError(f"Error: Could not find '{LABEL_MAP_PATH}'. Please make sure this file is uploaded to the Space.")
65
 
66
  NUM_CLASSES = len(idx_to_label)
67
- D_MODEL = 640 # Dimension for esm2_t30_150M_UR50D
68
 
69
- # Load Protein Language Model (PLM) and tokenizer
70
  print("Loading Protein Language Model...")
71
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
72
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
73
  plm_model.eval()
74
  print("PLM loaded successfully.")
75
 
76
- # Load your trained downstream classifier
77
  print("Loading downstream classifier...")
78
  classifier = ProtDualBranchEnhancedClassifier(
79
- d_model=D_MODEL,
80
- projection_dim=32,
81
- num_classes=NUM_CLASSES,
82
- dropout=0.3,
83
- kernel_size=3
84
  ).to(DEVICE)
85
 
86
  if not os.path.exists(CLASSIFIER_PATH):
@@ -92,24 +86,15 @@ print("Classifier loaded. Application is ready!")
92
 
93
  # --- 3. Prediction Function ---
94
  def predict(sequence_input):
95
- """
96
- Receives a protein sequence and returns a dictionary of class probabilities.
97
- """
98
  if not sequence_input or sequence_input.isspace():
99
  return {"Error": "Please enter a protein sequence."}
100
 
101
- # Clean the input, support FASTA format
102
- if sequence_input.startswith('>'):
103
- sequence = "".join(sequence_input.split('\n')[1:])
104
- else:
105
- sequence = sequence_input
106
-
107
  sequence = re.sub(r'[^A-Z]', '', sequence.upper())
108
 
109
  if not sequence:
110
  return {"Error": "Sequence is empty after cleaning. Please enter a valid amino acid sequence."}
111
 
112
- # Feature extraction with PLM
113
  with torch.no_grad():
114
  inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
115
  outputs = plm_model(**inputs)
@@ -117,41 +102,68 @@ def predict(sequence_input):
117
  hidden_states = outputs.last_hidden_state
118
  cls_embedding = hidden_states[:, 0, :]
119
  token_embeddings = hidden_states[:, 1:-1, :]
120
- token_mask = inputs['attention_mask'][:, 2:]
121
-
122
- # Prediction with the downstream classifier
123
  with torch.no_grad():
124
  logits = classifier(cls_embedding, token_embeddings, token_mask)
125
  probabilities = F.softmax(logits, dim=1)[0]
126
 
127
- # Format the output
128
  confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
129
-
130
  return confidences
131
 
132
- # --- 4. Create Gradio Interface ---
133
- title = "Predicting the subcellular location of prokaryotic proteins with LocPred-Prok"
134
- description = """
135
- This is a prediction tool based on the **ESM-2 (150M)** Protein Language Model and a custom **`dual_branch_enhanced`** classifier.
136
- Simply paste a protein's amino acid sequence (FASTA format or raw sequence are both supported) into the text box below, and the model will predict its localization within the cell.
137
- """
138
- examples = [
139
- [">sp|P27361|PBP2_ECOLI Penicillin-binding protein 2 OS=Escherichia coli (strain K12) OX=83333 GN=mrdA PE=1 SV=2\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
140
- ["MSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
141
- ]
142
-
143
- gr.Interface(
144
- fn=predict,
145
- inputs=gr.Textbox(
146
- lines=10,
147
- label="Protein Sequence",
148
- placeholder="Paste your amino acid sequence here..."
149
- ),
150
- outputs=gr.Label(num_top_classes=NUM_CLASSES, label="Prediction Results"),
151
- title=title,
152
- description=description,
153
- examples=examples,
154
- allow_flagging="never",
155
- theme=gr.themes.Soft()
156
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
 
55
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
56
  LABEL_MAP_PATH = "label_map.json"
57
 
 
58
  try:
59
  with open(LABEL_MAP_PATH, 'r') as f:
60
  label_to_idx = json.load(f)
 
63
  raise FileNotFoundError(f"Error: Could not find '{LABEL_MAP_PATH}'. Please make sure this file is uploaded to the Space.")
64
 
65
  NUM_CLASSES = len(idx_to_label)
66
+ D_MODEL = 640
67
 
 
68
  print("Loading Protein Language Model...")
69
  tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
70
  plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
71
  plm_model.eval()
72
  print("PLM loaded successfully.")
73
 
 
74
  print("Loading downstream classifier...")
75
  classifier = ProtDualBranchEnhancedClassifier(
76
+ d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES,
77
+ dropout=0.3, kernel_size=3
 
 
 
78
  ).to(DEVICE)
79
 
80
  if not os.path.exists(CLASSIFIER_PATH):
 
86
 
87
  # --- 3. Prediction Function ---
88
  def predict(sequence_input):
 
 
 
89
  if not sequence_input or sequence_input.isspace():
90
  return {"Error": "Please enter a protein sequence."}
91
 
92
+ sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
 
 
 
 
 
93
  sequence = re.sub(r'[^A-Z]', '', sequence.upper())
94
 
95
  if not sequence:
96
  return {"Error": "Sequence is empty after cleaning. Please enter a valid amino acid sequence."}
97
 
 
98
  with torch.no_grad():
99
  inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
100
  outputs = plm_model(**inputs)
 
102
  hidden_states = outputs.last_hidden_state
103
  cls_embedding = hidden_states[:, 0, :]
104
  token_embeddings = hidden_states[:, 1:-1, :]
105
+ token_mask = inputs['attention_mask'][:, 1:-1]
106
+
 
107
  with torch.no_grad():
108
  logits = classifier(cls_embedding, token_embeddings, token_mask)
109
  probabilities = F.softmax(logits, dim=1)[0]
110
 
 
111
  confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
 
112
  return confidences
113
 
114
+ # --- 4. Create Beautified Gradio Interface using Blocks ---
115
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px; margin: auto;}") as app:
116
+ gr.Markdown(
117
+ """
118
+ # Protein Subcellular Localization Prediction
119
+ An online prediction tool based on the **ESM-2 (150M)** Protein Language Model and a custom **`dual_branch_enhanced`** classifier.
120
+
121
+ Just paste the amino acid sequence of a protein (FASTA format or raw sequence are supported), and the model will predict its location within the cell.
122
+ """
123
+ )
124
+
125
+ with gr.Row():
126
+ with gr.Column(scale=1):
127
+ sequence_input = gr.Textbox(
128
+ lines=10,
129
+ label="Protein Sequence",
130
+ placeholder="Paste your amino acid sequence here..."
131
+ )
132
+
133
+ with gr.Row():
134
+ clear_btn = gr.ClearButton()
135
+ submit_btn = gr.Button("🚀 Predict", variant="primary")
136
+
137
+ gr.Examples(
138
+ examples=[
139
+ [">sp|P27361|PBP2_ECOLI Penicillin-binding protein 2 OS=Escherichia coli (strain K12) OX=83333 GN=mrdA PE=1 SV=2\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
140
+ ["MSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
141
+ ],
142
+ inputs=sequence_input,
143
+ label="Examples"
144
+ )
145
+
146
+ with gr.Column(scale=1):
147
+ output_label = gr.Label(num_top_classes=NUM_CLASSES, label="Prediction Results")
148
+
149
+ with gr.Accordion("Model Information", open=False):
150
+ gr.Markdown(
151
+ """
152
+ * **Protein Language Model (PLM)**: `facebook/esm2_t30_150M_UR50D`
153
+ * **Downstream Classifier**: `ProtDualBranchEnhancedClassifier`
154
+ * **GitHub Repository**: github.com/isyslab-hust
155
+ """
156
+ )
157
+
158
+ gr.Markdown(
159
+ """
160
+ ---
161
+ *Built by isyslab*
162
+ """
163
+ )
164
+
165
+ submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label, api_name="predict")
166
+ clear_btn.click(lambda: [None, None], outputs=[sequence_input, output_label])
167
+
168
+ app.launch()
169