""" Professional Protein Sequence Analyzer - With Live Sequence Input """ import streamlit as st import torch import torch.nn as nn import numpy as np import pandas as pd import pickle import plotly.graph_objects as go from collections import Counter import re import os import sys sys.path.append("D:/CAFA project") sys.path.append("D:/CAFA project/scripts") sys.path.append("D:/CAFA project/goontology") from scripts.ontologyparser import GOGraphParser # Page config MUST be first st.set_page_config( page_title="Protein Analyzer", page_icon="🧬", layout="wide" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Model class class MultiLabelClassifier(nn.Module): def __init__(self, input_dim, output_dim): super(MultiLabelClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, output_dim) ) def forward(self, x): return self.network(x) @st.cache_resource def load_prediction_models(): """Load prediction models only""" try: base_path = "D:/CAFA project" with open(f"{base_path}/processed_data/selected_terms.pkl", 'rb') as f: term_mappings = pickle.load(f) with open(f"{base_path}/go_parser.pkl", 'rb') as f: go_parser = pickle.load(f) device = torch.device('cpu') models = {} for ontology in ['MFO', 'BPO', 'CCO']: n_terms = len(term_mappings['selected_terms'][ontology]) model = MultiLabelClassifier(1280, n_terms) checkpoint = torch.load( f"{base_path}/models/model_{ontology}_best.pth", map_location=device ) model.load_state_dict(checkpoint['model_state_dict']) model.eval() models[ontology] = model return models, term_mappings, go_parser, device, None except Exception as e: return None, None, None, None, str(e) @st.cache_resource def load_esm2_model(): """Load ESM2 model for embedding generation""" try: from transformers import AutoTokenizer, AutoModel st.info("🔄 Loading ESM2 model (this takes 2-3 minutes first time)...") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") model.eval() st.success("✅ ESM2 model loaded!") return tokenizer, model, None except Exception as e: return None, None, str(e) @st.cache_resource def load_test_embeddings(): """Load pre-computed test embeddings""" try: base_path = "D:/CAFA project" with open(f"{base_path}/scripts/embeddings/test_esm2_embeddings.pkl", 'rb') as f: embeddings = pickle.load(f) def normalize_pid(pid): if '|' in pid: return pid.split('|')[1] return pid embeddings = {normalize_pid(k): v for k, v in embeddings.items()} return embeddings, None except Exception as e: return None, str(e) def convert_three_to_one(sequence): """Convert 3-letter to 1-letter amino acid code""" three_to_one = { 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V' } # Check if sequence contains 3-letter codes if '-' in sequence or len(sequence) > 50 and sequence[3:4] in ['-', ' ']: # Split by dash or space codes = re.split(r'[-\s]+', sequence.upper()) converted = ''.join(three_to_one.get(code, '') for code in codes if code) return converted return sequence def generate_embedding_from_sequence(sequence, tokenizer, esm2_model, device): """Generate embedding from raw sequence""" # Try to convert 3-letter to 1-letter code sequence = convert_three_to_one(sequence) # Clean sequence sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence.upper()) if len(sequence) < 20: return None, "Sequence too short (minimum 20 amino acids)" if len(sequence) > 1024: sequence = sequence[:1024] st.warning("⚠️ Sequence truncated to 1024 amino acids") try: # Tokenize inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate embedding with torch.no_grad(): outputs = esm2_model(**inputs) embeddings = outputs.last_hidden_state # Mean pooling (exclude special tokens) embedding = embeddings[0, 1:-1, :].mean(dim=0) return embedding.cpu().numpy(), None except Exception as e: return None, str(e) def calculate_properties(sequence): """Calculate basic molecular properties""" aa_weights = { 'A': 89, 'R': 174, 'N': 132, 'D': 133, 'C': 121, 'E': 147, 'Q': 146, 'G': 75, 'H': 155, 'I': 131, 'L': 131, 'K': 146, 'M': 149, 'F': 165, 'P': 115, 'S': 105, 'T': 119, 'W': 204, 'Y': 181, 'V': 117 } length = len(sequence) mw = sum(aa_weights.get(aa, 110) for aa in sequence) / 1000 composition = Counter(sequence) hydrophobic = sum(composition.get(aa, 0) for aa in 'AILMFWYV') / length * 100 polar = sum(composition.get(aa, 0) for aa in 'STNQ') / length * 100 charged = sum(composition.get(aa, 0) for aa in 'DEKR') / length * 100 return { 'length': length, 'molecular_weight': round(mw, 1), 'hydrophobic': round(hydrophobic, 1), 'polar': round(polar, 1), 'charged': round(charged, 1), 'composition': composition } def predict_from_embedding(embedding, models, term_mappings, go_parser, device): """Make predictions from embedding""" embedding_tensor = torch.FloatTensor(embedding).unsqueeze(0).to(device) predictions = {} with torch.no_grad(): for ontology in ['MFO', 'BPO', 'CCO']: model = models[ontology] outputs = model(embedding_tensor) probs = torch.sigmoid(outputs).cpu().numpy()[0] terms = term_mappings['selected_terms'][ontology] idx_to_term = term_mappings['idx_to_term'][ontology] pred_list = [] for idx in range(len(probs)): if probs[idx] > 0.05: term_id = terms[idx] try: term_info = go_parser.get_term_info(term_id) name = term_info['name'] if term_info else 'Unknown' except: name = term_id pred_list.append({ 'term_id': term_id, 'confidence': float(probs[idx]), 'name': name }) pred_list.sort(key=lambda x: x['confidence'], reverse=True) predictions[ontology] = pred_list return predictions def create_chart(predictions, ontology, top_n=10): """Create visualization""" data = predictions[ontology][:top_n] if not data: return None names = [p['name'][:50] for p in data] confidences = [p['confidence'] * 100 for p in data] colors = ['#11998e' if c > 70 else '#f5576c' if c > 40 else '#4facfe' for c in confidences] fig = go.Figure(go.Bar( y=names, x=confidences, orientation='h', marker=dict(color=colors), text=[f'{c:.1f}%' for c in confidences], textposition='outside' )) fig.update_layout( title=f'Top {len(data)} {ontology} Predictions', xaxis_title='Confidence (%)', height=max(400, len(data) * 40), yaxis=dict(autorange="reversed"), xaxis=dict(range=[0, 100]) ) return fig def display_results(predictions, sequence=None): """Display prediction results""" st.success("✅ Analysis Complete!") # Show sequence properties if provided if sequence: st.markdown("### 🔬 Sequence Properties") props = calculate_properties(sequence) col1, col2, col3, col4 = st.columns(4) with col1: st.markdown(f"""

