Xianfish9 commited on
Commit
db29e5c
·
verified ·
1 Parent(s): 2009bd2

Update Feature_extraction_algorithms/PSTAAP.py

Browse files
Feature_extraction_algorithms/PSTAAP.py CHANGED
@@ -6,41 +6,30 @@ import scipy.io
6
  def load_precomputed_fr_matrix(mat_file_path: str):
7
  """
8
  从 .mat 文件直接加载预先计算好的 Fr 矩阵并进行缓存。
9
-
10
- Args:
11
- mat_file_path (str): .mat 文件的路径。
12
  """
13
  global _cached_fr_matrix, _expected_length_after_processing
14
 
15
  print(f"正在从 {mat_file_path} 加载预计算的 Fr 矩阵...")
16
 
17
  try:
18
- # 加载 .mat 文件,它会被读成一个字典
19
  mat_data = scipy.io.loadmat(mat_file_path)
20
 
21
- # !!! 关键步骤: 你需要知道 .mat 文件中矩阵的变量名 !!!
22
- # 假设变量名是 'Fr_train'。如果不是,请修改这里的键值。
23
- # 你可以通过 print(mat_data.keys()) 来查看 .mat 文件中所有的变量名。
24
- matrix_key = 'Fr_train' # <--- 如果需要,请修改这个变量名
25
  if matrix_key not in mat_data:
26
  raise KeyError(f"在 {mat_file_path} 中未找到变量名 '{matrix_key}'。 "
27
  f"文件中可用的变量有: {list(mat_data.keys())}")
28
 
29
  _cached_fr_matrix = mat_data[matrix_key]
30
 
31
- # 从加载的矩阵形状推断出序列的期望长度
32
- # Fr 矩阵的形状是 (8000, seq_len - 2)
33
- # 所以,处理后的序列长度 = Fr矩阵的列数 + 2
34
  _expected_length_after_processing = _cached_fr_matrix.shape[1] + 2
35
 
36
  print(f"Fr 矩阵加载并缓存成功。形状: {_cached_fr_matrix.shape}")
37
  print(f"推断出的序列期望长度 (处理后): {_expected_length_after_processing}")
38
 
39
- except FileNotFoundError:
40
- raise FileNotFoundError(f"错误:未能找到 .mat 文件于路径: {mat_file_path}")
41
- except Exception as e:
42
- # 捕获其他可能的错误,例如文件格式问题或键错误
43
- raise RuntimeError(f"加载 .mat 文件时发生错误: {e}")
44
 
45
  # --- 模块级缓存 ---
46
  # 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。
 
6
  def load_precomputed_fr_matrix(mat_file_path: str):
7
  """
8
  从 .mat 文件直接加载预先计算好的 Fr 矩阵并进行缓存。
9
+ ...
 
 
10
  """
11
  global _cached_fr_matrix, _expected_length_after_processing
12
 
13
  print(f"正在从 {mat_file_path} 加载预计算的 Fr 矩阵...")
14
 
15
  try:
 
16
  mat_data = scipy.io.loadmat(mat_file_path)
17
 
18
+ # --- !!! 这里是唯一的修改点 !!! ---
19
+ # 'Fr_train' 修改为 'Fr',以匹配您的 .mat 文件
20
+ matrix_key = 'Fr' # <--- 修改这里!
21
+
22
  if matrix_key not in mat_data:
23
  raise KeyError(f"在 {mat_file_path} 中未找到变量名 '{matrix_key}'。 "
24
  f"文件中可用的变量有: {list(mat_data.keys())}")
25
 
26
  _cached_fr_matrix = mat_data[matrix_key]
27
 
 
 
 
28
  _expected_length_after_processing = _cached_fr_matrix.shape[1] + 2
29
 
30
  print(f"Fr 矩阵加载并缓存成功。形状: {_cached_fr_matrix.shape}")
31
  print(f"推断出的序列期望长度 (处理后): {_expected_length_after_processing}")
32
 
 
 
 
 
 
33
 
34
  # --- 模块级缓存 ---
35
  # 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。