Xianfish9 commited on
Commit
9c08eb6
·
verified ·
1 Parent(s): 7a27d8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -128
app.py CHANGED
@@ -1,206 +1,246 @@
1
  import numpy as np
2
  import os
3
  import re
4
- import pandas as pd # --- 新增:引入 pandas 用于处理表格数据 ---
 
 
5
 
6
  # --- 依赖导入 ---
7
- # 请确保 model.py, Feature_extraction_algorithms 文件夹在同一目录下
8
- from model import CAFN
9
- from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
10
- from Feature_extraction_algorithms.Physicochemical import PC_feature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  try:
12
- FR_MATRIX_PATH = 'Fr_train.mat'
13
  if not os.path.exists(FR_MATRIX_PATH):
14
- # 如果是本地运行且文件确实存在,请忽略此模拟错误;
15
- # 这里为了防止代码报错,如果文件不存在可以仅打印警告
16
  print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH},如果是测试环境请忽略。")
17
  else:
18
  load_precomputed_fr_matrix(FR_MATRIX_PATH)
19
  except Exception as e:
20
- print(f"PSTAAP 初始化过程中发生严重错误: {e}")
21
- # model = None # 暂时注释掉,以免本地测试时因为缺文件直接无法运行
22
-
23
 
24
- # --- 3. 特征提取函数 (与之前相同) ---
25
- data = np.hstack((data, feature))
26
- return data.astype(np.float32), data2.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # --- 4. 核心预测函数 (保持不变,依旧返回浮点数) ---
29
- def predict_single_49mer(sequence_49mer):
30
  """
31
- 对单个、长度为49的序列片段进行预测。
32
-
33
  """
34
- if model is None:
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- print("错误:模型核心未加载。")
 
 
 
37
  return None
38
 
39
- sequence_list = [sequence_49mer]
40
- # 注意:如果缺少依赖文件,这里可能会报错,请确保环境完整
41
  try:
 
