Xianfish9 commited on
Commit
0bea2f5
·
verified ·
1 Parent(s): 75c3a94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -105
app.py CHANGED
@@ -1,207 +1,195 @@
1
  import numpy as np
2
  import os
3
  import re
4
-
 
5
 
6
  # --- 依赖导入 ---
7
-
8
  from model import CAFN
9
  from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
10
  from Feature_extraction_algorithms.Physicochemical import PC_feature
 
 
 
 
 
 
11
  try:
12
  FR_MATRIX_PATH = 'Fr_train.mat'
13
  if not os.path.exists(FR_MATRIX_PATH):
14
- raise FileNotFoundError(f"PSTAAP初始化失败:找不到矩阵文件 {FR_MATRIX_PATH}")
15
- load_precomputed_fr_matrix(FR_MATRIX_PATH)
16
-
17
-
 
 
 
 
 
18
 
19
  except Exception as e:
20
- print(f"PSTAAP 初始化过程中发生严重错误: {e}")
21
- model = None
22
-
23
 
24
- # --- 3. 特征提取函数 (与之前相同) ---
25
- data = np.hstack((data, feature))
26
- return data.astype(np.float32), data2.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # --- 4. 核心预测函数 (重构为处理单个49-mer片段) ---
29
  def predict_single_49mer(sequence_49mer):
30
  """
31
  对单个、长度为49的序列片段进行预测。
32
- 这是底层的预测引擎。
33
  """
 
34
  if model is None:
35
- # 这个错误不应该在UI层面抛出,而是在后台日志中记录
36
- print("错误:模型核心未加载。")
37
- return None
38
 
39
  sequence_list = [sequence_49mer]
40
  x1_np, x2_np = extract_features_from_seq(sequence_list)
41
-
42
-
43
-
44
-
45
-
46
 
47
  tensor_x1 = torch.tensor(x1_np).to(device)
48
  tensor_x2 = torch.tensor(x2_np).to(device)
49
- outputs = model(tensor_x1, tensor_x2)
50
 
51
- probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
52
- #Lysine-Acetylation(K-Ac)
 
 
53
  labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
54
-
55
-
56
- result = {label: float(prob) for label, prob in zip(labels, probabilities)}
57
 
 
 
 
 
 
58
  return result
59
 
60
- # --- 5. 新增:FASTA格式解析与主处理流程 ---
61
  def parse_fasta(fasta_string):
62
- """从FASTA格式文本中提取序列。"""
63
- # 移除FASTA头(以'>'开头的行)
64
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
65
- # 连接所有行并移除任何空白字符
66
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
67
 
68
  def process_fasta_and_predict(fasta_input):
69
- """
70
- 接收FASTA输入,找到所有K位点,进行切片和预测,
71
- 并返回用于Gradio HighlightedText组件的数据和一个包含预测结果的状态字典。
72
- """
73
  if not fasta_input or not isinstance(fasta_input, str):
74
  raise gr.Error("Please enter a valid FASTA format sequence.")
75
 
76
  sequence = parse_fasta(fasta_input)
77
 
78
  if len(sequence) < 49:
79
- raise gr.Error(f"The sequence is too short! It needs to be at least 49 amino acids. The current length is {len(sequence)}")
80
 
81
- # 存储每个可预测K位点(索引)及其预测结果
82
  predictions_map = {}
83
-
84
- # 寻找所有 'K' 的索引
85
  k_indices = [m.start() for m in re.finditer('K', sequence)]
86
 
87
  for k_index in k_indices:
88
- # 尝试以K为中心截取片段 (K前24个, K, K后24个)
89
  start, end = k_index - 24, k_index + 25
90
-
91
- # 边界检查,如果长度不足49则跳过
92
  if start >= 0 and end <= len(sequence):
93
  fragment = sequence[start:end]
94
  prediction_result = predict_single_49mer(fragment)
95
  if prediction_result:
96
- # 使用K的原始索引作为键
97
  predictions_map[k_index] = prediction_result
98
 
99
  if not predictions_map:
100
- # 如果没有一个K位点可以被成功预测
101
- return [(sequence, None)], {}, "No valid K sites were found in the sequence for prediction (i.e., there were not enough amino acids before and after K)."
102
 
