import os import numpy as np from typing import List import scipy.io def load_precomputed_fr_matrix(mat_file_path: str): """ 从 .mat 文件直接加载预先计算好的 Fr 矩阵并进行缓存。 ... """ global _cached_fr_matrix, _expected_length_after_processing print(f"正在从 {mat_file_path} 加载预计算的 Fr 矩阵...") try: mat_data = scipy.io.loadmat(mat_file_path) matrix_key = 'Fr' # 修改后的变量名 if matrix_key not in mat_data: raise KeyError(f"在 {mat_file_path} 中未找到变量名 '{matrix_key}'。 " f"文件中可用的变量有: {list(mat_data.keys())}") _cached_fr_matrix = mat_data[matrix_key] _expected_length_after_processing = _cached_fr_matrix.shape[1] + 2 print(f"Fr 矩阵加载并缓存成功。形状: {_cached_fr_matrix.shape}") print(f"推断出的序列期望长度 (处理后): {_expected_length_after_processing}") except Exception as e: print(f"❌ 加载 Fr 矩阵失败: {e}") raise # --- 模块级缓存 --- # 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。 _cached_fr_matrix = None # 存储预处理后的序列长度,用于后续校验 _expected_length_after_processing = None # --- 内部辅助函数 (从你的mat计算代码中提取) --- def _read_fasta_sequences(filename: str) -> List[str]: """ (内部函数) 读取FASTA格式文件,返回序列列表。 """ sequences = [] current_seq = [] try: with open(filename, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue if line.startswith('>'): if current_seq: sequences.append(''.join(current_seq)) current_seq = [] else: # 确保序列为大写,以匹配氨基酸字典 current_seq.append(line.upper()) if current_seq: sequences.append(''.join(current_seq)) except FileNotFoundError: raise FileNotFoundError(f"错误:文件 '{filename}' 未找到。") return sequences def _process_sequence(sequence: str) -> str: """ (内部函数) 对单条序列进行预处理:移除正中间的氨基酸。 这个函数统一了训练和提取时的预处理逻辑。 """ middle_index = (len(sequence) - 1) // 2 return sequence[:middle_index] + sequence[middle_index + 1:] def _calculate_frequency_matrix(sequences: List[str], aa_map: dict) -> np.ndarray: """ (内部函数) 为一组序列计算标准化的三肽频率矩阵。 """ if not sequences: return np.zeros((20 ** 3, 0)) num_sequences = len(sequences) seq_length = len(sequences[0]) freq_matrix = np.zeros((20 ** 3, seq_length - 2)) for seq in sequences: for j in range(seq_length - 2): k1 = aa_map.get(seq[j], -1) k2 = aa_map.get(seq[j + 1], -1) k3 = aa_map.get(seq[j + 2], -1) if -1 not in {k1, k2, k3}: index = 400 * k1 + 20 * k2 + k3 freq_matrix[index, j] += 1 return freq_matrix / num_sequences if num_sequences > 0 else freq_matrix # --- 公共API函数 --- def initialize_fr_matrix(fasta_files: List[str]): """ 根据输入的FASTA文件列表计算并缓存Fr矩阵。 这是使用 PSTAAP_feature 前必须调用的初始化函数。 Args: fasta_files (List[str]): FASTA文件的路径列表。每个文件被视为一个独立的类别。 """ global _cached_fr_matrix, _expected_length_after_processing print("正在初始化PSTAAP特征提取器...") AA_MAP = {char: i for i, char in enumerate('ACDEFGHIKLMNPQRSTVWY')} # 1. 读取并验证所有序列 all_sequences_by_file = [_read_fasta_sequences(f) for f in fasta_files] if not all_sequences_by_file or not any(all_sequences_by_file): raise ValueError("输入的文件列表为空或所有文件均不包含序列。") first_len = len(all_sequences_by_file[0][0]) for i, seqs in enumerate(all_sequences_by_file): if not all(len(s) == first_len for s in seqs): raise ValueError(f"文件 '{fasta_files[i]}' 中的序列长度不一致或与其他文件不同。") # 2. 预处理所有序列 processed_sequences_list = [[_process_sequence(seq) for seq in seqs] for seqs in all_sequences_by_file] _expected_length_after_processing = len(processed_sequences_list[0][0]) # 3. 计算 Fr 矩阵 f_matrices, ff_matrices = [], [] num_files = len(processed_sequences_list) for i in range(num_files): current_seqs = processed_sequences_list[i] other_seqs_combined = [seq for idx, lst in enumerate(processed_sequences_list) if idx != i for seq in lst] f_matrices.append(_calculate_frequency_matrix(current_seqs, AA_MAP)) ff_matrices.append(_calculate_frequency_matrix(other_seqs_combined, AA_MAP)) F_avg = np.mean(f_matrices, axis=0) FF_avg = np.mean(ff_matrices, axis=0) # 4. 缓存计算结果 _cached_fr_matrix = F_avg - FF_avg print(f"Fr 矩阵计算完成并已缓存。形状: {_cached_fr_matrix.shape}") def PSTAAP_feature(protein_sequences: List[str]) -> np.ndarray: """ 从蛋白质序列中提取PSTAAP特征。 在使用此函数之前,必须先调用 initialize_fr_matrix() 来计算并缓存Fr矩阵。 Args: protein_sequences (List[str]): 需要提取特征的蛋白质序列列表。 Returns: np.ndarray: PSTAAP特征矩阵,形状为 (序列数, 特征维度)。 """ if _cached_fr_matrix is None: raise RuntimeError( "Fr矩阵尚未初始化。请在使用此函数前,先调用 initialize_fr_matrix(fasta_files) 函数。" ) # 统一的预处理步骤 processed_sequences = [_process_sequence(seq) for seq in protein_sequences] if len(processed_sequences[0]) != _expected_length_after_processing: raise ValueError(f"输入序列处理后的长度 ({len(processed_sequences[0])}) 与训练时" f"的期望长度 ({_expected_length_after_processing}) 不匹配。") AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'] num_seqs = len(processed_sequences) feature_dim = len(processed_sequences[0]) - 2 PSTAAP = np.zeros((num_seqs, feature_dim)) for i in range(num_seqs): for j in range(feature_dim): t1 = processed_sequences[i][j] t2 = processed_sequences[i][j+1] t3 = processed_sequences[i][j+2] try: position1 = AA.index(t1) position2 = AA.index(t2) position3 = AA.index(t3) index = 400 * position1 + 20 * position2 + position3 PSTAAP[i][j] = _cached_fr_matrix[index][j] except ValueError: # 如果遇到非标准氨基酸,可以选择跳过、设为0或报错 # 这里我们默认该特征值为0 pass return PSTAAP