import numpy as np import os import re import torch import gradio as gr # 必须导入 gradio # --- 依赖导入 --- # 请确保这些文件存在,否则会报错 from model import CAFN from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix from Feature_extraction_algorithms.Physicochemical import PC_feature # 设置设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --- 1. 模型初始化 --- model = None try: FR_MATRIX_PATH = 'Fr_train.mat' if not os.path.exists(FR_MATRIX_PATH): # 为了演示,这里改为警告而不是阻断,实际运行时请确保文件存在 print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH}") else: load_precomputed_fr_matrix(FR_MATRIX_PATH) # 假设这里还需要加载模型权重 # model = CAFN().to(device) # model.load_state_dict(torch.load('model_weights.pth')) # model.eval() except Exception as e: print(f"初始化过程中发生错误: {e}") # --- 3. 特征提取函数 (修复了你代码中缺失的函数定义头) --- def extract_features_from_seq(sequence_list): """ 这里是你缺失的特征提取函数逻辑。 你需要把之前的 data, data2 生成逻辑放回来。 """ # 模拟数据用于演示,请替换为你真实的特征提取代码 # 假设 PSTAAP 和 PC_feature 返回 numpy 数组 # data = PC_feature(sequence_list) ... # data2 = PSTAAP_feature(sequence_list) ... # 这是一个占位符,保证代码不报错 # 实际维度需要根据你的模型输入调整 (Batch, Channel, Length) x1_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32) x2_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32) return x1_dummy, x2_dummy # --- 4. 核心预测函数 --- def predict_single_49mer(sequence_49mer): """ 对单个、长度为49的序列片段进行预测。 """ # 如果模型未加载,返回模拟数据(方便调试UI) if model is None: print("警告:模型未加载,返回随机结果用于演示。") labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"] return {label: float(np.random.rand()) for label in labels} sequence_list = [sequence_49mer] x1_np, x2_np = extract_features_from_seq(sequence_list) tensor_x1 = torch.tensor(x1_np).to(device) tensor_x2 = torch.tensor(x2_np).to(device) with torch.no_grad(): # 推理时不需要梯度 outputs = model(tensor_x1, tensor_x2) probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy() labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"] # 处理 batch size 为 1 的情况 if probabilities.ndim == 0: probabilities = [probabilities] result = {label: float(prob) for label, prob in zip(labels, probabilities)} return result # --- 5. FASTA格式解析与主处理流程 --- def parse_fasta(fasta_string): sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')] return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper() def process_fasta_and_predict(fasta_input): if not fasta_input or not isinstance(fasta_input, str): raise gr.Error("Please enter a valid FASTA format sequence.") sequence = parse_fasta(fasta_input) if len(sequence) < 49: raise gr.Error(f"Sequence too short! Needs at least 49 AA. Current: {len(sequence)}.") predictions_map = {} k_indices = [m.start() for m in re.finditer('K', sequence)] for k_index in k_indices: start, end = k_index - 24, k_index + 25 if start >= 0 and end <= len(sequence): fragment = sequence[start:end] prediction_result = predict_single_49mer(fragment) if prediction_result: predictions_map[k_index] = prediction_result if not predictions_map: return [(sequence, None)], {}, "No valid K sites (with enough context) found." highlight_data = [] last_pos = 0 sorted_predictable_indices = sorted(predictions_map.keys()) for k_index in sorted_predictable_indices: highlight_data.append((sequence[last_pos:k_index], None)) highlight_data.append(("K", str(k_index))) last_pos = k_index + 1 highlight_data.append((sequence[last_pos:], None)) initial_info = "Processing complete! Click on the highlighted 'K' site below." return highlight_data, predictions_map, initial_info # --- 6. Gradio事件处理函数 --- def show_results_for_site(evt: gr.SelectData, state_data): if evt.value: k_index_str = evt.value[1] # 获取索引 if k_index_str is None: # 如果点击了非高亮部分 return None, "Please click on a highlighted 'K'." k_index = int(k_index_str) result_dict = state_data.get(k_index) if result_dict: site_info = f"Prediction results for 'K' at position {k_index + 1}:" return result_dict, site_info return None, "Please click on the highlighted 'K' site." # --- 7. 创建并启动 Gradio 界面 --- # 修复点 1: 补全字符串 fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7 MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG""" # 修复点 2: 使用 with gr.Blocks() 包裹界面代码,并修复缩进 with gr.Blocks(css=".predictable-k { color: red; font-weight: bold; }") as demo: gr.Markdown( """ # DeepKMulti Model: Multi-label Classifier for Lysine Modifications **Supports FASTA format input. Click on highlighted 'K' sites to view predictions.** """ ) with gr.Row(): with gr.Column(scale=2): fasta_input = gr.Textbox( lines=10, label="Input FASTA Sequence", value=fasta_example, # 设置默认值方便测试 placeholder="Paste FASTA sequence here..." ) submit_btn = gr.Button("Submit Prediction", variant="primary") with gr.Column(scale=3): gr.Markdown("### Prediction Results") info_text = gr.Textbox(label="Status", interactive=False, value="Waiting for input...") predictions_state = gr.State({}) results_output = gr.Label(num_top_classes=4, label="Probabilities") gr.Markdown("---") gr.Markdown("### Visualized Sequence (Click 'K' to view details)") highlighted_output = gr.HighlightedText( label="Sequence Analysis", combine_adjacent=False, show_legend=False, color_map={"K": "red"} ) # 绑定事件 submit_btn.click( fn=process_fasta_and_predict, inputs=[fasta_input], outputs=[highlighted_output, predictions_state, info_text] ) highlighted_output.select( fn=show_results_for_site, inputs=[predictions_state], outputs=[results_output, info_text] ) if __name__ == "__main__": demo.launch(debug=True)