103
- # --- 构建Gradio HighlightedText的输入格式 ---
104
  highlight_data = []
105
  last_pos = 0
106
- # 按索引排序,确保我们按顺序处理序列
107
  sorted_predictable_indices = sorted(predictions_map.keys())
108
 
109
  for k_index in sorted_predictable_indices:
110
- # 添加K之前未高亮的部分
111
  highlight_data.append((sequence[last_pos:k_index], None))
112
- # 添加需要高亮的K,并用其索引作为标签
113
  highlight_data.append(("K", str(k_index)))
114
  last_pos = k_index + 1
115
 
116
- # 添加最后一个K之后剩余的部分
117
  highlight_data.append((sequence[last_pos:], None))
118
 
119
- initial_info = "Processing complete! Click on the highlighted 'K' site in the sequence below to see its prediction."
120
-
121
  return highlight_data, predictions_map, initial_info
122
 
123
- # --- 6. 新增:Gradio事件处理函数 ---
124
  def show_results_for_site(evt: gr.SelectData, state_data):
125
- """
126
- 当用户点击高亮的K时,此函数被触发。
127
- 它从state_data中查找并返回该位点的预测结果。
128
- """
129
  if evt.value:
130
- # evt.value ('K', '索引字符串')
131
- k_index_str = evt.value[1]
 
 
132
  k_index = int(k_index_str)
133
-
134
-
135
-
136
-
137
- # 从状态字典中获取结果
138
  result_dict = state_data.get(k_index)
139
 
140
  if result_dict:
141
- site_info = f"Prediction results for the segment centered at 'K' at position {k_index + 1}:"
142
  return result_dict, site_info
143
-
144
-
145
-
146
-
147
-
148
-
149
-
150
-
151
-
152
-
153
-
154
-
155
 
156
- # 如果没有选择或出现错误
157
- return None, "Please click on the highlighted 'K' site in the sequence above to view the results."
158
 
 
159
 
160
- # --- 7. 创建并启动 Gradio 界面 (使用 gr.Blocks) ---
161
  fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
162
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
163
- IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG
 
 
 
164
  gr.Markdown(
165
  """
166
  # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
167
- # **Supports FASTA format input, allowing interactive viewing of the modification possibilities of each lysine site in the protein sequence.**
168
  """
169
  )
 
170
  with gr.Row():
 
171
  fasta_input = gr.Textbox(
172
  lines=10,
173
- label="Input FASTA format protein sequence",
174
- placeholder="Please paste your FASTA formatted sequence here (we provide an example sequence below)..."
 
175
  )
176
  submit_btn = gr.Button("Submit Prediction", variant="primary")
177
-
178
  with gr.Column(scale=3):
179
  gr.Markdown("### Prediction Results")
180
- info_text = gr.Textbox(label="State", interactive=False, value="Waiting for input...")
181
- # 用于��储所有位点的预测结果,对用户不可见
182
  predictions_state = gr.State({})
183
- results_output = gr.Label(num_top_classes=4, label="After clicking on the colored 'K' site, the results will be displayed here")
184
-
185
-
186
-
187
-
188
-
189
-
190
-
191
-
192
-
193
  gr.Markdown("---")
194
- gr.Markdown("### Visualized Sequence")
195
- # 使用 a[class='predictable-k'] 来应用CSS
196
  highlighted_output = gr.HighlightedText(
197
  label="Sequence Analysis",
198
- color_map={"predictable-k": "red"}, # 旧版Gradio的用法
199
- # 在新版Gradio中,CSS通过gr.Blocks的css参数全局定义更可靠
 
 
 
 
 
 
 
 
200
  )
201
 
202
- gr.Examples(
 
 
203
  outputs=[results_output, info_text]
204
  )
205
 
206
- # 启动应用
207
- demo.launch(debug=True)
 
1
  import numpy as np
2
  import os
3
  import re
4
+ import torch
5
+ import gradio as gr # 必须导入 gradio
6
 
7
  # --- 依赖导入 ---
8
+ # 请确保这些文件存在,否则会报错
9
  from model import CAFN
10
  from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
11
  from Feature_extraction_algorithms.Physicochemical import PC_feature
12
+
13
+ # 设置设备
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # --- 1. 模型初始化 ---
17
+ model = None
18
  try:
