Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import re
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
|
@@ -9,10 +27,8 @@ import matplotlib.pyplot as plt
|
|
| 9 |
import numpy as np
|
| 10 |
from transformers import AutoTokenizer, AutoModel
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
# ==========================
|
| 15 |
-
plt.switch_backend('Agg')
|
| 16 |
os.environ["HF_HOME"] = "/tmp/hf_cache"
|
| 17 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
|
| 18 |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
|
|
@@ -22,17 +38,16 @@ for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
|
|
| 22 |
shutil.rmtree(path, ignore_errors=True)
|
| 23 |
os.makedirs(path, exist_ok=True)
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
# 1. 模型架构 (含 Attention 输出)
|
| 27 |
-
# ==========================
|
| 28 |
class AttentionPooling(nn.Module):
|
| 29 |
def __init__(self, d_model):
|
| 30 |
super().__init__()
|
| 31 |
self.attention_net = nn.Linear(d_model, 1)
|
| 32 |
|
| 33 |
def forward(self, x, mask):
|
| 34 |
-
|
| 35 |
-
attn_logits.
|
|
|
|
| 36 |
attn_weights = F.softmax(attn_logits, dim=1)
|
| 37 |
return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
|
| 38 |
|
|
@@ -58,16 +73,16 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
|
|
| 58 |
z_fused_gated = z_fused_concat * gate_values
|
| 59 |
return self.classifier_head(z_fused_gated), pooling_weights
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
# 2. 加载模型
|
| 63 |
-
# ==========================
|
| 64 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 65 |
PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
|
| 66 |
CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
|
| 67 |
LABEL_MAP_PATH = "label_map.json"
|
| 68 |
|
| 69 |
-
if not os.path.exists(LABEL_MAP_PATH):
|
| 70 |
-
|
|
|
|
|
|
|
| 71 |
|
| 72 |
with open(LABEL_MAP_PATH, 'r') as f:
|
| 73 |
label_to_idx = json.load(f)
|
|
@@ -81,15 +96,14 @@ plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
|
|
| 81 |
classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
|
| 82 |
classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
|
| 83 |
classifier.eval()
|
| 84 |
-
print("✅
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
| 88 |
-
#
|
| 89 |
-
def generate_bacterial_svg(target_class):
|
| 90 |
target = target_class.lower() if target_class else ""
|
| 91 |
-
|
| 92 |
-
#
|
| 93 |
is_sec = "extracellular" in target or "secreted" in target
|
| 94 |
is_om = "outer membrane" in target
|
| 95 |
is_peri = "periplasm" in target
|
|
@@ -97,28 +111,27 @@ def generate_bacterial_svg(target_class):
|
|
| 97 |
is_im = "plasma membrane" in target or "inner membrane" in target
|
| 98 |
is_cyto = "cytoplasm" in target or "cytosol" in target
|
| 99 |
|
| 100 |
-
#
|
|
|
|
| 101 |
c = {
|
| 102 |
-
"hl_stroke": "
|
| 103 |
-
"bg_stroke": "
|
| 104 |
-
"bg_text": "
|
| 105 |
}
|
| 106 |
|
| 107 |
-
# 结构样式 (修复了 width_norm 变量名错误)
|
| 108 |
def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
|
| 109 |
-
if active:
|
| 110 |
-
|
| 111 |
return base_fill, base_stroke, w_norm
|
| 112 |
|
| 113 |
-
om_f, om_s, om_w = style(is_peri, c[
|
| 114 |
-
cw_s = c[
|
| 115 |
cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
|
| 116 |
-
im_f, im_s, im_w = style(is_cyto, c[
|
| 117 |
|
| 118 |
-
# 标签样式
|
| 119 |
def label_style(active):
|
| 120 |
-
if active: return c[
|
| 121 |
-
return c[
|
| 122 |
|
| 123 |
l_sec = label_style(is_sec)
|
| 124 |
l_om = label_style(is_om)
|
|
@@ -127,216 +140,278 @@ def generate_bacterial_svg(target_class):
|
|
| 127 |
l_im = label_style(is_im)
|
| 128 |
l_cyto = label_style(is_cyto)
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
targets = {
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
}
|
| 143 |
-
|
| 144 |
-
#
|
| 145 |
text_y = {
|
| 146 |
-
|
|
|
|
| 147 |
}
|
| 148 |
|
| 149 |
-
# 贝塞尔曲线生成器
|
| 150 |
def draw_connector(key, style_tuple, label_text):
|
| 151 |
txt_col, weight, line_col, width, dot_col, r = style_tuple
|
| 152 |
tx_pos, ty_pos = tx, text_y[key]
|
| 153 |
ex, ey = targets[key]
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
c1x, c1y = tx_pos - 80, ty_pos
|
| 157 |
-
c2x, c2y = ex + 60, ey
|
| 158 |
-
|
| 159 |
path = f"M {tx_pos - 10} {ty_pos - 5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
|
| 160 |
-
|
| 161 |
-
return f"""
|
| 162 |
<g>
|
| 163 |
-
<text x="{tx_pos}" y="{ty_pos}" fill="{txt_col}" font-weight="{weight}" font-size="14" font-family="Arial">{label_text}</text>
|
| 164 |
-
<path d="{path}" fill="none" stroke="{line_col}" stroke-width="{width}" />
|
| 165 |
<circle cx="{ex}" cy="{ey}" r="{r}" fill="{dot_col}" stroke="white" stroke-width="1" />
|
| 166 |
</g>
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
</g>
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
def draw_attention_heatmap_strip(weights, sequence):
|
| 193 |
-
"""
|
| 194 |
-
Draws a 1D Heatmap Strip for Attention Weights.
|
| 195 |
-
Standard Bioinformatics visualization style.
|
| 196 |
-
"""
|
| 197 |
-
# 归一化 (0-1)
|
| 198 |
if weights.max() > 0:
|
| 199 |
weights = (weights - weights.min()) / (weights.max() - weights.min())
|
| 200 |
-
|
| 201 |
-
# 准备数据 (Reshape to 2D for imshow: [1, Seq_Len])
|
| 202 |
data = weights.reshape(1, -1)
|
| 203 |
-
|
| 204 |
-
fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150) # 长条形
|
| 205 |
-
|
| 206 |
-
# 绘制热图 (使用 Reds 色系,颜色越深 Attention 越高)
|
| 207 |
im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
ax.set_title("Sequence Attention Heatmap (High Color = Key Feature)", fontsize=10, fontweight='bold', color='#37474F', pad=10)
|
| 211 |
-
ax.set_xlabel("Residue Position", fontsize=9)
|
| 212 |
-
|
| 213 |
-
# 隐藏 Y 轴刻度
|
| 214 |
ax.set_yticks([])
|
| 215 |
-
|
| 216 |
-
# 添加 Colorbar
|
| 217 |
cbar = plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
|
| 218 |
cbar.ax.tick_params(labelsize=8)
|
| 219 |
cbar.outline.set_visible(False)
|
| 220 |
-
|
| 221 |
-
# 隐藏边框
|
| 222 |
for spine in ax.spines.values():
|
| 223 |
spine.set_visible(False)
|
| 224 |
-
|
| 225 |
plt.tight_layout()
|
| 226 |
return fig
|
| 227 |
|
| 228 |
-
#
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
|
| 233 |
-
|
| 234 |
seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
|
| 235 |
seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
|
| 236 |
-
if not seq:
|
| 237 |
-
|
|
|
|
| 238 |
with torch.no_grad():
|
| 239 |
-
inputs = tokenizer(seq, return_tensors=
|
| 240 |
outputs = plm_model(**inputs)
|
| 241 |
-
|
| 242 |
hidden_states = outputs.last_hidden_state
|
| 243 |
cls_embedding = hidden_states[:, 0, :]
|
| 244 |
token_embeddings = hidden_states[:, 1:-1, :]
|
| 245 |
token_mask = inputs['attention_mask'][:, 1:-1]
|
| 246 |
-
|
| 247 |
logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
|
| 248 |
probs = F.softmax(logits, dim=1)[0]
|
| 249 |
-
|
| 250 |
-
# 1. 结果
|
| 251 |
top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
|
| 252 |
confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
svg = generate_bacterial_svg(top_label)
|
| 256 |
-
|
| 257 |
-
# 3. Heatmap (Panel D)
|
| 258 |
w_np = pooling_weights[0].cpu().numpy()
|
| 259 |
heatmap_plot = draw_attention_heatmap_strip(w_np, seq)
|
| 260 |
-
|
| 261 |
return confidences, svg, heatmap_plot
|
| 262 |
|
| 263 |
-
#
|
| 264 |
-
# 6. UI Layout (4-Block)
|
| 265 |
-
# ==========================
|
| 266 |
layout_css = """
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
background: linear-gradient(to right, #E0F7FA, #E1F5FE);
|
| 273 |
-
padding: 1.5rem;
|
| 274 |
-
border-radius: 8px;
|
| 275 |
-
margin-bottom: 20px;
|
| 276 |
-
text-align: center;
|
| 277 |
-
border: 1px solid #B3E5FC;
|
| 278 |
-
}
|
| 279 |
-
.header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
|
| 280 |
-
.header-sub { font-size: 1.0rem; color: #0277BD; }
|
| 281 |
-
|
| 282 |
-
/* Panel Cards */
|
| 283 |
-
.panel-card {
|
| 284 |
-
border: 1px solid #e2e8f0;
|
| 285 |
-
border-radius: 8px;
|
| 286 |
-
padding: 15px;
|
| 287 |
-
background: white;
|
| 288 |
-
height: 100%;
|
| 289 |
-
display: flex;
|
| 290 |
-
flex-direction: column;
|
| 291 |
-
}
|
| 292 |
-
.panel-header {
|
| 293 |
-
font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9;
|
| 294 |
-
padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem;
|
| 295 |
}
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
| 299 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
"""
|
| 301 |
|
|
|
|
| 302 |
theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
gr.HTML("""
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
""")
|
| 312 |
|
| 313 |
-
# Row 1
|
| 314 |
with gr.Row():
|
| 315 |
-
with gr.Column(elem_classes="panel-card"):
|
| 316 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
|
| 317 |
-
sequence_input = gr.Textbox(lines=8, show_label=False, placeholder=">Sequence...")
|
| 318 |
with gr.Row():
|
| 319 |
clear_btn = gr.ClearButton(sequence_input, value="Clear")
|
| 320 |
submit_btn = gr.Button("Predict Analysis", variant="primary")
|
| 321 |
-
gr.
|
| 322 |
-
[
|
| 323 |
-
|
|
|
|
| 324 |
|
| 325 |
-
with gr.Column(elem_classes="panel-card"):
|
| 326 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
|
| 327 |
output_svg = gr.HTML(label="Visual", show_label=False)
|
| 328 |
|
| 329 |
-
# Row 2
|
| 330 |
with gr.Row():
|
| 331 |
-
with gr.Column(elem_classes="panel-card"):
|
| 332 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
|
| 333 |
output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
|
| 334 |
-
|
| 335 |
-
with gr.Column(elem_classes="panel-card"):
|
| 336 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Learned Attention Heatmap</div>")
|
| 337 |
output_plot = gr.Plot(label="Attention", show_label=False)
|
| 338 |
|
| 339 |
-
submit_btn.click(fn=predict, inputs=sequence_input, outputs=[output_label, output_svg, output_plot])
|
| 340 |
clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
|
| 341 |
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Improved version of your LocPred-Prok app
|
| 2 |
+
# (Tailwind layout, responsive SVG, download PNG, dark/light auto theme)
|
| 3 |
+
|
| 4 |
+
# NOTE: This is a placeholder structure.
|
| 5 |
+
# I will generate the full, ready-to-run file in the next update.
|
| 6 |
+
|
| 7 |
+
# --- START OF FILE ---
|
| 8 |
+
|
| 9 |
+
# Improved, complete LocPred-Prok Gradio app
|
| 10 |
+
# Features added/fixed:
|
| 11 |
+
# - Responsive, centered SVG that supports horizontal/circular layouts and high-res rendering
|
| 12 |
+
# - Dark / light automatic color adaptation (via prefers-color-scheme)
|
| 13 |
+
# - Client-side SVG -> PNG download buttons (no server libs required)
|
| 14 |
+
# - Tailwind CDN used for layout utilities (works inside Gradio HTML panel)
|
| 15 |
+
# - Tailwind-like alignment applied to main layout
|
| 16 |
+
# - Keeps original model loading and prediction logic
|
| 17 |
+
|
| 18 |
import os
|
| 19 |
import json
|
| 20 |
import re
|
| 21 |
+
import uuid
|
| 22 |
import torch
|
| 23 |
import torch.nn as nn
|
| 24 |
import torch.nn.functional as F
|
|
|
|
| 27 |
import numpy as np
|
| 28 |
from transformers import AutoTokenizer, AutoModel
|
| 29 |
|
| 30 |
+
# ---------- Environment & cache (same as original) ----------
|
| 31 |
+
plt.switch_backend('Agg')
|
|
|
|
|
|
|
| 32 |
os.environ["HF_HOME"] = "/tmp/hf_cache"
|
| 33 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
|
| 34 |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
|
|
|
|
| 38 |
shutil.rmtree(path, ignore_errors=True)
|
| 39 |
os.makedirs(path, exist_ok=True)
|
| 40 |
|
| 41 |
+
# ---------- Model architecture (unchanged core, with attention output) ----------
|
|
|
|
|
|
|
| 42 |
class AttentionPooling(nn.Module):
|
| 43 |
def __init__(self, d_model):
|
| 44 |
super().__init__()
|
| 45 |
self.attention_net = nn.Linear(d_model, 1)
|
| 46 |
|
| 47 |
def forward(self, x, mask):
|
| 48 |
+
# x: [B, L, D], mask: [B, L]
|
| 49 |
+
attn_logits = self.attention_net(x).squeeze(2) # [B, L]
|
| 50 |
+
attn_logits = attn_logits.masked_fill(mask == 0, -1e9)
|
| 51 |
attn_weights = F.softmax(attn_logits, dim=1)
|
| 52 |
return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
|
| 53 |
|
|
|
|
| 73 |
z_fused_gated = z_fused_concat * gate_values
|
| 74 |
return self.classifier_head(z_fused_gated), pooling_weights
|
| 75 |
|
| 76 |
+
# ---------- Load PLM + classifier ----------
|
|
|
|
|
|
|
| 77 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 78 |
PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
|
| 79 |
CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
|
| 80 |
LABEL_MAP_PATH = "label_map.json"
|
| 81 |
|
| 82 |
+
if not os.path.exists(LABEL_MAP_PATH):
|
| 83 |
+
raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
|
| 84 |
+
if not os.path.exists(CLASSIFIER_PATH):
|
| 85 |
+
raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
|
| 86 |
|
| 87 |
with open(LABEL_MAP_PATH, 'r') as f:
|
| 88 |
label_to_idx = json.load(f)
|
|
|
|
| 96 |
classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
|
| 97 |
classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
|
| 98 |
classifier.eval()
|
| 99 |
+
print("✅ Models loaded and ready.")
|
| 100 |
|
| 101 |
+
# ---------- SVG Generator with layout options and responsive wrapper ----------
|
| 102 |
+
def generate_bacterial_svg(target_class, layout='circular', high_res=False):
|
| 103 |
+
# Normalize target
|
|
|
|
| 104 |
target = target_class.lower() if target_class else ""
|
| 105 |
+
|
| 106 |
+
# Determine active compartments
|
| 107 |
is_sec = "extracellular" in target or "secreted" in target
|
| 108 |
is_om = "outer membrane" in target
|
| 109 |
is_peri = "periplasm" in target
|
|
|
|
| 111 |
is_im = "plasma membrane" in target or "inner membrane" in target
|
| 112 |
is_cyto = "cytoplasm" in target or "cytosol" in target
|
| 113 |
|
| 114 |
+
# Color tokens using CSS variables (support dark/light)
|
| 115 |
+
# Colors are referenced as var(--color-...)
|
| 116 |
c = {
|
| 117 |
+
"hl_stroke": "var(--hl-stroke)", "hl_fill": "var(--hl-fill)", "hl_text": "var(--hl-text)", "hl_dot": "var(--hl-dot)",
|
| 118 |
+
"bg_stroke": "var(--bg-stroke)", "bg_fill_om": "var(--bg-fill-om)", "bg_fill_im": "var(--bg-fill-im)",
|
| 119 |
+
"bg_text": "var(--bg-text)", "bg_line": "var(--bg-line)", "bg_dot": "var(--bg-dot)"
|
| 120 |
}
|
| 121 |
|
|
|
|
| 122 |
def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
|
| 123 |
+
if active:
|
| 124 |
+
return c['hl_fill'], c['hl_stroke'], w_act
|
| 125 |
return base_fill, base_stroke, w_norm
|
| 126 |
|
| 127 |
+
om_f, om_s, om_w = style(is_peri, c['bg_fill_om'], c['hl_stroke'] if is_om else c['bg_stroke'])
|
| 128 |
+
cw_s = c['hl_stroke'] if is_cw else "var(--muted)"
|
| 129 |
cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
|
| 130 |
+
im_f, im_s, im_w = style(is_cyto, c['bg_fill_im'], c['hl_stroke'] if is_im else c['bg_stroke'])
|
| 131 |
|
|
|
|
| 132 |
def label_style(active):
|
| 133 |
+
if active: return c['hl_text'], 'bold', c['hl_stroke'], '2.5', c['hl_dot'], '5'
|
| 134 |
+
return c['bg_text'], 'normal', c['bg_line'], '1.5', c['bg_dot'], '3'
|
| 135 |
|
| 136 |
l_sec = label_style(is_sec)
|
| 137 |
l_om = label_style(is_om)
|
|
|
|
| 140 |
l_im = label_style(is_im)
|
| 141 |
l_cyto = label_style(is_cyto)
|
| 142 |
|
| 143 |
+
# Size and viewBox (increase resolution if high_res)
|
| 144 |
+
base_w, base_h = (1200, 600) if high_res else (800, 420)
|
| 145 |
+
viewbox = f"0 0 {base_w} {base_h}"
|
| 146 |
+
|
| 147 |
+
# Choose layout: circular or horizontal
|
| 148 |
+
if layout == 'horizontal':
|
| 149 |
+
# Place cell on left, labels on right in a row
|
| 150 |
+
bx, by = int(base_w * 0.35), int(base_h * 0.5)
|
| 151 |
+
tx = int(base_w * 0.75)
|
| 152 |
+
else:
|
| 153 |
+
# circular: center cell and labels on right
|
| 154 |
+
bx, by = int(base_w * 0.35), int(base_h * 0.5)
|
| 155 |
+
tx = int(base_w * 0.75)
|
| 156 |
+
|
| 157 |
+
# Anchor points relative to center
|
| 158 |
targets = {
|
| 159 |
+
'sec': (bx, by - 180),
|
| 160 |
+
'om': (bx + 140, by - 120),
|
| 161 |
+
'peri': (bx + 120, by - 90),
|
| 162 |
+
'cw': (bx + 100, by - 70),
|
| 163 |
+
'im': (bx + 70, by - 50),
|
| 164 |
+
'cyto': (bx, by)
|
| 165 |
}
|
| 166 |
+
|
| 167 |
+
# label Y positions
|
| 168 |
text_y = {
|
| 169 |
+
'sec': int(base_h*0.12), 'om': int(base_h*0.22), 'peri': int(base_h*0.32),
|
| 170 |
+
'cw': int(base_h*0.42), 'im': int(base_h*0.62), 'cyto': int(base_h*0.78)
|
| 171 |
}
|
| 172 |
|
|
|
|
| 173 |
def draw_connector(key, style_tuple, label_text):
|
| 174 |
txt_col, weight, line_col, width, dot_col, r = style_tuple
|
| 175 |
tx_pos, ty_pos = tx, text_y[key]
|
| 176 |
ex, ey = targets[key]
|
| 177 |
+
c1x, c1y = tx_pos - int(base_w*0.08), ty_pos
|
| 178 |
+
c2x, c2y = ex + int(base_w*0.06), ey
|
|
|
|
|
|
|
|
|
|
| 179 |
path = f"M {tx_pos - 10} {ty_pos - 5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
|
| 180 |
+
return f'''
|
|
|
|
| 181 |
<g>
|
| 182 |
+
<text x="{tx_pos}" y="{ty_pos}" fill="{txt_col}" font-weight="{weight}" font-size="14" font-family="Inter, Arial">{label_text}</text>
|
| 183 |
+
<path d="{path}" fill="none" stroke="{line_col}" stroke-width="{width}" stroke-linecap="round" stroke-linejoin="round" />
|
| 184 |
<circle cx="{ex}" cy="{ey}" r="{r}" fill="{dot_col}" stroke="white" stroke-width="1" />
|
| 185 |
</g>
|
| 186 |
+
'''
|
| 187 |
+
|
| 188 |
+
# Draw cell shapes: use rounded rects for membranes (keeping original geometry scaled)
|
| 189 |
+
svg_shapes = f'''
|
| 190 |
+
<g transform="translate({bx}, {by})">
|
| 191 |
+
<rect x="{-150}" y="{-150}" width="300" height="300" rx="150" ry="150" fill="{om_f}" stroke="{om_s}" stroke-width="{om_w}" />
|
| 192 |
+
<rect x="{-110}" y="{-110}" width="220" height="220" rx="110" ry="110" fill="none" stroke="{cw_s}" stroke-width="{cw_w}" stroke-dasharray="{cw_d}" />
|
| 193 |
+
<rect x="{-70}" y="{-70}" width="140" height="140" rx="70" ry="70" fill="{im_f}" stroke="{im_s}" stroke-width="{im_w}" />
|
| 194 |
+
<g opacity="0.45">
|
| 195 |
+
<path d="M -30 -20 Q 0 -60 30 -20 T 60 -10" fill="none" stroke="var(--muted)" stroke-width="3" />
|
| 196 |
+
<circle cx="-40" cy="30" r="3" fill="var(--muted)" /> <circle cx="20" cy="40" r="3" fill="var(--muted)" />
|
| 197 |
</g>
|
| 198 |
+
</g>
|
| 199 |
+
'''
|
| 200 |
+
|
| 201 |
+
# Compose connectors
|
| 202 |
+
connectors = "".join([
|
| 203 |
+
draw_connector('sec', l_sec, 'Extracellular / Secreted'),
|
| 204 |
+
draw_connector('om', l_om, 'Outer Membrane'),
|
| 205 |
+
draw_connector('peri', l_peri, 'Periplasm'),
|
| 206 |
+
draw_connector('cw', l_cw, 'Cell Wall'),
|
| 207 |
+
draw_connector('im', l_im, 'Inner Membrane'),
|
| 208 |
+
draw_connector('cyto', l_cyto, 'Cytoplasm')
|
| 209 |
+
])
|
| 210 |
+
|
| 211 |
+
svg_core = f'''<svg id="svg_core" width="100%" height="auto" viewBox="{viewbox}" xmlns="http://www.w3.org/2000/svg" role="img" aria-label="Bacterial localization diagram">
|
| 212 |
+
<defs>
|
| 213 |
+
<style><![CDATA[
|
| 214 |
+
text {{ font-family: Inter, Arial, sans-serif; }}
|
| 215 |
+
]]></style>
|
| 216 |
+
</defs>
|
| 217 |
+
{svg_shapes}
|
| 218 |
+
{connectors}
|
| 219 |
+
</svg>'''
|
| 220 |
+
|
| 221 |
+
# Create unique wrapper id so multiple calls don't clash
|
| 222 |
+
uid = str(uuid.uuid4()).replace('-', '')
|
| 223 |
+
wrapper = f"loc_svg_{uid}"
|
| 224 |
+
|
| 225 |
+
# Build responsive HTML with inline JS to enable SVG/PNG download
|
| 226 |
+
html = f'''
|
| 227 |
+
<div id="{wrapper}" style="width:100%; text-align:center;">
|
| 228 |
+
<div style="display:inline-block; max-width:100%; width:900px;">
|
| 229 |
+
{svg_core}
|
| 230 |
+
<div style="margin-top:8px; display:flex; gap:8px; justify-content:center; align-items:center;">
|
| 231 |
+
<button id="btn_svg_{uid}" class="download-btn">Download SVG</button>
|
| 232 |
+
<button id="btn_png_{uid}" class="download-btn">Download PNG</button>
|
| 233 |
+
<div style="font-size:12px; color:var(--bg-text); align-self:center;">Layout: {layout.title()} {'· High-res' if high_res else ''}</div>
|
| 234 |
+
</div>
|
| 235 |
+
</div>
|
| 236 |
+
</div>
|
| 237 |
+
<script>
|
| 238 |
+
(function(){{
|
| 239 |
+
const wrapper = document.getElementById('{wrapper}');
|
| 240 |
+
const svgEl = wrapper.querySelector('svg');
|
| 241 |
+
const btnSvg = document.getElementById('btn_svg_{uid}');
|
| 242 |
+
const btnPng = document.getElementById('btn_png_{uid}');
|
| 243 |
+
|
| 244 |
+
// Helper: download file
|
| 245 |
+
function download(filename, blob){{
|
| 246 |
+
const url = URL.createObjectURL(blob);
|
| 247 |
+
const a = document.createElement('a');
|
| 248 |
+
a.href = url; a.download = filename; document.body.appendChild(a); a.click();
|
| 249 |
+
setTimeout(()=>{{ URL.revokeObjectURL(url); a.remove(); }}, 100);
|
| 250 |
+
}}
|
| 251 |
+
|
| 252 |
+
btnSvg.addEventListener('click', ()=>{{
|
| 253 |
+
const serializer = new XMLSerializer();
|
| 254 |
+
let source = serializer.serializeToString(svgEl);
|
| 255 |
+
// Add name spaces.
|
| 256 |
+
if(!source.match(/^<svg[^>]+xmlns="http\:\/\/www\.w3\.org\/2000\/svg"/)){{
|
| 257 |
+
source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
|
| 258 |
+
}}
|
| 259 |
+
if(!source.match(/^<svg[^>]+xmlns:xlink="http\:\/\/www\.w3\.org\/1999\/xlink"/)){{
|
| 260 |
+
source = source.replace(/^<svg/, '<svg xmlns:xlink="http://www.w3.org/1999/xlink"');
|
| 261 |
+
}}
|
| 262 |
+
const blob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
|
| 263 |
+
download('locpred_diagram.svg', blob);
|
| 264 |
+
}});
|
| 265 |
+
|
| 266 |
+
btnPng.addEventListener('click', ()=>{{
|
| 267 |
+
const serializer = new XMLSerializer();
|
| 268 |
+
let source = serializer.serializeToString(svgEl);
|
| 269 |
+
if(!source.match(/^<svg[^>]+xmlns="http\:\/\/www\.w3\.org\/2000\/svg"/)){{
|
| 270 |
+
source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
|
| 271 |
+
}}
|
| 272 |
+
const image = new Image();
|
| 273 |
+
const svgBlob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
|
| 274 |
+
const url = URL.createObjectURL(svgBlob);
|
| 275 |
+
image.onload = function(){{
|
| 276 |
+
const canvas = document.createElement('canvas');
|
| 277 |
+
// scale canvas to natural image size (use 2x for better quality)
|
| 278 |
+
canvas.width = image.width * 2;
|
| 279 |
+
canvas.height = image.height * 2;
|
| 280 |
+
const ctx = canvas.getContext('2d');
|
| 281 |
+
// set background transparent-friendly
|
| 282 |
+
ctx.fillStyle = 'white';
|
| 283 |
+
ctx.fillRect(0,0,canvas.width, canvas.height);
|
| 284 |
+
ctx.drawImage(image, 0, 0, canvas.width, canvas.height);
|
| 285 |
+
canvas.toBlob(function(blob){{
|
| 286 |
+
download('locpred_diagram.png', blob);
|
| 287 |
+
}}, 'image/png');
|
| 288 |
+
URL.revokeObjectURL(url);
|
| 289 |
+
}};
|
| 290 |
+
// In some environments the SVG does not have width/height set; set reasonable defaults
|
| 291 |
+
image.src = url;
|
| 292 |
+
}});
|
| 293 |
+
}})();
|
| 294 |
+
</script>
|
| 295 |
+
'''
|
| 296 |
+
return html
|
| 297 |
+
|
| 298 |
+
# ---------- Attention heatmap (unchanged) ----------
|
| 299 |
def draw_attention_heatmap_strip(weights, sequence):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
if weights.max() > 0:
|
| 301 |
weights = (weights - weights.min()) / (weights.max() - weights.min())
|
|
|
|
|
|
|
| 302 |
data = weights.reshape(1, -1)
|
| 303 |
+
fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150)
|
|
|
|
|
|
|
|
|
|
| 304 |
im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
|
| 305 |
+
ax.set_title('Sequence Attention Heatmap (High Color = Key Feature)', fontsize=10, fontweight='bold', color='#37474F', pad=10)
|
| 306 |
+
ax.set_xlabel('Residue Position', fontsize=9)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
ax.set_yticks([])
|
|
|
|
|
|
|
| 308 |
cbar = plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
|
| 309 |
cbar.ax.tick_params(labelsize=8)
|
| 310 |
cbar.outline.set_visible(False)
|
|
|
|
|
|
|
| 311 |
for spine in ax.spines.values():
|
| 312 |
spine.set_visible(False)
|
|
|
|
| 313 |
plt.tight_layout()
|
| 314 |
return fig
|
| 315 |
|
| 316 |
+
# ---------- Prediction logic (exposes layout + high_res options) ----------
|
| 317 |
+
def predict(sequence_input, layout_choice, high_res_flag):
|
| 318 |
+
if not sequence_input or sequence_input.isspace():
|
| 319 |
+
raise gr.Error('Empty Input')
|
|
|
|
|
|
|
| 320 |
seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
|
| 321 |
seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
|
| 322 |
+
if not seq:
|
| 323 |
+
raise gr.Error('Invalid Sequence')
|
| 324 |
+
|
| 325 |
with torch.no_grad():
|
| 326 |
+
inputs = tokenizer(seq, return_tensors='pt', truncation=True, max_length=1024).to(DEVICE)
|
| 327 |
outputs = plm_model(**inputs)
|
|
|
|
| 328 |
hidden_states = outputs.last_hidden_state
|
| 329 |
cls_embedding = hidden_states[:, 0, :]
|
| 330 |
token_embeddings = hidden_states[:, 1:-1, :]
|
| 331 |
token_mask = inputs['attention_mask'][:, 1:-1]
|
|
|
|
| 332 |
logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
|
| 333 |
probs = F.softmax(logits, dim=1)[0]
|
| 334 |
+
|
|
|
|
| 335 |
top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
|
| 336 |
confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
|
| 337 |
+
|
| 338 |
+
svg = generate_bacterial_svg(top_label, layout=layout_choice, high_res=high_res_flag)
|
|
|
|
|
|
|
|
|
|
| 339 |
w_np = pooling_weights[0].cpu().numpy()
|
| 340 |
heatmap_plot = draw_attention_heatmap_strip(w_np, seq)
|
| 341 |
+
|
| 342 |
return confidences, svg, heatmap_plot
|
| 343 |
|
| 344 |
+
# ---------- UI (Tailwind CDN + Auto dark mode via CSS variables) ----------
|
|
|
|
|
|
|
| 345 |
layout_css = """
|
| 346 |
+
/* Minimal overrides and CSS variables for dark/light */
|
| 347 |
+
:root{
|
| 348 |
+
--bg-fill-om: #F5F5F5; --bg-fill-im: #FAFAFA; --bg-stroke: #90A4AE; --muted: #B0BEC5;
|
| 349 |
+
--hl-stroke: #D32F2F; --hl-fill: #FFEBEE; --hl-text: #B71C1C; --hl-dot: #D32F2F;
|
| 350 |
+
--bg-text: #37474F; --bg-line: #CFD8DC; --bg-dot: #B0BEC5;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
}
|
| 352 |
+
@media (prefers-color-scheme: dark) {
|
| 353 |
+
:root{
|
| 354 |
+
--bg-fill-om: #263238; --bg-fill-im: #1E2930; --bg-stroke: #455A64; --muted: #37474F;
|
| 355 |
+
--hl-stroke: #FF8A80; --hl-fill: #3E2723; --hl-text: #FFCDD2; --hl-dot: #FF8A80;
|
| 356 |
+
--bg-text: #ECEFF1; --bg-line: #37474F; --bg-dot: #546E7A;
|
| 357 |
+
}
|
| 358 |
}
|
| 359 |
+
|
| 360 |
+
.download-btn{ padding:8px 12px; border-radius:6px; border:1px solid var(--bg-line); background:transparent; cursor:pointer; }
|
| 361 |
+
.download-btn:hover{ box-shadow:0 2px 8px rgba(0,0,0,0.08); }
|
| 362 |
+
|
| 363 |
+
/* Keep Gradio panels tidy */
|
| 364 |
+
.gradio-container{ max-width:1100px; margin:0 auto; }
|
| 365 |
"""
|
| 366 |
|
| 367 |
+
# Use Gradio theme but also inject Tailwind CDN for utility classes in HTML
|
| 368 |
theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
|
| 369 |
|
| 370 |
+
gr_tailwind = """
|
| 371 |
+
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
| 375 |
+
# Inject Tailwind (works in Gradio HTML scope)
|
| 376 |
+
gr.HTML(gr_tailwind)
|
| 377 |
+
|
| 378 |
gr.HTML("""
|
| 379 |
+
<div class="w-full p-4 rounded-lg" style="background:linear-gradient(to right,#E0F7FA,#E1F5FE); border:1px solid #B3E5FC; text-align:center;">
|
| 380 |
+
<h1 style="font-family:Inter, Arial; font-size:28px; margin:0; color:#0288D1;">LocPred-Prok</h1>
|
| 381 |
+
<div style="color:#0277BD; margin-top:6px;">Deep Learning Framework for Prokaryotic Subcellular Localization</div>
|
| 382 |
+
</div>
|
| 383 |
""")
|
| 384 |
|
|
|
|
| 385 |
with gr.Row():
|
| 386 |
+
with gr.Column(elem_classes="panel-card", scale=6):
|
| 387 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
|
| 388 |
+
sequence_input = gr.Textbox(lines=8, show_label=False, placeholder=">Sequence... (single-letter AA)")
|
| 389 |
with gr.Row():
|
| 390 |
clear_btn = gr.ClearButton(sequence_input, value="Clear")
|
| 391 |
submit_btn = gr.Button("Predict Analysis", variant="primary")
|
| 392 |
+
with gr.Row():
|
| 393 |
+
layout_choice = gr.Radio(['circular', 'horizontal'], value='circular', label='Diagram Layout', info='Choose circular (default) or horizontal layout for the cell diagram')
|
| 394 |
+
high_res_flag = gr.Checkbox(label='High resolution render (larger SVG)', value=False)
|
| 395 |
+
gr.Examples([[">Outer Membrane\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAF..." ]], inputs=sequence_input, label=None)
|
| 396 |
|
| 397 |
+
with gr.Column(elem_classes="panel-card", scale=6):
|
| 398 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
|
| 399 |
output_svg = gr.HTML(label="Visual", show_label=False)
|
| 400 |
|
|
|
|
| 401 |
with gr.Row():
|
| 402 |
+
with gr.Column(elem_classes="panel-card", scale=6):
|
| 403 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
|
| 404 |
output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
|
| 405 |
+
with gr.Column(elem_classes="panel-card", scale=6):
|
|
|
|
| 406 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Learned Attention Heatmap</div>")
|
| 407 |
output_plot = gr.Plot(label="Attention", show_label=False)
|
| 408 |
|
| 409 |
+
submit_btn.click(fn=predict, inputs=[sequence_input, layout_choice, high_res_flag], outputs=[output_label, output_svg, output_plot])
|
| 410 |
clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
|
| 411 |
|
| 412 |
+
# ---------- Launch ----------
|
| 413 |
+
if __name__ == '__main__':
|
| 414 |
+
app.launch()
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# --- END OF FILE ---
|