{props['length']}

Length (aa)

""", unsafe_allow_html=True) with col2: st.markdown(f"""

{props['molecular_weight']}

MW (kDa)

""", unsafe_allow_html=True) with col3: st.markdown(f"""

{props['hydrophobic']}

Hydrophobic %

""", unsafe_allow_html=True) with col4: st.markdown(f"""

{props['charged']}

Charged %

""", unsafe_allow_html=True) # Prediction summary st.markdown("### 📊 Prediction Summary") col1, col2, col3 = st.columns(3) with col1: count = len([p for p in predictions['MFO'] if p['confidence'] > 0.5]) st.markdown(f"""

{count}

MFO Predictions (>50%)

""", unsafe_allow_html=True) with col2: count = len([p for p in predictions['BPO'] if p['confidence'] > 0.5]) st.markdown(f"""

{count}

BPO Predictions (>50%)

""", unsafe_allow_html=True) with col3: count = len([p for p in predictions['CCO'] if p['confidence'] > 0.5]) st.markdown(f"""

{count}

CCO Predictions (>50%)

""", unsafe_allow_html=True) # Detailed predictions in tabs tabs = st.tabs(["🔵 Molecular Function", "🟢 Biological Process", "🟠 Cellular Component"]) for tab, ont in zip(tabs, ['MFO', 'BPO', 'CCO']): with tab: preds = predictions[ont][:10] if preds: fig = create_chart(predictions, ont) if fig: st.plotly_chart(fig, use_container_width=True) st.markdown("#### Top Predictions") for i, pred in enumerate(preds, 1): conf = pred['confidence'] * 100 if conf > 70: color = "#11998e" level = "HIGH" elif conf > 40: color = "#f5576c" level = "MEDIUM" else: color = "#4facfe" level = "LOW" st.markdown(f"""
{i}. {pred['name']}
{pred['term_id']}
{conf:.1f}%
{level}
""", unsafe_allow_html=True) else: st.info(f"No significant {ont} predictions") # Export st.markdown("### 💾 Export Results") all_preds = [] for ont in ['MFO', 'BPO', 'CCO']: for pred in predictions[ont]: all_preds.append({ 'Ontology': ont, 'GO Term': pred['term_id'], 'Function': pred['name'], 'Confidence': f"{pred['confidence']*100:.2f}%" }) df = pd.DataFrame(all_preds) csv = df.to_csv(index=False) st.download_button( "📥 Download Predictions CSV", csv, "protein_predictions.csv", "text/csv", use_container_width=True ) # MAIN APP def main(): st.markdown("""

