Xianfish9 commited on
Commit
2020c05
·
verified ·
1 Parent(s): c027c6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -107
app.py CHANGED
@@ -1,229 +1,207 @@
1
- #test
2
- import gradio as gr
3
- import torch
4
  import numpy as np
5
  import os
6
  import re
7
- import pandas as pd # --- 新增:引入 pandas 用于处理表格数据 ---
8
 
9
  # --- 依赖导入 ---
10
- # 请确保 model.py, Feature_extraction_algorithms 文件夹在同一目录下
11
  from model import CAFN
12
  from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
13
  from Feature_extraction_algorithms.Physicochemical import PC_feature
14
-
15
- # --- 1. 模型加载 (与之前相同) ---
16
- MODEL_PATH = "Adam_lr7e-05_weightdecay0.0001_epochs3480.pth"
17
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
-
19
- def load_model(model_path):
20
- model = CAFN().to(device)
21
- if os.path.exists(model_path):
22
- model.load_state_dict(torch.load(model_path, map_location=device))
23
- model.eval()
24
- print("模型加载成功!")
25
- return model
26
- else:
27
- print(f"错误:在路径 {model_path} 未找到模型文件")
28
- return None
29
-
30
- model = load_model(MODEL_PATH)
31
-
32
- # --- 2. PSTAAP 特征提取器初始化 (与之前相同) ---
33
  try:
34
  FR_MATRIX_PATH = 'Fr_train.mat'
35
  if not os.path.exists(FR_MATRIX_PATH):
36
- # 如果是本地运行且文件确实存在,请忽略此模拟错误;
37
- # 这里为了防止代码报错,如果文件不存在可以仅打印警告
38
- print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH},如果是测试环境请忽略。")
39
- else:
40
- load_precomputed_fr_matrix(FR_MATRIX_PATH)
41
  except Exception as e:
42
  print(f"PSTAAP 初始化过程中发生严重错误: {e}")
43
- # model = None # 暂时注释掉,以免本地测试时因为缺文件直接无法运行
44
 
45
 
46
  # --- 3. 特征提取函数 (与之前相同) ---
47
- def extract_features_from_seq(sequence_list):
48
- data2 = PC_feature(sequence_list)
49
- N = len(sequence_list)
50
- empty_list_array = [[] for _ in range(N)]
51
- data = np.array(empty_list_array, dtype=object)
52
- feature = PSTAAP_feature(sequence_list)
53
  data = np.hstack((data, feature))
54
  return data.astype(np.float32), data2.astype(np.float32)
55
 
56
- # --- 4. 核心预测函数 (保持不变,依旧返回浮点数) ---
57
  def predict_single_49mer(sequence_49mer):
58
  """
59
  对单个、长度为49的序列片段进行预测。
 
60
  """
61
  if model is None:
 
62
  print("错误:模型核心未加载。")
63
  return None
64
 
65
  sequence_list = [sequence_49mer]
66
- # 注意:如果缺少依赖文件,这里可能会报错,请确保环境完整
67
- try:
68
- x1_np, x2_np = extract_features_from_seq(sequence_list)
69
- except Exception as e:
70
- print(f"特征提取失败: {e}")
71
- return None
72
 
73
  tensor_x1 = torch.tensor(x1_np).to(device)
74
  tensor_x2 = torch.tensor(x2_np).to(device)
75
-
76
- with torch.no_grad():
77
  outputs = model(tensor_x1, tensor_x2)
78
 
79
  probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
80
-
81
  labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
82
-
83
- # 这里保持返回原始 float 数据,方便后续处理
84
  result = {label: float(prob) for label, prob in zip(labels, probabilities)}
85
 
86
  return result
87
 
88
- # --- 5. FASTA格式解析与主处理流程 (与之前相同) ---
89
  def parse_fasta(fasta_string):
 
 
90
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
 
91
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
92
 
93
  def process_fasta_and_predict(fasta_input):
 
 
 
 
94
  if not fasta_input or not isinstance(fasta_input, str):
95
  raise gr.Error("Please enter a valid FASTA format sequence.")
96
 
97
  sequence = parse_fasta(fasta_input)
98
 
99
  if len(sequence) < 49:
100
- raise gr.Error(f"The sequence is too short! It needs to be at least 49 amino acids. The current length is {len(sequence)}.")
101
 
 
102
  predictions_map = {}
 
 
