Spaces:
Build error
Build error
| """ | |
| Professional Protein Sequence Analyzer - With Live Sequence Input | |
| """ | |
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import pandas as pd | |
| import pickle | |
| import plotly.graph_objects as go | |
| from collections import Counter | |
| import re | |
| import os | |
| import sys | |
| sys.path.append("D:/CAFA project") | |
| sys.path.append("D:/CAFA project/scripts") | |
| sys.path.append("D:/CAFA project/goontology") | |
| from scripts.ontologyparser import GOGraphParser | |
| # Page config MUST be first | |
| st.set_page_config( | |
| page_title="Protein Analyzer", | |
| page_icon="π§¬", | |
| layout="wide" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .main-title { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 2rem; | |
| border-radius: 15px; | |
| color: white; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .metric-card { | |
| background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
| padding: 1.5rem; | |
| border-radius: 12px; | |
| text-align: center; | |
| } | |
| .stButton>button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 50px; | |
| padding: 0.75rem 2rem; | |
| font-weight: 600; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Model class | |
| class MultiLabelClassifier(nn.Module): | |
| def __init__(self, input_dim, output_dim): | |
| super(MultiLabelClassifier, self).__init__() | |
| self.network = nn.Sequential( | |
| nn.Linear(input_dim, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, output_dim) | |
| ) | |
| def forward(self, x): | |
| return self.network(x) | |
| def load_prediction_models(): | |
| """Load prediction models only""" | |
| try: | |
| base_path = "D:/CAFA project" | |
| with open(f"{base_path}/processed_data/selected_terms.pkl", 'rb') as f: | |
| term_mappings = pickle.load(f) | |
| with open(f"{base_path}/go_parser.pkl", 'rb') as f: | |
| go_parser = pickle.load(f) | |
| device = torch.device('cpu') | |
| models = {} | |
| for ontology in ['MFO', 'BPO', 'CCO']: | |
| n_terms = len(term_mappings['selected_terms'][ontology]) | |
| model = MultiLabelClassifier(1280, n_terms) | |
| checkpoint = torch.load( | |
| f"{base_path}/models/model_{ontology}_best.pth", | |
| map_location=device | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| models[ontology] = model | |
| return models, term_mappings, go_parser, device, None | |
| except Exception as e: | |
| return None, None, None, None, str(e) | |
| def load_esm2_model(): | |
| """Load ESM2 model for embedding generation""" | |
| try: | |
| from transformers import AutoTokenizer, AutoModel | |
| st.info("π Loading ESM2 model (this takes 2-3 minutes first time)...") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| model.eval() | |
| st.success("β ESM2 model loaded!") | |
| return tokenizer, model, None | |
| except Exception as e: | |
| return None, None, str(e) | |
| def load_test_embeddings(): | |
| """Load pre-computed test embeddings""" | |
| try: | |
| base_path = "D:/CAFA project" | |
| with open(f"{base_path}/scripts/embeddings/test_esm2_embeddings.pkl", 'rb') as f: | |
| embeddings = pickle.load(f) | |
| def normalize_pid(pid): | |
| if '|' in pid: | |
| return pid.split('|')[1] | |
| return pid | |
| embeddings = {normalize_pid(k): v for k, v in embeddings.items()} | |
| return embeddings, None | |
| except Exception as e: | |
| return None, str(e) | |
| def convert_three_to_one(sequence): | |
| """Convert 3-letter to 1-letter amino acid code""" | |
| three_to_one = { | |
| 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', | |
| 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', | |
| 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', | |
| 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V' | |
| } | |
| # Check if sequence contains 3-letter codes | |
| if '-' in sequence or len(sequence) > 50 and sequence[3:4] in ['-', ' ']: | |
| # Split by dash or space | |
| codes = re.split(r'[-\s]+', sequence.upper()) | |
| converted = ''.join(three_to_one.get(code, '') for code in codes if code) | |
| return converted | |
| return sequence | |
| def generate_embedding_from_sequence(sequence, tokenizer, esm2_model, device): | |
| """Generate embedding from raw sequence""" | |
| # Try to convert 3-letter to 1-letter code | |
| sequence = convert_three_to_one(sequence) | |
| # Clean sequence | |
| sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence.upper()) | |
| if len(sequence) < 20: | |
| return None, "Sequence too short (minimum 20 amino acids)" | |
| if len(sequence) > 1024: | |
| sequence = sequence[:1024] | |
| st.warning("β οΈ Sequence truncated to 1024 amino acids") | |
| try: | |
| # Tokenize | |
| inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generate embedding | |
| with torch.no_grad(): | |
| outputs = esm2_model(**inputs) | |
| embeddings = outputs.last_hidden_state | |
| # Mean pooling (exclude special tokens) | |
| embedding = embeddings[0, 1:-1, :].mean(dim=0) | |
| return embedding.cpu().numpy(), None | |
| except Exception as e: | |
| return None, str(e) | |
| def calculate_properties(sequence): | |
| """Calculate basic molecular properties""" | |
| aa_weights = { | |
| 'A': 89, 'R': 174, 'N': 132, 'D': 133, 'C': 121, | |
| 'E': 147, 'Q': 146, 'G': 75, 'H': 155, 'I': 131, | |
| 'L': 131, 'K': 146, 'M': 149, 'F': 165, 'P': 115, | |
| 'S': 105, 'T': 119, 'W': 204, 'Y': 181, 'V': 117 | |
| } | |
| length = len(sequence) | |
| mw = sum(aa_weights.get(aa, 110) for aa in sequence) / 1000 | |
| composition = Counter(sequence) | |
| hydrophobic = sum(composition.get(aa, 0) for aa in 'AILMFWYV') / length * 100 | |
| polar = sum(composition.get(aa, 0) for aa in 'STNQ') / length * 100 | |
| charged = sum(composition.get(aa, 0) for aa in 'DEKR') / length * 100 | |
| return { | |
| 'length': length, | |
| 'molecular_weight': round(mw, 1), | |
| 'hydrophobic': round(hydrophobic, 1), | |
| 'polar': round(polar, 1), | |
| 'charged': round(charged, 1), | |
| 'composition': composition | |
| } | |
| def predict_from_embedding(embedding, models, term_mappings, go_parser, device): | |
| """Make predictions from embedding""" | |
| embedding_tensor = torch.FloatTensor(embedding).unsqueeze(0).to(device) | |
| predictions = {} | |
| with torch.no_grad(): | |
| for ontology in ['MFO', 'BPO', 'CCO']: | |
| model = models[ontology] | |
| outputs = model(embedding_tensor) | |
| probs = torch.sigmoid(outputs).cpu().numpy()[0] | |
| terms = term_mappings['selected_terms'][ontology] | |
| idx_to_term = term_mappings['idx_to_term'][ontology] | |
| pred_list = [] | |
| for idx in range(len(probs)): | |
| if probs[idx] > 0.05: | |
| term_id = terms[idx] | |
| try: | |
| term_info = go_parser.get_term_info(term_id) | |
| name = term_info['name'] if term_info else 'Unknown' | |
| except: | |
| name = term_id | |
| pred_list.append({ | |
| 'term_id': term_id, | |
| 'confidence': float(probs[idx]), | |
| 'name': name | |
| }) | |
| pred_list.sort(key=lambda x: x['confidence'], reverse=True) | |
| predictions[ontology] = pred_list | |
| return predictions | |
| def create_chart(predictions, ontology, top_n=10): | |
| """Create visualization""" | |
| data = predictions[ontology][:top_n] | |
| if not data: | |
| return None | |
| names = [p['name'][:50] for p in data] | |
| confidences = [p['confidence'] * 100 for p in data] | |
| colors = ['#11998e' if c > 70 else '#f5576c' if c > 40 else '#4facfe' for c in confidences] | |
| fig = go.Figure(go.Bar( | |
| y=names, | |
| x=confidences, | |
| orientation='h', | |
| marker=dict(color=colors), | |
| text=[f'{c:.1f}%' for c in confidences], | |
| textposition='outside' | |
| )) | |
| fig.update_layout( | |
| title=f'Top {len(data)} {ontology} Predictions', | |
| xaxis_title='Confidence (%)', | |
| height=max(400, len(data) * 40), | |
| yaxis=dict(autorange="reversed"), | |
| xaxis=dict(range=[0, 100]) | |
| ) | |
| return fig | |
| def display_results(predictions, sequence=None): | |
| """Display prediction results""" | |
| st.success("β Analysis Complete!") | |
| # Show sequence properties if provided | |
| if sequence: | |
| st.markdown("### π¬ Sequence Properties") | |
| props = calculate_properties(sequence) | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.markdown(f""" | |
| <div class="metric-card"> | |
| <h3>{props['length']}</h3> | |
| <p>Length (aa)</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown(f""" | |
| <div class="metric-card"> | |
| <h3>{props['molecular_weight']}</h3> | |
| <p>MW (kDa)</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col3: | |
| st.markdown(f""" | |
| <div class="metric-card"> | |
| <h3>{props['hydrophobic']}</h3> | |
| <p>Hydrophobic %</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col4: | |
| st.markdown(f""" | |
| <div class="metric-card"> | |
| <h3>{props['charged']}</h3> | |
| <p>Charged %</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Prediction summary | |
| st.markdown("### π Prediction Summary") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| count = len([p for p in predictions['MFO'] if p['confidence'] > 0.5]) | |
| st.markdown(f""" | |
| <div class="metric-card"> | |
| <h3>{count}</h3> | |
| <p>MFO Predictions (>50%)</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| count = len([p for p in predictions['BPO'] if p['confidence'] > 0.5]) | |
| st.markdown(f""" | |
| <div class="metric-card"> | |
| <h3>{count}</h3> | |
| <p>BPO Predictions (>50%)</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col3: | |
| count = len([p for p in predictions['CCO'] if p['confidence'] > 0.5]) | |
| st.markdown(f""" | |
| <div class="metric-card"> | |
| <h3>{count}</h3> | |
| <p>CCO Predictions (>50%)</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Detailed predictions in tabs | |
| tabs = st.tabs(["π΅ Molecular Function", "π’ Biological Process", "π Cellular Component"]) | |
| for tab, ont in zip(tabs, ['MFO', 'BPO', 'CCO']): | |
| with tab: | |
| preds = predictions[ont][:10] | |
| if preds: | |
| fig = create_chart(predictions, ont) | |
| if fig: | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.markdown("#### Top Predictions") | |
| for i, pred in enumerate(preds, 1): | |
| conf = pred['confidence'] * 100 | |
| if conf > 70: | |
| color = "#11998e" | |
| level = "HIGH" | |
| elif conf > 40: | |
| color = "#f5576c" | |
| level = "MEDIUM" | |
| else: | |
| color = "#4facfe" | |
| level = "LOW" | |
| st.markdown(f""" | |
| <div style="background: {color}; color: white; padding: 1rem; border-radius: 10px; margin: 0.5rem 0;"> | |
| <div style="display: flex; justify-content: space-between;"> | |
| <div> | |
| <strong>{i}. {pred['name']}</strong><br> | |
| <small>{pred['term_id']}</small> | |
| </div> | |
| <div style="text-align: right;"> | |
| <div style="font-size: 1.5rem; font-weight: bold;">{conf:.1f}%</div> | |
| <small>{level}</small> | |
| </div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.info(f"No significant {ont} predictions") | |
| # Export | |
| st.markdown("### πΎ Export Results") | |
| all_preds = [] | |
| for ont in ['MFO', 'BPO', 'CCO']: | |
| for pred in predictions[ont]: | |
| all_preds.append({ | |
| 'Ontology': ont, | |
| 'GO Term': pred['term_id'], | |
| 'Function': pred['name'], | |
| 'Confidence': f"{pred['confidence']*100:.2f}%" | |
| }) | |
| df = pd.DataFrame(all_preds) | |
| csv = df.to_csv(index=False) | |
| st.download_button( | |
| "π₯ Download Predictions CSV", | |
| csv, | |
| "protein_predictions.csv", | |
| "text/csv", | |
| use_container_width=True | |
| ) | |
| # MAIN APP | |
| def main(): | |
| st.markdown(""" | |
| <div class="main-title"> | |
| <h1>𧬠Protein Sequence Analyzer</h1> | |
| <p>AI-Powered Function Prediction</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Sidebar | |
| st.sidebar.header("βοΈ System Status") | |
| # Load prediction models | |
| with st.sidebar: | |
| with st.spinner("Loading prediction models..."): | |
| models, term_mappings, go_parser, device, error = load_prediction_models() | |
| if error: | |
| st.error(f"β Failed: {error}") | |
| st.stop() | |
| else: | |
| st.success("β Prediction models ready") | |
| # Main interface | |
| st.markdown("### π Choose Analysis Mode") | |
| mode = st.radio( | |
| "Select input method:", | |
| ["𧬠Enter Custom Sequence", "π Use Test Protein"], | |
| horizontal=True | |
| ) | |
| if mode == "𧬠Enter Custom Sequence": | |
| st.markdown("### π Enter Your Protein Sequence") | |
| st.info("π‘ **Tip:** Paste amino acid sequence using single-letter codes (ACDEFGHIKLMNPQRSTVWY)") | |
| # Example sequences | |
| with st.expander("π Click to see example sequences"): | |
| st.markdown("**Single-letter format (preferred):**") | |
| st.code(""" | |
| Example 1 - Small protein (100 aa): | |
| MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAK | |
| WSPELAAACEVWKEIKFEFPAMDLVVKAAGAVGS | |
| Example 2 - Kinase domain (250 aa): | |
| MGSSHHHHHHSSGLVPRGSHMQDPPDFLKRTPAATPDLPMFPESAEELEKITAFAKKLGFPKAQKKDEADSLEKLKDV | |
| TLVNDSLVKLGGKFTTAIQQRVAQALENALQDLWLVKYNPVSIKGLGKGSLQYLNEIKFKGKKFVYISVTKDPNLPA | |
| LDNFYTKALLSKTGLKFTNKDKFKELYVLLKKFEVLTYQWLAKAEKQEFCDKLLDLKDYLSDKLQVYKDVFKKLETL | |
| KHKKLDSALSDLEVQENKVFGGNNVVPKLDGLSGDFATSTAQFQKEVRQKIVSILTKNKKFVFGHDDLSKIFSGLHKV | |
| """) | |
| st.markdown("**Three-letter format (auto-converted):**") | |
| st.code(""" | |
| Example: Gly-Ile-Val-Glu-Gln-Cys-Cys-Thr-Ser-Ile-Cys-Ser-Leu-Tyr-Gln-Leu-Glu-Asn | |
| Will be converted to: GIVEQCCTSICSLYQLEN | |
| """) | |
| # Text area for sequence | |
| sequence_input = st.text_area( | |
| "Paste your sequence here:", | |
| height=150, | |
| placeholder="MKTAYIAKQRQISFVKSHFSRQLEERLGLIEV..." | |
| ) | |
| analyze_button = st.button("π Analyze Sequence", type="primary", use_container_width=True) | |
| if analyze_button and sequence_input: | |
| # Clean sequence | |
| sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence_input.upper()) | |
| if len(sequence) < 20: | |
| st.error("β Sequence too short. Minimum 20 amino acids required.") | |
| st.stop() | |
| st.info(f"β Valid sequence: {len(sequence)} amino acids") | |
| # Load ESM2 if not loaded | |
| with st.spinner("Loading ESM2 model (first time: 2-3 minutes)..."): | |
| tokenizer, esm2_model, esm2_error = load_esm2_model() | |
| if esm2_error: | |
| st.error(f"β ESM2 loading failed: {esm2_error}") | |
| st.info("π‘ Install transformers: pip install transformers") | |
| st.stop() | |
| # Generate embedding | |
| with st.spinner("𧬠Generating protein embedding..."): | |
| embedding, emb_error = generate_embedding_from_sequence( | |
| sequence, tokenizer, esm2_model, device | |
| ) | |
| if emb_error: | |
| st.error(f"β Embedding generation failed: {emb_error}") | |
| st.stop() | |
| # Make predictions | |
| with st.spinner("π€ Running AI predictions..."): | |
| predictions = predict_from_embedding( | |
| embedding, models, term_mappings, go_parser, device | |
| ) | |
| # Display results | |
| display_results(predictions, sequence) | |
| else: # Use Test Protein | |
| st.markdown("### π Select Test Protein") | |
| # Load test embeddings | |
| test_embeddings, test_error = load_test_embeddings() | |
| if test_error: | |
| st.error(f"β Test embeddings not available: {test_error}") | |
| st.stop() | |
| available_proteins = list(test_embeddings.keys())[:50] | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| selected_protein = st.selectbox( | |
| "Choose a protein:", | |
| available_proteins | |
| ) | |
| with col2: | |
| st.metric("Selected", selected_protein) | |
| if st.button("π Analyze Protein", type="primary", use_container_width=True): | |
| with st.spinner("Analyzing..."): | |
| embedding = test_embeddings[selected_protein] | |
| predictions = predict_from_embedding( | |
| embedding, models, term_mappings, go_parser, device | |
| ) | |
| display_results(predictions) | |
| if __name__ == "__main__": | |
| main() |