19
  FR_MATRIX_PATH = 'Fr_train.mat'
20
  if not os.path.exists(FR_MATRIX_PATH):
21
+ # 为了演示,这里改为警告而不是阻断,实际运行时请确保文件存在
22
+ print(f"警告:找不到矩阵文件 {FR_MATRIX_PATH}")
23
+ else:
24
+ load_precomputed_fr_matrix(FR_MATRIX_PATH)
25
+
26
+ # 假设这里还需要加载模型权重
27
+ # model = CAFN().to(device)
28
+ # model.load_state_dict(torch.load('model_weights.pth'))
29
+ # model.eval()
30
 
31
  except Exception as e:
32
+ print(f"初始化过程中发生错误: {e}")
 
 
33
 
34
+ # --- 3. 特征提取函数 (修复了你代码中缺失的函数定义头) ---
35
+ def extract_features_from_seq(sequence_list):
36
+ """
37
+ 这里是你缺失的特征提取函数逻辑。
38
+ 你需要把之前的 data, data2 生成逻辑放回来。
39
+ """
40
+ # 模拟数据用于演示,请替换为你真实的特征提取代码
41
+ # 假设 PSTAAP 和 PC_feature 返回 numpy 数组
42
+ # data = PC_feature(sequence_list) ...
43
+ # data2 = PSTAAP_feature(sequence_list) ...
44
+
45
+ # 这是一个占位符,保证代码不报错
46
+ # 实际维度需要根据你的模型输入调整 (Batch, Channel, Length)
47
+ x1_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32)
48
+ x2_dummy = np.random.rand(len(sequence_list), 1, 49).astype(np.float32)
49
+ return x1_dummy, x2_dummy
50
 
51
+ # --- 4. 核心预测函数 ---
52
  def predict_single_49mer(sequence_49mer):
53
  """
54
  对单个、长度为49的序列片段进行预测。
 
55
  """
56
+ # 如果模型未加载,返回模拟数据(方便调试UI)
57
  if model is None:
58
+ print("警告:模型未加载,返回随机结果用于演示。")
59
+ labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
60
+ return {label: float(np.random.rand()) for label in labels}
61
 
62
  sequence_list = [sequence_49mer]
63
  x1_np, x2_np = extract_features_from_seq(sequence_list)
 
 
 
 
 
64
 
65
  tensor_x1 = torch.tensor(x1_np).to(device)
66
  tensor_x2 = torch.tensor(x2_np).to(device)
 
67
 
68
+ with torch.no_grad(): # 推理时不需要梯度
69
+ outputs = model(tensor_x1, tensor_x2)
70
+ probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
71
+
72
  labels = ["Lysine-Acetyllysine (K-Ac)", "Lysine-Crotonyllysine (K-Cr)", "Lysine-Methyllysine (K-Me)", "Lysine-Succinyllysine (K-Succ)"]
 
 
 
73
 
74
+ # 处理 batch size 为 1 的情况
75
+ if probabilities.ndim == 0:
76
+ probabilities = [probabilities]
77
+
78
+ result = {label: float(prob) for label, prob in zip(labels, probabilities)}
79
  return result
80
 
81
+ # --- 5. FASTA格式解析与主处理流程 ---
82
  def parse_fasta(fasta_string):
 
 
83
  sequence_lines = [line for line in fasta_string.splitlines() if not line.startswith('>')]
 
84
  return "".join(sequence_lines).replace(" ", "").replace("\n", "").upper()
85
 
86
  def process_fasta_and_predict(fasta_input):
 
 
 
 
87
  if not fasta_input or not isinstance(fasta_input, str):
88
  raise gr.Error("Please enter a valid FASTA format sequence.")
89
 
90
  sequence = parse_fasta(fasta_input)
91
 
92
  if len(sequence) < 49:
93
+ raise gr.Error(f"Sequence too short! Needs at least 49 AA. Current: {len(sequence)}.")
94
 
 
95
  predictions_map = {}
 
 
96
  k_indices = [m.start() for m in re.finditer('K', sequence)]
97
 
98
  for k_index in k_indices:
 
99
  start, end = k_index - 24, k_index + 25
 
 
100
  if start >= 0 and end <= len(sequence):
101
  fragment = sequence[start:end]
102
  prediction_result = predict_single_49mer(fragment)
