Xianfish9 commited on
Commit
ad61623
·
verified ·
1 Parent(s): 2020c05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -85
app.py CHANGED
@@ -1,207 +1,257 @@
 
 
 
1
  import numpy as np
2
  import os
3
  import re
4
-
5
 
6
  # --- 依赖导入 ---
7
-
8
  from model import CAFN
9
  from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
10
  from Feature_extraction_algorithms.Physicochemical import PC_feature
11
- try:
12
- FR_MATRIX_PATH = 'Fr_train.mat'
13
- if not os.path.exists(FR_MATRIX_PATH):
14
- raise FileNotFoundError(f"PSTAAP初始化失败:找不到矩阵文件 {FR_MATRIX_PATH}")
15
- load_precomputed_fr_matrix(FR_MATRIX_PATH)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
18
 
 
 
 
 
 
 
 
 
 
19
  except Exception as e:
20
  print(f"PSTAAP 初始化过程中发生严重错误: {e}")
21
- model = None
22
 
23
 
24
  # --- 3. 特征提取函数 (与之前相同) ---
 
 
 
 
 
 
25
  data = np.hstack((data, feature))
26
  return data.astype(np.float32), data2.astype(np.float32)
27
 
28
- # --- 4. 核心预测函数 (重构为处理单个49-mer片段) ---
29
  def predict_single_49mer(sequence_49mer):
30
  """
31
  对单个、长度为49的序列片段进行预测。
32
- 这是底层的预测引擎。
33
  """
34
  if model is None:
35
- # 这个错误不应该在UI层面抛出,而是在后台日志中记录
36
  print("错误:模型核心未加载。")
37
  return None
38
 
39
  sequence_list = [sequence_49mer]
40
- x1_np, x2_np = extract_features_from_seq(sequence_list)
41
-
42
-
43
-
44
-
45
-
46
 
47
  tensor_x1 = torch.tensor(x1_np).to(device)
48
  tensor_x2 = torch.tensor(x2_np).to(device)
 
 
49
  outputs = model(tensor_x1, tensor_x2)
50
 
51
  probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
52
- #Lysine-Acetylation(K-Ac)
53
  labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
54
-
55
-
56
  result = {label: float(prob) for label, prob in zip(labels, probabilities)}
57
 
58
  return result
59
 
60
- # --- 5. 新增:FASTA格式解析与主处理流程 ---
61
  def parse_fasta(fasta_string):
62
- """从FASTA格式文本中提取序列。"""
63
- # 移除FASTA头(以'>'开头的行)
64
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
65
- # 连接所有行并移除任何空白字符
66
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
67
 
68
  def process_fasta_and_predict(fasta_input):
69
- """
70
- 接收FASTA输入,找到所有K位点,进行切片和预测,
71
- 并返回用于Gradio HighlightedText组件的数据和一个包含预测结果的状态字典。
72
- """
73
  if not fasta_input or not isinstance(fasta_input, str):
74
  raise gr.Error("Please enter a valid FASTA format sequence.")
75
 
76
  sequence = parse_fasta(fasta_input)
77
 
78
  if len(sequence) < 49:
79
- raise gr.Error(f"The sequence is too short! It needs to be at least 49 amino acids. The current length is {len(sequence)}")
 
80
 
81
- # 存储每个可预测K位���(索引)及其预测结果
82
  predictions_map = {}
83
-
84
- # 寻找所有 'K' 的索引
85
  k_indices = [m.start() for m in re.finditer('K', sequence)]
86
 
87
  for k_index in k_indices:
88
- # 尝试以K为中心截取片段 (K前24个, K, K后24个)
89
  start, end = k_index - 24, k_index + 25
90
-
91
- # 边界检查,如果长度不足49则跳过
92
  if start >= 0 and end <= len(sequence):
93
  fragment = sequence[start:end]
94
  prediction_result = predict_single_49mer(fragment)
95
  if prediction_result:
96
- # 使用K的原始索引作为键
97
  predictions_map[k_index] = prediction_result
98
 
99
  if not predictions_map:
100
- # 如果没有一个K位点可以被成功预测
101
- return [(sequence, None)], {}, "No valid K sites were found in the sequence for prediction (i.e., there were not enough amino acids before and after K)."
 
102
 
103
- # --- 构建Gradio HighlightedText的输入格式 ---
104
  highlight_data = []
105
  last_pos = 0
106
- # 按索引排序,确保我们按顺序处理序列
107
  sorted_predictable_indices = sorted(predictions_map.keys())