103
  k_indices = [m.start() for m in re.finditer('K', sequence)]
104
 
105
  for k_index in k_indices:
 
106
  start, end = k_index - 24, k_index + 25
 
 
107
  if start >= 0 and end <= len(sequence):
108
  fragment = sequence[start:end]
109
  prediction_result = predict_single_49mer(fragment)
110
  if prediction_result:
 
111
  predictions_map[k_index] = prediction_result
112
 
113
  if not predictions_map:
114
- return [(sequence, None)], {}, "No valid K sites were found in the sequence for prediction."
 
115
 
 
116
  highlight_data = []
117
  last_pos = 0
 
118
  sorted_predictable_indices = sorted(predictions_map.keys())
119
 
120
  for k_index in sorted_predictable_indices:
 
121
  highlight_data.append((sequence[last_pos:k_index], None))
 
122
  highlight_data.append(("K", str(k_index)))
123
  last_pos = k_index + 1
124
 
 
125
  highlight_data.append((sequence[last_pos:], None))
126
 
127
  initial_info = "Processing complete! Click on the highlighted 'K' site in the sequence below to see its prediction."
128
 
129
  return highlight_data, predictions_map, initial_info
130
 
131
- # --- 6. 修改重点:Gradio事件处理函数 ---
132
  def show_results_for_site(evt: gr.SelectData, state_data):
133
  """
134
- 当用户点击高亮的K时触发。
135
- 此处我们将结果格式化为DataFrame,并精确控制百分比格式。
136
  """
137
  if evt.value:
 
138
  k_index_str = evt.value[1]
139
- try:
140
- k_index = int(k_index_str)
141
- except ValueError:
142
- return None, "Invalid selection."
143
 
 
144
  result_dict = state_data.get(k_index)
145
 
146
  if result_dict:
147
- site_info = f"Prediction results for 'K' at position {k_index + 1}:"
148
-
149
- # --- 修改开始:构建详细的表格数据 ---
150
- table_data = []
151
- for label, score in result_dict.items():
152
- # 使用 f-string 的 :.2% 语法,将 0.9299 转换为 92.99%
153
- percentage_str = f"{score:.2%}"
154
- table_data.append([label, percentage_str])
155
-
156
- # 创建 Pandas DataFrame
157
- df_result = pd.DataFrame(table_data, columns=["Modification Type", "Probability"])
158
- # --- 修改结束 ---
159
-
160
- return df_result, site_info
161
 
 
162
  return None, "Please click on the highlighted 'K' site in the sequence above to view the results."
163
 
164
 
165
- # --- 7. 创建并启动 Gradio 界面 ---
166
  fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
167
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
168
  IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
169
- GAAGATSLCFVYPLDFARTRLAADVGKAGAEREFRGLGDCLVKIYKSDGIKGLYQGFNVS
170
- VQGIIIYRAAYFGIYDTAKGMLPDPKNTHIVISWMIAQTVTAVAGLTSYPFDTVRRRMMM
171
- QSGRKGTDIMYTGTLDCWRKIARDEGGKAFFKGAWSNVLRGMGGAFVLVLYDEIKKYT"""
172
-
173
- with gr.Blocks(css=".predictable-k {color: red; font-weight: bold;}") as demo:
174
  gr.Markdown(
175
  """
176
  # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
