Xianfish9 commited on
Commit
6a13dea
·
verified ·
1 Parent(s): 9c08eb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -187
app.py CHANGED
@@ -1,246 +1,186 @@
 
1
  import numpy as np
2
  import os
3
  import re
4
- import pandas as pd
5
- import torch
6
- import gradio as gr
7
-
8
  # --- 依赖导入 ---
9
- # 请确保目录结构正确
10
- try:
11
- from model import CAFN
12
- from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
13
- from Feature_extraction_algorithms.Physicochemical import PC_feature
14
- except ImportError as e:
15
- print(f"警告:依赖导入失败,请检查文件路径。错误: {e}")
16
- # 设置占位符防止直接崩溃
17
- CAFN = None
18
- PSTAAP_feature = None
19
- PC_feature = None
20
- load_precomputed_fr_matrix = lambda x: None
21
-
22
- # --- 1. 初始化设置 ---
23
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
- FR_MATRIX_PATH = 'Fr_train.mat'
25
- MODEL_WEIGHTS_PATH = 'DeepKMulti.pth' # 请确保此文件存在
26
-
27
- # 初始化 PSTAAP
28
  try:
 
29
  if not os.path.exists(FR_MATRIX_PATH):
30
- print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH},如果是测试环境请忽略。")
31
- else:
32
- load_precomputed_fr_matrix(FR_MATRIX_PATH)
 
 
33
  except Exception as e:
34
- print(f"PSTAAP 初始化错误: {e}")
35
-
36
- # --- 2. 加载模型 ---
37
- model = None
38
- if CAFN is not None:
39
- try:
40
- # 这里需要根据实际参数实例化模型
41
- model = CAFN().to(device)
42
- if os.path.exists(MODEL_WEIGHTS_PATH):
43
- model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH, map_location=device))
44
- model.eval()
45
- print("模型加载成功!")
46
- else:
47
- print(f"警告: 权重文件 {MODEL_WEIGHTS_PATH} 不存在")
48
- except Exception as e:
49
- print(f"模型加载失败: {e}")
50
-
51
- # --- 3. 特征提取函数 ---
52
- def extract_features_from_seq(sequence_list):
53
  """
54
- 包装特征提取逻辑
 
55
  """
56
- if PSTAAP_feature is None:
57
- raise RuntimeError("特征提取模块未加载")
58
-
59
- # 模拟特征提取,请根据你实际的 Feature_extraction_algorithms 逻辑调整
60
- x1_features = PSTAAP_feature(sequence_list)
61
- x2_features = PC_feature(sequence_list)
 
 
 
 
 
 
62
 
63
- # 转换为 Numpy 数组
64
- x1_np = np.array(x1_features, dtype=np.float32)
65
- x2_np = np.array(x2_features, dtype=np.float32)
66
 
67
- return x1_np, x2_np
 
 
68
 
69
- # --- 4. 核心预测函数 ---
70
- def predict_single_49mer(sequence_49mer):
71
- if model is None:
72
- print("错误:模型未加载")
73
- return None
74
 
75
- try:
76
- sequence_list = [sequence_49mer]
77
- x1_np, x2_np = extract_features_from_seq(sequence_list)
78
-
79
- tensor_x1 = torch.tensor(x1_np).to(device)
80
- tensor_x2 = torch.tensor(x2_np).to(device)
81
-
82
- with torch.no_grad():
83
- outputs = model(tensor_x1, tensor_x2)
84
- probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
85
-
86
- # 处理 batch_size=1 的维度问题
87
- if probabilities.ndim == 0:
88
- probabilities = np.array([probabilities])
89
-
90
- labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
91
-
92
- # 保持原始 float,格式化留给前端展示函数
93
- result = {label: float(prob) for label, prob in zip(labels, probabilities)}
94
- return result
95
-
96
- except Exception as e:
97
- print(f"预测出错: {e}")
98
- return None
99
 
100
- # --- 5. FASTA 解析与处理 ---
101
  def parse_fasta(fasta_string):
 
 
102
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
 
103
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
104
 
105
  def process_fasta_and_predict(fasta_input):
 
 
 
 
106
  if not fasta_input or not isinstance(fasta_input, str):
107
  raise gr.Error("Please enter a valid FASTA format sequence.")
108
 
109
  sequence = parse_fasta(fasta_input)
110
 
111
  if len(sequence) < 49:
