Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -85,71 +85,103 @@ classifier.eval()
|
|
| 85 |
print("✅ Ready.")
|
| 86 |
|
| 87 |
# ==========================
|
| 88 |
-
# 3.
|
| 89 |
# ==========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def generate_scientific_svg(target_class):
|
| 91 |
-
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
# 颜色配置 (Nature Style)
|
| 102 |
c = {
|
| 103 |
'hl_stroke': '#D32F2F', 'hl_fill': '#FFEBEE', 'hl_text': '#B71C1C', 'hl_dot': '#D32F2F',
|
| 104 |
-
'bg_stroke': '#90A4AE', 'bg_fill': '#
|
| 105 |
'bg_text': '#78909C', 'bg_line': '#CFD8DC', 'bg_dot': '#B0BEC5'
|
| 106 |
}
|
| 107 |
|
| 108 |
-
# 几何参数
|
| 109 |
svg_id = f"svg_{str(uuid.uuid4())[:8]}"
|
| 110 |
-
cx, cy = 300, 210
|
| 111 |
-
tx = 620
|
| 112 |
|
| 113 |
-
#
|
| 114 |
shapes = ""
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
# DNA Decoration
|
| 135 |
shapes += f"""<g opacity="0.4">
|
| 136 |
<path d="M {cx-30} {cy-10} Q {cx} {cy-50} {cx+30} {cy-10} T {cx+60} {cy}" fill="none" stroke="#CFD8DC" stroke-width="3" />
|
| 137 |
<circle cx="{cx-40}" cy="{cy+20}" r="3" fill="#B0BEC5" /> <circle cx="{cx+20}" cy="{cy+30}" r="3" fill="#B0BEC5" />
|
| 138 |
</g>"""
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
"
|
| 146 |
-
"peri": (cx + 180, cy - 30), # 周质间隙
|
| 147 |
-
"cw": (cx + 170, cy), # 细胞壁
|
| 148 |
-
"im": (cx + 140, cy + 30), # 内膜边界
|
| 149 |
-
"cyto": (cx, cy) # 胞质中心
|
| 150 |
-
}
|
| 151 |
|
| 152 |
-
# 标签配置
|
| 153 |
labels_config = [
|
| 154 |
("Extracellular", "sec", is_sec),
|
| 155 |
("Outer Membrane", "om", is_om),
|
|
@@ -158,16 +190,17 @@ def generate_scientific_svg(target_class):
|
|
| 158 |
("Inner Membrane", "im", is_im),
|
| 159 |
("Cytoplasm", "cyto", is_cyto)
|
| 160 |
]
|
|
|
|
|
|
|
| 161 |
|
| 162 |
label_svg = ""
|
| 163 |
y_start = 50
|
| 164 |
-
y_step = 60
|
| 165 |
|
| 166 |
for i, (text, key, active) in enumerate(labels_config):
|
| 167 |
ty = y_start + i * y_step
|
| 168 |
ex, ey = anchors.get(key, (0,0))
|
| 169 |
|
| 170 |
-
# 样式
|
| 171 |
col_txt = c['hl_text'] if active else c['bg_text']
|
| 172 |
w_txt = "bold" if active else "normal"
|
| 173 |
col_line = c['hl_stroke'] if active else c['bg_line']
|
|
@@ -175,31 +208,36 @@ def generate_scientific_svg(target_class):
|
|
| 175 |
col_dot = c['hl_dot'] if active else c['bg_dot']
|
| 176 |
r_dot = "5" if active else "3"
|
| 177 |
|
| 178 |
-
# 贝塞尔 S 形曲线
|
| 179 |
-
# c1: 从文字左侧水平延伸; c2: 向锚点垂直延伸
|
| 180 |
c1x, c1y = tx - 80, ty
|
| 181 |
c2x, c2y = ex + 60, ey
|
| 182 |
path_d = f"M {tx-10} {ty-5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
|
| 183 |
|
| 184 |
label_svg += f"""
|
| 185 |
<g>
|
| 186 |
-
<text x="{tx}" y="{ty}" fill="{col_txt}" font-weight="{w_txt}" font-size="14" font-family="Arial">{text}</text>
|
| 187 |
<path d="{path_d}" fill="none" stroke="{col_line}" stroke-width="{w_line}" />
|
| 188 |
<circle cx="{ex}" cy="{ey}" r="{r_dot}" fill="{col_dot}" stroke="white" stroke-width="1" />
|
| 189 |
</g>
|
| 190 |
"""
|
| 191 |
|
| 192 |
final_svg = f"""<svg id="{svg_id}" width="100%" height="100%" viewBox="0 0 800 420" xmlns="http://www.w3.org/2000/svg">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
<rect width="800" height="420" fill="white" />
|
| 194 |
{shapes}
|
|
|
|
| 195 |
{label_svg}
|
| 196 |
-
<text x="400" y="400" text-anchor="middle" font-family="Arial" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {target_class}</text>
|
| 197 |
</svg>"""
|
| 198 |
|
| 199 |
-
# 嵌入 JS 下载
|
| 200 |
html = f"""<div>{final_svg}
|
| 201 |
<div style="display:flex; justify-content:center; gap:10px; margin-top:5px;">
|
| 202 |
-
<button onclick="downloadSVG('{svg_id}')" style="font-size:11px; padding:4px 8px; border:1px solid #ccc; border-radius:4px; cursor:pointer;">Download SVG</button>
|
| 203 |
</div>
|
| 204 |
<script>
|
| 205 |
function downloadSVG(id) {{
|
|
@@ -214,26 +252,24 @@ def generate_scientific_svg(target_class):
|
|
| 214 |
return html
|
| 215 |
|
| 216 |
# ==========================
|
| 217 |
-
# 4.
|
| 218 |
# ==========================
|
| 219 |
def draw_attention_heatmap_strip(weights, sequence):
|
| 220 |
-
# 归一化
|
| 221 |
if weights.max() > 0:
|
| 222 |
weights = (weights - weights.min()) / (weights.max() - weights.min())
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
|
|
|
| 226 |
|
| 227 |
-
|
|
|
|
| 228 |
im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
|
| 229 |
|
| 230 |
-
ax.set_title("Sequence Attention Heatmap (Darker = Higher Attention)", fontsize=10, fontweight='bold', color='#37474F', pad=
|
| 231 |
ax.set_xlabel("Residue Position", fontsize=9)
|
| 232 |
-
ax.set_yticks([])
|
| 233 |
-
|
| 234 |
-
# 隐藏四周边框
|
| 235 |
-
for spine in ax.spines.values(): spine.set_visible(False)
|
| 236 |
-
|
| 237 |
plt.tight_layout()
|
| 238 |
return fig
|
| 239 |
|
|
@@ -248,46 +284,45 @@ def predict(sequence_input):
|
|
| 248 |
with torch.no_grad():
|
| 249 |
inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
|
| 250 |
outputs = plm_model(**inputs)
|
| 251 |
-
|
| 252 |
-
logits, pooling_weights = classifier(
|
| 253 |
-
outputs.last_hidden_state[:, 0, :],
|
| 254 |
-
outputs.last_hidden_state[:, 1:-1, :],
|
| 255 |
-
inputs['attention_mask'][:, 1:-1]
|
| 256 |
-
)
|
| 257 |
probs = F.softmax(logits, dim=1)[0]
|
| 258 |
|
| 259 |
top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
|
| 260 |
confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
|
| 261 |
|
| 262 |
-
# Panel B: SVG
|
| 263 |
svg = generate_scientific_svg(top_label)
|
| 264 |
-
|
| 265 |
-
# Panel D: Heatmap (纯净版)
|
| 266 |
heatmap = draw_attention_heatmap_strip(pooling_weights[0].cpu().numpy(), seq)
|
| 267 |
|
| 268 |
return confidences, svg, heatmap
|
| 269 |
|
| 270 |
# ==========================
|
| 271 |
-
# 6. UI Layout (
|
| 272 |
# ==========================
|
| 273 |
layout_css = """
|
| 274 |
-
@import url('https://fonts.googleapis.com/css2?family=
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
.header-div { background: linear-gradient(to right, #E0F7FA, #E1F5FE); padding: 1.5rem; border-radius: 8px; margin-bottom: 20px; text-align: center; border: 1px solid #B3E5FC; }
|
| 277 |
-
.header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; }
|
| 278 |
.header-sub { font-size: 1.0rem; color: #0277BD; }
|
| 279 |
.panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
|
| 280 |
.panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
|
| 281 |
.panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
|
| 282 |
"""
|
| 283 |
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
| 287 |
-
|
| 288 |
gr.HTML("""<div class="header-div"><div class="header-title">LocPred-Prok</div><div class="header-sub">Deep Learning Framework for Prokaryotic Subcellular Localization</div></div>""")
|
| 289 |
|
| 290 |
-
# Row 1: A & B
|
| 291 |
with gr.Row():
|
| 292 |
with gr.Column(elem_classes="panel-card"):
|
| 293 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
|
|
@@ -301,7 +336,6 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
|
| 301 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
|
| 302 |
output_svg = gr.HTML(label="Visual", show_label=False)
|
| 303 |
|
| 304 |
-
# Row 2: C & D
|
| 305 |
with gr.Row():
|
| 306 |
with gr.Column(elem_classes="panel-card"):
|
| 307 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
|
|
|
|
| 85 |
print("✅ Ready.")
|
| 86 |
|
| 87 |
# ==========================
|
| 88 |
+
# 3. 标签映射与 SVG 引擎 (UniProt Font)
|
| 89 |
# ==========================
|
| 90 |
+
|
| 91 |
+
def map_model_label_to_svg_key(model_label):
|
| 92 |
+
l = model_label.strip()
|
| 93 |
+
mapping = {
|
| 94 |
+
"Extracellular": "sec",
|
| 95 |
+
"OuterMembrane": "om",
|
| 96 |
+
"Periplasmic": "peri",
|
| 97 |
+
"Cellwall": "cw",
|
| 98 |
+
"CYtoplasmicMembrane": "im",
|
| 99 |
+
"Cytoplasmic": "cyto"
|
| 100 |
+
}
|
| 101 |
+
if l in mapping: return mapping[l]
|
| 102 |
+
l_lower = l.lower()
|
| 103 |
+
for k, v in mapping.items():
|
| 104 |
+
if k.lower() == l_lower: return v
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
def infer_gram_type(model_label):
|
| 108 |
+
key = map_model_label_to_svg_key(model_label)
|
| 109 |
+
if key in ["om", "peri"]: return "negative"
|
| 110 |
+
if key == "cw": return "positive"
|
| 111 |
+
return "negative"
|
| 112 |
+
|
| 113 |
def generate_scientific_svg(target_class):
|
| 114 |
+
active_key = map_model_label_to_svg_key(target_class)
|
| 115 |
+
gram_type = infer_gram_type(target_class)
|
| 116 |
|
| 117 |
+
is_sec = (active_key == "sec")
|
| 118 |
+
is_om = (active_key == "om")
|
| 119 |
+
is_peri = (active_key == "peri")
|
| 120 |
+
is_cw = (active_key == "cw")
|
| 121 |
+
is_im = (active_key == "im")
|
| 122 |
+
is_cyto = (active_key == "cyto")
|
| 123 |
+
|
|
|
|
|
|
|
| 124 |
c = {
|
| 125 |
'hl_stroke': '#D32F2F', 'hl_fill': '#FFEBEE', 'hl_text': '#B71C1C', 'hl_dot': '#D32F2F',
|
| 126 |
+
'bg_stroke': '#90A4AE', 'bg_fill': '#FAFAFA',
|
| 127 |
'bg_text': '#78909C', 'bg_line': '#CFD8DC', 'bg_dot': '#B0BEC5'
|
| 128 |
}
|
| 129 |
|
|
|
|
| 130 |
svg_id = f"svg_{str(uuid.uuid4())[:8]}"
|
| 131 |
+
cx, cy = 300, 210
|
| 132 |
+
tx = 620
|
| 133 |
|
| 134 |
+
# 绘制主体形状
|
| 135 |
shapes = ""
|
| 136 |
+
if gram_type == 'negative':
|
| 137 |
+
col_om = c['hl_stroke'] if is_om else c['bg_stroke']
|
| 138 |
+
fill_om = c['hl_fill'] if is_peri else c['bg_fill']
|
| 139 |
+
w_om = "4" if is_om else "2"
|
| 140 |
+
shapes += f'<rect x="{cx-200}" y="{cy-120}" width="400" height="240" rx="120" ry="120" fill="{fill_om}" stroke="{col_om}" stroke-width="{w_om}" />'
|
| 141 |
+
|
| 142 |
+
col_cw = c['hl_stroke'] if is_cw else '#B0BEC5'
|
| 143 |
+
w_cw = "3" if is_cw else "1.5"
|
| 144 |
+
dash_cw = "0" if is_cw else "6,4"
|
| 145 |
+
shapes += f'<rect x="{cx-170}" y="{cy-90}" width="340" height="180" rx="90" ry="90" fill="none" stroke="{col_cw}" stroke-width="{w_cw}" stroke-dasharray="{dash_cw}" />'
|
| 146 |
+
|
| 147 |
+
col_im = c['hl_stroke'] if is_im else c['bg_stroke']
|
| 148 |
+
fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
|
| 149 |
+
w_im = "4" if is_im else "2"
|
| 150 |
+
shapes += f'<rect x="{cx-140}" y="{cy-60}" width="280" height="120" rx="60" ry="60" fill="{fill_im}" stroke="{col_im}" stroke-width="{w_im}" />'
|
| 151 |
+
|
| 152 |
+
anchors = {
|
| 153 |
+
"sec": (cx, cy-160), "om": (cx+200, cy-60), "peri": (cx+180, cy-30),
|
| 154 |
+
"cw": (cx+170, cy), "im": (cx+140, cy+30), "cyto": (cx, cy)
|
| 155 |
+
}
|
| 156 |
+
else:
|
| 157 |
+
col_cw = c['hl_stroke'] if is_cw else c['bg_stroke']
|
| 158 |
+
fill_bg = c['hl_fill'] if is_peri else c['bg_fill']
|
| 159 |
+
w_cw = "6" if is_cw else "4"
|
| 160 |
+
shapes += f'<rect x="{cx-180}" y="{cy-100}" width="360" height="200" rx="100" ry="100" fill="{fill_bg}" stroke="{col_cw}" stroke-width="{w_cw}" stroke-opacity="0.7" />'
|
| 161 |
+
|
| 162 |
+
col_im = c['hl_stroke'] if is_im else c['bg_stroke']
|
| 163 |
+
fill_im = c['hl_fill'] if is_cyto else c['bg_fill']
|
| 164 |
+
w_im = "4" if is_im else "2"
|
| 165 |
+
shapes += f'<rect x="{cx-140}" y="{cy-60}" width="280" height="120" rx="60" ry="60" fill="{fill_im}" stroke="{col_im}" stroke-width="{w_im}" />'
|
| 166 |
+
|
| 167 |
+
anchors = {
|
| 168 |
+
"sec": (cx, cy-140), "om": (cx, cy),
|
| 169 |
+
"peri": (cx+160, cy-40), "cw": (cx+180, cy-60),
|
| 170 |
+
"im": (cx+140, cy+30), "cyto": (cx, cy)
|
| 171 |
+
}
|
| 172 |
|
|
|
|
| 173 |
shapes += f"""<g opacity="0.4">
|
| 174 |
<path d="M {cx-30} {cy-10} Q {cx} {cy-50} {cx+30} {cy-10} T {cx+60} {cy}" fill="none" stroke="#CFD8DC" stroke-width="3" />
|
| 175 |
<circle cx="{cx-40}" cy="{cy+20}" r="3" fill="#B0BEC5" /> <circle cx="{cx+20}" cy="{cy+30}" r="3" fill="#B0BEC5" />
|
| 176 |
</g>"""
|
| 177 |
|
| 178 |
+
sec_svg = ""
|
| 179 |
+
if is_sec:
|
| 180 |
+
sec_svg = f"""<g transform="translate({cx}, {cy-170})">
|
| 181 |
+
<text x="0" y="0" text-anchor="middle" fill="{c['hl_text']}" font-weight="bold" font-family="'Lato', 'Helvetica Neue', Arial, sans-serif" font-size="14">SECRETED</text>
|
| 182 |
+
<path d="M 0 5 L 0 30" stroke="{c['hl_stroke']}" stroke-width="2" marker-end="url(#arrow_hl)" />
|
| 183 |
+
</g>"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
|
|
|
| 185 |
labels_config = [
|
| 186 |
("Extracellular", "sec", is_sec),
|
| 187 |
("Outer Membrane", "om", is_om),
|
|
|
|
| 190 |
("Inner Membrane", "im", is_im),
|
| 191 |
("Cytoplasm", "cyto", is_cyto)
|
| 192 |
]
|
| 193 |
+
if gram_type == 'positive':
|
| 194 |
+
labels_config = [l for l in labels_config if l[1] != 'om']
|
| 195 |
|
| 196 |
label_svg = ""
|
| 197 |
y_start = 50
|
| 198 |
+
y_step = 60
|
| 199 |
|
| 200 |
for i, (text, key, active) in enumerate(labels_config):
|
| 201 |
ty = y_start + i * y_step
|
| 202 |
ex, ey = anchors.get(key, (0,0))
|
| 203 |
|
|
|
|
| 204 |
col_txt = c['hl_text'] if active else c['bg_text']
|
| 205 |
w_txt = "bold" if active else "normal"
|
| 206 |
col_line = c['hl_stroke'] if active else c['bg_line']
|
|
|
|
| 208 |
col_dot = c['hl_dot'] if active else c['bg_dot']
|
| 209 |
r_dot = "5" if active else "3"
|
| 210 |
|
|
|
|
|
|
|
| 211 |
c1x, c1y = tx - 80, ty
|
| 212 |
c2x, c2y = ex + 60, ey
|
| 213 |
path_d = f"M {tx-10} {ty-5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
|
| 214 |
|
| 215 |
label_svg += f"""
|
| 216 |
<g>
|
| 217 |
+
<text x="{tx}" y="{ty}" fill="{col_txt}" font-weight="{w_txt}" font-size="14" font-family="'Lato', 'Helvetica Neue', Arial, sans-serif">{text}</text>
|
| 218 |
<path d="{path_d}" fill="none" stroke="{col_line}" stroke-width="{w_line}" />
|
| 219 |
<circle cx="{ex}" cy="{ey}" r="{r_dot}" fill="{col_dot}" stroke="white" stroke-width="1" />
|
| 220 |
</g>
|
| 221 |
"""
|
| 222 |
|
| 223 |
final_svg = f"""<svg id="{svg_id}" width="100%" height="100%" viewBox="0 0 800 420" xmlns="http://www.w3.org/2000/svg">
|
| 224 |
+
<defs>
|
| 225 |
+
<style>
|
| 226 |
+
@import url('https://fonts.googleapis.com/css2?family=Lato:wght@400;700&display=swap');
|
| 227 |
+
text {{ font-family: 'Lato', 'Helvetica Neue', Arial, sans-serif; }}
|
| 228 |
+
</style>
|
| 229 |
+
<marker id="arrow_hl" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"><polygon points="0 0, 10 3.5, 0 7" fill="{c['hl_stroke']}" /></marker>
|
| 230 |
+
</defs>
|
| 231 |
<rect width="800" height="420" fill="white" />
|
| 232 |
{shapes}
|
| 233 |
+
{sec_svg}
|
| 234 |
{label_svg}
|
| 235 |
+
<text x="400" y="400" text-anchor="middle" font-family="'Lato', 'Helvetica Neue', Arial, sans-serif" font-size="16" fill="#546E7A" font-weight="bold">Prediction: {target_class}</text>
|
| 236 |
</svg>"""
|
| 237 |
|
|
|
|
| 238 |
html = f"""<div>{final_svg}
|
| 239 |
<div style="display:flex; justify-content:center; gap:10px; margin-top:5px;">
|
| 240 |
+
<button onclick="downloadSVG('{svg_id}')" style="font-size:11px; padding:4px 8px; border:1px solid #ccc; border-radius:4px; cursor:pointer; font-family:'Lato', sans-serif;">Download SVG</button>
|
| 241 |
</div>
|
| 242 |
<script>
|
| 243 |
function downloadSVG(id) {{
|
|
|
|
| 252 |
return html
|
| 253 |
|
| 254 |
# ==========================
|
| 255 |
+
# 4. Attention Heatmap
|
| 256 |
# ==========================
|
| 257 |
def draw_attention_heatmap_strip(weights, sequence):
|
|
|
|
| 258 |
if weights.max() > 0:
|
| 259 |
weights = (weights - weights.min()) / (weights.max() - weights.min())
|
| 260 |
|
| 261 |
+
# 字体设置 (Matplotlib)
|
| 262 |
+
plt.rcParams['font.family'] = 'sans-serif'
|
| 263 |
+
plt.rcParams['font.sans-serif'] = ['Lato', 'Helvetica Neue', 'Arial']
|
| 264 |
|
| 265 |
+
fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150)
|
| 266 |
+
data = weights.reshape(1, -1)
|
| 267 |
im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
|
| 268 |
|
| 269 |
+
ax.set_title("Sequence Attention Heatmap (Darker = Higher Attention)", fontsize=10, fontweight='bold', color='#37474F', pad=8)
|
| 270 |
ax.set_xlabel("Residue Position", fontsize=9)
|
| 271 |
+
ax.set_yticks([])
|
| 272 |
+
for s in ax.spines.values(): s.set_visible(False)
|
|
|
|
|
|
|
|
|
|
| 273 |
plt.tight_layout()
|
| 274 |
return fig
|
| 275 |
|
|
|
|
| 284 |
with torch.no_grad():
|
| 285 |
inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
|
| 286 |
outputs = plm_model(**inputs)
|
| 287 |
+
logits, pooling_weights = classifier(outputs.last_hidden_state[:, 0, :], outputs.last_hidden_state[:, 1:-1, :], inputs['attention_mask'][:, 1:-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
probs = F.softmax(logits, dim=1)[0]
|
| 289 |
|
| 290 |
top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
|
| 291 |
confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
|
| 292 |
|
|
|
|
| 293 |
svg = generate_scientific_svg(top_label)
|
|
|
|
|
|
|
| 294 |
heatmap = draw_attention_heatmap_strip(pooling_weights[0].cpu().numpy(), seq)
|
| 295 |
|
| 296 |
return confidences, svg, heatmap
|
| 297 |
|
| 298 |
# ==========================
|
| 299 |
+
# 6. UI Layout (UniProt Style Fonts)
|
| 300 |
# ==========================
|
| 301 |
layout_css = """
|
| 302 |
+
@import url('https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap');
|
| 303 |
+
|
| 304 |
+
/* 全局字体设置为 Lato */
|
| 305 |
+
body, button, input, textarea, .gradio-container {
|
| 306 |
+
font-family: 'Lato', 'Helvetica Neue', Arial, sans-serif !important;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
.header-div { background: linear-gradient(to right, #E0F7FA, #E1F5FE); padding: 1.5rem; border-radius: 8px; margin-bottom: 20px; text-align: center; border: 1px solid #B3E5FC; }
|
| 310 |
+
.header-title { font-size: 2.2rem; font-weight: 800; color: #0288D1; margin-bottom: 5px; letter-spacing: -0.5px; }
|
| 311 |
.header-sub { font-size: 1.0rem; color: #0277BD; }
|
| 312 |
.panel-card { border: 1px solid #e2e8f0; border-radius: 8px; padding: 15px; background: white; height: 100%; display: flex; flex-direction: column; }
|
| 313 |
.panel-header { font-weight: 700; color: #475569; border-bottom: 2px solid #f1f5f9; padding-bottom: 8px; margin-bottom: 12px; font-size: 1.0rem; }
|
| 314 |
.panel-label { display: inline-block; background: #E0F7FA; color: #0277BD; border: 1px solid #B2EBF2; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; margin-right: 8px; font-weight: 800; }
|
| 315 |
"""
|
| 316 |
|
| 317 |
+
# 设置 Gradio 主题字体为 Lato
|
| 318 |
+
theme = gr.themes.Soft(
|
| 319 |
+
primary_hue="sky",
|
| 320 |
+
font=[gr.themes.GoogleFont("Lato"), "Helvetica Neue", "Arial", "sans-serif"]
|
| 321 |
+
).set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
|
| 322 |
|
| 323 |
with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
|
|
|
| 324 |
gr.HTML("""<div class="header-div"><div class="header-title">LocPred-Prok</div><div class="header-sub">Deep Learning Framework for Prokaryotic Subcellular Localization</div></div>""")
|
| 325 |
|
|
|
|
| 326 |
with gr.Row():
|
| 327 |
with gr.Column(elem_classes="panel-card"):
|
| 328 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
|
|
|
|
| 336 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
|
| 337 |
output_svg = gr.HTML(label="Visual", show_label=False)
|
| 338 |
|
|
|
|
| 339 |
with gr.Row():
|
| 340 |
with gr.Column(elem_classes="panel-card"):
|
| 341 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
|