nl45 commited on
Commit
dcddf39
Β·
verified Β·
1 Parent(s): 2630dec

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +555 -70
app.py CHANGED
@@ -1,70 +1,555 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
68
-
69
- if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Professional Protein Sequence Analyzer - With Live Sequence Input
3
+ """
4
+
5
+ import streamlit as st
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import pandas as pd
10
+ import pickle
11
+ import plotly.graph_objects as go
12
+ from collections import Counter
13
+ import re
14
+ import os
15
+ import sys
16
+ sys.path.append("D:/CAFA project")
17
+ sys.path.append("D:/CAFA project/scripts")
18
+ sys.path.append("D:/CAFA project/goontology")
19
+ from scripts.ontologyparser import GOGraphParser
20
+
21
+ # Page config MUST be first
22
+ st.set_page_config(
23
+ page_title="Protein Analyzer",
24
+ page_icon="🧬",
25
+ layout="wide"
26
+ )
27
+
28
+ # Custom CSS
29
+ st.markdown("""
30
+ <style>
31
+ .main-title {
32
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
33
+ padding: 2rem;
34
+ border-radius: 15px;
35
+ color: white;
36
+ text-align: center;
37
+ margin-bottom: 2rem;
38
+ }
39
+ .metric-card {
40
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
41
+ padding: 1.5rem;
42
+ border-radius: 12px;
43
+ text-align: center;
44
+ }
45
+ .stButton>button {
46
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
47
+ color: white;
48
+ border-radius: 50px;
49
+ padding: 0.75rem 2rem;
50
+ font-weight: 600;
51
+ }
52
+ </style>
53
+ """, unsafe_allow_html=True)
54
+
55
+ # Model class
56
+ class MultiLabelClassifier(nn.Module):
57
+ def __init__(self, input_dim, output_dim):
58
+ super(MultiLabelClassifier, self).__init__()
59
+ self.network = nn.Sequential(
60
+ nn.Linear(input_dim, 512),
61
+ nn.BatchNorm1d(512),
62
+ nn.ReLU(),
63
+ nn.Dropout(0.3),
64
+ nn.Linear(512, 256),
65
+ nn.BatchNorm1d(256),
66
+ nn.ReLU(),
67
+ nn.Dropout(0.3),
68
+ nn.Linear(256, output_dim)
69
+ )
70
+
71
+ def forward(self, x):
72
+ return self.network(x)
73
+
74
+ @st.cache_resource
75
+ def load_prediction_models():
76
+ """Load prediction models only"""
77
+ try:
78
+ base_path = "D:/CAFA project"
79
+
80
+ with open(f"{base_path}/processed_data/selected_terms.pkl", 'rb') as f:
81
+ term_mappings = pickle.load(f)
82
+
83
+ with open(f"{base_path}/go_parser.pkl", 'rb') as f:
84
+ go_parser = pickle.load(f)
85
+
86
+ device = torch.device('cpu')
87
+ models = {}
88
+
89
+ for ontology in ['MFO', 'BPO', 'CCO']:
90
+ n_terms = len(term_mappings['selected_terms'][ontology])
91
+ model = MultiLabelClassifier(1280, n_terms)
92
+
93
+ checkpoint = torch.load(
94
+ f"{base_path}/models/model_{ontology}_best.pth",
95
+ map_location=device
96
+ )
97
+ model.load_state_dict(checkpoint['model_state_dict'])
98
+ model.eval()
99
+ models[ontology] = model
100
+
101
+ return models, term_mappings, go_parser, device, None
102
+
103
+ except Exception as e:
104
+ return None, None, None, None, str(e)
105
+
106
+ @st.cache_resource
107
+ def load_esm2_model():
108
+ """Load ESM2 model for embedding generation"""
109
+ try:
110
+ from transformers import AutoTokenizer, AutoModel
111
+
112
+ st.info("πŸ”„ Loading ESM2 model (this takes 2-3 minutes first time)...")
113
+
114
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
115
+ model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
116
+ model.eval()
117
+
118
+ st.success("βœ… ESM2 model loaded!")
119
+ return tokenizer, model, None
120
+ except Exception as e:
121
+ return None, None, str(e)
122
+
123
+ @st.cache_resource
124
+ def load_test_embeddings():
125
+ """Load pre-computed test embeddings"""
126
+ try:
127
+ base_path = "D:/CAFA project"
128
+ with open(f"{base_path}/scripts/embeddings/test_esm2_embeddings.pkl", 'rb') as f:
129
+ embeddings = pickle.load(f)
130
+
131
+ def normalize_pid(pid):
132
+ if '|' in pid:
133
+ return pid.split('|')[1]
134
+ return pid
135
+
136
+ embeddings = {normalize_pid(k): v for k, v in embeddings.items()}
137
+ return embeddings, None
138
+ except Exception as e:
139
+ return None, str(e)
140
+
141
+ def convert_three_to_one(sequence):
142
+ """Convert 3-letter to 1-letter amino acid code"""
143
+ three_to_one = {
144
+ 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
145
+ 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
146
+ 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
147
+ 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
148
+ }
149
+
150
+ # Check if sequence contains 3-letter codes
151
+ if '-' in sequence or len(sequence) > 50 and sequence[3:4] in ['-', ' ']:
152
+ # Split by dash or space
153
+ codes = re.split(r'[-\s]+', sequence.upper())
154
+ converted = ''.join(three_to_one.get(code, '') for code in codes if code)
155
+ return converted
156
+
157
+ return sequence
158
+
159
+ def generate_embedding_from_sequence(sequence, tokenizer, esm2_model, device):
160
+ """Generate embedding from raw sequence"""
161
+ # Try to convert 3-letter to 1-letter code
162
+ sequence = convert_three_to_one(sequence)
163
+
164
+ # Clean sequence
165
+ sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence.upper())
166
+
167
+ if len(sequence) < 20:
168
+ return None, "Sequence too short (minimum 20 amino acids)"
169
+
170
+ if len(sequence) > 1024:
171
+ sequence = sequence[:1024]
172
+ st.warning("⚠️ Sequence truncated to 1024 amino acids")
173
+
174
+ try:
175
+ # Tokenize
176
+ inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
177
+ inputs = {k: v.to(device) for k, v in inputs.items()}
178
+
179
+ # Generate embedding
180
+ with torch.no_grad():
181
+ outputs = esm2_model(**inputs)
182
+ embeddings = outputs.last_hidden_state
183
+ # Mean pooling (exclude special tokens)
184
+ embedding = embeddings[0, 1:-1, :].mean(dim=0)
185
+
186
+ return embedding.cpu().numpy(), None
187
+ except Exception as e:
188
+ return None, str(e)
189
+
190
+ def calculate_properties(sequence):
191
+ """Calculate basic molecular properties"""
192
+ aa_weights = {
193
+ 'A': 89, 'R': 174, 'N': 132, 'D': 133, 'C': 121,
194
+ 'E': 147, 'Q': 146, 'G': 75, 'H': 155, 'I': 131,
195
+ 'L': 131, 'K': 146, 'M': 149, 'F': 165, 'P': 115,
196
+ 'S': 105, 'T': 119, 'W': 204, 'Y': 181, 'V': 117
197
+ }
198
+
199
+ length = len(sequence)
200
+ mw = sum(aa_weights.get(aa, 110) for aa in sequence) / 1000
201
+ composition = Counter(sequence)
202
+
203
+ hydrophobic = sum(composition.get(aa, 0) for aa in 'AILMFWYV') / length * 100
204
+ polar = sum(composition.get(aa, 0) for aa in 'STNQ') / length * 100
205
+ charged = sum(composition.get(aa, 0) for aa in 'DEKR') / length * 100
206
+
207
+ return {
208
+ 'length': length,
209
+ 'molecular_weight': round(mw, 1),
210
+ 'hydrophobic': round(hydrophobic, 1),
211
+ 'polar': round(polar, 1),
212
+ 'charged': round(charged, 1),
213
+ 'composition': composition
214
+ }
215
+
216
+ def predict_from_embedding(embedding, models, term_mappings, go_parser, device):
217
+ """Make predictions from embedding"""
218
+ embedding_tensor = torch.FloatTensor(embedding).unsqueeze(0).to(device)
219
+ predictions = {}
220
+
221
+ with torch.no_grad():
222
+ for ontology in ['MFO', 'BPO', 'CCO']:
223
+ model = models[ontology]
224
+ outputs = model(embedding_tensor)
225
+ probs = torch.sigmoid(outputs).cpu().numpy()[0]
226
+
227
+ terms = term_mappings['selected_terms'][ontology]
228
+ idx_to_term = term_mappings['idx_to_term'][ontology]
229
+
230
+ pred_list = []
231
+ for idx in range(len(probs)):
232
+ if probs[idx] > 0.05:
233
+ term_id = terms[idx]
234
+ try:
235
+ term_info = go_parser.get_term_info(term_id)
236
+ name = term_info['name'] if term_info else 'Unknown'
237
+ except:
238
+ name = term_id
239
+
240
+ pred_list.append({
241
+ 'term_id': term_id,
242
+ 'confidence': float(probs[idx]),
243
+ 'name': name
244
+ })
245
+
246
+ pred_list.sort(key=lambda x: x['confidence'], reverse=True)
247
+ predictions[ontology] = pred_list
248
+
249
+ return predictions
250
+
251
+ def create_chart(predictions, ontology, top_n=10):
252
+ """Create visualization"""
253
+ data = predictions[ontology][:top_n]
254
+
255
+ if not data:
256
+ return None
257
+
258
+ names = [p['name'][:50] for p in data]
259
+ confidences = [p['confidence'] * 100 for p in data]
260
+ colors = ['#11998e' if c > 70 else '#f5576c' if c > 40 else '#4facfe' for c in confidences]
261
+
262
+ fig = go.Figure(go.Bar(
263
+ y=names,
264
+ x=confidences,
265
+ orientation='h',
266
+ marker=dict(color=colors),
267
+ text=[f'{c:.1f}%' for c in confidences],
268
+ textposition='outside'
269
+ ))
270
+
271
+ fig.update_layout(
272
+ title=f'Top {len(data)} {ontology} Predictions',
273
+ xaxis_title='Confidence (%)',
274
+ height=max(400, len(data) * 40),
275
+ yaxis=dict(autorange="reversed"),
276
+ xaxis=dict(range=[0, 100])
277
+ )
278
+
279
+ return fig
280
+
281
+ def display_results(predictions, sequence=None):
282
+ """Display prediction results"""
283
+ st.success("βœ… Analysis Complete!")
284
+
285
+ # Show sequence properties if provided
286
+ if sequence:
287
+ st.markdown("### πŸ”¬ Sequence Properties")
288
+ props = calculate_properties(sequence)
289
+
290
+ col1, col2, col3, col4 = st.columns(4)
291
+ with col1:
292
+ st.markdown(f"""
293
+ <div class="metric-card">
294
+ <h3>{props['length']}</h3>
295
+ <p>Length (aa)</p>
296
+ </div>
297
+ """, unsafe_allow_html=True)
298
+ with col2:
299
+ st.markdown(f"""
300
+ <div class="metric-card">
301
+ <h3>{props['molecular_weight']}</h3>
302
+ <p>MW (kDa)</p>
303
+ </div>
304
+ """, unsafe_allow_html=True)
305
+ with col3:
306
+ st.markdown(f"""
307
+ <div class="metric-card">
308
+ <h3>{props['hydrophobic']}</h3>
309
+ <p>Hydrophobic %</p>
310
+ </div>
311
+ """, unsafe_allow_html=True)
312
+ with col4:
313
+ st.markdown(f"""
314
+ <div class="metric-card">
315
+ <h3>{props['charged']}</h3>
316
+ <p>Charged %</p>
317
+ </div>
318
+ """, unsafe_allow_html=True)
319
+
320
+ # Prediction summary
321
+ st.markdown("### πŸ“Š Prediction Summary")
322
+ col1, col2, col3 = st.columns(3)
323
+
324
+ with col1:
325
+ count = len([p for p in predictions['MFO'] if p['confidence'] > 0.5])
326
+ st.markdown(f"""
327
+ <div class="metric-card">
328
+ <h3>{count}</h3>
329
+ <p>MFO Predictions (>50%)</p>
330
+ </div>
331
+ """, unsafe_allow_html=True)
332
+
333
+ with col2:
334
+ count = len([p for p in predictions['BPO'] if p['confidence'] > 0.5])
335
+ st.markdown(f"""
336
+ <div class="metric-card">
337
+ <h3>{count}</h3>
338
+ <p>BPO Predictions (>50%)</p>
339
+ </div>
340
+ """, unsafe_allow_html=True)
341
+
342
+ with col3:
343
+ count = len([p for p in predictions['CCO'] if p['confidence'] > 0.5])
344
+ st.markdown(f"""
345
+ <div class="metric-card">
346
+ <h3>{count}</h3>
347
+ <p>CCO Predictions (>50%)</p>
348
+ </div>
349
+ """, unsafe_allow_html=True)
350
+
351
+ # Detailed predictions in tabs
352
+ tabs = st.tabs(["πŸ”΅ Molecular Function", "🟒 Biological Process", "🟠 Cellular Component"])
353
+
354
+ for tab, ont in zip(tabs, ['MFO', 'BPO', 'CCO']):
355
+ with tab:
356
+ preds = predictions[ont][:10]
357
+
358
+ if preds:
359
+ fig = create_chart(predictions, ont)
360
+ if fig:
361
+ st.plotly_chart(fig, use_container_width=True)
362
+
363
+ st.markdown("#### Top Predictions")
364
+ for i, pred in enumerate(preds, 1):
365
+ conf = pred['confidence'] * 100
366
+
367
+ if conf > 70:
368
+ color = "#11998e"
369
+ level = "HIGH"
370
+ elif conf > 40:
371
+ color = "#f5576c"
372
+ level = "MEDIUM"
373
+ else:
374
+ color = "#4facfe"
375
+ level = "LOW"
376
+
377
+ st.markdown(f"""
378
+ <div style="background: {color}; color: white; padding: 1rem; border-radius: 10px; margin: 0.5rem 0;">
379
+ <div style="display: flex; justify-content: space-between;">
380
+ <div>
381
+ <strong>{i}. {pred['name']}</strong><br>
382
+ <small>{pred['term_id']}</small>
383
+ </div>
384
+ <div style="text-align: right;">
385
+ <div style="font-size: 1.5rem; font-weight: bold;">{conf:.1f}%</div>
386
+ <small>{level}</small>
387
+ </div>
388
+ </div>
389
+ </div>
390
+ """, unsafe_allow_html=True)
391
+ else:
392
+ st.info(f"No significant {ont} predictions")
393
+
394
+ # Export
395
+ st.markdown("### πŸ’Ύ Export Results")
396
+ all_preds = []
397
+ for ont in ['MFO', 'BPO', 'CCO']:
398
+ for pred in predictions[ont]:
399
+ all_preds.append({
400
+ 'Ontology': ont,
401
+ 'GO Term': pred['term_id'],
402
+ 'Function': pred['name'],
403
+ 'Confidence': f"{pred['confidence']*100:.2f}%"
404
+ })
405
+
406
+ df = pd.DataFrame(all_preds)
407
+ csv = df.to_csv(index=False)
408
+
409
+ st.download_button(
410
+ "πŸ“₯ Download Predictions CSV",
411
+ csv,
412
+ "protein_predictions.csv",
413
+ "text/csv",
414
+ use_container_width=True
415
+ )
416
+
417
+ # MAIN APP
418
+ def main():
419
+ st.markdown("""
420
+ <div class="main-title">
421
+ <h1>🧬 Protein Sequence Analyzer</h1>
422
+ <p>AI-Powered Function Prediction</p>
423
+ </div>
424
+ """, unsafe_allow_html=True)
425
+
426
+ # Sidebar
427
+ st.sidebar.header("βš™οΈ System Status")
428
+
429
+ # Load prediction models
430
+ with st.sidebar:
431
+ with st.spinner("Loading prediction models..."):
432
+ models, term_mappings, go_parser, device, error = load_prediction_models()
433
+
434
+ if error:
435
+ st.error(f"❌ Failed: {error}")
436
+ st.stop()
437
+ else:
438
+ st.success("βœ… Prediction models ready")
439
+
440
+ # Main interface
441
+ st.markdown("### πŸ” Choose Analysis Mode")
442
+
443
+ mode = st.radio(
444
+ "Select input method:",
445
+ ["🧬 Enter Custom Sequence", "πŸ“‹ Use Test Protein"],
446
+ horizontal=True
447
+ )
448
+
449
+ if mode == "🧬 Enter Custom Sequence":
450
+ st.markdown("### πŸ“ Enter Your Protein Sequence")
451
+
452
+ st.info("πŸ’‘ **Tip:** Paste amino acid sequence using single-letter codes (ACDEFGHIKLMNPQRSTVWY)")
453
+
454
+ # Example sequences
455
+ with st.expander("πŸ“Œ Click to see example sequences"):
456
+ st.markdown("**Single-letter format (preferred):**")
457
+ st.code("""
458
+ Example 1 - Small protein (100 aa):
459
+ MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAK
460
+ WSPELAAACEVWKEIKFEFPAMDLVVKAAGAVGS
461
+
462
+ Example 2 - Kinase domain (250 aa):
463
+ MGSSHHHHHHSSGLVPRGSHMQDPPDFLKRTPAATPDLPMFPESAEELEKITAFAKKLGFPKAQKKDEADSLEKLKDV
464
+ TLVNDSLVKLGGKFTTAIQQRVAQALENALQDLWLVKYNPVSIKGLGKGSLQYLNEIKFKGKKFVYISVTKDPNLPA
465
+ LDNFYTKALLSKTGLKFTNKDKFKELYVLLKKFEVLTYQWLAKAEKQEFCDKLLDLKDYLSDKLQVYKDVFKKLETL
466
+ KHKKLDSALSDLEVQENKVFGGNNVVPKLDGLSGDFATSTAQFQKEVRQKIVSILTKNKKFVFGHDDLSKIFSGLHKV
467
+ """)
468
+
469
+ st.markdown("**Three-letter format (auto-converted):**")
470
+ st.code("""
471
+ Example: Gly-Ile-Val-Glu-Gln-Cys-Cys-Thr-Ser-Ile-Cys-Ser-Leu-Tyr-Gln-Leu-Glu-Asn
472
+ Will be converted to: GIVEQCCTSICSLYQLEN
473
+ """)
474
+
475
+ # Text area for sequence
476
+ sequence_input = st.text_area(
477
+ "Paste your sequence here:",
478
+ height=150,
479
+ placeholder="MKTAYIAKQRQISFVKSHFSRQLEERLGLIEV..."
480
+ )
481
+
482
+ analyze_button = st.button("πŸš€ Analyze Sequence", type="primary", use_container_width=True)
483
+
484
+ if analyze_button and sequence_input:
485
+ # Clean sequence
486
+ sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence_input.upper())
487
+
488
+ if len(sequence) < 20:
489
+ st.error("❌ Sequence too short. Minimum 20 amino acids required.")
490
+ st.stop()
491
+
492
+ st.info(f"βœ“ Valid sequence: {len(sequence)} amino acids")
493
+
494
+ # Load ESM2 if not loaded
495
+ with st.spinner("Loading ESM2 model (first time: 2-3 minutes)..."):
496
+ tokenizer, esm2_model, esm2_error = load_esm2_model()
497
+
498
+ if esm2_error:
499
+ st.error(f"❌ ESM2 loading failed: {esm2_error}")
500
+ st.info("πŸ’‘ Install transformers: pip install transformers")
501
+ st.stop()
502
+
503
+ # Generate embedding
504
+ with st.spinner("🧬 Generating protein embedding..."):
505
+ embedding, emb_error = generate_embedding_from_sequence(
506
+ sequence, tokenizer, esm2_model, device
507
+ )
508
+
509
+ if emb_error:
510
+ st.error(f"❌ Embedding generation failed: {emb_error}")
511
+ st.stop()
512
+
513
+ # Make predictions
514
+ with st.spinner("πŸ€– Running AI predictions..."):
515
+ predictions = predict_from_embedding(
516
+ embedding, models, term_mappings, go_parser, device
517
+ )
518
+
519
+ # Display results
520
+ display_results(predictions, sequence)
521
+
522
+ else: # Use Test Protein
523
+ st.markdown("### πŸ“‹ Select Test Protein")
524
+
525
+ # Load test embeddings
526
+ test_embeddings, test_error = load_test_embeddings()
527
+
528
+ if test_error:
529
+ st.error(f"❌ Test embeddings not available: {test_error}")
530
+ st.stop()
531
+
532
+ available_proteins = list(test_embeddings.keys())[:50]
533
+
534
+ col1, col2 = st.columns([3, 1])
535
+
536
+ with col1:
537
+ selected_protein = st.selectbox(
538
+ "Choose a protein:",
539
+ available_proteins
540
+ )
541
+
542
+ with col2:
543
+ st.metric("Selected", selected_protein)
544
+
545
+ if st.button("πŸš€ Analyze Protein", type="primary", use_container_width=True):
546
+ with st.spinner("Analyzing..."):
547
+ embedding = test_embeddings[selected_protein]
548
+ predictions = predict_from_embedding(
549
+ embedding, models, term_mappings, go_parser, device
550
+ )
551
+
552
+ display_results(predictions)
553
+
554
+ if __name__ == "__main__":
555
+ main()