112
- raise gr.Error(f"Sequence too short (Length: {len(sequence)}). Minimum 49 AA required.")
113
 
 
114
  predictions_map = {}
 
 
115
  k_indices = [m.start() for m in re.finditer('K', sequence)]
116
 
117
  for k_index in k_indices:
 
118
  start, end = k_index - 24, k_index + 25
 
 
119
  if start >= 0 and end <= len(sequence):
120
  fragment = sequence[start:end]
121
- res = predict_single_49mer(fragment)
122
- if res:
123
- predictions_map[k_index] = res
 
124
 
125
  if not predictions_map:
126
- return [(sequence, None)], {}, "No valid K sites found."
 
127
 
128
- # 构建高亮数据
129
  highlight_data = []
130
  last_pos = 0
131
- sorted_indices = sorted(predictions_map.keys())
 
132
 
133
- for k_index in sorted_indices:
134
- if k_index > last_pos:
135
- highlight_data.append((sequence[last_pos:k_index], None))
136
-
137
  highlight_data.append(("K", str(k_index)))
138
  last_pos = k_index + 1
139
 
140
- if last_pos < len(sequence):
141
- highlight_data.append((sequence[last_pos:], None))
142
 
143
- return highlight_data, predictions_map, "Processing complete! Click on a red 'K' to see details."
144
 
145
- # --- 6. 结果展示函数 (这里控制小数位) ---
 
 
146
  def show_results_for_site(evt: gr.SelectData, state_data):
147
- # 处理选中事件
148
- selected_val = evt.value
149
- k_index_str = None
150
-
151
- # 兼容不同 Gradio 版本的返回值
152
- if isinstance(selected_val, (list, tuple)) and len(selected_val) == 2:
153
- if selected_val[0] == "K":
154
- k_index_str = selected_val[1]
155
- elif isinstance(selected_val, str):
156
- # 某些情况可能直接返回 label 字符串,需视具体版本而定
157
- # 这里主要依赖上方的高亮组件传回 index 字符串
158
- pass
159
-
160
- if k_index_str and state_data:
161
- try:
162
- k_index = int(k_index_str)
163
- result_dict = state_data.get(k_index)
164
-
165
- if result_dict:
166
- site_info = f"Prediction results for 'K' at position {k_index + 1}:"
167
-
168
- table_data = []
169
- for label, score in result_dict.items():
170
- # -----------------------------------------------------
171
- # 【核心修改】控制小数位数
172
- # 方式1:百分比 (推荐) -> "95.12%"
173
- val_str = f"{score:.0%}"
174
-
175
- # 方式2:保留4位小数 -> "0.9512"
176
- # val_str = f"{score:.4f}"
177
- # -----------------------------------------------------
178
-
179
- table_data.append([label, val_str])
180
-
181
- df_result = pd.DataFrame(table_data, columns=["Modification Type", "Probability"])
182
- return df_result, site_info
183
-
184
- except ValueError:
185
- pass
186
-
187
- return None, "Please click on a highlighted 'K' site."
188
-
189
- # --- 7. Gradio 界面 ---
190
- fasta_example_str = """>sp|P05141|ADT2_HUMAN Example
191
- MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
192
- IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
193
- """
194
 
195
- css = ".predictable-k { color: white; background-color: #d32f2f; font-weight: bold; }"
 
 
 
 
 
 
196
 
197
- with gr.Blocks(css=css, title="DeepKMulti") as demo:
198
- gr.Markdown("# DeepKMulti Prediction Tool")
199
-
 
 
 
 
 
 
 
 
 
 
 
 
200
  with gr.Row():
201
- with gr.Column(scale=2):
202
  fasta_input = gr.Textbox(
203
- lines=8,
204
- label="Input FASTA",
205
- value=fasta_example_str,
206
- placeholder="Paste sequence here..."
207
  )
208
- submit_btn = gr.Button("Submit", variant="primary")
209
 
210
  with gr.Column(scale=3):
211
- gr.Markdown("### Results")
212
- info_text = gr.Textbox(label="Status", value="Waiting...", interactive=False)
213
-
214
- # 隐藏的状态组件,用于存储数据
215
- predictions_state = gr.State({})
216
-
217
- # 使用 DataFrame 展示表格
218
- results_output = gr.DataFrame(
219
- headers=["Modification Type", "Probability"],
220
- datatype=["str", "str"], # 设置为 str 以保持百分比格式不被自动转回 float
221
- label="Site Probabilities",
222
- interactive=False
223
- )
224
 
