🧬 Protein Sequence Analyzer
AI-Powered Function Prediction
""" Professional Protein Sequence Analyzer - With Live Sequence Input """ import streamlit as st import torch import torch.nn as nn import numpy as np import pandas as pd import pickle import plotly.graph_objects as go from collections import Counter import re import os import sys sys.path.append("D:/CAFA project") sys.path.append("D:/CAFA project/scripts") sys.path.append("D:/CAFA project/goontology") from scripts.ontologyparser import GOGraphParser # Page config MUST be first st.set_page_config( page_title="Protein Analyzer", page_icon="🧬", layout="wide" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Model class class MultiLabelClassifier(nn.Module): def __init__(self, input_dim, output_dim): super(MultiLabelClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, output_dim) ) def forward(self, x): return self.network(x) @st.cache_resource def load_prediction_models(): """Load prediction models only""" try: base_path = "D:/CAFA project" with open(f"{base_path}/processed_data/selected_terms.pkl", 'rb') as f: term_mappings = pickle.load(f) with open(f"{base_path}/go_parser.pkl", 'rb') as f: go_parser = pickle.load(f) device = torch.device('cpu') models = {} for ontology in ['MFO', 'BPO', 'CCO']: n_terms = len(term_mappings['selected_terms'][ontology]) model = MultiLabelClassifier(1280, n_terms) checkpoint = torch.load( f"{base_path}/models/model_{ontology}_best.pth", map_location=device ) model.load_state_dict(checkpoint['model_state_dict']) model.eval() models[ontology] = model return models, term_mappings, go_parser, device, None except Exception as e: return None, None, None, None, str(e) @st.cache_resource def load_esm2_model(): """Load ESM2 model for embedding generation""" try: from transformers import AutoTokenizer, AutoModel st.info("🔄 Loading ESM2 model (this takes 2-3 minutes first time)...") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") model.eval() st.success("✅ ESM2 model loaded!") return tokenizer, model, None except Exception as e: return None, None, str(e) @st.cache_resource def load_test_embeddings(): """Load pre-computed test embeddings""" try: base_path = "D:/CAFA project" with open(f"{base_path}/scripts/embeddings/test_esm2_embeddings.pkl", 'rb') as f: embeddings = pickle.load(f) def normalize_pid(pid): if '|' in pid: return pid.split('|')[1] return pid embeddings = {normalize_pid(k): v for k, v in embeddings.items()} return embeddings, None except Exception as e: return None, str(e) def convert_three_to_one(sequence): """Convert 3-letter to 1-letter amino acid code""" three_to_one = { 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V' } # Check if sequence contains 3-letter codes if '-' in sequence or len(sequence) > 50 and sequence[3:4] in ['-', ' ']: # Split by dash or space codes = re.split(r'[-\s]+', sequence.upper()) converted = ''.join(three_to_one.get(code, '') for code in codes if code) return converted return sequence def generate_embedding_from_sequence(sequence, tokenizer, esm2_model, device): """Generate embedding from raw sequence""" # Try to convert 3-letter to 1-letter code sequence = convert_three_to_one(sequence) # Clean sequence sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence.upper()) if len(sequence) < 20: return None, "Sequence too short (minimum 20 amino acids)" if len(sequence) > 1024: sequence = sequence[:1024] st.warning("⚠️ Sequence truncated to 1024 amino acids") try: # Tokenize inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate embedding with torch.no_grad(): outputs = esm2_model(**inputs) embeddings = outputs.last_hidden_state # Mean pooling (exclude special tokens) embedding = embeddings[0, 1:-1, :].mean(dim=0) return embedding.cpu().numpy(), None except Exception as e: return None, str(e) def calculate_properties(sequence): """Calculate basic molecular properties""" aa_weights = { 'A': 89, 'R': 174, 'N': 132, 'D': 133, 'C': 121, 'E': 147, 'Q': 146, 'G': 75, 'H': 155, 'I': 131, 'L': 131, 'K': 146, 'M': 149, 'F': 165, 'P': 115, 'S': 105, 'T': 119, 'W': 204, 'Y': 181, 'V': 117 } length = len(sequence) mw = sum(aa_weights.get(aa, 110) for aa in sequence) / 1000 composition = Counter(sequence) hydrophobic = sum(composition.get(aa, 0) for aa in 'AILMFWYV') / length * 100 polar = sum(composition.get(aa, 0) for aa in 'STNQ') / length * 100 charged = sum(composition.get(aa, 0) for aa in 'DEKR') / length * 100 return { 'length': length, 'molecular_weight': round(mw, 1), 'hydrophobic': round(hydrophobic, 1), 'polar': round(polar, 1), 'charged': round(charged, 1), 'composition': composition } def predict_from_embedding(embedding, models, term_mappings, go_parser, device): """Make predictions from embedding""" embedding_tensor = torch.FloatTensor(embedding).unsqueeze(0).to(device) predictions = {} with torch.no_grad(): for ontology in ['MFO', 'BPO', 'CCO']: model = models[ontology] outputs = model(embedding_tensor) probs = torch.sigmoid(outputs).cpu().numpy()[0] terms = term_mappings['selected_terms'][ontology] idx_to_term = term_mappings['idx_to_term'][ontology] pred_list = [] for idx in range(len(probs)): if probs[idx] > 0.05: term_id = terms[idx] try: term_info = go_parser.get_term_info(term_id) name = term_info['name'] if term_info else 'Unknown' except: name = term_id pred_list.append({ 'term_id': term_id, 'confidence': float(probs[idx]), 'name': name }) pred_list.sort(key=lambda x: x['confidence'], reverse=True) predictions[ontology] = pred_list return predictions def create_chart(predictions, ontology, top_n=10): """Create visualization""" data = predictions[ontology][:top_n] if not data: return None names = [p['name'][:50] for p in data] confidences = [p['confidence'] * 100 for p in data] colors = ['#11998e' if c > 70 else '#f5576c' if c > 40 else '#4facfe' for c in confidences] fig = go.Figure(go.Bar( y=names, x=confidences, orientation='h', marker=dict(color=colors), text=[f'{c:.1f}%' for c in confidences], textposition='outside' )) fig.update_layout( title=f'Top {len(data)} {ontology} Predictions', xaxis_title='Confidence (%)', height=max(400, len(data) * 40), yaxis=dict(autorange="reversed"), xaxis=dict(range=[0, 100]) ) return fig def display_results(predictions, sequence=None): """Display prediction results""" st.success("✅ Analysis Complete!") # Show sequence properties if provided if sequence: st.markdown("### 🔬 Sequence Properties") props = calculate_properties(sequence) col1, col2, col3, col4 = st.columns(4) with col1: st.markdown(f"""
Length (aa)
MW (kDa)
Hydrophobic %
Charged %
MFO Predictions (>50%)
BPO Predictions (>50%)
CCO Predictions (>50%)
AI-Powered Function Prediction