File size: 7,410 Bytes
5945074
 
 
2b065a1
 
 
 
 
db29e5c
2b065a1
 
 
 
 
 
 
 
f808770
db29e5c
2b065a1
 
 
 
 
 
 
 
 
 
f808770
 
 
 
 
5945074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import numpy as np
from typing import List
import scipy.io

def load_precomputed_fr_matrix(mat_file_path: str):
    """
    从 .mat 文件直接加载预先计算好的 Fr 矩阵并进行缓存。
    ...
    """
    global _cached_fr_matrix, _expected_length_after_processing
    
    print(f"正在从 {mat_file_path} 加载预计算的 Fr 矩阵...")
    
    try:
        mat_data = scipy.io.loadmat(mat_file_path)
        
        matrix_key = 'Fr'  # 修改后的变量名
        
        if matrix_key not in mat_data:
            raise KeyError(f"在 {mat_file_path} 中未找到变量名 '{matrix_key}'。 "
                           f"文件中可用的变量有: {list(mat_data.keys())}")
            
        _cached_fr_matrix = mat_data[matrix_key]
        _expected_length_after_processing = _cached_fr_matrix.shape[1] + 2

        print(f"Fr 矩阵加载并缓存成功。形状: {_cached_fr_matrix.shape}")
        print(f"推断出的序列期望长度 (处理后): {_expected_length_after_processing}")

    except Exception as e:
        print(f"❌ 加载 Fr 矩阵失败: {e}")
        raise



# --- 模块级缓存 ---
# 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。
_cached_fr_matrix = None
# 存储预处理后的序列长度,用于后续校验
_expected_length_after_processing = None

# --- 内部辅助函数 (从你的mat计算代码中提取) ---

def _read_fasta_sequences(filename: str) -> List[str]:
    """
    (内部函数) 读取FASTA格式文件,返回序列列表。
    """
    sequences = []
    current_seq = []
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line: continue
                if line.startswith('>'):
                    if current_seq:
                        sequences.append(''.join(current_seq))
                    current_seq = []
                else:
                    # 确保序列为大写,以匹配氨基酸字典
                    current_seq.append(line.upper()) 
            if current_seq:
                sequences.append(''.join(current_seq))
    except FileNotFoundError:
        raise FileNotFoundError(f"错误:文件 '{filename}' 未找到。")
    return sequences

def _process_sequence(sequence: str) -> str:
    """
    (内部函数) 对单条序列进行预处理:移除正中间的氨基酸。
    这个函数统一了训练和提取时的预处理逻辑。
    """
    middle_index = (len(sequence) - 1) // 2
    return sequence[:middle_index] + sequence[middle_index + 1:]

def _calculate_frequency_matrix(sequences: List[str], aa_map: dict) -> np.ndarray:
    """
    (内部函数) 为一组序列计算标准化的三肽频率矩阵。
    """
    if not sequences:
        return np.zeros((20 ** 3, 0))

    num_sequences = len(sequences)
    seq_length = len(sequences[0])
    freq_matrix = np.zeros((20 ** 3, seq_length - 2))

    for seq in sequences:
        for j in range(seq_length - 2):
            k1 = aa_map.get(seq[j], -1)
            k2 = aa_map.get(seq[j + 1], -1)
            k3 = aa_map.get(seq[j + 2], -1)
            
            if -1 not in {k1, k2, k3}:
                index = 400 * k1 + 20 * k2 + k3
                freq_matrix[index, j] += 1

    return freq_matrix / num_sequences if num_sequences > 0 else freq_matrix

# --- 公共API函数 ---

def initialize_fr_matrix(fasta_files: List[str]):
    """
    根据输入的FASTA文件列表计算并缓存Fr矩阵。
    这是使用 PSTAAP_feature 前必须调用的初始化函数。

    Args:
        fasta_files (List[str]): FASTA文件的路径列表。每个文件被视为一个独立的类别。
    """
    global _cached_fr_matrix, _expected_length_after_processing
    
    print("正在初始化PSTAAP特征提取器...")
    AA_MAP = {char: i for i, char in enumerate('ACDEFGHIKLMNPQRSTVWY')}
    
    # 1. 读取并验证所有序列
    all_sequences_by_file = [_read_fasta_sequences(f) for f in fasta_files]
    if not all_sequences_by_file or not any(all_sequences_by_file):
        raise ValueError("输入的文件列表为空或所有文件均不包含序列。")
    
    first_len = len(all_sequences_by_file[0][0])
    for i, seqs in enumerate(all_sequences_by_file):
        if not all(len(s) == first_len for s in seqs):
            raise ValueError(f"文件 '{fasta_files[i]}' 中的序列长度不一致或与其他文件不同。")

    # 2. 预处理所有序列
    processed_sequences_list = [[_process_sequence(seq) for seq in seqs] for seqs in all_sequences_by_file]
    _expected_length_after_processing = len(processed_sequences_list[0][0])
    
    # 3. 计算 Fr 矩阵
    f_matrices, ff_matrices = [], []
    num_files = len(processed_sequences_list)

    for i in range(num_files):
        current_seqs = processed_sequences_list[i]
        other_seqs_combined = [seq for idx, lst in enumerate(processed_sequences_list) if idx != i for seq in lst]
        f_matrices.append(_calculate_frequency_matrix(current_seqs, AA_MAP))
        ff_matrices.append(_calculate_frequency_matrix(other_seqs_combined, AA_MAP))
        
    F_avg = np.mean(f_matrices, axis=0)
    FF_avg = np.mean(ff_matrices, axis=0)
    
    # 4. 缓存计算结果
    _cached_fr_matrix = F_avg - FF_avg
    print(f"Fr 矩阵计算完成并已缓存。形状: {_cached_fr_matrix.shape}")


def PSTAAP_feature(protein_sequences: List[str]) -> np.ndarray:
    """
    从蛋白质序列中提取PSTAAP特征。

    在使用此函数之前,必须先调用 initialize_fr_matrix() 来计算并缓存Fr矩阵。
    
    Args:
        protein_sequences (List[str]): 需要提取特征的蛋白质序列列表。

    Returns:
        np.ndarray: PSTAAP特征矩阵,形状为 (序列数, 特征维度)。
    """
    if _cached_fr_matrix is None:
        raise RuntimeError(
            "Fr矩阵尚未初始化。请在使用此函数前,先调用 initialize_fr_matrix(fasta_files) 函数。"
        )

    # 统一的预处理步骤
    processed_sequences = [_process_sequence(seq) for seq in protein_sequences]
    
    if len(processed_sequences[0]) != _expected_length_after_processing:
        raise ValueError(f"输入序列处理后的长度 ({len(processed_sequences[0])}) 与训练时"
                         f"的期望长度 ({_expected_length_after_processing}) 不匹配。")

    AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
    num_seqs = len(processed_sequences)
    feature_dim = len(processed_sequences[0]) - 2
    PSTAAP = np.zeros((num_seqs, feature_dim))

    for i in range(num_seqs):
        for j in range(feature_dim):
            t1 = processed_sequences[i][j]
            t2 = processed_sequences[i][j+1]
            t3 = processed_sequences[i][j+2]
            
            try:
                position1 = AA.index(t1)
                position2 = AA.index(t2)
                position3 = AA.index(t3)
                index = 400 * position1 + 20 * position2 + position3
                PSTAAP[i][j] = _cached_fr_matrix[index][j]
            except ValueError:
                # 如果遇到非标准氨基酸,可以选择跳过、设为0或报错
                # 这里我们默认该特征值为0
                pass 
                
    return PSTAAP