42
  x1_np, x2_np = extract_features_from_seq(sequence_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  except Exception as e:
44
- print(f"特征提取失败: {e}")
45
  return None
46
-
47
- tensor_x1 = torch.tensor(x1_np).to(device)
48
- tensor_x2 = torch.tensor(x2_np).to(device)
49
- outputs = model(tensor_x1, tensor_x2)
50
-
51
- probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
52
-
53
- labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
54
-
55
- # 这里保持返回原始 float 数据,方便后续处理
56
- result = {label: float(prob) for label, prob in zip(labels, probabilities)}
57
-
58
- return result
59
 
60
- # --- 5. FASTA格式解析与主处理流程 (与之前相同) ---
61
  def parse_fasta(fasta_string):
62
-
63
-
64
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
65
-
66
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
67
 
68
  def process_fasta_and_predict(fasta_input):
69
-
70
-
71
-
72
-
73
  if not fasta_input or not isinstance(fasta_input, str):
74
  raise gr.Error("Please enter a valid FASTA format sequence.")
75
 
76
  sequence = parse_fasta(fasta_input)
77
 
78
  if len(sequence) < 49:
79
- raise gr.Error(f"The sequence is too short! It needs to be at least 49 amino acids. The current length is {len(sequence)}.")
80
-
81
 
82
  predictions_map = {}
83
-
84
-
85
  k_indices = [m.start() for m in re.finditer('K', sequence)]
86
 
87
  for k_index in k_indices:
88
-
89
  start, end = k_index - 24, k_index + 25
90
-
91
-
92
  if start >= 0 and end <= len(sequence):
93
  fragment = sequence[start:end]
94
- prediction_result = predict_single_49mer(fragment)
95
- if prediction_result:
96
-
97
- predictions_map[k_index] = prediction_result
98
 
99
  if not predictions_map:
100
- return [(sequence, None)], {}, "No valid K sites were found in the sequence for prediction."
101
-
102
-
103
 
 
104
  highlight_data = []
105
  last_pos = 0
 
106
 
107
- sorted_predictable_indices = sorted(predictions_map.keys())
108
-
109
- for k_index in sorted_predictable_indices:
110
-
111
- highlight_data.append((sequence[last_pos:k_index], None))
112
-
113
  highlight_data.append(("K", str(k_index)))
114
  last_pos = k_index + 1
115
 
116
-
117
- highlight_data.append((sequence[last_pos:], None))
118
 
119
- initial_info = "Processing complete! Click on the highlighted 'K' site in the sequence below to see its prediction."
120
-
121
- return highlight_data, predictions_map, initial_info
122
 
123
- # --- 6. 修改重点:Gradio事件处理函数 ---
124
  def show_results_for_site(evt: gr.SelectData, state_data):
125
- """
126
- 当用户点击高亮的K时触发。
127
- 此处我们将结果格式化为DataFrame,并精确控制百分比格式。
128
- """
129
- if evt.value:
130
-
131
- k_index_str = evt.value[1]
 
 
 
 
 
 
 
132
  try:
133
  k_index = int(k_index_str)
134
- except ValueError:
135
- return None, "Invalid selection."
136
-
137
-
138
- result_dict = state_data.get(k_index)
139
-
140
- if result_dict:
141
- site_info = f"Prediction results for 'K' at position {k_index + 1}:"
142
-
143
- # --- 修改开始:构建详细的表格数据 ---
144
- table_data = []
145
- for label, score in result_dict.items():
146
- # 使用 f-string 的 :.2% 语法,将 0.9299 转换为 92.99%
147
- percentage_str = f"{score:.0%}"
148
- table_data.append([label, percentage_str])
149
-
150
- # 创建 Pandas DataFrame
151
- df_result = pd.DataFrame(table_data, columns=["Modification Type", "Probability"])
152
- # --- 修改结束 ---
153
-
154
- return df_result, site_info
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- return None, "Please click on the highlighted 'K' site in the sequence above to view the results."
158
-
159
 
160
- # --- 7. 创建并启动 Gradio 界面 ---
161
- fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
162
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
163
  IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
164
- gr.Markdown(
165
- """
166
- # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
167
- **Supports FASTA format input, allowing interactive viewing of the modification possibilities of each lysine site.**
168
- """
169
- )
 
170
  with gr.Row():
 
171
  fasta_input = gr.Textbox(
172
- lines=10,
173
- label="Input FASTA format protein sequence",
174
- placeholder="Paste your FASTA sequence here..."
 
175
  )
176
- submit_btn = gr.Button("Submit Prediction", variant="primary")
177
 
178
  with gr.Column(scale=3):
179
- gr.Markdown("### Prediction Results")
180
- info_text = gr.Textbox(label="State", interactive=False, value="Waiting for input...")
181
-
182
- predictions_state = gr.State({})
 
183
 
184
- # --- 修改重点:将 gr.Label 替换为 gr.DataFrame ---
185
  results_output = gr.DataFrame(
186
  headers=["Modification Type", "Probability"],
187
- datatype=["str", "str"],
188
- label="Detailed Probabilities",
189
  interactive=False
190
  )
191
- # ------------------------------------------------
192
-
193
- gr.Markdown("---")
194
- gr.Markdown("### Visualized Sequence")
195
 
 
196
  highlighted_output = gr.HighlightedText(
197
- label="Sequence Analysis",
198
- color_map={"predictable-k": "red"},
199
-
 
 
 
 
 
 
 
200
  )
201
 
202
- gr.Examples(
 
 
203
  outputs=[results_output, info_text]
204
  )
205
 
206
- demo.launch()
 
 
1
  import numpy as np
2
  import os
3
  import re
4
+ import pandas as pd
5
+ import torch
6
+ import gradio as gr
7
 
8
  # --- 依赖导入 ---
9
+ # 请确保目录结构正确
10
+ try:
11
+ from model import CAFN
12
+ from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
13
+ from Feature_extraction_algorithms.Physicochemical import PC_feature
14
+ except ImportError as e:
15
+ print(f"警告:依赖导入失败,请检查文件路径。错误: {e}")
16
+ # 设置占位符防止直接崩溃
17
+ CAFN = None
18
+ PSTAAP_feature = None
19
+ PC_feature = None
20
+ load_precomputed_fr_matrix = lambda x: None
21
+
22
+ # --- 1. 初始化设置 ---
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ FR_MATRIX_PATH = 'Fr_train.mat'
25
+ MODEL_WEIGHTS_PATH = 'DeepKMulti.pth' # 请确保此文件存在
26
+
27
+ # 初始化 PSTAAP
28
  try:
 
29
  if not os.path.exists(FR_MATRIX_PATH):
 
 
30
  print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH},如果是测试环境请忽略。")
