DeepKMulti / app.py
Xianfish9's picture
Update app.py
0bea2f5 verified
import numpy as np
import os
import re
import torch
import gradio as gr # 必须导入 gradio
# --- 依赖导入 ---
# 请确保这些文件存在,否则会报错
from model import CAFN
from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
from Feature_extraction_algorithms.Physicochemical import PC_feature
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- 1. 模型初始化 ---
model = None
try:
FR_MATRIX_PATH = 'Fr_train.mat'
if not os.path.exists(FR_MATRIX_PATH):
# 为了演示,这里改为警告而不是阻断,实际运行时请确保文件存在
print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH}")
else:
load_precomputed_fr_matrix(FR_MATRIX_PATH)
# 假设这里还需要加载模型权重
# model = CAFN().to(device)
# model.load_state_dict(torch.load('model_weights.pth'))
# model.eval()
except Exception as e:
print(f"初始化过程中发生错误: {e}")
# --- 3. 特征提取函数 (修复了你代码中缺失的函数定义头) ---
def extract_features_from_seq(sequence_list):
"""
这里是你缺失的特征提取函数逻辑。
你需要把之前的 data, data2 生成逻辑放回来。
"""
# 模拟数据用于演示,请替换为你真实的特征提取代码
# 假设 PSTAAP 和 PC_feature 返回 numpy 数组
# data = PC_feature(sequence_list) ...
# data2 = PSTAAP_feature(sequence_list) ...
# 这是一个占位符,保证代码不报错
# 实际维度需要根据你的模型输入调整 (Batch, Channel, Length)
x1_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32)
x2_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32)
return x1_dummy, x2_dummy
# --- 4. 核心预测函数 ---
def predict_single_49mer(sequence_49mer):
"""
对单个、长度为49的序列片段进行预测。
"""
# 如果模型未加载,返回模拟数据(方便调试UI)
if model is None:
print("警告:模型未加载,返回随机结果用于演示。")
labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
return {label: float(np.random.rand()) for label in labels}
sequence_list = [sequence_49mer]
x1_np, x2_np = extract_features_from_seq(sequence_list)
tensor_x1 = torch.tensor(x1_np).to(device)
tensor_x2 = torch.tensor(x2_np).to(device)
with torch.no_grad(): # 推理时不需要梯度
outputs = model(tensor_x1, tensor_x2)
probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
# 处理 batch size 为 1 的情况
if probabilities.ndim == 0:
probabilities = [probabilities]
result = {label: float(prob) for label, prob in zip(labels, probabilities)}
return result
# --- 5. FASTA格式解析与主处理流程 ---
def parse_fasta(fasta_string):
sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
def process_fasta_and_predict(fasta_input):
if not fasta_input or not isinstance(fasta_input, str):
raise gr.Error("Please enter a valid FASTA format sequence.")
sequence = parse_fasta(fasta_input)
if len(sequence) < 49:
raise gr.Error(f"Sequence too short! Needs at least 49 AA. Current: {len(sequence)}.")
predictions_map = {}
k_indices = [m.start() for m in re.finditer('K', sequence)]
for k_index in k_indices:
start, end = k_index - 24, k_index + 25
if start >= 0 and end <= len(sequence):
fragment = sequence[start:end]
prediction_result = predict_single_49mer(fragment)
if prediction_result:
predictions_map[k_index] = prediction_result
if not predictions_map:
return [(sequence, None)], {}, "No valid K sites (with enough context) found."
highlight_data = []
last_pos = 0
sorted_predictable_indices = sorted(predictions_map.keys())
for k_index in sorted_predictable_indices:
highlight_data.append((sequence[last_pos:k_index], None))
highlight_data.append(("K", str(k_index)))
last_pos = k_index + 1
highlight_data.append((sequence[last_pos:], None))
initial_info = "Processing complete! Click on the highlighted 'K' site below."
return highlight_data, predictions_map, initial_info
# --- 6. Gradio事件处理函数 ---
def show_results_for_site(evt: gr.SelectData, state_data):
if evt.value:
k_index_str = evt.value[1] # 获取索引
if k_index_str is None: # 如果点击了非高亮部分
return None, "Please click on a highlighted 'K'."
k_index = int(k_index_str)
result_dict = state_data.get(k_index)
if result_dict:
site_info = f"Prediction results for 'K' at position {k_index + 1}:"
return result_dict, site_info
return None, "Please click on the highlighted 'K' site."
# --- 7. 创建并启动 Gradio 界面 ---
# 修复点 1: 补全字符串
fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG"""
# 修复点 2: 使用 with gr.Blocks() 包裹界面代码,并修复缩进
with gr.Blocks(css=".predictable-k { color: red; font-weight: bold; }") as demo:
gr.Markdown(
"""
# DeepKMulti Model: Multi-label Classifier for Lysine Modifications
**Supports FASTA format input. Click on highlighted 'K' sites to view predictions.**
"""
)
with gr.Row():
with gr.Column(scale=2):
fasta_input = gr.Textbox(
lines=10,
label="Input FASTA Sequence",
value=fasta_example, # 设置默认值方便测试
placeholder="Paste FASTA sequence here..."
)
submit_btn = gr.Button("Submit Prediction", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### Prediction Results")
info_text = gr.Textbox(label="Status", interactive=False, value="Waiting for input...")
predictions_state = gr.State({})
results_output = gr.Label(num_top_classes=4, label="Probabilities")
gr.Markdown("---")
gr.Markdown("### Visualized Sequence (Click 'K' to view details)")
highlighted_output = gr.HighlightedText(
label="Sequence Analysis",
combine_adjacent=False,
show_legend=False,
color_map={"K": "red"}
)
# 绑定事件
submit_btn.click(
fn=process_fasta_and_predict,
inputs=[fasta_input],
outputs=[highlighted_output, predictions_state, info_text]
)
highlighted_output.select(
fn=show_results_for_site,
inputs=[predictions_state],
outputs=[results_output, info_text]
)
if __name__ == "__main__":
demo.launch(debug=True)