Xianfish9 commited on
Commit
2009bd2
·
verified ·
1 Parent(s): 279b1d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -30
app.py CHANGED
@@ -1,27 +1,23 @@
1
- #Adam_lr7e-05_weightdecay0.0001_epochs3480.pth
2
  import gradio as gr
3
  import torch
4
  import numpy as np
5
  import os
6
- import re
7
 
8
  # --- 依赖导入 ---
9
- # 从你的代码库中导入必要的模块
10
- # 这要求你的文件结构是正确的 (例如: /Feature_extraction_algorithms/PSTAAP.py)
11
  from model import CAFN
12
- from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature
 
13
  from Feature_extraction_algorithms.Physicochemical import PC_feature
14
 
15
  # --- 1. 模型加载 ---
16
- # 确保 'your_model_name.pth' 和你上传的文件名完全一致
17
- MODEL_PATH = "Adam_lr7e-05_weightdecay0.0001_epochs3480.pth" # <--- 在这里修改成你的 .pth 文件名
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
 
20
  def load_model(model_path):
21
  model = CAFN().to(device)
22
  if os.path.exists(model_path):
23
  model.load_state_dict(torch.load(model_path, map_location=device))
24
- model.eval() # 设置为评估模式
25
  print("模型加载成功!")
26
  return model
27
  else:
@@ -30,29 +26,21 @@ def load_model(model_path):
30
 
31
  model = load_model(MODEL_PATH)
32
 
33
- # --- 2. 特征提取函数 ---
34
- # 这个函数直接改编自你的 dataProcess.py
35
- # --- 2. 特征提取函数 (已修正) ---
36
- def extract_features_from_seq(sequence_list): # <--- 不再需要 test_PSTAAP 参数
37
- """
38
- 接收一个包含序列的列表,返回模型所需的两个特征张量 x1 和 x2。
39
- """
40
- # 提取 PC_feature (对应 x2)
41
- data2 = PC_feature(sequence_list)
42
-
43
- # 提取 PSTAAP_feature (对应 x1)
44
- N = len(sequence_list)
45
- empty_list_array = [[] for _ in range(N)]
46
- data = np.array(empty_list_array, dtype=object)
47
-
48
- # --- 这是修改的关键点 ---
49
- # 只传递一个参数给 PSTAAP_feature
50
- feature = PSTAAP_feature(sequence_list)
51
-
52
- data = np.hstack((data, feature))
53
 
54
- # 返回 NumPy 数组
55
- return data.astype(np.float32), data2.astype(np.float32)
 
 
 
 
 
 
 
56
 
57
  # --- 3. 核心预测函数 (也需要微调) ---
58
  def predict(sequence_input):
 
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
  import os
 
5
 
6
  # --- 依赖导入 ---
 
 
7
  from model import CAFN
8
+ # --- 修改点 1: 导入 load_precomputed_fr_matrix 而不是 initialize_fr_matrix ---
9
+ from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature, load_precomputed_fr_matrix
10
  from Feature_extraction_algorithms.Physicochemical import PC_feature
11
 
12
  # --- 1. 模型加载 ---
13
+ MODEL_PATH = "Adam_lr7e-05_weightdecay0.0001_epochs3480.pth"
 
14
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
 
16
  def load_model(model_path):
17
  model = CAFN().to(device)
18
  if os.path.exists(model_path):
19
  model.load_state_dict(torch.load(model_path, map_location=device))
20
+ model.eval()
21
  print("模型加载成功!")
22
  return model
23
  else:
 
26
 
27
  model = load_model(MODEL_PATH)
28
 
29
+ # --- 修改点 2: PSTAAP 初始化方式改变 ---
30
+ # 直接加载预计算的 .mat 文件,高效且无需原始数据
31
+ try:
32
+ # --- 请将 'Fr_train.mat' 放在与 app.py 相同的目录下,或者提供完整路径 ---
33
+ FR_MATRIX_PATH = 'Fr_train.mat'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ if not os.path.exists(FR_MATRIX_PATH):
36
+ raise FileNotFoundError(f"PSTAAP初始化失败:找不到矩阵文件 {FR_MATRIX_PATH}")
37
+
38
+ # 调用新的加载函数
39
+ load_precomputed_fr_matrix(FR_MATRIX_PATH)
40
+
41
+ except Exception as e:
42
+ print(f"PSTAAP 初始化过程中发生严重错误: {e}")
43
+ model = None
44
 
45
  # --- 3. 核心预测函数 (也需要微调) ---
46
  def predict(sequence_input):