Spaces:
Running
Running
| import os | |
| import json | |
| import re | |
| import uuid | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| # ========================== | |
| # 0. 环境初始化 | |
| # ========================== | |
| plt.switch_backend('Agg') | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| import shutil | |
| for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]: | |
| shutil.rmtree(path, ignore_errors=True) | |
| os.makedirs(path, exist_ok=True) | |
| # ========================== | |
| # 1. 模型架构 | |
| # ========================== | |
| class AttentionPooling(nn.Module): | |
| def __init__(self, d_model): | |
| super().__init__() | |
| self.attention_net = nn.Linear(d_model, 1) | |
| def forward(self, x, mask): | |
| attn_logits = self.attention_net(x).squeeze(2) | |
| attn_logits.masked_fill_(mask == 0, -float('inf')) | |
| attn_weights = F.softmax(attn_logits, dim=1) | |
| return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights | |
| class ProtDualBranchEnhancedClassifier(nn.Module): | |
| def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size): | |
| super().__init__() | |
| self.cls_projector = nn.Linear(d_model, projection_dim) | |
| self.token_refiner = nn.Sequential(nn.Conv1d(d_model, d_model, kernel_size, padding='same'), nn.ReLU()) | |
| self.attention_pooling = AttentionPooling(d_model) | |
| self.tok_projector = nn.Linear(d_model, projection_dim) | |
| fused_dim = projection_dim * 2 | |
| self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid()) | |
| self.classifier_head = nn.Sequential(nn.LayerNorm(fused_dim), nn.Linear(fused_dim, fused_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fused_dim * 2, num_classes)) | |
| def forward(self, cls_embedding, token_embeddings, mask): | |
| z_cls = self.cls_projector(cls_embedding) | |
| tok_emb_permuted = token_embeddings.permute(0, 2, 1) | |
| refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1) | |
| z_tok_pooled, pooling_weights = self.attention_pooling(refined_tok_emb, mask) | |
| z_tok = self.tok_projector(z_tok_pooled) | |
| z_fused_concat = torch.cat([z_cls, z_tok], dim=1) | |
| gate_values = self.gate(z_fused_concat) | |
| z_fused_gated = z_fused_concat * gate_values | |
| return self.classifier_head(z_fused_gated), pooling_weights | |
| # ========================== | |
| # 2. 加载模型 | |
| # ========================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D" | |
| CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth" | |
| LABEL_MAP_PATH = "label_map.json" | |
| if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}") | |
| if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}") | |
| with open(LABEL_MAP_PATH, 'r') as f: | |
| label_to_idx = json.load(f) | |
| idx_to_label = {v: k for k, v in label_to_idx.items()} | |
| NUM_CLASSES = len(idx_to_label) | |
| D_MODEL = 640 | |
| print("🔹 Loading models...") | |
| tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME) | |
| plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval() | |
| classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE) | |
| classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE)) | |
| classifier.eval() | |
| print("✅ Ready.") | |
| # ========================== | |
| # 3. 标签标准化映射 | |
| # ========================== | |
| def clean_label_name(raw_label): | |
| raw = raw_label.strip() | |
| mapping = { | |
| "OuterMembrane": "Outer membrane", "Outer membrane": "Outer membrane", | |
| "Periplasmic": "Periplasm", "Periplasm": "Periplasm", | |
| "Cellwall": "Cell wall", "Cell wall": "Cell wall", | |
| "CYtoplasmicMembrane": "Cytoplasmic membrane", "InnerMembrane": "Cytoplasmic membrane", | |
| "Cytoplasmic": "Cytoplasm", "Cytoplasm": "Cytoplasm", | |
| "Extracellular": "Extracellular", "Secreted": "Extracellular" | |
| } | |
| if raw in mapping: return mapping[raw] | |
| raw_lower = raw.lower() | |
| for k, v in mapping.items(): | |
| if k.lower() == raw_lower: return v | |
| return raw | |
| # ========================== | |
| # 4. SVG 引擎 (纯净展示版 - 无下载按钮) | |
| # ========================== | |
| def infer_gram_type(std_label): | |
| if std_label in ["Outer membrane", "Periplasm"]: return "negative" | |
| if std_label == "Cell wall": return "positive" | |
| return "negative" | |
| def generate_scientific_svg(target_class): | |
| std_target = clean_label_name(target_class) | |
| gram_type = infer_gram_type(std_target) | |
| is_sec = (std_target == "Extracellular") | |
| is_om = (std_target == "Outer membrane") | |
| is_peri = (std_target == "Periplasm") | |
| is_cw = (std_target == "Cell wall") | |
| is_im = (std_target == "Cytoplasmic membrane") | |
| is_cyto = (std_target == "Cytoplasm") | |
| c = { | |
| 'hl_stroke': '#D32F2F', 'hl_fill': '#FFEBEE', 'hl_text': '#B71C1C', 'hl_dot': '#D32F2F', | |
| 'bg_stroke': '#90A4AE', 'bg_fill': '#FAFAFA', | |
| 'bg_text': '#78909C', 'bg_line': '#CFD8DC', 'bg_dot': '#B0BEC5' | |
| } | |
| svg_id = f"svg_{str(uuid.uuid4())[:8]}" | |
| cx, cy = 300, 210 | |
| tx = 620 | |
| shapes = "" | |
| if gram_type == 'negative': | |
| col_om = c['hl_stroke'] if is_om else c['bg_stroke'] | |
| fill_om = c['hl_fill'] if is_peri else c['bg_fill'] | |
| w_om = "4" if is_om else "2" | |
| shapes += f'<rect x="{cx-200}" y="{cy-120}" width="400" height="240" rx="120" ry="120" fill="{fill_om}" stroke="{col_om}" stroke-width="{w_om}" />' | |
| col_cw = c['hl_stroke'] if is_cw else '#B0BEC5' | |
| w_cw = "3" if is_cw else "1.5" | |
| dash_cw = "0" if is_cw else "6,4" | |
| shapes += f'<rect x="{cx-170}" y="{cy-90}" width="340" height="180" rx="90" ry="90" fill="none" stroke="{col_cw}" stroke-width="{w_cw}" stroke-dasharray="{dash_cw}" />' | |
| col_im = c['hl_stroke'] if is_im else c['bg_stroke'] | |
| fill_im = c['hl_fill'] if is_cyto else c['bg_fill'] | |
| w_im = "4" if is_im else "2" | |
| shapes += f'<rect x="{cx-140}" y="{cy-60}" width="280" height="120" rx="60" ry="60" fill="{fill_im}" stroke="{col_im}" stroke-width="{w_im}" />' | |
| anchors = { | |
| "sec": (cx, cy-160), "om": (cx+200, cy-60), "peri": (cx+180, cy-30), | |
| "cw": (cx+170, cy), "im": (cx+140, cy+30), "cyto": (cx, cy) | |
| } | |
| else: | |
| col_cw = c['hl_stroke'] if is_cw else c['bg_stroke'] | |
| fill_bg = c['hl_fill'] if is_peri else c['bg_fill'] | |
| w_cw = "6" if is_cw else "4" | |
| shapes += f'<rect x="{cx-180}" y="{cy-100}" width="360" height="200" rx="100" ry="100" fill="{fill_bg}" stroke="{col_cw}" stroke-width="{w_cw}" stroke-opacity="0.7" />' | |
| col_im = c['hl_stroke'] if is_im else c['bg_stroke'] | |
| fill_im = c['hl_fill'] if is_cyto else c['bg_fill'] | |
| w_im = "4" if is_im else "2" | |
| shapes += f'<rect x="{cx-140}" y="{cy-60}" width="280" height="120" rx="60" ry="60" fill="{fill_im}" stroke="{col_im}" stroke-width="{w_im}" />' | |
| anchors = { | |
| "sec": (cx, cy-140), "om": (cx, cy), | |
| "peri": (cx+160, cy-40), "cw": (cx+180, cy-60), | |
| "im": (cx+140, cy+30), "cyto": (cx, cy) | |
| } | |
| shapes += f"""<g opacity="0.4"> | |
| <path d="M {cx-30} {cy-10} Q {cx} {cy-50} {cx+30} {cy-10} T {cx+60} {cy}" fill="none" stroke="#CFD8DC" stroke-width="3" /> | |
| <circle cx="{cx-40}" cy="{cy+20}" r="3" fill="#B0BEC5" /> <circle cx="{cx+20}" cy="{cy+30}" r="3" fill="#B0BEC5" /> | |
| </g>""" | |
| sec_svg = "" | |
| if is_sec: | |
| sec_svg = f"""<g transform="translate({cx}, {cy-170})"> | |
| <text x="0" y="0" text-anchor="middle" fill="{c['hl_text']}" font-weight="bold" font-family="'Lato', sans-serif" font-size="14">EXTRACELLULAR</text> | |
| <path d="M 0 5 L 0 30" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" /> | |
| </g>""" | |
| labels_config = [ | |
| ("Extracellular", "sec", is_sec), | |
| ("Outer membrane", "om", is_om), | |
| ("Periplasm", "peri", is_peri), | |
| ("Cell wall", "cw", is_cw), | |
| ("Cytoplasmic membrane", "im", is_im), | |
| ("Cytoplasm", "cyto", is_cyto) | |
| ] | |
| if gram_type == 'positive': | |
| labels_config = [l for l in labels_config if l[1] != 'om'] | |
| label_svg = "" | |
| y_start = 50 | |
| y_step = 60 | |
| for i, (text, key, active) in enumerate(labels_config): | |
| ty = y_start + i * y_step | |
| ex, ey = anchors.get(key, (0,0)) | |
| col_txt = c['hl_text'] if active else c['bg_text'] | |
| w_txt = "bold" if active else "normal" | |
| col_line = c['hl_stroke'] if active else c['bg_line'] | |
| w_line = "2.5" if active else "1.0" | |
| col_dot = c['hl_dot'] if active else c['bg_dot'] | |
| r_dot = "5" if active else "3" | |
| c1x, c1y = tx - 80, ty | |
| c2x, c2y = ex + 60, ey | |
| path_d = f"M {tx-10} {ty-5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}" | |
| label_svg += f""" | |
| <g> | |
| <text x="{tx}" y="{ty}" fill="{col_txt}" font-weight="{w_txt}" font-size="14" font-family="'Lato', sans-serif">{text}</text> | |
| <path d="{path_d}" fill="none" stroke="{col_line}" stroke-width="{w_line}" /> | |
| <circle cx="{ex}" cy="{ey}" r="{r_dot}" fill="{col_dot}" stroke="white" stroke-width="1" /> | |
| </g> | |
| """ | |
| final_svg = f"""<svg id="{svg_id}" width="100%" height="100%" viewBox="0 0 800 420" xmlns="http://www.w3.org/2000/svg"> | |
| <defs> | |
| <style>@import url('https://fonts.googleapis.com/css2?family=Lato:wght@400;700&display=swap'); text {{ font-family: 'Lato', sans-serif; }}</style> | |
| <marker id="arrow_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"><polygon points="0 0, 10 3.5, 0 7" fill="{c['hl_stroke']}" /></marker> | |
| </defs> | |
| <rect width="800" height="420" fill="white" /> | |
| {shapes} | |
| {sec_svg} | |
| {label_svg} | |
| <text x="400" y="400" text-anchor="middle" font-family="'Lato', sans-serif" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {std_target}</text> | |
| </svg>""" | |
| # 纯净 HTML,无按钮 | |
| html = f"<div style='text-align:center;'>{final_svg}</div>" | |
| return html | |
| # ========================== | |
| # 4. Wrapped Attention Heatmap | |
| # ========================== | |
| def draw_wrapped_attention_heatmap(weights, sequence, chars_per_line=60): | |
| if weights.max() > 0: weights = (weights - weights.min()) / (weights.max() - weights.min()) | |
| seq_len = len(sequence) | |
| num_rows = (seq_len + chars_per_line - 1) // chars_per_line | |
| fig_height = max(2, num_rows * 0.8) | |
| fig, axes = plt.subplots(num_rows, 1, figsize=(10, fig_height), dpi=150) | |
| plt.rcParams['font.family'] = 'sans-serif' | |
| plt.rcParams['font.sans-serif'] = ['Lato', 'monospace'] | |
| if num_rows == 1: axes = [axes] | |
| for i in range(num_rows): | |
| ax = axes[i] | |
| start_idx = i * chars_per_line | |
| end_idx = min((i + 1) * chars_per_line, seq_len) | |
| sub_weights = weights[start_idx:end_idx] | |
| sub_seq = sequence[start_idx:end_idx] | |
| current_len = len(sub_seq) | |
| display_weights = np.zeros((1, chars_per_line)) | |
| display_weights[0, :current_len] = sub_weights | |
| im = ax.imshow(display_weights, cmap='Reds', aspect='auto', vmin=0, vmax=1) | |
| for j, char in enumerate(sub_seq): | |
| ax.text(j, 0, char, ha='center', va='center', fontsize=9, color='black', fontweight='bold') | |
| ax.set_xticks(np.arange(chars_per_line) - 0.5, minor=True) | |
| ax.set_yticks([]) | |
| ax.grid(which="minor", color="w", linestyle='-', linewidth=1) | |
| ax.tick_params(which="minor", bottom=False, left=False) | |
| ax.tick_params(which="major", bottom=False, left=False, labelbottom=False) | |
| for spine in ax.spines.values(): spine.set_visible(False) | |
| ax.set_ylabel(f"{start_idx+1}", rotation=0, ha='right', va='center', fontsize=10, color='#546E7A') | |
| plt.tight_layout() | |
| fig.suptitle(f"Attention Heatmap (Sequence Length: {seq_len})", fontsize=12, fontweight='bold', color='#37474F', y=1.02) | |
| return fig | |
| # ========================== | |
| # 5. 预测主逻辑 | |
| # ========================== | |
| def predict(sequence_input): | |
| if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input") | |
| seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input | |
| seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024] | |
| with torch.no_grad(): | |
| inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE) | |
| outputs = plm_model(**inputs) | |
| logits, pooling_weights = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1]) | |
| probs = F.softmax(logits, dim=1)[0] | |
| top_id = torch.max(probs, dim=0)[1].item() | |
| raw_label = idx_to_label[top_id] | |
| clean_top_label = clean_label_name(raw_label) | |
| confidences = {} | |
| for i, p in enumerate(probs): | |
| orig_name = idx_to_label[i] | |
| std_name = clean_label_name(orig_name) | |
| confidences[std_name] = float(p) | |
| svg = generate_scientific_svg(clean_top_label) | |
| heatmap = draw_wrapped_attention_heatmap(pooling_weights[0].cpu().numpy(), seq, chars_per_line=60) | |
| return confidences, svg, heatmap | |
| # ========================== | |
| # 6. UI Layout (Enhanced Header) | |
| # ========================== | |
| layout_css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap'); | |
| body, button, input, textarea, .gradio-container { font-family: 'Lato', sans-serif !important; } | |
| /* Header 样式 */ | |
| .header-div { | |
| background: linear-gradient(to right, #E0F7FA, #E1F5FE); | |
| padding: 1.5rem; border-radius: 8px; margin-bottom: 20px; | |
| text-align: center; border: 1px solid #B3E5FC; | |
| } | |
| .header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; } | |
| .header-sub { font-size: 1.0rem; color: #0277BD; margin-bottom: 12px; } | |
| /* Badge 链接样式 */ | |
| .badge-container { display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; } | |
| .badge-link { | |
| text-decoration: none; display: inline-flex; align-items: center; | |
| background-color: #ffffff; color: #334155; | |
| padding: 4px 10px; border-radius: 6px; | |
| font-size: 0.85rem; font-weight: 600; | |
| border: 1px solid #cbd5e1; transition: all 0.2s; | |
| } | |
| .badge-link:hover { background-color: #f1f5f9; border-color: #0288D1; color: #0288D1; } | |
| .badge-icon { margin-right: 5px; } | |
| /* Panel 样式 */ | |
| .panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; } | |
| .panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; } | |
| .panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; } | |
| """ | |
| theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px") | |
| with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app: | |
| # --- Enhanced Header --- | |
| gr.HTML(""" | |
| <div class="header-div"> | |
| <div class="header-title">LocPred-Prok</div> | |
| <div class="header-sub">Dual-Branch Deep Learning for Prokaryotic Subcellular Localization</div> | |
| <div class="badge-container"> | |
| <a href="https://github.com/protein-ailab/LocPred-Prok" target="_blank" class="badge-link"> | |
| GitHub | |
| </a> | |
| <a href="#" target="_blank" class="badge-link"> | |
| Paper | |
| </a> | |
| <span class="badge-link" style="cursor:default"> | |
| ESM-2 | |
| </span> | |
| <span class="badge-link" style="cursor:default"> | |
| ⚖️ MIT License | |
| </span> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(elem_classes="panel-card"): | |
| gr.Markdown("<div class='panel-header'>Sequence Input</div>") | |
| sequence_input = gr.Textbox(lines=8, show_label=False, placeholder=">Sequence...") | |
| with gr.Row(): | |
| clear_btn = gr.ClearButton(sequence_input, value="Clear") | |
| submit_btn = gr.Button("Predict Analysis", variant="primary") | |
| gr.Examples([[">A0A0C5CJR8|Extracellular\nMSKAKDKAIVSAAQASTAYSQIDSFSHLYDRGGNLTINGKPSYTVDQAATQLLRDGAAYRDFDGNGKIDLTYTFLTSASSSTMNKHGISGFSQFNAQQKAQAALAMQSWSDVANVTFTEKASGGDGHMTFGNYSSGQDGAAAFAYLPGTGAGYDGTSWYLTNNSYTPNKTPDLNNYGRQTLTHEIGHTLGLAHPGDYNAGEGAPTYNDATYGQDTRGYSLMSYWSESNTNQNFSKGGVEAYASGPLIDDIAAIQKLYGANYNTRAGDTTYGFNSNTGRDFLSATSNADKLVFSVWDGGGNDTLDFSGFTQNQKINLNEASFSDVGGLVGNVSIAKGVTIENAFGGAGNDLIIGNNAANVIKGGAGNDLIYGAGGADQLWGGAGNDTFVFGASSDSKPGAADKIFDFTSGSDKIDLSGITKGAGLTFVNAFTGHAGDAVLTYAAGTNLGTLAVDFSGHGVADFLVTTVGQAAVSDIVA"]], inputs=sequence_input, label=None) | |
| with gr.Column(elem_classes="panel-card"): | |
| gr.Markdown("<div class='panel-header'>Sublocalization Visualization</div>") | |
| output_svg = gr.HTML(label="Visual", show_label=False) | |
| with gr.Row(): | |
| with gr.Column(elem_classes="panel-card"): | |
| gr.Markdown("<div class='panel-header'>Prediction Confidence</div>") | |
| output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False) | |
| with gr.Column(elem_classes="panel-card"): | |
| gr.Markdown("<div class='panel-header'>Attention Heatmap (Sequence Weights)</div>") | |
| output_plot = gr.Plot(label="Attention", show_label=False) | |
| submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot]) | |
| clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot]) | |
| app.launch() |