225
- gr.Markdown("### Sequence Map")
 
 
 
226
  highlighted_output = gr.HighlightedText(
227
- label="Click 'K' to view",
228
- combine_adjacent=False,
229
- show_legend=False
230
- )
231
-
232
- # 事件绑定
233
- submit_btn.click(
234
- process_fasta_and_predict,
235
- inputs=[fasta_input],
236
- outputs=[highlighted_output, predictions_state, info_text]
237
  )
238
 
239
- highlighted_output.select(
240
- show_results_for_site,
241
- inputs=[predictions_state],
242
  outputs=[results_output, info_text]
243
  )
244
 
245
- if __name__ == "__main__":
246
- demo.launch()
 
1
+ 以下是我编写的app.py代码:
2
  import numpy as np
3
  import os
4
  import re
 
 
 
 
5
  # --- 依赖导入 ---
6
+
7
+ from model import CAFN
8
+ from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
9
+ from Feature_extraction_algorithms.Physicochemical import PC_feature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
+ FR_MATRIX_PATH = 'Fr_train.mat'
12
  if not os.path.exists(FR_MATRIX_PATH):
13
+ raise FileNotFoundError(f"PSTAAP初始化失败:找不到矩阵文件 {FR_MATRIX_PATH}")
14
+ load_precomputed_fr_matrix(FR_MATRIX_PATH)
15
+
16
+
17
+
18
  except Exception as e:
19
+ print(f"PSTAAP 初始化过程中发生严重错误: {e}")
20
+ model = None
21
+
22
+
23
+ # --- 3. 特征提取函数 (与之前相同) ---
24
+ data = np.hstack((data, feature))
25
+ return data.astype(np.float32), data2.astype(np.float32)
26
+
27
+ # --- 4. 核心预测函数 (重构为处理单个49-mer片段) ---
28
+ def predict_single_49mer(sequence_49mer):
 
 
 
 
 
 
 
 
 
29
  """
30
+ 对单个、长度为49的序列片段进行预测。
31
+ 这是底层的预测引擎。
32
  """
33
+ if model is None:
34
+ # 这个错误不应该在UI层面抛出,而是在后台日志中记录
35
+ print("错误:模型核心未加载。")
36
+ return None
37
+
38
+ sequence_list = [sequence_49mer]
39
+ x1_np, x2_np = extract_features_from_seq(sequence_list)
40
+
41
+
42
+
43
+
44
+
45
 
46
+ tensor_x1 = torch.tensor(x1_np).to(device)
47
+ tensor_x2 = torch.tensor(x2_np).to(device)
48
+ outputs = model(tensor_x1, tensor_x2)
49
 
50
+ probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
51
+ #Lysine-Acetylation(K-Ac)
52
+ labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
53
 
 
 
 
 
 
54
 
55
+ result = {label: float(prob) for label, prob in zip(labels, probabilities)}
56
+
57
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # --- 5. 新增:FASTA格式解析与主处理流程 ---
60
  def parse_fasta(fasta_string):
61
+ """从FASTA格式文本中提取序列。"""
62
+ # 移除FASTA头(以'>'开头的行)
63
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
64
+ # 连接所有行并移除任何空白字符
65
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
66
 
67
  def process_fasta_and_predict(fasta_input):
68
+ """
69
+ 接收FASTA输入,找到所有K位点,进行切片和预测,
70
+ 并返回用于Gradio HighlightedText组件的数据和一个包含预测结果的状态字典。
71
+ """
72
  if not fasta_input or not isinstance(fasta_input, str):
73
  raise gr.Error("Please enter a valid FASTA format sequence.")
74
 
75
  sequence = parse_fasta(fasta_input)
76
 
77
  if len(sequence) < 49:
78
+ raise gr.Error(f"The sequence is too short! It needs to be at least 49 amino acids. The current length is {len(sequence)}")
79
 
80
+ # 存储每个可预测K位点(索引)及其预测结果
81
  predictions_map = {}
82
+
83
+ # 寻找所有 'K' 的索引
84
  k_indices = [m.start() for m in re.finditer('K', sequence)]
85
 
86
  for k_index in k_indices:
87
+ # 尝试以K为中心截取片段 (K前24个, K, K后24个)
88
  start, end = k_index - 24, k_index + 25
