wangleiofficial commited on
Commit
4c7c3fa
·
verified ·
1 Parent(s): 662dd77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import json
7
+ import os
8
+ import re
9
+
10
+ # --- 1. Model Definition (Must be identical to the one used during training) ---
11
+ class AttentionPooling(nn.Module):
12
+ """Attention Pooling Layer"""
13
+ def __init__(self, d_model):
14
+ super().__init__()
15
+ self.attention_net = nn.Linear(d_model, 1)
16
+
17
+ def forward(self, x, mask):
18
+ attn_logits = self.attention_net(x).squeeze(2)
19
+ attn_logits.masked_fill_(mask == 0, -float('inf'))
20
+ attn_weights = F.softmax(attn_logits, dim=1)
21
+ return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
22
+
23
+ class ProtDualBranchEnhancedClassifier(nn.Module):
24
+ """Enhanced dual-branch model"""
25
+ def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
26
+ super().__init__()
27
+ self.cls_projector = nn.Linear(d_model, projection_dim)
28
+ self.token_refiner = nn.Sequential(
29
+ nn.Conv1d(d_model, d_model, kernel_size, padding='same'),
30
+ nn.ReLU())
31
+ self.attention_pooling = AttentionPooling(d_model)
32
+ self.tok_projector = nn.Linear(d_model, projection_dim)
33
+ fused_dim = projection_dim * 2
34
+ self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
35
+ self.classifier_head = nn.Sequential(
36
+ nn.LayerNorm(fused_dim),
37
+ nn.Linear(fused_dim, fused_dim * 2),
38
+ nn.ReLU(),
39
+ nn.Dropout(dropout),
40
+ nn.Linear(fused_dim * 2, num_classes))
41
+ def forward(self, cls_embedding, token_embeddings, mask):
42
+ z_cls = self.cls_projector(cls_embedding)
43
+ tok_emb_permuted = token_embeddings.permute(0, 2, 1)
44
+ refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
45
+ z_tok_pooled = self.attention_pooling(refined_tok_emb, mask)
46
+ z_tok = self.tok_projector(z_tok_pooled)
47
+ z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
48
+ gate_values = self.gate(z_fused_concat)
49
+ z_fused_gated = z_fused_concat * gate_values
50
+ return self.classifier_head(z_fused_gated)
51
+
52
+ # --- 2. Load Models and Auxiliary Files ---
53
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
55
+ CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
56
+ LABEL_MAP_PATH = "label_map.json"
57
+
58
+ # Load the label map file
59
+ try:
60
+ with open(LABEL_MAP_PATH, 'r') as f:
61
+ label_to_idx = json.load(f)
62
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
63
+ except FileNotFoundError:
64
+ raise FileNotFoundError(f"Error: Could not find '{LABEL_MAP_PATH}'. Please make sure this file is uploaded to the Space.")
65
+
66
+ NUM_CLASSES = len(idx_to_label)
67
+ D_MODEL = 640 # Dimension for esm2_t30_150M_UR50D
68
+
69
+ # Load Protein Language Model (PLM) and tokenizer
70
+ print("Loading Protein Language Model...")
71
+ tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
72
+ plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE)
73
+ plm_model.eval()
74
+ print("PLM loaded successfully.")
75
+
76
+ # Load your trained downstream classifier
77
+ print("Loading downstream classifier...")
78
+ classifier = ProtDualBranchEnhancedClassifier(
79
+ d_model=D_MODEL,
80
+ projection_dim=32,
81
+ num_classes=NUM_CLASSES,
82
+ dropout=0.3,
83
+ kernel_size=3
84
+ ).to(DEVICE)
85
+
86
+ if not os.path.exists(CLASSIFIER_PATH):
87
+ raise FileNotFoundError(f"Error: Could not find the trained model file '{CLASSIFIER_PATH}'. Please make sure the correct .pth file is uploaded.")
88
+
89
+ classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
90
+ classifier.eval()
91
+ print("Classifier loaded. Application is ready!")
92
+
93
+ # --- 3. Prediction Function ---
94
+ def predict(sequence_input):
95
+ """
96
+ Receives a protein sequence and returns a dictionary of class probabilities.
97
+ """
98
+ if not sequence_input or sequence_input.isspace():
99
+ return {"Error": "Please enter a protein sequence."}
100
+
101
+ # Clean the input, support FASTA format
102
+ if sequence_input.startswith('>'):
103
+ sequence = "".join(sequence_input.split('\n')[1:])
104
+ else:
105
+ sequence = sequence_input
106
+
107
+ sequence = re.sub(r'[^A-Z]', '', sequence.upper())
108
+
109
+ if not sequence:
110
+ return {"Error": "Sequence is empty after cleaning. Please enter a valid amino acid sequence."}
111
+
112
+ # Feature extraction with PLM
113
+ with torch.no_grad():
114
+ inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
115
+ outputs = plm_model(**inputs)
116
+
117
+ hidden_states = outputs.last_hidden_state
118
+ cls_embedding = hidden_states[:, 0, :]
119
+ token_embeddings = hidden_states[:, 1:-1, :]
120
+ token_mask = inputs['attention_mask'][:, 2:]
121
+
122
+ # Prediction with the downstream classifier
123
+ with torch.no_grad():
124
+ logits = classifier(cls_embedding, token_embeddings, token_mask)
125
+ probabilities = F.softmax(logits, dim=1)[0]
126
+
127
+ # Format the output
128
+ confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
129
+
130
+ return confidences
131
+
132
+ # --- 4. Create Gradio Interface ---
133
+ title = "Predicting the subcellular location of prokaryotic proteins with LocPred-Prok"
134
+ description = """
135
+ This is a prediction tool based on the **ESM-2 (150M)** Protein Language Model and a custom **`dual_branch_enhanced`** classifier.
136
+ Simply paste a protein's amino acid sequence (FASTA format or raw sequence are both supported) into the text box below, and the model will predict its localization within the cell.
137
+ """
138
+ examples = [
139
+ [">sp|P27361|PBP2_ECOLI Penicillin-binding protein 2 OS=Escherichia coli (strain K12) OX=83333 GN=mrdA PE=1 SV=2\nMKFKLTAGCLAVAGVLLASSFGADAEIVVNAIYDQVARTEDGVYTQGQLTGRRIELLNKLGIEPEDSLASTVIHEFVARVGDDHGIETIIDEFYRQHPSASL"],
140
+ ["MSKLVKTLTISEISKAQNNGGKPAWCWYTLAMCGAGYDSGTCDYMYSHCFGIKHHSSGSSSYHC"],
141
+ ]
142
+
143
+ gr.Interface(
144
+ fn=predict,
145
+ inputs=gr.Textbox(
146
+ lines=10,
147
+ label="Protein Sequence",
148
+ placeholder="Paste your amino acid sequence here..."
149
+ ),
150
+ outputs=gr.Label(num_top_classes=NUM_CLASSES, label="Prediction Results"),
151
+ title=title,
152
+ description=description,
153
+ examples=examples,
154
+ allow_flagging="never",
155
+ theme=gr.themes.Soft()
156
+ ).launch()
157
+