🧬 Protein Sequence Analyzer

AI-Powered Function Prediction

""", unsafe_allow_html=True) # Sidebar st.sidebar.header("⚙️ System Status") # Load prediction models with st.sidebar: with st.spinner("Loading prediction models..."): models, term_mappings, go_parser, device, error = load_prediction_models() if error: st.error(f"❌ Failed: {error}") st.stop() else: st.success("✅ Prediction models ready") # Main interface st.markdown("### 🔍 Choose Analysis Mode") mode = st.radio( "Select input method:", ["🧬 Enter Custom Sequence", "📋 Use Test Protein"], horizontal=True ) if mode == "🧬 Enter Custom Sequence": st.markdown("### 📝 Enter Your Protein Sequence") st.info("💡 **Tip:** Paste amino acid sequence using single-letter codes (ACDEFGHIKLMNPQRSTVWY)") # Example sequences with st.expander("📌 Click to see example sequences"): st.markdown("**Single-letter format (preferred):**") st.code(""" Example 1 - Small protein (100 aa): MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAK WSPELAAACEVWKEIKFEFPAMDLVVKAAGAVGS Example 2 - Kinase domain (250 aa): MGSSHHHHHHSSGLVPRGSHMQDPPDFLKRTPAATPDLPMFPESAEELEKITAFAKKLGFPKAQKKDEADSLEKLKDV TLVNDSLVKLGGKFTTAIQQRVAQALENALQDLWLVKYNPVSIKGLGKGSLQYLNEIKFKGKKFVYISVTKDPNLPA LDNFYTKALLSKTGLKFTNKDKFKELYVLLKKFEVLTYQWLAKAEKQEFCDKLLDLKDYLSDKLQVYKDVFKKLETL KHKKLDSALSDLEVQENKVFGGNNVVPKLDGLSGDFATSTAQFQKEVRQKIVSILTKNKKFVFGHDDLSKIFSGLHKV """) st.markdown("**Three-letter format (auto-converted):**") st.code(""" Example: Gly-Ile-Val-Glu-Gln-Cys-Cys-Thr-Ser-Ile-Cys-Ser-Leu-Tyr-Gln-Leu-Glu-Asn Will be converted to: GIVEQCCTSICSLYQLEN """) # Text area for sequence sequence_input = st.text_area( "Paste your sequence here:", height=150, placeholder="MKTAYIAKQRQISFVKSHFSRQLEERLGLIEV..." ) analyze_button = st.button("🚀 Analyze Sequence", type="primary", use_container_width=True) if analyze_button and sequence_input: # Clean sequence sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence_input.upper()) if len(sequence) < 20: st.error("❌ Sequence too short. Minimum 20 amino acids required.") st.stop() st.info(f"✓ Valid sequence: {len(sequence)} amino acids") # Load ESM2 if not loaded with st.spinner("Loading ESM2 model (first time: 2-3 minutes)..."): tokenizer, esm2_model, esm2_error = load_esm2_model() if esm2_error: st.error(f"❌ ESM2 loading failed: {esm2_error}") st.info("💡 Install transformers: pip install transformers") st.stop() # Generate embedding with st.spinner("🧬 Generating protein embedding..."): embedding, emb_error = generate_embedding_from_sequence( sequence, tokenizer, esm2_model, device ) if emb_error: st.error(f"❌ Embedding generation failed: {emb_error}") st.stop() # Make predictions with st.spinner("🤖 Running AI predictions..."): predictions = predict_from_embedding( embedding, models, term_mappings, go_parser, device ) # Display results display_results(predictions, sequence) else: # Use Test Protein st.markdown("### 📋 Select Test Protein") # Load test embeddings test_embeddings, test_error = load_test_embeddings() if test_error: st.error(f"❌ Test embeddings not available: {test_error}") st.stop() available_proteins = list(test_embeddings.keys())[:50] col1, col2 = st.columns([3, 1]) with col1: selected_protein = st.selectbox( "Choose a protein:", available_proteins ) with col2: st.metric("Selected", selected_protein) if st.button("🚀 Analyze Protein", type="primary", use_container_width=True): with st.spinner("Analyzing..."): embedding = test_embeddings[selected_protein] predictions = predict_from_embedding( embedding, models, term_mappings, go_parser, device ) display_results(predictions) if __name__ == "__main__": main()