108
 
109
  for k_index in sorted_predictable_indices:
110
- # 添加K之前未高亮的部分
111
  highlight_data.append((sequence[last_pos:k_index], None))
112
- # 添加需要高亮的K,并用其索引作为标签
113
  highlight_data.append(("K", str(k_index)))
114
  last_pos = k_index + 1
115
 
116
- # 添加最后一个K之后剩余的部分
117
  highlight_data.append((sequence[last_pos:], None))
118
 
119
  initial_info = "Processing complete! Click on the highlighted 'K' site in the sequence below to see its prediction."
120
 
121
  return highlight_data, predictions_map, initial_info
122
 
123
- # --- 6. 新增:Gradio事件处理函数 ---
124
  def show_results_for_site(evt: gr.SelectData, state_data):
125
  """
126
- 当用户点击高亮的K时,此函数被触发。
127
- 它从state_data中查找并返回该位点的预测结果。
128
  """
129
  if evt.value:
130
- # evt.value 是 ('K', '索引字符串')
131
- k_index_str = evt.value[1]
132
- k_index = int(k_index_str)
133
-
134
-
135
 
 
 
 
 
 
136
 
137
- # 从状态字典中获取结果
138
  result_dict = state_data.get(k_index)
139
 
140
  if result_dict:
141
- site_info = f"Prediction results for the segment centered at 'K' at position {k_index + 1}:"
142
- return result_dict, site_info
143
-
144
-
145
-
146
-
147
-
148
-
149
-
150
-
151
-
152
-
153
-
154
-
155
 
156
- # 如果没有选择或出现错误
 
 
 
 
 
 
157
  return None, "Please click on the highlighted 'K' site in the sequence above to view the results."
158
 
159
 
160
- # --- 7. 创建并启动 Gradio 界面 (使用 gr.Blocks) ---
161
  fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
162
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
163
  IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
 
 
 
 
 
164
  gr.Markdown(
165
  """
166
  # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
167
- **Supports FASTA format input, allowing interactive viewing of the modification possibilities of each lysine site in the protein sequence.**
168
  """
169
  )
170
  with gr.Row():
 
171
  fasta_input = gr.Textbox(
172
  lines=10,
173
  label="Input FASTA format protein sequence",
174
- placeholder="Please paste your FASTA formatted sequence here (we provide an example sequence below)..."
175
  )
176
  submit_btn = gr.Button("Submit Prediction", variant="primary")
177
 
178
  with gr.Column(scale=3):
179
  gr.Markdown("### Prediction Results")
180
  info_text = gr.Textbox(label="State", interactive=False, value="Waiting for input...")
181
- # 用于存储所有位点的预测结果,对用户不可见
182
- predictions_state = gr.State({})
183
- results_output = gr.Label(num_top_classes=4, label="After clicking on the colored 'K' site, the results will be displayed here")
184
-
185
-
186
-
187
-
188
-
189
-
190
-
191
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  gr.Markdown("---")
194
  gr.Markdown("### Visualized Sequence")
195
- # 使用 a[class='predictable-k'] 来应用CSS
196
  highlighted_output = gr.HighlightedText(
197
  label="Sequence Analysis",
198
- color_map={"predictable-k": "red"}, # 旧版Gradio的用法
199
- # 在新版Gradio中,CSS通过gr.Blocks的css参数全局定义更可靠
200
  )
201
 
202
  gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  outputs=[results_output, info_text]
204
  )
205
 
206
- # 启动应用
207
- demo.launch(debug=True)
 
1
+ #test
2
+ import gradio as gr
3
+ import torch
4
  import numpy as np
5
  import os
6
  import re
7
+ import pandas as pd # --- 新增:引入 pandas 用于处理表格数据 ---
8
 
9
  # --- 依赖导入 ---
10
+ # 请确保 model.py, Feature_extraction_algorithms 文件夹在同一目录下
11
  from model import CAFN
12
  from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
13
  from Feature_extraction_algorithms.Physicochemical import PC_feature
 
 
 
 
 
14
 
15
+ # --- 1. 模型加载 (与之前相同) ---
16
+ MODEL_PATH = "Adam_lr7e-05_weightdecay0.0001_epochs3480.pth"
17
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+
19
+ def load_model(model_path):
20
+ model = CAFN().to(device)
21
+ if os.path.exists(model_path):
22
+ model.load_state_dict(torch.load(model_path, map_location=device))
23
+ model.eval()
24
+ print("模型加载成功!")
25
+ return model
26
+ else:
27
+ print(f"错误:在路径 {model_path} 未找到模型文件")
28
+ return None
29
 
