Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from typing import List | |
| import scipy.io | |
| def load_precomputed_fr_matrix(mat_file_path: str): | |
| """ | |
| 从 .mat 文件直接加载预先计算好的 Fr 矩阵并进行缓存。 | |
| ... | |
| """ | |
| global _cached_fr_matrix, _expected_length_after_processing | |
| print(f"正在从 {mat_file_path} 加载预计算的 Fr 矩阵...") | |
| try: | |
| mat_data = scipy.io.loadmat(mat_file_path) | |
| matrix_key = 'Fr' # 修改后的变量名 | |
| if matrix_key not in mat_data: | |
| raise KeyError(f"在 {mat_file_path} 中未找到变量名 '{matrix_key}'。 " | |
| f"文件中可用的变量有: {list(mat_data.keys())}") | |
| _cached_fr_matrix = mat_data[matrix_key] | |
| _expected_length_after_processing = _cached_fr_matrix.shape[1] + 2 | |
| print(f"Fr 矩阵加载并缓存成功。形状: {_cached_fr_matrix.shape}") | |
| print(f"推断出的序列期望长度 (处理后): {_expected_length_after_processing}") | |
| except Exception as e: | |
| print(f"❌ 加载 Fr 矩阵失败: {e}") | |
| raise | |
| # --- 模块级缓存 --- | |
| # 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。 | |
| _cached_fr_matrix = None | |
| # 存储预处理后的序列长度,用于后续校验 | |
| _expected_length_after_processing = None | |
| # --- 内部辅助函数 (从你的mat计算代码中提取) --- | |
| def _read_fasta_sequences(filename: str) -> List[str]: | |
| """ | |
| (内部函数) 读取FASTA格式文件,返回序列列表。 | |
| """ | |
| sequences = [] | |
| current_seq = [] | |
| try: | |
| with open(filename, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: continue | |
| if line.startswith('>'): | |
| if current_seq: | |
| sequences.append(''.join(current_seq)) | |
| current_seq = [] | |
| else: | |
| # 确保序列为大写,以匹配氨基酸字典 | |
| current_seq.append(line.upper()) | |
| if current_seq: | |
| sequences.append(''.join(current_seq)) | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"错误:文件 '{filename}' 未找到。") | |
| return sequences | |
| def _process_sequence(sequence: str) -> str: | |
| """ | |
| (内部函数) 对单条序列进行预处理:移除正中间的氨基酸。 | |
| 这个函数统一了训练和提取时的预处理逻辑。 | |
| """ | |
| middle_index = (len(sequence) - 1) // 2 | |
| return sequence[:middle_index] + sequence[middle_index + 1:] | |
| def _calculate_frequency_matrix(sequences: List[str], aa_map: dict) -> np.ndarray: | |
| """ | |
| (内部函数) 为一组序列计算标准化的三肽频率矩阵。 | |
| """ | |
| if not sequences: | |
| return np.zeros((20 ** 3, 0)) | |
| num_sequences = len(sequences) | |
| seq_length = len(sequences[0]) | |
| freq_matrix = np.zeros((20 ** 3, seq_length - 2)) | |
| for seq in sequences: | |
| for j in range(seq_length - 2): | |
| k1 = aa_map.get(seq[j], -1) | |
| k2 = aa_map.get(seq[j + 1], -1) | |
| k3 = aa_map.get(seq[j + 2], -1) | |
| if -1 not in {k1, k2, k3}: | |
| index = 400 * k1 + 20 * k2 + k3 | |
| freq_matrix[index, j] += 1 | |
| return freq_matrix / num_sequences if num_sequences > 0 else freq_matrix | |
| # --- 公共API函数 --- | |
| def initialize_fr_matrix(fasta_files: List[str]): | |
| """ | |
| 根据输入的FASTA文件列表计算并缓存Fr矩阵。 | |
| 这是使用 PSTAAP_feature 前必须调用的初始化函数。 | |
| Args: | |
| fasta_files (List[str]): FASTA文件的路径列表。每个文件被视为一个独立的类别。 | |
| """ | |
| global _cached_fr_matrix, _expected_length_after_processing | |
| print("正在初始化PSTAAP特征提取器...") | |
| AA_MAP = {char: i for i, char in enumerate('ACDEFGHIKLMNPQRSTVWY')} | |
| # 1. 读取并验证所有序列 | |
| all_sequences_by_file = [_read_fasta_sequences(f) for f in fasta_files] | |
| if not all_sequences_by_file or not any(all_sequences_by_file): | |
| raise ValueError("输入的文件列表为空或所有文件均不包含序列。") | |
| first_len = len(all_sequences_by_file[0][0]) | |
| for i, seqs in enumerate(all_sequences_by_file): | |
| if not all(len(s) == first_len for s in seqs): | |
| raise ValueError(f"文件 '{fasta_files[i]}' 中的序列长度不一致或与其他文件不同。") | |
| # 2. 预处理所有序列 | |
| processed_sequences_list = [[_process_sequence(seq) for seq in seqs] for seqs in all_sequences_by_file] | |
| _expected_length_after_processing = len(processed_sequences_list[0][0]) | |
| # 3. 计算 Fr 矩阵 | |
| f_matrices, ff_matrices = [], [] | |
| num_files = len(processed_sequences_list) | |
| for i in range(num_files): | |
| current_seqs = processed_sequences_list[i] | |
| other_seqs_combined = [seq for idx, lst in enumerate(processed_sequences_list) if idx != i for seq in lst] | |
| f_matrices.append(_calculate_frequency_matrix(current_seqs, AA_MAP)) | |
| ff_matrices.append(_calculate_frequency_matrix(other_seqs_combined, AA_MAP)) | |
| F_avg = np.mean(f_matrices, axis=0) | |
| FF_avg = np.mean(ff_matrices, axis=0) | |
| # 4. 缓存计算结果 | |
| _cached_fr_matrix = F_avg - FF_avg | |
| print(f"Fr 矩阵计算完成并已缓存。形状: {_cached_fr_matrix.shape}") | |
| def PSTAAP_feature(protein_sequences: List[str]) -> np.ndarray: | |
| """ | |
| 从蛋白质序列中提取PSTAAP特征。 | |
| 在使用此函数之前,必须先调用 initialize_fr_matrix() 来计算并缓存Fr矩阵。 | |
| Args: | |
| protein_sequences (List[str]): 需要提取特征的蛋白质序列列表。 | |
| Returns: | |
| np.ndarray: PSTAAP特征矩阵,形状为 (序列数, 特征维度)。 | |
| """ | |
| if _cached_fr_matrix is None: | |
| raise RuntimeError( | |
| "Fr矩阵尚未初始化。请在使用此函数前,先调用 initialize_fr_matrix(fasta_files) 函数。" | |
| ) | |
| # 统一的预处理步骤 | |
| processed_sequences = [_process_sequence(seq) for seq in protein_sequences] | |
| if len(processed_sequences[0]) != _expected_length_after_processing: | |
| raise ValueError(f"输入序列处理后的长度 ({len(processed_sequences[0])}) 与训练时" | |
| f"的期望长度 ({_expected_length_after_processing}) 不匹配。") | |
| AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'] | |
| num_seqs = len(processed_sequences) | |
| feature_dim = len(processed_sequences[0]) - 2 | |
| PSTAAP = np.zeros((num_seqs, feature_dim)) | |
| for i in range(num_seqs): | |
| for j in range(feature_dim): | |
| t1 = processed_sequences[i][j] | |
| t2 = processed_sequences[i][j+1] | |
| t3 = processed_sequences[i][j+2] | |
| try: | |
| position1 = AA.index(t1) | |
| position2 = AA.index(t2) | |
| position3 = AA.index(t3) | |
| index = 400 * position1 + 20 * position2 + position3 | |
| PSTAAP[i][j] = _cached_fr_matrix[index][j] | |
| except ValueError: | |
| # 如果遇到非标准氨基酸,可以选择跳过、设为0或报错 | |
| # 这里我们默认该特征值为0 | |
| pass | |
| return PSTAAP |