89
+
90
+ # 边界检查,如果长度不足49则跳过
91
  if start >= 0 and end <= len(sequence):
92
  fragment = sequence[start:end]
93
+ prediction_result = predict_single_49mer(fragment)
94
+ if prediction_result:
95
+ # 使用K的原始索引作为键
96
+ predictions_map[k_index] = prediction_result
97
 
98
  if not predictions_map:
99
+ # 如果没有一个K位点可以被成功预测
100
+ return [(sequence, None)], {}, "No valid K sites were found in the sequence for prediction (i.e., there were not enough amino acids before and after K)."
101
 
102
+ # --- 构建Gradio HighlightedText的输入格式 ---
103
  highlight_data = []
104
  last_pos = 0
105
+ # 按索引排序,确保我们按顺序处理序列
106
+ sorted_predictable_indices = sorted(predictions_map.keys())
107
 
108
+ for k_index in sorted_predictable_indices:
109
+ # 添加K之前未高亮的部分
110
+ highlight_data.append((sequence[last_pos:k_index], None))
111
+ # 添加需要高亮的K,并用其索引作为标签
112
  highlight_data.append(("K", str(k_index)))
113
  last_pos = k_index + 1
114
 
115
+ # 添加最后一个K之后剩余的部分
116
+ highlight_data.append((sequence[last_pos:], None))
117
 
118
+ initial_info = "Processing complete! Click on the highlighted 'K' site in the sequence below to see its prediction."
119
 
120
+ return highlight_data, predictions_map, initial_info
121
+
122
+ # --- 6. 新增:Gradio事件处理函数 ---
123
  def show_results_for_site(evt: gr.SelectData, state_data):
124
+ """
125
+ 当用户点击高亮的K时,此函数被触发。
126
+ 它从state_data中查找并返回该位点的预测结果。
127
+ """
128
+ if evt.value:
129
+ # evt.value ('K', '索引字符串')
130
+ k_index_str = evt.value[1]
131
+ k_index = int(k_index_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+
134
+ # 从状态字典中获取结果
135
+ result_dict = state_data.get(k_index)
136
+
137
+ if result_dict:
138
+ site_info = f"Prediction results for the segment centered at 'K' at position {k_index + 1}:"
139
+ return result_dict, site_info
140
 
141
+
142
+ # 如果没有选择或出现错误
143
+ return None, "Please click on the highlighted 'K' site in the sequence above to view the results."
144
+
145
+
146
+ # --- 7. 创建并启动 Gradio 界面 (使用 gr.Blocks) ---
147
+ fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
148
+ MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
149
+ IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
150
+ gr.Markdown(
151
+ """
152
+ # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
153
+ **Supports FASTA format input, allowing interactive viewing of the modification possibilities of each lysine site in the protein sequence.**
154
+ """
155
+ )
156
  with gr.Row():
 
157
  fasta_input = gr.Textbox(
158
+ lines=10,
159
+ label="Input FASTA format protein sequence",
160
+ placeholder="Please paste your FASTA formatted sequence here (we provide an example sequence below)..."
 
161
  )
162
+ submit_btn = gr.Button("Submit Prediction", variant="primary")
163
 
164
  with gr.Column(scale=3):
165
+ gr.Markdown("### Prediction Results")
166
+ info_text = gr.Textbox(label="State", interactive=False, value="Waiting for input...")
167
+ # 用于存储所有位点的预测结果,对用户不可见
168
+ predictions_state = gr.State({})
169
+ results_output = gr.Label(num_top_classes=4, label="After clicking on the colored 'K' site, the results will be displayed here")
 
 
 
 
 
 
 
 
170
 
171
+
172
+ gr.Markdown("---")
173
+ gr.Markdown("### Visualized Sequence")
174
+ # 使用 a[class='predictable-k'] 来应用CSS
175
  highlighted_output = gr.HighlightedText(
176
+ label="Sequence Analysis",
177
+ color_map={"predictable-k": "red"}, # 旧版Gradio的用法
178
+ # 在新版Gradio中,CSS通过gr.Blocks的css参数全局定义更可靠
 
 
 
 
 
 
 
179
  )
180
 
181
+ gr.Examples(
 
 
182
  outputs=[results_output, info_text]
183
  )
184
 
185
+ # 启动应用
186
+ demo.launch(debug=True)