Spaces:
Sleeping
Sleeping
Update Feature_extraction_algorithms/PSTAAP.py
Browse files- Feature_extraction_algorithms/PSTAAP.py +161 -121
Feature_extraction_algorithms/PSTAAP.py
CHANGED
|
@@ -1,121 +1,161 @@
|
|
| 1 |
-
|
| 2 |
-
import numpy as np
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
return
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
# --- 模块级缓存 ---
|
| 6 |
+
# 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。
|
| 7 |
+
_cached_fr_matrix = None
|
| 8 |
+
# 存储预处理后的序列长度,用于后续校验
|
| 9 |
+
_expected_length_after_processing = None
|
| 10 |
+
|
| 11 |
+
# --- 内部辅助函数 (从你的mat计算代码中提取) ---
|
| 12 |
+
|
| 13 |
+
def _read_fasta_sequences(filename: str) -> List[str]:
|
| 14 |
+
"""
|
| 15 |
+
(内部函数) 读取FASTA格式文件,返回序列列表。
|
| 16 |
+
"""
|
| 17 |
+
sequences = []
|
| 18 |
+
current_seq = []
|
| 19 |
+
try:
|
| 20 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
| 21 |
+
for line in f:
|
| 22 |
+
line = line.strip()
|
| 23 |
+
if not line: continue
|
| 24 |
+
if line.startswith('>'):
|
| 25 |
+
if current_seq:
|
| 26 |
+
sequences.append(''.join(current_seq))
|
| 27 |
+
current_seq = []
|
| 28 |
+
else:
|
| 29 |
+
# 确保序列为大写,以匹配氨基酸字典
|
| 30 |
+
current_seq.append(line.upper())
|
| 31 |
+
if current_seq:
|
| 32 |
+
sequences.append(''.join(current_seq))
|
| 33 |
+
except FileNotFoundError:
|
| 34 |
+
raise FileNotFoundError(f"错误:文件 '{filename}' 未找到。")
|
| 35 |
+
return sequences
|
| 36 |
+
|
| 37 |
+
def _process_sequence(sequence: str) -> str:
|
| 38 |
+
"""
|
| 39 |
+
(内部函数) 对单条序列进行预处理:移除正中间的氨基酸。
|
| 40 |
+
这个函数统一了训练和提取时的预处理逻辑。
|
| 41 |
+
"""
|
| 42 |
+
middle_index = (len(sequence) - 1) // 2
|
| 43 |
+
return sequence[:middle_index] + sequence[middle_index + 1:]
|
| 44 |
+
|
| 45 |
+
def _calculate_frequency_matrix(sequences: List[str], aa_map: dict) -> np.ndarray:
|
| 46 |
+
"""
|
| 47 |
+
(内部函数) 为一组序列计算标准化的三肽频率矩阵。
|
| 48 |
+
"""
|
| 49 |
+
if not sequences:
|
| 50 |
+
return np.zeros((20 ** 3, 0))
|
| 51 |
+
|
| 52 |
+
num_sequences = len(sequences)
|
| 53 |
+
seq_length = len(sequences[0])
|
| 54 |
+
freq_matrix = np.zeros((20 ** 3, seq_length - 2))
|
| 55 |
+
|
| 56 |
+
for seq in sequences:
|
| 57 |
+
for j in range(seq_length - 2):
|
| 58 |
+
k1 = aa_map.get(seq[j], -1)
|
| 59 |
+
k2 = aa_map.get(seq[j + 1], -1)
|
| 60 |
+
k3 = aa_map.get(seq[j + 2], -1)
|
| 61 |
+
|
| 62 |
+
if -1 not in {k1, k2, k3}:
|
| 63 |
+
index = 400 * k1 + 20 * k2 + k3
|
| 64 |
+
freq_matrix[index, j] += 1
|
| 65 |
+
|
| 66 |
+
return freq_matrix / num_sequences if num_sequences > 0 else freq_matrix
|
| 67 |
+
|
| 68 |
+
# --- 公共API函数 ---
|
| 69 |
+
|
| 70 |
+
def initialize_fr_matrix(fasta_files: List[str]):
|
| 71 |
+
"""
|
| 72 |
+
根据输入的FASTA文件列表计算并缓存Fr矩阵。
|
| 73 |
+
这是使用 PSTAAP_feature 前必须调用的初始化函数。
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
fasta_files (List[str]): FASTA文件的路径列表。每个文件被视为一个独立的类别。
|
| 77 |
+
"""
|
| 78 |
+
global _cached_fr_matrix, _expected_length_after_processing
|
| 79 |
+
|
| 80 |
+
print("正在初始化PSTAAP特征提取器...")
|
| 81 |
+
AA_MAP = {char: i for i, char in enumerate('ACDEFGHIKLMNPQRSTVWY')}
|
| 82 |
+
|
| 83 |
+
# 1. 读取并验证所有序列
|
| 84 |
+
all_sequences_by_file = [_read_fasta_sequences(f) for f in fasta_files]
|
| 85 |
+
if not all_sequences_by_file or not any(all_sequences_by_file):
|
| 86 |
+
raise ValueError("输入的文件列表为空或所有文件均不包含序列。")
|
| 87 |
+
|
| 88 |
+
first_len = len(all_sequences_by_file[0][0])
|
| 89 |
+
for i, seqs in enumerate(all_sequences_by_file):
|
| 90 |
+
if not all(len(s) == first_len for s in seqs):
|
| 91 |
+
raise ValueError(f"文件 '{fasta_files[i]}' 中的序列长度不一致或与其他文件不同。")
|
| 92 |
+
|
| 93 |
+
# 2. 预处理所有序列
|
| 94 |
+
processed_sequences_list = [[_process_sequence(seq) for seq in seqs] for seqs in all_sequences_by_file]
|
| 95 |
+
_expected_length_after_processing = len(processed_sequences_list[0][0])
|
| 96 |
+
|
| 97 |
+
# 3. 计算 Fr 矩阵
|
| 98 |
+
f_matrices, ff_matrices = [], []
|
| 99 |
+
num_files = len(processed_sequences_list)
|
| 100 |
+
|
| 101 |
+
for i in range(num_files):
|
| 102 |
+
current_seqs = processed_sequences_list[i]
|
| 103 |
+
other_seqs_combined = [seq for idx, lst in enumerate(processed_sequences_list) if idx != i for seq in lst]
|
| 104 |
+
f_matrices.append(_calculate_frequency_matrix(current_seqs, AA_MAP))
|
| 105 |
+
ff_matrices.append(_calculate_frequency_matrix(other_seqs_combined, AA_MAP))
|
| 106 |
+
|
| 107 |
+
F_avg = np.mean(f_matrices, axis=0)
|
| 108 |
+
FF_avg = np.mean(ff_matrices, axis=0)
|
| 109 |
+
|
| 110 |
+
# 4. 缓存计算结果
|
| 111 |
+
_cached_fr_matrix = F_avg - FF_avg
|
| 112 |
+
print(f"Fr 矩阵计算完成并已缓存。形状: {_cached_fr_matrix.shape}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def PSTAAP_feature(protein_sequences: List[str]) -> np.ndarray:
|
| 116 |
+
"""
|
| 117 |
+
从蛋白质序列中提取PSTAAP特征。
|
| 118 |
+
|
| 119 |
+
在使用此函数之前,必须先调用 initialize_fr_matrix() 来计算并缓存Fr矩阵。
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
protein_sequences (List[str]): 需要提取特征的蛋白质序列列表。
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
np.ndarray: PSTAAP特征矩阵,形状为 (序列数, 特征维度)。
|
| 126 |
+
"""
|
| 127 |
+
if _cached_fr_matrix is None:
|
| 128 |
+
raise RuntimeError(
|
| 129 |
+
"Fr矩阵尚未初始化。请在使用此函数前,先调用 initialize_fr_matrix(fasta_files) 函数。"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# 统一的预处理步骤
|
| 133 |
+
processed_sequences = [_process_sequence(seq) for seq in protein_sequences]
|
| 134 |
+
|
| 135 |
+
if len(processed_sequences[0]) != _expected_length_after_processing:
|
| 136 |
+
raise ValueError(f"输入序列处理后的长度 ({len(processed_sequences[0])}) 与训练时"
|
| 137 |
+
f"的期望长度 ({_expected_length_after_processing}) 不匹配。")
|
| 138 |
+
|
| 139 |
+
AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
|
| 140 |
+
num_seqs = len(processed_sequences)
|
| 141 |
+
feature_dim = len(processed_sequences[0]) - 2
|
| 142 |
+
PSTAAP = np.zeros((num_seqs, feature_dim))
|
| 143 |
+
|
| 144 |
+
for i in range(num_seqs):
|
| 145 |
+
for j in range(feature_dim):
|
| 146 |
+
t1 = processed_sequences[i][j]
|
| 147 |
+
t2 = processed_sequences[i][j+1]
|
| 148 |
+
t3 = processed_sequences[i][j+2]
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
position1 = AA.index(t1)
|
| 152 |
+
position2 = AA.index(t2)
|
| 153 |
+
position3 = AA.index(t3)
|
| 154 |
+
index = 400 * position1 + 20 * position2 + position3
|
| 155 |
+
PSTAAP[i][j] = _cached_fr_matrix[index][j]
|
| 156 |
+
except ValueError:
|
| 157 |
+
# 如果遇到非标准氨基酸,可以选择跳过、设为0或报错
|
| 158 |
+
# 这里我们默认该特征值为0
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
return PSTAAP
|