30
+ model = load_model(MODEL_PATH)
31
 
32
+ # --- 2. PSTAAP 特征提取器初始化 (与之前相同) ---
33
+ try:
34
+ FR_MATRIX_PATH = 'Fr_train.mat'
35
+ if not os.path.exists(FR_MATRIX_PATH):
36
+ # 如果是本地运行且文件确实存在,请忽略此模拟错误;
37
+ # 这里为了防止代码报错,如果文件不存在可以仅打印警告
38
+ print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH},如果是测试环境请忽略。")
39
+ else:
40
+ load_precomputed_fr_matrix(FR_MATRIX_PATH)
41
  except Exception as e:
42
  print(f"PSTAAP 初始化过程中发生严重错误: {e}")
43
+ # model = None # 暂时注释掉,以免本地测试时因为缺文件直接无法运行
44
 
45
 
46
  # --- 3. 特征提取函数 (与之前相同) ---
47
+ def extract_features_from_seq(sequence_list):
48
+ data2 = PC_feature(sequence_list)
49
+ N = len(sequence_list)
50
+ empty_list_array = [[] for _ in range(N)]
51
+ data = np.array(empty_list_array, dtype=object)
52
+ feature = PSTAAP_feature(sequence_list)
53
  data = np.hstack((data, feature))
54
  return data.astype(np.float32), data2.astype(np.float32)
55
 
56
+ # --- 4. 核心预测函数 (保持不变,依旧返回浮点数) ---
57
  def predict_single_49mer(sequence_49mer):
58
  """
59
  对单个、长度为49的序列片段进行预测。
60
+
61
  """
62
  if model is None:
63
+
64
  print("错误:模型核心未加载。")
65
  return None
66
 
67
  sequence_list = [sequence_49mer]
68
+ # 注意:如果缺少依赖文件,这里可能会报错,请确保环境完整
69
+ try:
70
+ x1_np, x2_np = extract_features_from_seq(sequence_list)
71
+ except Exception as e:
72
+ print(f"特征提取失败: {e}")
73
+ return None
74
 
75
  tensor_x1 = torch.tensor(x1_np).to(device)
76
  tensor_x2 = torch.tensor(x2_np).to(device)
77
+
78
+ with torch.no_grad():
79
  outputs = model(tensor_x1, tensor_x2)
80
 
81
  probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
82
+
83
  labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
84
+
85
+ # 这里保持返回原始 float 数据,方便后续处理
86
  result = {label: float(prob) for label, prob in zip(labels, probabilities)}
87
 
88
  return result
89
 
90
+ # --- 5. FASTA格式解析与主处理流程 (与之前相同) ---
91
  def parse_fasta(fasta_string):
92
+
93
+
94
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
95
+
96
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
97
 
98
  def process_fasta_and_predict(fasta_input):
99
+
100
+
101
+
102
+
103
  if not fasta_input or not isinstance(fasta_input, str):
104
  raise gr.Error("Please enter a valid FASTA format sequence.")
105
 
106
  sequence = parse_fasta(fasta_input)
107
 
108
  if len(sequence) < 49:
109
+ raise gr.Error(f"The sequence is too short! It needs to be at least 49 amino acids. The current length is {len(sequence)}.")
110
+
111
 
 
112
  predictions_map = {}
113
+
114
+
115
  k_indices = [m.start() for m in re.finditer('K', sequence)]
116
 
117
  for k_index in k_indices:
118
+
119
  start, end = k_index - 24, k_index + 25
120
+
121
+
122
  if start >= 0 and end <= len(sequence):
123
  fragment = sequence[start:end]
124
  prediction_result = predict_single_49mer(fragment)
125
  if prediction_result:
126
+
127
  predictions_map[k_index] = prediction_result
128
 
129
  if not predictions_map:
130
+ return [(sequence, None)], {}, "No valid K sites were found in the sequence for prediction."
131
+
132
+
133
 
 
134
  highlight_data = []
135
  last_pos = 0
136
+
137
  sorted_predictable_indices = sorted(predictions_map.keys())
138
 
139
  for k_index in sorted_predictable_indices:
140
+
141
  highlight_data.append((sequence[last_pos:k_index], None))
142
+
143
  highlight_data.append(("K", str(k_index)))
144
  last_pos = k_index + 1
145
 
146
+
147
  highlight_data.append((sequence[last_pos:], None))