31
  else:
32
  load_precomputed_fr_matrix(FR_MATRIX_PATH)
33
  except Exception as e:
34
+ print(f"PSTAAP 初始化错误: {e}")
 
 
35
 
36
+ # --- 2. 加载模型 ---
37
+ model = None
38
+ if CAFN is not None:
39
+ try:
40
+ # 这里需要根据实际参数实例化模型
41
+ model = CAFN().to(device)
42
+ if os.path.exists(MODEL_WEIGHTS_PATH):
43
+ model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH, map_location=device))
44
+ model.eval()
45
+ print("模型加载成功!")
46
+ else:
47
+ print(f"警告: 权重文件 {MODEL_WEIGHTS_PATH} 不存在")
48
+ except Exception as e:
49
+ print(f"模型加载失败: {e}")
50
 
51
+ # --- 3. 特征提取函数 ---
52
+ def extract_features_from_seq(sequence_list):
53
  """
54
+ 包装特征提取逻辑
 
55
  """
56
+ if PSTAAP_feature is None:
57
+ raise RuntimeError("特征提取模块未加载")
58
+
59
+ # 模拟特征提取,请根据你实际的 Feature_extraction_algorithms 逻辑调整
60
+ x1_features = PSTAAP_feature(sequence_list)
61
+ x2_features = PC_feature(sequence_list)
62
+
63
+ # 转换为 Numpy 数组
64
+ x1_np = np.array(x1_features, dtype=np.float32)
65
+ x2_np = np.array(x2_features, dtype=np.float32)
66
+
67
+ return x1_np, x2_np
68
 
69
+ # --- 4. 核心预测函数 ---
70
+ def predict_single_49mer(sequence_49mer):
71
+ if model is None:
72
+ print("错误:模型未加载")
73
  return None
74
 
 
 
75
  try:
76
+ sequence_list = [sequence_49mer]
77
  x1_np, x2_np = extract_features_from_seq(sequence_list)
78
+
79
+ tensor_x1 = torch.tensor(x1_np).to(device)
80
+ tensor_x2 = torch.tensor(x2_np).to(device)
81
+
82
+ with torch.no_grad():
83
+ outputs = model(tensor_x1, tensor_x2)
84
+ probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
85
+
86
+ # 处理 batch_size=1 的维度问题
87
+ if probabilities.ndim == 0:
88
+ probabilities = np.array([probabilities])
89
+
90
+ labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
91
+
92
+ # 保持原始 float,格式化留给前端展示函数
93
+ result = {label: float(prob) for label, prob in zip(labels, probabilities)}
94
+ return result
95
+
96
  except Exception as e:
97
+ print(f"预测出错: {e}")
98
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # --- 5. FASTA 解析与处理 ---
101
  def parse_fasta(fasta_string):
 
 
102
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
 
103
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
104
 
105
  def process_fasta_and_predict(fasta_input):
 
 
 
 
106
  if not fasta_input or not isinstance(fasta_input, str):
107
  raise gr.Error("Please enter a valid FASTA format sequence.")
108
 
109
  sequence = parse_fasta(fasta_input)
110
 
111
  if len(sequence) < 49:
112
+ raise gr.Error(f"Sequence too short (Length: {len(sequence)}). Minimum 49 AA required.")
 
113
 
114
  predictions_map = {}
 
 
115
  k_indices = [m.start() for m in re.finditer('K', sequence)]
116
 
117
  for k_index in k_indices:
 
118
  start, end = k_index - 24, k_index + 25
 
 
119
  if start >= 0 and end <= len(sequence):
120
  fragment = sequence[start:end]
121
+ res = predict_single_49mer(fragment)
122
+ if res:
123
+ predictions_map[k_index] = res
 
124
 
125
  if not predictions_map:
126
+ return [(sequence, None)], {}, "No valid K sites found."
 
 
127
 
128
+ # 构建高亮数据
129
  highlight_data = []
130
  last_pos = 0
131
+ sorted_indices = sorted(predictions_map.keys())
132
 
133
+ for k_index in sorted_indices:
134
+ if k_index > last_pos:
135
+ highlight_data.append((sequence[last_pos:k_index], None))
136
+
 
 
137
  highlight_data.append(("K", str(k_index)))