177
- **Supports FASTA format input, allowing interactive viewing of the modification possibilities of each lysine site.**
178
  """
179
  )
180
  with gr.Row():
181
- with gr.Column(scale=2):
182
  fasta_input = gr.Textbox(
183
  lines=10,
184
  label="Input FASTA format protein sequence",
185
- placeholder="Paste your FASTA sequence here..."
186
  )
187
  submit_btn = gr.Button("Submit Prediction", variant="primary")
188
 
189
  with gr.Column(scale=3):
190
  gr.Markdown("### Prediction Results")
191
  info_text = gr.Textbox(label="State", interactive=False, value="Waiting for input...")
 
192
  predictions_state = gr.State({})
193
-
194
- # --- 修改重点:将 gr.Label 替换为 gr.DataFrame ---
195
- results_output = gr.DataFrame(
196
- headers=["Modification Type", "Probability"],
197
- datatype=["str", "str"],
198
- label="Detailed Probabilities",
199
- interactive=False
200
- )
201
- # ------------------------------------------------
202
 
203
  gr.Markdown("---")
204
  gr.Markdown("### Visualized Sequence")
 
205
  highlighted_output = gr.HighlightedText(
206
  label="Sequence Analysis",
207
- color_map={"predictable-k": "red"},
 
208
  )
209
 
210
  gr.Examples(
211
- examples=[[fasta_example]],
212
- inputs=fasta_input,
213
- label="Example sequence"
214
- )
215
-
216
- # --- 设定事件逻辑 ---
217
- submit_btn.click(
218
- fn=process_fasta_and_predict,
219
- inputs=fasta_input,
220
- outputs=[highlighted_output, predictions_state, info_text]
221
- )
222
-
223
- highlighted_output.select(
224
- fn=show_results_for_site,
225
- inputs=[predictions_state],
226
  outputs=[results_output, info_text]
227
  )
228
 
229
- demo.launch()
 
 
 
 
 
1
  import numpy as np
2
  import os
3
  import re
4
+
5
 
6
  # --- 依赖导入 ---
7
+
8
  from model import CAFN
9
  from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
10
  from Feature_extraction_algorithms.Physicochemical import PC_feature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  try:
12
  FR_MATRIX_PATH = 'Fr_train.mat'
13
  if not os.path.exists(FR_MATRIX_PATH):
14
+ raise FileNotFoundError(f"PSTAAP初始化失败:找不到矩阵文件 {FR_MATRIX_PATH}")
15
+ load_precomputed_fr_matrix(FR_MATRIX_PATH)
16
+
17
+
18
+
19
  except Exception as e:
20
  print(f"PSTAAP 初始化过程中发生严重错误: {e}")
21
+ model = None
22
 
23
 
24
  # --- 3. 特征提取函数 (与之前相同) ---
 
 
 
 
 
 
25
  data = np.hstack((data, feature))
26
  return data.astype(np.float32), data2.astype(np.float32)
27
 
28
+ # --- 4. 核心预测函数 (重构为处理单个49-mer片段) ---
29
  def predict_single_49mer(sequence_49mer):
30
  """
31
  对单个、长度为49的序列片段进行预测。
32
+ 这是底层的预测引擎。
33
  """
34
  if model is None:
35
+ # 这个错误不应该在UI层面抛出,而是在后台日志中记录
36
  print("错误:模型核心未加载。")
37
  return None
38
 
39
  sequence_list = [sequence_49mer]
40
+ x1_np, x2_np = extract_features_from_seq(sequence_list)
41
+
42
+
43
+
44
+
45
+
46
 
47
  tensor_x1 = torch.tensor(x1_np).to(device)
48
  tensor_x2 = torch.tensor(x2_np).to(device)
 
 
49
  outputs = model(tensor_x1, tensor_x2)
50
 
51
  probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
52
+ #Lysine-Acetylation(K-Ac)
53
  labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
54
+
55
+
56
  result = {label: float(prob) for label, prob in zip(labels, probabilities)}
57
 
58
  return result
59
 
60
+ # --- 5. 新增:FASTA格式解析与主处理流程 ---
61
  def parse_fasta(fasta_string):
62
+ """从FASTA格式文本中提取序列。"""
63
+ # 移除FASTA头(以'>'开头的行)
64
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
65
+ # 连接所有行并移除任何空白字符
66
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
67
 
68
  def process_fasta_and_predict(fasta_input):
69
+ """
70
+ 接收FASTA输入,找到所有K位点,进行切片和预测,
71
+ 并返回用于Gradio HighlightedText组件的数据和一个包含预测结果的状态字典。
72
+ """
73
  if not fasta_input or not isinstance(fasta_input, str):
74
  raise gr.Error("Please enter a valid FASTA format sequence.")
75
 
76
  sequence = parse_fasta(fasta_input)
77
 
78
  if len(sequence) < 49:
79
+ raise gr.Error(f"The sequence is too short! It needs to be at least 49 amino acids. The current length is {len(sequence)}")
80
 
81
+ # 存储每个可预测K位点(索引)及其预测结果
82
  predictions_map = {}
83
+
84
+ # 寻找所有 'K' 的索引
85
  k_indices = [m.start() for m in re.finditer('K', sequence)]
86
 
87
  for k_index in k_indices:
88
+ # 尝试以K为中心截取片段 (K前24个, K, K后24个)
89
  start, end = k_index - 24, k_index + 25
90
+
91
+ # 边界检查,如果长度不足49则跳过
92
  if start >= 0 and end <= len(sequence):
93
  fragment = sequence[start:end]
94
  prediction_result = predict_single_49mer(fragment)
95
  if prediction_result:
96
+ # 使用K的原始索引作为键
97
  predictions_map[k_index] = prediction_result
98
 
99
  if not predictions_map:
100
+ # 如果没有一个K位点可以被成功预测
101
+ return [(sequence, None)], {}, "No valid K sites were found in the sequence for prediction (i.e., there were not enough amino acids before and after K)."
102
 
103
+ # --- 构建Gradio HighlightedText的输入格式 ---
104
  highlight_data = []
105
  last_pos = 0
106
+ # 按索引排序,确保我们按顺序处理序列
107
  sorted_predictable_indices = sorted(predictions_map.keys())
108
 
109
  for k_index in sorted_predictable_indices:
110
+ # 添加K之前未高亮的部分
111
  highlight_data.append((sequence[last_pos:k_index], None))
112
+ # 添加需要高亮的K,并用其索引作为标签
113
  highlight_data.append(("K", str(k_index)))
114
  last_pos = k_index + 1
115
 
116
+ # 添加最后一个K之后剩余的部分
117
  highlight_data.append((sequence[last_pos:], None))
118
 
119
  initial_info = "Processing complete! Click on the highlighted 'K' site in the sequence below to see its prediction."
120
 
121
  return highlight_data, predictions_map, initial_info
122
 
123
+ # --- 6. 新增:Gradio事件处理函数 ---
124
  def show_results_for_site(evt: gr.SelectData, state_data):
125
  """