148
 
149
  initial_info = "Processing complete! Click on the highlighted 'K' site in the sequence below to see its prediction."
150
 
151
  return highlight_data, predictions_map, initial_info
152
 
153
+ # --- 6. 修改重点:Gradio事件处理函数 ---
154
  def show_results_for_site(evt: gr.SelectData, state_data):
155
  """
156
+ 当用户点击高亮的K时触发。
157
+ 此处我们将结果格式化为DataFrame,并精确控制百分比格式。
158
  """
159
  if evt.value:
 
 
 
 
 
160
 
161
+ k_index_str = evt.value[1]
162
+ try:
163
+ k_index = int(k_index_str)
164
+ except ValueError:
165
+ return None, "Invalid selection."
166
 
167
+
168
  result_dict = state_data.get(k_index)
169
 
170
  if result_dict:
171
+ site_info = f"Prediction results for 'K' at position {k_index + 1}:"
172
+
173
+ # --- 修改开始:构建详细的表格数据 ---
174
+ table_data = []
175
+ for label, score in result_dict.items():
176
+ # 使用 f-string 的 :.2% 语法,将 0.9299 转换为 92.99%
177
+ percentage_str = f"{score:.2%}"
178
+ table_data.append([label, percentage_str])
 
 
 
 
 
 
179
 
180
+ # 创建 Pandas DataFrame
181
+ df_result = pd.DataFrame(table_data, columns=["Modification Type", "Probability"])
182
+ # --- 修改结束 ---
183
+
184
+ return df_result, site_info
185
+
186
+
187
  return None, "Please click on the highlighted 'K' site in the sequence above to view the results."
188
 
189
 
190
+ # --- 7. 创建并启动 Gradio 界面 ---
191
  fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
192
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
193
  IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
194
+ GAAGATSLCFVYPLDFARTRLAADVGKAGAEREFRGLGDCLVKIYKSDGIKGLYQGFNVS
195
+ VQGIIIYRAAYFGIYDTAKGMLPDPKNTHIVISWMIAQTVTAVAGLTSYPFDTVRRRMMM
196
+ QSGRKGTDIMYTGTLDCWRKIARDEGGKAFFKGAWSNVLRGMGGAFVLVLYDEIKKYT"""
197
+
198
+ with gr.Blocks(css=".predictable-k {color: red; font-weight: bold;}") as demo:
199
  gr.Markdown(
200
  """
201
  # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
202
+ **Supports FASTA format input, allowing interactive viewing of the modification possibilities of each lysine site.**
203
  """
204
  )
205
  with gr.Row():
206
+ with gr.Column(scale=2):
207
  fasta_input = gr.Textbox(
208
  lines=10,
209
  label="Input FASTA format protein sequence",
210
+ placeholder="Paste your FASTA sequence here..."
211
  )
212
  submit_btn = gr.Button("Submit Prediction", variant="primary")
213
 
214
  with gr.Column(scale=3):
215
  gr.Markdown("### Prediction Results")
216
  info_text = gr.Textbox(label="State", interactive=False, value="Waiting for input...")
 
 
 
 
 
 
 
 
 
 
217
 
218
+ predictions_state = gr.State({})
219
+
220
+ # --- 修改重点:将 gr.Label 替换为 gr.DataFrame ---
221
+ results_output = gr.DataFrame(
222
+ headers=["Modification Type", "Probability"],
223
+ datatype=["str", "str"],
224
+ label="Detailed Probabilities",
225
+ interactive=False
226
+ )
227
+ # ------------------------------------------------
228
 
229
  gr.Markdown("---")
230
  gr.Markdown("### Visualized Sequence")
231
+
232
  highlighted_output = gr.HighlightedText(
233
  label="Sequence Analysis",
234
+ color_map={"predictable-k": "red"},
235
+
236
  )
237
 
238
  gr.Examples(
239
+ examples=[[fasta_example]],
240
+ inputs=fasta_input,
241
+ label="Example sequence"
242
+ )
243
+
244
+ # --- 设定事件逻辑 ---
245
+ submit_btn.click(
246
+ fn=process_fasta_and_predict,
247
+ inputs=fasta_input,
248
+ outputs=[highlighted_output, predictions_state, info_text]
249
+ )
250
+
251
+ highlighted_output.select(
252
+ fn=show_results_for_site,
253
+ inputs=[predictions_state],
254
  outputs=[results_output, info_text]
255
  )
256
 
257
+ demo.launch()