import os
import json
import re
import uuid
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModel
# ==========================
# 0. 环境初始化
# ==========================
plt.switch_backend('Agg')
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
import shutil
for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path, exist_ok=True)
# ==========================
# 1. 模型架构
# ==========================
class AttentionPooling(nn.Module):
def __init__(self, d_model):
super().__init__()
self.attention_net = nn.Linear(d_model, 1)
def forward(self, x, mask):
attn_logits = self.attention_net(x).squeeze(2)
attn_logits.masked_fill_(mask == 0, -float('inf'))
attn_weights = F.softmax(attn_logits, dim=1)
return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
class ProtDualBranchEnhancedClassifier(nn.Module):
def __init__(self, d_model, projection_dim, num_classes, dropout, kernel_size):
super().__init__()
self.cls_projector = nn.Linear(d_model, projection_dim)
self.token_refiner = nn.Sequential(nn.Conv1d(d_model, d_model, kernel_size, padding='same'), nn.ReLU())
self.attention_pooling = AttentionPooling(d_model)
self.tok_projector = nn.Linear(d_model, projection_dim)
fused_dim = projection_dim * 2
self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
self.classifier_head = nn.Sequential(nn.LayerNorm(fused_dim), nn.Linear(fused_dim, fused_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fused_dim * 2, num_classes))
def forward(self, cls_embedding, token_embeddings, mask):
z_cls = self.cls_projector(cls_embedding)
tok_emb_permuted = token_embeddings.permute(0, 2, 1)
refined_tok_emb = self.token_refiner(tok_emb_permuted).permute(0, 2, 1)
z_tok_pooled, pooling_weights = self.attention_pooling(refined_tok_emb, mask)
z_tok = self.tok_projector(z_tok_pooled)
z_fused_concat = torch.cat([z_cls, z_tok], dim=1)
gate_values = self.gate(z_fused_concat)
z_fused_gated = z_fused_concat * gate_values
return self.classifier_head(z_fused_gated), pooling_weights
# ==========================
# 2. 加载模型
# ==========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
LABEL_MAP_PATH = "label_map.json"
if not os.path.exists(LABEL_MAP_PATH): raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
if not os.path.exists(CLASSIFIER_PATH): raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
with open(LABEL_MAP_PATH, 'r') as f:
label_to_idx = json.load(f)
idx_to_label = {v: k for k, v in label_to_idx.items()}
NUM_CLASSES = len(idx_to_label)
D_MODEL = 640
print("🔹 Loading models...")
tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
classifier.eval()
print("✅ Ready.")
# ==========================
# 3. 标签标准化映射
# ==========================
def clean_label_name(raw_label):
raw = raw_label.strip()
mapping = {
"OuterMembrane": "Outer membrane", "Outer membrane": "Outer membrane",
"Periplasmic": "Periplasm", "Periplasm": "Periplasm",
"Cellwall": "Cell wall", "Cell wall": "Cell wall",
"CYtoplasmicMembrane": "Cytoplasmic membrane", "InnerMembrane": "Cytoplasmic membrane",
"Cytoplasmic": "Cytoplasm", "Cytoplasm": "Cytoplasm",
"Extracellular": "Extracellular", "Secreted": "Extracellular"
}
if raw in mapping: return mapping[raw]
raw_lower = raw.lower()
for k, v in mapping.items():
if k.lower() == raw_lower: return v
return raw
# ==========================
# 4. SVG 引擎 (纯净展示版 - 无下载按钮)
# ==========================
def infer_gram_type(std_label):
if std_label in ["Outer membrane", "Periplasm"]: return "negative"
if std_label == "Cell wall": return "positive"
return "negative"
def generate_scientific_svg(target_class):
std_target = clean_label_name(target_class)
gram_type = infer_gram_type(std_target)
is_sec = (std_target == "Extracellular")
is_om = (std_target == "Outer membrane")
is_peri = (std_target == "Periplasm")
is_cw = (std_target == "Cell wall")
is_im = (std_target == "Cytoplasmic membrane")
is_cyto = (std_target == "Cytoplasm")
c = {
'hl_stroke': '#D32F2F', 'hl_fill': '#FFEBEE', 'hl_text': '#B71C1C', 'hl_dot': '#D32F2F',
'bg_stroke': '#90A4AE', 'bg_fill': '#FAFAFA',
'bg_text': '#78909C', 'bg_line': '#CFD8DC', 'bg_dot': '#B0BEC5'
}
svg_id = f"svg_{str(uuid.uuid4())[:8]}"
cx, cy = 300, 210
tx = 620
shapes = ""
if gram_type == 'negative':
col_om = c['hl_stroke'] if is_om else c['bg_stroke']
fill_om = c['hl_fill'] if is_peri else c['bg_fill']
w_om = "4" if is_om else "2"
shapes += f''
col_cw = c['hl_stroke'] if is_cw else '#B0BEC5'
w_cw = "3" if is_cw else "1.5"
dash_cw = "0" if is_cw else "6,4"
shapes += f''
col_im = c['hl_stroke'] if is_im else c['bg_stroke']
fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
w_im = "4" if is_im else "2"
shapes += f''
anchors = {
"sec": (cx, cy-160), "om": (cx+200, cy-60), "peri": (cx+180, cy-30),
"cw": (cx+170, cy), "im": (cx+140, cy+30), "cyto": (cx, cy)
}
else:
col_cw = c['hl_stroke'] if is_cw else c['bg_stroke']
fill_bg = c['hl_fill'] if is_peri else c['bg_fill']
w_cw = "6" if is_cw else "4"
shapes += f''
col_im = c['hl_stroke'] if is_im else c['bg_stroke']
fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
w_im = "4" if is_im else "2"
shapes += f''
anchors = {
"sec": (cx, cy-140), "om": (cx, cy),
"peri": (cx+160, cy-40), "cw": (cx+180, cy-60),
"im": (cx+140, cy+30), "cyto": (cx, cy)
}
shapes += f"""
"""
sec_svg = ""
if is_sec:
sec_svg = f"""
EXTRACELLULAR
"""
labels_config = [
("Extracellular", "sec", is_sec),
("Outer membrane", "om", is_om),
("Periplasm", "peri", is_peri),
("Cell wall", "cw", is_cw),
("Cytoplasmic membrane", "im", is_im),
("Cytoplasm", "cyto", is_cyto)
]
if gram_type == 'positive':
labels_config = [l for l in labels_config if l[1] != 'om']
label_svg = ""
y_start = 50
y_step = 60
for i, (text, key, active) in enumerate(labels_config):
ty = y_start + i * y_step
ex, ey = anchors.get(key, (0,0))
col_txt = c['hl_text'] if active else c['bg_text']
w_txt = "bold" if active else "normal"
col_line = c['hl_stroke'] if active else c['bg_line']
w_line = "2.5" if active else "1.0"
col_dot = c['hl_dot'] if active else c['bg_dot']
r_dot = "5" if active else "3"
c1x, c1y = tx - 80, ty
c2x, c2y = ex + 60, ey
path_d = f"M {tx-10} {ty-5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
label_svg += f"""
{text}
"""
final_svg = f""""""
# 纯净 HTML,无按钮
html = f"
{final_svg}
"
return html
# ==========================
# 4. Wrapped Attention Heatmap
# ==========================
def draw_wrapped_attention_heatmap(weights, sequence, chars_per_line=60):
if weights.max() > 0: weights = (weights - weights.min()) / (weights.max() - weights.min())
seq_len = len(sequence)
num_rows = (seq_len + chars_per_line - 1) // chars_per_line
fig_height = max(2, num_rows * 0.8)
fig, axes = plt.subplots(num_rows, 1, figsize=(10, fig_height), dpi=150)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Lato', 'monospace']
if num_rows == 1: axes = [axes]
for i in range(num_rows):
ax = axes[i]
start_idx = i * chars_per_line
end_idx = min((i + 1) * chars_per_line, seq_len)
sub_weights = weights[start_idx:end_idx]
sub_seq = sequence[start_idx:end_idx]
current_len = len(sub_seq)
display_weights = np.zeros((1, chars_per_line))
display_weights[0, :current_len] = sub_weights
im = ax.imshow(display_weights, cmap='Reds', aspect='auto', vmin=0, vmax=1)
for j, char in enumerate(sub_seq):
ax.text(j, 0, char, ha='center', va='center', fontsize=9, color='black', fontweight='bold')
ax.set_xticks(np.arange(chars_per_line) - 0.5, minor=True)
ax.set_yticks([])
ax.grid(which="minor", color="w", linestyle='-', linewidth=1)
ax.tick_params(which="minor", bottom=False, left=False)
ax.tick_params(which="major", bottom=False, left=False, labelbottom=False)
for spine in ax.spines.values(): spine.set_visible(False)
ax.set_ylabel(f"{start_idx+1}", rotation=0, ha='right', va='center', fontsize=10, color='#546E7A')
plt.tight_layout()
fig.suptitle(f"Attention Heatmap (Sequence Length: {seq_len})", fontsize=12, fontweight='bold', color='#37474F', y=1.02)
return fig
# ==========================
# 5. 预测主逻辑
# ==========================
def predict(sequence_input):
if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
with torch.no_grad():
inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
outputs = plm_model(**inputs)
logits, pooling_weights = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
probs = F.softmax(logits, dim=1)[0]
top_id = torch.max(probs, dim=0)[1].item()
raw_label = idx_to_label[top_id]
clean_top_label = clean_label_name(raw_label)
confidences = {}
for i, p in enumerate(probs):
orig_name = idx_to_label[i]
std_name = clean_label_name(orig_name)
confidences[std_name] = float(p)
svg = generate_scientific_svg(clean_top_label)
heatmap = draw_wrapped_attention_heatmap(pooling_weights[0].cpu().numpy(), seq, chars_per_line=60)
return confidences, svg, heatmap
# ==========================
# 6. UI Layout (Enhanced Header)
# ==========================
layout_css = """
@import url('https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap');
body, button, input, textarea, .gradio-container { font-family: 'Lato', sans-serif !important; }
/* Header 样式 */
.header-div {
background: linear-gradient(to right, #E0F7FA, #E1F5FE);
padding: 1.5rem; border-radius: 8px; margin-bottom: 20px;
text-align: center; border: 1px solid #B3E5FC;
}
.header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
.header-sub { font-size: 1.0rem; color: #0277BD; margin-bottom: 12px; }
/* Badge 链接样式 */
.badge-container { display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; }
.badge-link {
text-decoration: none; display: inline-flex; align-items: center;
background-color: #ffffff; color: #334155;
padding: 4px 10px; border-radius: 6px;
font-size: 0.85rem; font-weight: 600;
border: 1px solid #cbd5e1; transition: all 0.2s;
}
.badge-link:hover { background-color: #f1f5f9; border-color: #0288D1; color: #0288D1; }
.badge-icon { margin-right: 5px; }
/* Panel 样式 */
.panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
.panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
.panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
"""
theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
# --- Enhanced Header ---
gr.HTML("""
""")
with gr.Row():
with gr.Column(elem_classes="panel-card"):
gr.Markdown("")
sequence_input = gr.Textbox(lines=8, show_label=False, placeholder=">Sequence...")
with gr.Row():
clear_btn = gr.ClearButton(sequence_input, value="Clear")
submit_btn = gr.Button("Predict Analysis", variant="primary")
gr.Examples([[">A0A0C5CJR8|Extracellular\nMSKAKDKAIVSAAQASTAYSQIDSFSHLYDRGGNLTINGKPSYTVDQAATQLLRDGAAYRDFDGNGKIDLTYTFLTSASSSTMNKHGISGFSQFNAQQKAQAALAMQSWSDVANVTFTEKASGGDGHMTFGNYSSGQDGAAAFAYLPGTGAGYDGTSWYLTNNSYTPNKTPDLNNYGRQTLTHEIGHTLGLAHPGDYNAGEGAPTYNDATYGQDTRGYSLMSYWSESNTNQNFSKGGVEAYASGPLIDDIAAIQKLYGANYNTRAGDTTYGFNSNTGRDFLSATSNADKLVFSVWDGGGNDTLDFSGFTQNQKINLNEASFSDVGGLVGNVSIAKGVTIENAFGGAGNDLIIGNNAANVIKGGAGNDLIYGAGGADQLWGGAGNDTFVFGASSDSKPGAADKIFDFTSGSDKIDLSGITKGAGLTFVNAFTGHAGDAVLTYAAGTNLGTLAVDFSGHGVADFLVTTVGQAAVSDIVA"]], inputs=sequence_input, label=None)
with gr.Column(elem_classes="panel-card"):
gr.Markdown("")
output_svg = gr.HTML(label="Visual", show_label=False)
with gr.Row():
with gr.Column(elem_classes="panel-card"):
gr.Markdown("")
output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
with gr.Column(elem_classes="panel-card"):
gr.Markdown("")
output_plot = gr.Plot(label="Attention", show_label=False)
submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])
clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
app.launch()