126
+ 当用户点击高亮的K时,此函数被触发。
127
+ 它从state_data中查找并返回该位点的预测结果。
128
  """
129
  if evt.value:
130
+ # evt.value 是 ('K', '索引字符串')
131
  k_index_str = evt.value[1]
132
+ k_index = int(k_index_str)
133
+
134
+
135
+
136
 
137
+ # 从状态字典中获取结果
138
  result_dict = state_data.get(k_index)
139
 
140
  if result_dict:
141
+ site_info = f"Prediction results for the segment centered at 'K' at position {k_index + 1}:"
142
+ return result_dict, site_info
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
 
156
+ # 如果没有选择或出现错误
157
  return None, "Please click on the highlighted 'K' site in the sequence above to view the results."
158
 
159
 
160
+ # --- 7. 创建并启动 Gradio 界面 (使用 gr.Blocks) ---
161
  fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
162
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
163
  IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
 
 
 
 
 
164
  gr.Markdown(
165
  """
166
  # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
167
+ **Supports FASTA format input, allowing interactive viewing of the modification possibilities of each lysine site in the protein sequence.**
168
  """
169
  )
170
  with gr.Row():
 
171
  fasta_input = gr.Textbox(
172
  lines=10,
173
  label="Input FASTA format protein sequence",
174
+ placeholder="Please paste your FASTA formatted sequence here (we provide an example sequence below)..."
175
  )
176
  submit_btn = gr.Button("Submit Prediction", variant="primary")
177
 
178
  with gr.Column(scale=3):
179
  gr.Markdown("### Prediction Results")
180
  info_text = gr.Textbox(label="State", interactive=False, value="Waiting for input...")
181
+ # 用于存储所有位点的预测结果,对用户不可见
182
  predictions_state = gr.State({})
183
+ results_output = gr.Label(num_top_classes=4, label="After clicking on the colored 'K' site, the results will be displayed here")
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
 
193
  gr.Markdown("---")
194
  gr.Markdown("### Visualized Sequence")
195
+ # 使用 a[class='predictable-k'] 来应用CSS
196
  highlighted_output = gr.HighlightedText(
197
  label="Sequence Analysis",
198
+ color_map={"predictable-k": "red"}, # 旧版Gradio的用法
199
+ # 在新版Gradio中,CSS通过gr.Blocks的css参数全局定义更可靠
200
  )
201
 
202
  gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  outputs=[results_output, info_text]
204
  )
205
 
206
+ # 启动应用
207
+ demo.launch(debug=True)