138
  last_pos = k_index + 1
139
 
140
+ if last_pos < len(sequence):
141
+ highlight_data.append((sequence[last_pos:], None))
142
 
143
+ return highlight_data, predictions_map, "Processing complete! Click on a red 'K' to see details."
 
 
144
 
145
+ # --- 6. 结果展示函数 (这里控制小数位) ---
146
  def show_results_for_site(evt: gr.SelectData, state_data):
147
+ # 处理选中事件
148
+ selected_val = evt.value
149
+ k_index_str = None
150
+
151
+ # 兼容不同 Gradio 版本的返回值
152
+ if isinstance(selected_val, (list, tuple)) and len(selected_val) == 2:
153
+ if selected_val[0] == "K":
154
+ k_index_str = selected_val[1]
155
+ elif isinstance(selected_val, str):
156
+ # 某些情况可能直接返回 label 字符串,需视具体版本而定
157
+ # 这里主要依赖上方的高亮组件传回 index 字符串
158
+ pass
159
+
160
+ if k_index_str and state_data:
161
  try:
162
  k_index = int(k_index_str)
163
+ result_dict = state_data.get(k_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ if result_dict:
166
+ site_info = f"Prediction results for 'K' at position {k_index + 1}:"
167
+
168
+ table_data = []
169
+ for label, score in result_dict.items():
170
+ # -----------------------------------------------------
171
+ # 【核心修改】控制小数位数
172
+ # 方式1:百分比 (推荐) -> "95.12%"
173
+ val_str = f"{score:.0%}"
174
+
175
+ # 方式2:保留4位小数 -> "0.9512"
176
+ # val_str = f"{score:.4f}"
177
+ # -----------------------------------------------------
178
+
179
+ table_data.append([label, val_str])
180
+
181
+ df_result = pd.DataFrame(table_data, columns=["Modification Type", "Probability"])
182
+ return df_result, site_info
183
+
184
+ except ValueError:
185
+ pass
186
 
187
+ return None, "Please click on a highlighted 'K' site."
 
188
 
189
+ # --- 7. Gradio 界面 ---
190
+ fasta_example_str = """>sp|P05141|ADT2_HUMAN Example
191
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
192
  IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
193
+ """
194
+
195
+ css = ".predictable-k { color: white; background-color: #d32f2f; font-weight: bold; }"
196
+
197
+ with gr.Blocks(css=css, title="DeepKMulti") as demo:
198
+ gr.Markdown("# DeepKMulti Prediction Tool")
199
+
200
  with gr.Row():
201
+ with gr.Column(scale=2):
202
  fasta_input = gr.Textbox(
203
+ lines=8,
204
+ label="Input FASTA",
205
+ value=fasta_example_str,
206
+ placeholder="Paste sequence here..."
207
  )
208
+ submit_btn = gr.Button("Submit", variant="primary")
209
 
210
  with gr.Column(scale=3):
211
+ gr.Markdown("### Results")
212
+ info_text = gr.Textbox(label="Status", value="Waiting...", interactive=False)
213
+
214
+ # 隐藏的状态组件,用于存储数据
215
+ predictions_state = gr.State({})
216
 
217
+ # 使用 DataFrame 展示表格
218
  results_output = gr.DataFrame(
219
  headers=["Modification Type", "Probability"],
220
+ datatype=["str", "str"], # 设置为 str 以保持百分比格式不被自动转回 float
221
+ label="Site Probabilities",
222
  interactive=False
223
  )
 
 
 
 
224
 
225
+ gr.Markdown("### Sequence Map")
226
  highlighted_output = gr.HighlightedText(
227
+ label="Click 'K' to view",
228
+ combine_adjacent=False,
229
+ show_legend=False
230
+ )
231
+
232
+ # 事件绑定
233
+ submit_btn.click(
234
+ process_fasta_and_predict,
235
+ inputs=[fasta_input],
236
+ outputs=[highlighted_output, predictions_state, info_text]
237
  )
238
 
239
+ highlighted_output.select(
240
+ show_results_for_site,
241
+ inputs=[predictions_state],
242
  outputs=[results_output, info_text]
243
  )
244
 
245
+ if __name__ == "__main__":
246
+ demo.launch()