Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,9 +9,9 @@ import matplotlib.pyplot as plt
|
|
| 9 |
import numpy as np
|
| 10 |
from transformers import AutoTokenizer, AutoModel
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
# 0. 环境与缓存设置
|
| 14 |
-
#
|
| 15 |
# 强制使用非交互式后端,防止 matplotlib 在服务器报错
|
| 16 |
plt.switch_backend('Agg')
|
| 17 |
|
|
@@ -24,9 +24,9 @@ for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
|
|
| 24 |
shutil.rmtree(path, ignore_errors=True)
|
| 25 |
os.makedirs(path, exist_ok=True)
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
# 1. 模型架构定义 (
|
| 29 |
-
#
|
| 30 |
class AttentionPooling(nn.Module):
|
| 31 |
def __init__(self, d_model):
|
| 32 |
super().__init__()
|
|
@@ -81,9 +81,9 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
|
|
| 81 |
|
| 82 |
return self.classifier_head(z_fused_gated), pooling_weights
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
# 2. 加载模型与配置
|
| 86 |
-
#
|
| 87 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
|
| 89 |
CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
|
|
@@ -105,14 +105,13 @@ tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
|
|
| 105 |
plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
|
| 106 |
|
| 107 |
classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
|
| 108 |
-
# strict=False 允许加载即使权重文件中没有 pooling_weights 相关的特定状态(通常不影响)
|
| 109 |
classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
|
| 110 |
classifier.eval()
|
| 111 |
print("✅ Ready.")
|
| 112 |
|
| 113 |
-
#
|
| 114 |
-
# 3. Panel B: SVG 绘图引擎 (
|
| 115 |
-
#
|
| 116 |
def generate_bacterial_svg(target_class):
|
| 117 |
target = target_class.lower() if target_class else ""
|
| 118 |
|
|
@@ -133,28 +132,29 @@ def generate_bacterial_svg(target_class):
|
|
| 133 |
"bg_text": "#78909C", "bg_line": "#CFD8DC", "bg_dot": "#B0BEC5"
|
| 134 |
}
|
| 135 |
|
| 136 |
-
# 样式生成器
|
| 137 |
def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
|
| 138 |
-
if active:
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
|
| 141 |
om_f, om_s, om_w = style(is_peri, c["bg_fill_om"], c["hl_stroke"] if is_om else c["bg_stroke"])
|
| 142 |
cw_s = c["hl_stroke"] if is_cw else "#B0BEC5"
|
| 143 |
cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
|
| 144 |
im_f, im_s, im_w = style(is_cyto, c["bg_fill_im"], c["hl_stroke"] if is_im else c["bg_stroke"])
|
| 145 |
|
| 146 |
-
# 标签样式
|
| 147 |
def label_style(active):
|
| 148 |
if active: return c["hl_text"], "bold", c["hl_stroke"], "2.5", c["hl_dot"], "5"
|
| 149 |
return c["bg_text"], "normal", c["bg_line"], "1.5", c["bg_dot"], "3"
|
| 150 |
|
| 151 |
l_om, l_peri, l_cw, l_im, l_cyto = label_style(is_om), label_style(is_peri), label_style(is_cw), label_style(is_im), label_style(is_cyto)
|
| 152 |
|
| 153 |
-
#
|
| 154 |
bx, by = 280, 210 # 细菌中心
|
| 155 |
tx = 600 # 标签文字起始 X 坐标
|
| 156 |
|
| 157 |
-
# 目标锚点 (Target Anchor Points) - 精确落在结构上
|
| 158 |
targets = {
|
| 159 |
"om": (bx + 140, by - 120), # 外膜线
|
| 160 |
"peri": (bx + 120, by - 90), # 周质间隙
|
|
@@ -165,13 +165,13 @@ def generate_bacterial_svg(target_class):
|
|
| 165 |
|
| 166 |
text_y = {"om": 90, "peri": 150, "cw": 210, "im": 270, "cyto": 330}
|
| 167 |
|
| 168 |
-
#
|
| 169 |
def draw_connector(key, style_tuple, label_text):
|
| 170 |
txt_col, weight, line_col, width, dot_col, r = style_tuple
|
| 171 |
tx_pos, ty_pos = tx, text_y[key]
|
| 172 |
ex, ey = targets[key]
|
| 173 |
|
| 174 |
-
#
|
| 175 |
c1x, c1y = tx_pos - 100, ty_pos
|
| 176 |
c2x, c2y = ex + 50, ey
|
| 177 |
|
|
@@ -213,9 +213,9 @@ def generate_bacterial_svg(target_class):
|
|
| 213 |
</svg>"""
|
| 214 |
return svg
|
| 215 |
|
| 216 |
-
#
|
| 217 |
-
# 4. Panel D: Attention 绘图引擎
|
| 218 |
-
#
|
| 219 |
def draw_pooling_weights(weights, sequence):
|
| 220 |
"""
|
| 221 |
Visualize Attention Pooling Weights (1D Heatmap/Bar).
|
|
@@ -239,8 +239,8 @@ def draw_pooling_weights(weights, sequence):
|
|
| 239 |
ax.spines['left'].set_visible(False)
|
| 240 |
ax.set_yticks([])
|
| 241 |
|
| 242 |
-
# 标注最高峰 (
|
| 243 |
-
threshold = np.percentile(weights, 98)
|
| 244 |
if weights.max() > threshold:
|
| 245 |
peak_idx = np.argmax(weights)
|
| 246 |
ax.annotate('Key Motif', xy=(peak_idx, weights[peak_idx]), xytext=(peak_idx, weights[peak_idx]+0.2),
|
|
@@ -250,9 +250,9 @@ def draw_pooling_weights(weights, sequence):
|
|
| 250 |
plt.tight_layout()
|
| 251 |
return fig
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
# 5. 预测主逻辑
|
| 255 |
-
#
|
| 256 |
def predict(sequence_input):
|
| 257 |
if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
|
| 258 |
|
|
@@ -269,11 +269,11 @@ def predict(sequence_input):
|
|
| 269 |
token_embeddings = hidden_states[:, 1:-1, :] # No CLS/EOS
|
| 270 |
token_mask = inputs['attention_mask'][:, 1:-1]
|
| 271 |
|
| 272 |
-
# ⚠️ 获取 logits 和
|
| 273 |
logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
|
| 274 |
probs = F.softmax(logits, dim=1)[0]
|
| 275 |
|
| 276 |
-
# 1. 结果
|
| 277 |
top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
|
| 278 |
confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
|
| 279 |
|
|
@@ -287,9 +287,9 @@ def predict(sequence_input):
|
|
| 287 |
|
| 288 |
return confidences, svg, attn_plot
|
| 289 |
|
| 290 |
-
#
|
| 291 |
-
# 6. UI
|
| 292 |
-
#
|
| 293 |
layout_css = """
|
| 294 |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
|
| 295 |
body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
|
|
@@ -330,6 +330,7 @@ theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", bloc
|
|
| 330 |
|
| 331 |
with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
| 332 |
|
|
|
|
| 333 |
gr.HTML("""
|
| 334 |
<div class="header-div">
|
| 335 |
<div class="header-title">LocPred-Prok</div>
|
|
@@ -337,7 +338,7 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
|
| 337 |
</div>
|
| 338 |
""")
|
| 339 |
|
| 340 |
-
# Row 1: A & B
|
| 341 |
with gr.Row():
|
| 342 |
with gr.Column(elem_classes="panel-card"):
|
| 343 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
|
|
@@ -353,7 +354,7 @@ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
|
| 353 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
|
| 354 |
output_svg = gr.HTML(label="Visual", show_label=False)
|
| 355 |
|
| 356 |
-
# Row 2: C & D
|
| 357 |
with gr.Row():
|
| 358 |
with gr.Column(elem_classes="panel-card"):
|
| 359 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
from transformers import AutoTokenizer, AutoModel
|
| 11 |
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
# 0. 环境与缓存设置 (Environment Setup)
|
| 14 |
+
# ==============================================================================
|
| 15 |
# 强制使用非交互式后端,防止 matplotlib 在服务器报错
|
| 16 |
plt.switch_backend('Agg')
|
| 17 |
|
|
|
|
| 24 |
shutil.rmtree(path, ignore_errors=True)
|
| 25 |
os.makedirs(path, exist_ok=True)
|
| 26 |
|
| 27 |
+
# ==============================================================================
|
| 28 |
+
# 1. 模型架构定义 (Model Architecture)
|
| 29 |
+
# ==============================================================================
|
| 30 |
class AttentionPooling(nn.Module):
|
| 31 |
def __init__(self, d_model):
|
| 32 |
super().__init__()
|
|
|
|
| 81 |
|
| 82 |
return self.classifier_head(z_fused_gated), pooling_weights
|
| 83 |
|
| 84 |
+
# ==============================================================================
|
| 85 |
+
# 2. 加载模型与配置 (Load Resources)
|
| 86 |
+
# ==============================================================================
|
| 87 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
|
| 89 |
CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
|
|
|
|
| 105 |
plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
|
| 106 |
|
| 107 |
classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
|
|
|
|
| 108 |
classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
|
| 109 |
classifier.eval()
|
| 110 |
print("✅ Ready.")
|
| 111 |
|
| 112 |
+
# ==============================================================================
|
| 113 |
+
# 3. Panel B: SVG 绘图引擎 (Visualization Engine)
|
| 114 |
+
# ==============================================================================
|
| 115 |
def generate_bacterial_svg(target_class):
|
| 116 |
target = target_class.lower() if target_class else ""
|
| 117 |
|
|
|
|
| 132 |
"bg_text": "#78909C", "bg_line": "#CFD8DC", "bg_dot": "#B0BEC5"
|
| 133 |
}
|
| 134 |
|
| 135 |
+
# 3. 样式生成器 (这里修复了之前的 bug)
|
| 136 |
def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
|
| 137 |
+
if active:
|
| 138 |
+
return c["hl_fill"], c["hl_stroke"], w_act
|
| 139 |
+
# ✅ 修复点:这里原来写成了 width_norm,现已修正为 w_norm
|
| 140 |
+
return base_fill, base_stroke, w_norm
|
| 141 |
|
| 142 |
om_f, om_s, om_w = style(is_peri, c["bg_fill_om"], c["hl_stroke"] if is_om else c["bg_stroke"])
|
| 143 |
cw_s = c["hl_stroke"] if is_cw else "#B0BEC5"
|
| 144 |
cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
|
| 145 |
im_f, im_s, im_w = style(is_cyto, c["bg_fill_im"], c["hl_stroke"] if is_im else c["bg_stroke"])
|
| 146 |
|
| 147 |
+
# 标签样式
|
| 148 |
def label_style(active):
|
| 149 |
if active: return c["hl_text"], "bold", c["hl_stroke"], "2.5", c["hl_dot"], "5"
|
| 150 |
return c["bg_text"], "normal", c["bg_line"], "1.5", c["bg_dot"], "3"
|
| 151 |
|
| 152 |
l_om, l_peri, l_cw, l_im, l_cyto = label_style(is_om), label_style(is_peri), label_style(is_cw), label_style(is_im), label_style(is_cyto)
|
| 153 |
|
| 154 |
+
# 4. 坐标定义
|
| 155 |
bx, by = 280, 210 # 细菌中心
|
| 156 |
tx = 600 # 标签文字起始 X 坐标
|
| 157 |
|
|
|
|
| 158 |
targets = {
|
| 159 |
"om": (bx + 140, by - 120), # 外膜线
|
| 160 |
"peri": (bx + 120, by - 90), # 周质间隙
|
|
|
|
| 165 |
|
| 166 |
text_y = {"om": 90, "peri": 150, "cw": 210, "im": 270, "cyto": 330}
|
| 167 |
|
| 168 |
+
# 5. 贝塞尔曲线连接器
|
| 169 |
def draw_connector(key, style_tuple, label_text):
|
| 170 |
txt_col, weight, line_col, width, dot_col, r = style_tuple
|
| 171 |
tx_pos, ty_pos = tx, text_y[key]
|
| 172 |
ex, ey = targets[key]
|
| 173 |
|
| 174 |
+
# 贝塞尔控制点
|
| 175 |
c1x, c1y = tx_pos - 100, ty_pos
|
| 176 |
c2x, c2y = ex + 50, ey
|
| 177 |
|
|
|
|
| 213 |
</svg>"""
|
| 214 |
return svg
|
| 215 |
|
| 216 |
+
# ==============================================================================
|
| 217 |
+
# 4. Panel D: Attention 绘图引擎 (Interpretability)
|
| 218 |
+
# ==============================================================================
|
| 219 |
def draw_pooling_weights(weights, sequence):
|
| 220 |
"""
|
| 221 |
Visualize Attention Pooling Weights (1D Heatmap/Bar).
|
|
|
|
| 239 |
ax.spines['left'].set_visible(False)
|
| 240 |
ax.set_yticks([])
|
| 241 |
|
| 242 |
+
# 标注最高峰 (Key Motif)
|
| 243 |
+
threshold = np.percentile(weights, 98)
|
| 244 |
if weights.max() > threshold:
|
| 245 |
peak_idx = np.argmax(weights)
|
| 246 |
ax.annotate('Key Motif', xy=(peak_idx, weights[peak_idx]), xytext=(peak_idx, weights[peak_idx]+0.2),
|
|
|
|
| 250 |
plt.tight_layout()
|
| 251 |
return fig
|
| 252 |
|
| 253 |
+
# ==============================================================================
|
| 254 |
+
# 5. 预测主逻辑 (Prediction Logic)
|
| 255 |
+
# ==============================================================================
|
| 256 |
def predict(sequence_input):
|
| 257 |
if not sequence_input or sequence_input.isspace(): raise gr.Error("Empty Input")
|
| 258 |
|
|
|
|
| 269 |
token_embeddings = hidden_states[:, 1:-1, :] # No CLS/EOS
|
| 270 |
token_mask = inputs['attention_mask'][:, 1:-1]
|
| 271 |
|
| 272 |
+
# ⚠️ 获取 logits 和 pooling_weights
|
| 273 |
logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
|
| 274 |
probs = F.softmax(logits, dim=1)[0]
|
| 275 |
|
| 276 |
+
# 1. 结果 (Panel C)
|
| 277 |
top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
|
| 278 |
confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
|
| 279 |
|
|
|
|
| 287 |
|
| 288 |
return confidences, svg, attn_plot
|
| 289 |
|
| 290 |
+
# ==============================================================================
|
| 291 |
+
# 6. UI 布局 (Four-Block Paper Style)
|
| 292 |
+
# ==============================================================================
|
| 293 |
layout_css = """
|
| 294 |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
|
| 295 |
body { background-color: #ffffff; font-family: 'Inter', sans-serif; }
|
|
|
|
| 330 |
|
| 331 |
with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
|
| 332 |
|
| 333 |
+
# --- Header ---
|
| 334 |
gr.HTML("""
|
| 335 |
<div class="header-div">
|
| 336 |
<div class="header-title">LocPred-Prok</div>
|
|
|
|
| 338 |
</div>
|
| 339 |
""")
|
| 340 |
|
| 341 |
+
# --- Row 1: Panels A & B ---
|
| 342 |
with gr.Row():
|
| 343 |
with gr.Column(elem_classes="panel-card"):
|
| 344 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
|
|
|
|
| 354 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
|
| 355 |
output_svg = gr.HTML(label="Visual", show_label=False)
|
| 356 |
|
| 357 |
+
# --- Row 2: Panels C & D ---
|
| 358 |
with gr.Row():
|
| 359 |
with gr.Column(elem_classes="panel-card"):
|
| 360 |
gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
|