Xianfish9 commited on
Commit
2b065a1
·
verified ·
1 Parent(s): 383884c

Update Feature_extraction_algorithms/PSTAAP.py

Browse files
Feature_extraction_algorithms/PSTAAP.py CHANGED
@@ -1,6 +1,46 @@
1
  import os
2
  import numpy as np
3
  from typing import List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # --- 模块级缓存 ---
6
  # 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。
 
1
  import os
2
  import numpy as np
3
  from typing import List
4
+ import scipy.io
5
+
6
+ def load_precomputed_fr_matrix(mat_file_path: str):
7
+ """
8
+ 从 .mat 文件直接加载预先计算好的 Fr 矩阵并进行缓存。
9
+
10
+ Args:
11
+ mat_file_path (str): .mat 文件的路径。
12
+ """
13
+ global _cached_fr_matrix, _expected_length_after_processing
14
+
15
+ print(f"正在从 {mat_file_path} 加载预计算的 Fr 矩阵...")
16
+
17
+ try:
18
+ # 加载 .mat 文件,它会被读成一个字典
19
+ mat_data = scipy.io.loadmat(mat_file_path)
20
+
21
+ # !!! 关键步骤: 你需要知道 .mat 文件中矩阵的变量名 !!!
22
+ # 假设变量名是 'Fr_train'。如果不是,请修改这里的键值。
23
+ # 你可以通过 print(mat_data.keys()) 来查看 .mat 文件中所有的变量名。
24
+ matrix_key = 'Fr_train' # <--- 如果需要,请修改这个变量名
25
+ if matrix_key not in mat_data:
26
+ raise KeyError(f"在 {mat_file_path} 中未找到变量名 '{matrix_key}'。 "
27
+ f"文件中可用的变量有: {list(mat_data.keys())}")
28
+
29
+ _cached_fr_matrix = mat_data[matrix_key]
30
+
31
+ # 从加载的矩阵形状推断出序列的期望长度
32
+ # Fr 矩阵的形状是 (8000, seq_len - 2)
33
+ # 所以,处理后的序列长度 = Fr矩阵的列数 + 2
34
+ _expected_length_after_processing = _cached_fr_matrix.shape[1] + 2
35
+
36
+ print(f"Fr 矩阵加载并缓存成功。形状: {_cached_fr_matrix.shape}")
37
+ print(f"推断出的序列期望长度 (处理后): {_expected_length_after_processing}")
38
+
39
+ except FileNotFoundError:
40
+ raise FileNotFoundError(f"错误:未能找到 .mat 文件于路径: {mat_file_path}")
41
+ except Exception as e:
42
+ # 捕获其他可能的错误,例如文件格式问题或键错误
43
+ raise RuntimeError(f"加载 .mat 文件时发生错误: {e}")
44
 
45
  # --- 模块级缓存 ---
46
  # 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。