Spaces:
Sleeping
Sleeping
File size: 7,434 Bytes
ec1b79a f41061a 0bea2f5 fc4fd4d ec1b79a 0bea2f5 6a13dea 0bea2f5 ad61623 6a13dea ad61623 0bea2f5 6a13dea 2009bd2 0bea2f5 6a13dea 0bea2f5 6a13dea 0bea2f5 6a13dea f41061a 6a13dea f41061a 0bea2f5 6a13dea 0bea2f5 6a13dea 9c08eb6 6a13dea 9c08eb6 0bea2f5 6a13dea 0bea2f5 6a13dea ec1b79a 0bea2f5 f41061a f135a23 f41061a 93f4b80 f41061a 0bea2f5 f41061a 6a13dea f41061a 0bea2f5 f41061a 6a13dea 7a27d8c 6a13dea f41061a 6a13dea f41061a 0bea2f5 6a13dea 0bea2f5 f41061a 6a13dea 0bea2f5 6a13dea 0bea2f5 6a13dea fc4fd4d 0bea2f5 6a13dea 0bea2f5 6a13dea 0bea2f5 6a13dea 0bea2f5 6a13dea 0bea2f5 6a13dea 0bea2f5 f41061a 0bea2f5 f41061a 6a13dea 0bea2f5 f41061a 6a13dea 0bea2f5 f41061a 6a13dea 0bea2f5 6a13dea 0bea2f5 6a13dea 0bea2f5 f41061a 6a13dea 0bea2f5 f41061a 0bea2f5 f41061a ec1b79a 0bea2f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import numpy as np
import os
import re
import torch
import gradio as gr # 必须导入 gradio
# --- 依赖导入 ---
# 请确保这些文件存在,否则会报错
from model import CAFN
from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
from Feature_extraction_algorithms.Physicochemical import PC_feature
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- 1. 模型初始化 ---
model = None
try:
FR_MATRIX_PATH = 'Fr_train.mat'
if not os.path.exists(FR_MATRIX_PATH):
# 为了演示,这里改为警告而不是阻断,实际运行时请确保文件存在
print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH}")
else:
load_precomputed_fr_matrix(FR_MATRIX_PATH)
# 假设这里还需要加载模型权重
# model = CAFN().to(device)
# model.load_state_dict(torch.load('model_weights.pth'))
# model.eval()
except Exception as e:
print(f"初始化过程中发生错误: {e}")
# --- 3. 特征提取函数 (修复了你代码中缺失的函数定义头) ---
def extract_features_from_seq(sequence_list):
"""
这里是你缺失的特征提取函数逻辑。
你需要把之前的 data, data2 生成逻辑放回来。
"""
# 模拟数据用于演示,请替换为你真实的特征提取代码
# 假设 PSTAAP 和 PC_feature 返回 numpy 数组
# data = PC_feature(sequence_list) ...
# data2 = PSTAAP_feature(sequence_list) ...
# 这是一个占位符,保证代码不报错
# 实际维度需要根据你的模型输入调整 (Batch, Channel, Length)
x1_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32)
x2_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32)
return x1_dummy, x2_dummy
# --- 4. 核心预测函数 ---
def predict_single_49mer(sequence_49mer):
"""
对单个、长度为49的序列片段进行预测。
"""
# 如果模型未加载,返回模拟数据(方便调试UI)
if model is None:
print("警告:模型未加载,返回随机结果用于演示。")
labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
return {label: float(np.random.rand()) for label in labels}
sequence_list = [sequence_49mer]
x1_np, x2_np = extract_features_from_seq(sequence_list)
tensor_x1 = torch.tensor(x1_np).to(device)
tensor_x2 = torch.tensor(x2_np).to(device)
with torch.no_grad(): # 推理时不需要梯度
outputs = model(tensor_x1, tensor_x2)
probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
# 处理 batch size 为 1 的情况
if probabilities.ndim == 0:
probabilities = [probabilities]
result = {label: float(prob) for label, prob in zip(labels, probabilities)}
return result
# --- 5. FASTA格式解析与主处理流程 ---
def parse_fasta(fasta_string):
sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
def process_fasta_and_predict(fasta_input):
if not fasta_input or not isinstance(fasta_input, str):
raise gr.Error("Please enter a valid FASTA format sequence.")
sequence = parse_fasta(fasta_input)
if len(sequence) < 49:
raise gr.Error(f"Sequence too short! Needs at least 49 AA. Current: {len(sequence)}.")
predictions_map = {}
k_indices = [m.start() for m in re.finditer('K', sequence)]
for k_index in k_indices:
start, end = k_index - 24, k_index + 25
if start >= 0 and end <= len(sequence):
fragment = sequence[start:end]
prediction_result = predict_single_49mer(fragment)
if prediction_result:
predictions_map[k_index] = prediction_result
if not predictions_map:
return [(sequence, None)], {}, "No valid K sites (with enough context) found."
highlight_data = []
last_pos = 0
sorted_predictable_indices = sorted(predictions_map.keys())
for k_index in sorted_predictable_indices:
highlight_data.append((sequence[last_pos:k_index], None))
highlight_data.append(("K", str(k_index)))
last_pos = k_index + 1
highlight_data.append((sequence[last_pos:], None))
initial_info = "Processing complete! Click on the highlighted 'K' site below."
return highlight_data, predictions_map, initial_info
# --- 6. Gradio事件处理函数 ---
def show_results_for_site(evt: gr.SelectData, state_data):
if evt.value:
k_index_str = evt.value[1] # 获取索引
if k_index_str is None: # 如果点击了非高亮部分
return None, "Please click on a highlighted 'K'."
k_index = int(k_index_str)
result_dict = state_data.get(k_index)
if result_dict:
site_info = f"Prediction results for 'K' at position {k_index + 1}:"
return result_dict, site_info
return None, "Please click on the highlighted 'K' site."
# --- 7. 创建并启动 Gradio 界面 ---
# 修复点 1: 补全字符串
fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG"""
# 修复点 2: 使用 with gr.Blocks() 包裹界面代码,并修复缩进
with gr.Blocks(css=".predictable-k { color: red; font-weight: bold; }") as demo:
gr.Markdown(
"""
# DeepKMulti Model: Multi-label Classifier for Lysine Modifications
**Supports FASTA format input. Click on highlighted 'K' sites to view predictions.**
"""
)
with gr.Row():
with gr.Column(scale=2):
fasta_input = gr.Textbox(
lines=10,
label="Input FASTA Sequence",
value=fasta_example, # 设置默认值方便测试
placeholder="Paste FASTA sequence here..."
)
submit_btn = gr.Button("Submit Prediction", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### Prediction Results")
info_text = gr.Textbox(label="Status", interactive=False, value="Waiting for input...")
predictions_state = gr.State({})
results_output = gr.Label(num_top_classes=4, label="Probabilities")
gr.Markdown("---")
gr.Markdown("### Visualized Sequence (Click 'K' to view details)")
highlighted_output = gr.HighlightedText(
label="Sequence Analysis",
combine_adjacent=False,
show_legend=False,
color_map={"K": "red"}
)
# 绑定事件
submit_btn.click(
fn=process_fasta_and_predict,
inputs=[fasta_input],
outputs=[highlighted_output, predictions_state, info_text]
)
highlighted_output.select(
fn=show_results_for_site,
inputs=[predictions_state],
outputs=[results_output, info_text]
)
if __name__ == "__main__":
demo.launch(debug=True) |