import os, shutil, json, re import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoTokenizer, AutoModel # ========================== # 🚧 0. 基础设置与缓存清理 (保持不变) # ========================== os.environ["HF_HOME"] = "/tmp/hf_cache" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]: shutil.rmtree(path, ignore_errors=True) os.makedirs(path, exist_ok=True) # ========================== # 1. Model Definition (保持不变) # ========================== class AttentionPooling(nn.Module): def __init__(self, d_model): super().__init__() self.attention_net = nn.Linear(d_model, 1) def forward(self, x, mask): attn_logits = self.attention_net(x).squeeze(2) attn_logits.masked_fill_(mask == 0, -float('inf')) attn_weights = F.softmax(attn_logits, dim=1) return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1) class ProtDualBranchEnhancedClassifier(nn.Module): def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size): super().__init__() self.cls_projector = nn.Linear(d_model, projection_dim) self.token_refiner = nn.Sequential( nn.Conv1d(d_model, d_model, kernel_size, padding='same'), nn.ReLU() ) self.attention_pooling = AttentionPooling(d_model) self.tok_projector = nn.Linear(d_model, projection_dim) fused_dim = projection_dim * 2 self.gate = nn.Sequential( nn.Linear(fused_dim, fused_dim), nn.Sigmoid() ) self.classifier_head = nn.Sequential( nn.LayerNorm(fused_dim), nn.Linear(fused_dim, fused_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fused_dim * 2, num_classes) ) def forward(self, cls_embedding, token_embeddings, mask): z_cls = self.cls_projector(cls_embedding) tok_emb_permuted = token_embeddings.permute(0, 2, 1) refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1) z_tok_pooled = self.attention_pooling(refined_tok_emb, mask) z_tok = self.tok_projector(z_tok_pooled) z_fused_concat = torch.cat([z_cls, z_tok], dim=1) gate_values = self.gate(z_fused_concat) z_fused_gated = z_fused_concat * gate_values return self.classifier_head(z_fused_gated) # ========================== # 2. Load Models (保持不变) # ========================== DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D" CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth" LABEL_MAP_PATH = "label_map.json" if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Error: Missing '{LABEL_MAP_PATH}'.") with open(LABEL_MAP_PATH, 'r') as f: label_to_idx = json.load(f) idx_to_label = {v: k for k, v in label_to_idx.items()} NUM_CLASSES = len(idx_to_label) D_MODEL = 640 print("🔹 Loading models...") tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME) plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE) plm_model.eval() classifier = ProtDualBranchEnhancedClassifier( d_model=D_MODEL, projection_dim=32, num_classes=NUM_CLASSES, dropout=0.3, kernel_size=3 ).to(DEVICE) if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Error: Could not find '{CLASSIFIER_PATH}'.") classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE)) classifier.eval() print("✅ Ready.") # ========================== # 3. Predict Logic (保持不变) # ========================== def predict(sequence_input): if not sequence_input or sequence_input.isspace(): raise gr.Error("Sequence cannot be empty.") sequence = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input sequence = re.sub(r'[^A-Z]', '', sequence.upper()) if not sequence: raise gr.Error("Invalid sequence.") with torch.no_grad(): inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE) outputs = plm_model(**inputs) hidden_states = outputs.last_hidden_state cls_embedding = hidden_states[:, 0, :] token_embeddings = hidden_states[:, 1:-1, :] token_mask = inputs['attention_mask'][:, 1:-1] logits = classifier(cls_embedding, token_embeddings, token_mask) probabilities = F.softmax(logits, dim=1)[0] confidences = {idx_to_label[i]: float(prob) for i, prob in enumerate(probabilities)} return confidences # ========================== # 4. Ultra-Modern UI Design # ========================== # 极简现代风 CSS modern_css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;800&display=swap'); body { font-family: 'Inter', sans-serif !important; background-color: #f8fafc; } /* 1. 顶部 Hero Section */ .hero-container { text-align: center; padding: 3rem 1rem; margin-bottom: 1rem; } .hero-title { font-size: 3rem; font-weight: 800; margin-bottom: 0.5rem; background: -webkit-linear-gradient(45deg, #0f172a, #334155); -webkit-background-clip: text; -webkit-text-fill-color: transparent; letter-spacing: -1px; } .hero-subtitle { font-size: 1.25rem; color: #64748b; font-weight: 300; max-width: 600px; margin: 0 auto; } /* 2. 卡片风格 */ .modern-card { background: white; border-radius: 16px; padding: 24px; border: 1px solid #e2e8f0; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -1px rgba(0, 0, 0, 0.03); transition: all 0.3s ease; } .modern-card:hover { box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); } /* 3. 输入框优化 - 模仿代码编辑器 */ textarea { font-family: 'SF Mono', 'Menlo', 'Monaco', 'Courier New', monospace !important; font-size: 14px !important; background-color: #f8fafc !important; border: 1px solid #e2e8f0 !important; border-radius: 8px !important; } /* 4. 按钮优化 */ button.primary { background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important; border: none !important; font-weight: 600 !important; letter-spacing: 0.5px !important; transition: transform 0.1s ease-in-out !important; } button.primary:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(37, 99, 235, 0.3); } /* 5. 标签页优化 */ .tabs { border: none !important; background: transparent !important; } .tab-nav { border-bottom: 1px solid #e2e8f0; margin-bottom: 20px; } .tab-nav button { font-weight: 600; color: #64748b; } .tab-nav button.selected { color: #2563eb; border-bottom: 2px solid #2563eb; } /* 6. Footer */ .footer-text { text-align: center; color: #94a3b8; font-size: 0.8rem; margin-top: 40px; padding-bottom: 20px; } """ # 使用极简主题作为底子 theme = gr.themes.Soft( primary_hue="blue", radius_size="lg", font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"] ) with gr.Blocks(theme=theme, css=modern_css, title="LocPred-Prok") as app: # --- Hero Section --- with gr.Column(elem_classes="hero-container"): gr.HTML("""
LocPred-Prok
Next-generation prokaryotic subcellular localization using dual-branch protein language models.
""") # --- Main Content --- with gr.Tabs(): # === TAB 1: Predict === with gr.TabItem("Predict", id="tab-predict"): with gr.Row(): # Input Column with gr.Column(scale=3, elem_classes="modern-card"): gr.Markdown("### Sequence Input") sequence_input = gr.Textbox( lines=12, placeholder="> Paste FASTA sequence here...", show_label=False, container=False ) with gr.Row(): clear_btn = gr.ClearButton(components=[sequence_input], value="Clear") submit_btn = gr.Button("Analyze Sequence", variant="primary", scale=2) # Output Column with gr.Column(scale=2, elem_classes="modern-card"): gr.Markdown("### Analysis Result") # 隐藏 Label 自身的文字标签,保持界面干净 output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False) gr.HTML("""
ℹ️ Model Insight: Prediction is based on the fusion of global semantic features (ESM-2) and local structural refinements.
""") # === TAB 2: Methodology === with gr.TabItem("Methodology", id="tab-about"): with gr.Column(elem_classes="modern-card"): gr.Markdown("### The Architecture") gr.Markdown( """ **LocPred-Prok** moves beyond the "bigger is better" paradigm. Instead of relying solely on massive parameter counts, we engineered a specialized **Dual-Branch Architecture**: * **Global Branch:** Leverages the `ESM-2 (150M)` foundation model to capture deep semantic dependencies. * **Local Branch:** Utilizes convolutional refinement and attention pooling to detect subtle signal motifs often missed by global pooling. This synergy allows for precise identification of challenging localization sites, particularly in **Cell Wall** and **Outer Membrane** regions. """ ) # === TAB 3: Cite === with gr.TabItem("Cite", id="tab-cite"): with gr.Column(elem_classes="modern-card"): gr.Markdown("### BibTeX Reference") gr.Code( value="""@article{LocPredProk2025, title={LocPred-Prok: Prokaryotic protein subcellular localization prediction with a dual-branch architecture}, author={Your Name et al.}, journal={Bioinformatics}, year={2025} }""", label=None, language=None, # 防止之前的报错 interactive=False ) # --- Footer --- gr.HTML(""" """) # Logic submit_btn.click(fn=predict, inputs=sequence_input, outputs=output_label) clear_btn.click(lambda: None, outputs=[output_label]) app.launch()