103
  if prediction_result:
 
104
  predictions_map[k_index] = prediction_result
105
 
106
  if not predictions_map:
107
+ return [(sequence, None)], {}, "No valid K sites (with enough context) found."
 
108
 
 
109
  highlight_data = []
110
  last_pos = 0
 
111
  sorted_predictable_indices = sorted(predictions_map.keys())
112
 
113
  for k_index in sorted_predictable_indices:
 
114
  highlight_data.append((sequence[last_pos:k_index], None))
 
115
  highlight_data.append(("K", str(k_index)))
116
  last_pos = k_index + 1
117
 
 
118
  highlight_data.append((sequence[last_pos:], None))
119
 
120
+ initial_info = "Processing complete! Click on the highlighted 'K' site below."
 
121
  return highlight_data, predictions_map, initial_info
122
 
123
+ # --- 6. Gradio事件处理函数 ---
124
  def show_results_for_site(evt: gr.SelectData, state_data):
 
 
 
 
125
  if evt.value:
126
+ k_index_str = evt.value[1] # 获取索引
127
+ if k_index_str is None: # 如果点击了非高亮部分
128
+ return None, "Please click on a highlighted 'K'."
129
+
130
  k_index = int(k_index_str)
 
 
 
 
 
131
  result_dict = state_data.get(k_index)
132
 
133
  if result_dict:
134
+ site_info = f"Prediction results for 'K' at position {k_index + 1}:"
135
  return result_dict, site_info
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ return None, "Please click on the highlighted 'K' site."
 
138
 
139
+ # --- 7. 创建并启动 Gradio 界面 ---
140
 
141
+ # 修复点 1: 补全字符串
142
  fasta_example = """>sp|P05141|ADT2_HUMAN ADP/ATP translocase 2 OS=Homo sapiens OX=9606 GN=SLC25A5 PE=1 SV=7
143
  MTDAAVSFAKDFLAGGVAAAISKTAVAPIERVKLLLQVQHASKQITADKQYKGIIDCVVR
144
+ IPKEQGVLSFWRGNLANVIRYFPTQALNFAFKDKYKQIFLGGVDKRTQFWLYFAGNLASG"""
145
+
146
+ # 修复点 2: 使用 with gr.Blocks() 包裹界面代码,并修复缩进
147
+ with gr.Blocks(css=".predictable-k { color: red; font-weight: bold; }") as demo:
148
  gr.Markdown(
149
  """
150
  # DeepKMulti Model: Multi-label Classifier for Lysine Modifications
151
+ **Supports FASTA format input. Click on highlighted 'K' sites to view predictions.**
152
  """
153
  )
154
+
155
  with gr.Row():
156
+ with gr.Column(scale=2):
157
  fasta_input = gr.Textbox(
158
  lines=10,
159
+ label="Input FASTA Sequence",
160
+ value=fasta_example, # 设置默认值方便测试
161
+ placeholder="Paste FASTA sequence here..."
162
  )
163
  submit_btn = gr.Button("Submit Prediction", variant="primary")
164
+
165
  with gr.Column(scale=3):
166
  gr.Markdown("### Prediction Results")
167
+ info_text = gr.Textbox(label="Status", interactive=False, value="Waiting for input...")
 
168
  predictions_state = gr.State({})
169
+ results_output = gr.Label(num_top_classes=4, label="Probabilities")
170
+
 
 
 
 
 
 
 
 
171
  gr.Markdown("---")
172
+ gr.Markdown("### Visualized Sequence (Click 'K' to view details)")
173
+
174
  highlighted_output = gr.HighlightedText(
175
  label="Sequence Analysis",
176
+ combine_adjacent=False,
177
+ show_legend=False,
178
+ color_map={"K": "red"}
179
+ )
180
+
181
+ # 绑定事件
182
+ submit_btn.click(
183
+ fn=process_fasta_and_predict,
184
+ inputs=[fasta_input],
185
+ outputs=[highlighted_output, predictions_state, info_text]
186
  )
187
 
188
+ highlighted_output.select(
189
+ fn=show_results_for_site,
190
+ inputs=[predictions_state],
191
  outputs=[results_output, info_text]
192
  )
193
 
194
+ if __name__ == "__main__":
195
+ demo.launch(debug=True)