Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import os | |
| import re | |
| import torch | |
| import gradio as gr # 必须导入 gradio | |
| # --- 依赖导入 --- | |
| # 请确保这些文件存在,否则会报错 | |
| from model import CAFN | |
| from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix | |
| from Feature_extraction_algorithms.Physicochemical import PC_feature | |
| # 设置设备 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # --- 1. 模型初始化 --- | |
| model = None | |
| try: | |
| FR_MATRIX_PATH = 'Fr_train.mat' | |
| if not os.path.exists(FR_MATRIX_PATH): | |
| # 为了演示,这里改为警告而不是阻断,实际运行时请确保文件存在 | |
| print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH}") | |
| else: | |
| load_precomputed_fr_matrix(FR_MATRIX_PATH) | |
| # 假设这里还需要加载模型权重 | |
| # model = CAFN().to(device) | |
| # model.load_state_dict(torch.load('model_weights.pth')) | |
| # model.eval() | |
| except Exception as e: | |
| print(f"初始化过程中发生错误: {e}") | |
| # --- 3. 特征提取函数 (修复了你代码中缺失的函数定义头) --- | |
| def extract_features_from_seq(sequence_list): | |
| """ | |
| 这里是你缺失的特征提取函数逻辑。 | |
| 你需要把之前的 data, data2 生成逻辑放回来。 | |
| """ | |
| # 模拟数据用于演示,请替换为你真实的特征提取代码 | |
| # 假设 PSTAAP 和 PC_feature 返回 numpy 数组 | |
| # data = PC_feature(sequence_list) ... | |
| # data2 = PSTAAP_feature(sequence_list) ... | |
| # 这是一个占位符,保证代码不报错 | |
| # 实际维度需要根据你的模型输入调整 (Batch, Channel, Length) | |
| x1_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32) | |
| x2_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32) | |
| return x1_dummy, x2_dummy | |
| # --- 4. 核心预测函数 --- | |
| def predict_single_49mer(sequence_49mer): | |
| """ | |
| 对单个、长度为49的序列片段进行预测。 | |
| """ | |
| # 如果模型未加载,返回模拟数据(方便调试UI) | |
| if model is None: | |
| print("警告:模型未加载,返回随机结果用于演示。") | |
| labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"] | |
| return {label: float(np.random.rand()) for label in labels} | |
| sequence_list = [sequence_49mer] | |
| x1_np, x2_np = extract_features_from_seq(sequence_list) | |
| tensor_x1 = torch.tensor(x1_np).to(device) | |
| tensor_x2 = torch.tensor(x2_np).to(device) | |
| with torch.no_grad(): # 推理时不需要梯度 | |
| outputs = model(tensor_x1, tensor_x2) | |
| probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy() | |
| labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"] | |
| # 处理 batch size 为 1 的情况 | |
| if probabilities.ndim == 0: | |
| probabilities = [probabilities] | |
| result = {label: float(prob) for label, prob in zip(labels, probabilities)} | |
| return result | |
| # --- 5. FASTA格式解析与主处理流程 --- | |
| def parse_fasta(fasta_string): | |
| sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')] | |
| return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper() | |
| def process_fasta_and_predict(fasta_input): | |
| if not fasta_input or not isinstance(fasta_input, str): | |
| raise gr.Error("Please enter a valid FASTA format sequence.") | |
| sequence = parse_fasta(fasta_input) | |
| if len(sequence) < 49: | |
| raise gr.Error(f"Sequence too short! Needs at least 49 AA. Current: {len(sequence)}.") | |
| predictions_map = {} | |
| k_indices = [m.start() for m in re.finditer('K', sequence)] | |
| for k_index in k_indices: | |
| start, end = k_index - 24, k_index + 25 | |
| if start >= 0 and end <= len(sequence): | |
| fragment = sequence[start:end] | |
| prediction_result = predict_single_49mer(fragment) | |
| if prediction_result: | |
| predictions_map[k_index] = prediction_result | |
| if not predictions_map: | |
| return [(sequence, None)], {}, "No valid K sites (with enough context) found." | |
| highlight_data = [] | |
| last_pos = 0 | |
| sorted_predictable_indices = sorted(predictions_map.keys()) | |
| for k_index in sorted_predictable_indices: | |
| highlight_data.append((sequence[last_pos:k_index], None)) | |
| highlight_data.append(("K", str(k_index))) | |
| last_pos = k_index + 1 | |
| highlight_data.append((sequence[last_pos:], None)) | |
| initial_info = "Processing complete! Click on the highlighted 'K' site below." | |
| return highlight_data, predictions_map, initial_info | |
| # --- 6. Gradio事件处理函数 --- | |
| def show_results_for_site(evt: gr.SelectData, state_data): | |
| if evt.value: | |
| k_index_str = evt.value[1] # 获取索引 | |
| if k_index_str is None: # 如果点击了非高亮部分 | |
| return None, "Please click on a highlighted 'K'." | |
| k_index = int(k_index_str) | |
| result_dict = state_data.get(k_index) | |
| if result_dict: | |
| site_info = f"Prediction results for 'K' at position {k_index + 1}:" | |
| return result_dict, site_info | |
| return None, "Please click on the highlighted 'K' site." | |
| # --- 7. 创建并启动 Gradio 界面 --- | |
| # 修复点 1: 补全字符串 | |
| fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7 | |
| MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR | |
| IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG""" | |
| # 修复点 2: 使用 with gr.Blocks() 包裹界面代码,并修复缩进 | |
| with gr.Blocks(css=".predictable-k { color: red; font-weight: bold; }") as demo: | |
| gr.Markdown( | |
| """ | |
| # DeepKMulti Model: Multi-label Classifier for Lysine Modifications | |
| **Supports FASTA format input. Click on highlighted 'K' sites to view predictions.** | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| fasta_input = gr.Textbox( | |
| lines=10, | |
| label="Input FASTA Sequence", | |
| value=fasta_example, # 设置默认值方便测试 | |
| placeholder="Paste FASTA sequence here..." | |
| ) | |
| submit_btn = gr.Button("Submit Prediction", variant="primary") | |
| with gr.Column(scale=3): | |
| gr.Markdown("### Prediction Results") | |
| info_text = gr.Textbox(label="Status", interactive=False, value="Waiting for input...") | |
| predictions_state = gr.State({}) | |
| results_output = gr.Label(num_top_classes=4, label="Probabilities") | |
| gr.Markdown("---") | |
| gr.Markdown("### Visualized Sequence (Click 'K' to view details)") | |
| highlighted_output = gr.HighlightedText( | |
| label="Sequence Analysis", | |
| combine_adjacent=False, | |
| show_legend=False, | |
| color_map={"K": "red"} | |
| ) | |
| # 绑定事件 | |
| submit_btn.click( | |
| fn=process_fasta_and_predict, | |
| inputs=[fasta_input], | |
| outputs=[highlighted_output, predictions_state, info_text] | |
| ) | |
| highlighted_output.select( | |
| fn=show_results_for_site, | |
| inputs=[predictions_state], | |
| outputs=[results_output, info_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |