File size: 7,434 Bytes
ec1b79a
 
f41061a
0bea2f5
 
fc4fd4d
ec1b79a
0bea2f5
6a13dea
 
 
0bea2f5
 
 
 
 
 
ad61623
6a13dea
ad61623
0bea2f5
 
 
 
 
 
 
 
 
6a13dea
2009bd2
0bea2f5
6a13dea
0bea2f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a13dea
0bea2f5
6a13dea
f41061a
6a13dea
f41061a
0bea2f5
6a13dea
0bea2f5
 
 
6a13dea
 
 
9c08eb6
6a13dea
 
9c08eb6
0bea2f5
 
 
 
6a13dea
 
0bea2f5
 
 
 
 
6a13dea
ec1b79a
0bea2f5
f41061a
 
 
 
 
 
f135a23
f41061a
 
93f4b80
f41061a
0bea2f5
f41061a
 
 
 
 
 
 
 
6a13dea
 
 
f41061a
 
0bea2f5
f41061a
 
 
6a13dea
7a27d8c
6a13dea
 
f41061a
 
 
6a13dea
f41061a
0bea2f5
6a13dea
 
0bea2f5
f41061a
6a13dea
0bea2f5
 
 
 
6a13dea
 
 
 
0bea2f5
6a13dea
fc4fd4d
0bea2f5
6a13dea
0bea2f5
6a13dea
0bea2f5
6a13dea
 
0bea2f5
 
 
 
6a13dea
 
 
0bea2f5
6a13dea
 
0bea2f5
f41061a
0bea2f5
f41061a
6a13dea
0bea2f5
 
 
f41061a
6a13dea
0bea2f5
f41061a
6a13dea
0bea2f5
6a13dea
0bea2f5
 
6a13dea
0bea2f5
 
f41061a
6a13dea
0bea2f5
 
 
 
 
 
 
 
 
 
f41061a
 
0bea2f5
 
 
f41061a
 
ec1b79a
0bea2f5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import numpy as np
import os
import re
import torch
import gradio as gr  # 必须导入 gradio

# --- 依赖导入 ---
# 请确保这些文件存在,否则会报错
from model import CAFN
from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
from Feature_extraction_algorithms.Physicochemical import PC_feature

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- 1. 模型初始化 ---
model = None
try:
    FR_MATRIX_PATH = 'Fr_train.mat' 
    if not os.path.exists(FR_MATRIX_PATH):
        # 为了演示,这里改为警告而不是阻断,实际运行时请确保文件存在
        print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH}")
    else:
        load_precomputed_fr_matrix(FR_MATRIX_PATH)
    
    # 假设这里还需要加载模型权重
    # model = CAFN().to(device)
    # model.load_state_dict(torch.load('model_weights.pth'))
    # model.eval()

except Exception as e:
    print(f"初始化过程中发生错误: {e}")

# --- 3. 特征提取函数 (修复了你代码中缺失的函数定义头) ---
def extract_features_from_seq(sequence_list):
    """
    这里是你缺失的特征提取函数逻辑。
    你需要把之前的 data, data2 生成逻辑放回来。
    """
    # 模拟数据用于演示,请替换为你真实的特征提取代码
    # 假设 PSTAAP 和 PC_feature 返回 numpy 数组
    # data = PC_feature(sequence_list) ...
    # data2 = PSTAAP_feature(sequence_list) ...
    
    # 这是一个占位符,保证代码不报错
    # 实际维度需要根据你的模型输入调整 (Batch, Channel, Length)
    x1_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32) 
    x2_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32)
    return x1_dummy, x2_dummy

# --- 4. 核心预测函数 ---
def predict_single_49mer(sequence_49mer):
    """
    对单个、长度为49的序列片段进行预测。
    """
    # 如果模型未加载,返回模拟数据(方便调试UI)
    if model is None:
        print("警告:模型未加载,返回随机结果用于演示。")
        labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
        return {label: float(np.random.rand()) for label in labels}

    sequence_list = [sequence_49mer]
    x1_np, x2_np = extract_features_from_seq(sequence_list)
    
    tensor_x1 = torch.tensor(x1_np).to(device)
    tensor_x2 = torch.tensor(x2_np).to(device)
    
    with torch.no_grad(): # 推理时不需要梯度
        outputs = model(tensor_x1, tensor_x2)
        probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
    
    labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
    
    # 处理 batch size 为 1 的情况
    if probabilities.ndim == 0:
        probabilities = [probabilities]
        
    result = {label: float(prob) for label, prob in zip(labels, probabilities)}
    return result

# --- 5. FASTA格式解析与主处理流程 ---
def parse_fasta(fasta_string):
    sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
    return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()

def process_fasta_and_predict(fasta_input):
    if not fasta_input or not isinstance(fasta_input, str):
        raise gr.Error("Please enter a valid FASTA format sequence.")

    sequence = parse_fasta(fasta_input)
    
    if len(sequence) < 49:
        raise gr.Error(f"Sequence too short! Needs at least 49 AA. Current: {len(sequence)}.")

    predictions_map = {}
    k_indices = [m.start() for m in re.finditer('K', sequence)]
    
    for k_index in k_indices:
        start, end = k_index - 24, k_index + 25
        if start >= 0 and end <= len(sequence):
            fragment = sequence[start:end]
            prediction_result = predict_single_49mer(fragment)
            if prediction_result:
                predictions_map[k_index] = prediction_result

    if not predictions_map:
        return [(sequence, None)], {}, "No valid K sites (with enough context) found."

    highlight_data = []
    last_pos = 0
    sorted_predictable_indices = sorted(predictions_map.keys())

    for k_index in sorted_predictable_indices:
        highlight_data.append((sequence[last_pos:k_index], None))
        highlight_data.append(("K", str(k_index)))
        last_pos = k_index + 1
    
    highlight_data.append((sequence[last_pos:], None))
    
    initial_info = "Processing complete! Click on the highlighted 'K' site below."
    return highlight_data, predictions_map, initial_info

# --- 6. Gradio事件处理函数 ---
def show_results_for_site(evt: gr.SelectData, state_data):
    if evt.value:
        k_index_str = evt.value[1] # 获取索引
        if k_index_str is None: # 如果点击了非高亮部分
             return None, "Please click on a highlighted 'K'."
             
        k_index = int(k_index_str)
        result_dict = state_data.get(k_index)
        
        if result_dict:
            site_info = f"Prediction results for 'K' at position {k_index + 1}:"
            return result_dict, site_info
            
    return None, "Please click on the highlighted 'K' site."

# --- 7. 创建并启动 Gradio 界面 ---

# 修复点 1: 补全字符串
fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG"""

# 修复点 2: 使用 with gr.Blocks() 包裹界面代码,并修复缩进
with gr.Blocks(css=".predictable-k { color: red; font-weight: bold; }") as demo:
    gr.Markdown(
        """
        # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
        **Supports FASTA format input. Click on highlighted 'K' sites to view predictions.**
        """
    )
    
    with gr.Row():
        with gr.Column(scale=2):
            fasta_input = gr.Textbox(
                lines=10, 
                label="Input FASTA Sequence",
                value=fasta_example, # 设置默认值方便测试
                placeholder="Paste FASTA sequence here..."
            )
            submit_btn = gr.Button("Submit Prediction", variant="primary")
            
        with gr.Column(scale=3):
            gr.Markdown("### Prediction Results")
            info_text = gr.Textbox(label="Status", interactive=False, value="Waiting for input...")
            predictions_state = gr.State({}) 
            results_output = gr.Label(num_top_classes=4, label="Probabilities")
            
    gr.Markdown("---")
    gr.Markdown("### Visualized Sequence (Click 'K' to view details)")
    
    highlighted_output = gr.HighlightedText(
        label="Sequence Analysis",
        combine_adjacent=False,
        show_legend=False,
        color_map={"K": "red"} 
    )
    
    # 绑定事件
    submit_btn.click(
        fn=process_fasta_and_predict,
        inputs=[fasta_input],
        outputs=[highlighted_output, predictions_state, info_text]
    )
    
    highlighted_output.select(
        fn=show_results_for_site,
        inputs=[predictions_state],
        outputs=[results_output, info_text]
    )

if __name__ == "__main__":
    demo.launch(debug=True)