File size: 11,903 Bytes
78f28d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import torch
import torch.nn as nn
from typing import List, Optional
import numpy as np
from sklearn.preprocessing import StandardScaler

# class PhysicochemicalEncoder(nn.Module):
#     """Amino Acid Physicochemical Property Encoder (AAindex版本)"""
    
#     def __init__(self, device, use_aaindex=True, selected_features=None):
#         """
#         Args:
#             device: torch device
#             use_aaindex: 是否使用AAindex特征(True)还是简单的5特征(False)
#             selected_features: 选择使用哪些AAindex特征(None=使用全部)
#         """
#         super().__init__()
#         self.device = device
#         self.use_aaindex = use_aaindex
        
#         if use_aaindex:
#             # 从AAindex加载特征
#             self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features)
#             self.n_features = len(list(self.aa_properties['A'].values()))
#             print(f"✓ Loaded {self.n_features} AAindex features")
#         else:
#             # 使用简单的5特征
#             self.aa_properties = self._get_basic_properties()
#             self.n_features = 5
#             print(f"✓ Using {self.n_features} basic features")
        
#         # 标准化(重要!不同特征范围差异大)
#         self.scaler = self._fit_scaler()
    
#     def _load_aaindex_features(self, selected_features=None):
#         """从AAindex加载特征"""
#         try:
#             # 尝试导入生成的文件
#             from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS
            
#             if selected_features is not None:
#                 # 只选择指定的特征
#                 filtered_props = {}
#                 for aa, props in AA_PROPERTIES_AAINDEX.items():
#                     filtered_props[aa] = {k: v for k, v in props.items() 
#                                          if k in selected_features}
#                 return filtered_props, selected_features
#             else:
#                 # 使用所有特征
#                 feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys())
#                 return AA_PROPERTIES_AAINDEX, feature_names
                
#         except ImportError:
#             print("⚠ Warning: aa_properties_aaindex.py not found!")
#             print("   Falling back to basic 5 features")
#             print("   Run 'python aaindex_downloader.py' to download AAindex features")
#             return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma']
    
#     def _get_basic_properties(self):
#         """基础的5特征(作为fallback)"""
#         return {
#             'A': [1.8, 0.0, 88.6, 0.36, 0.0],
#             'C': [2.5, 0.0, 108.5, 0.35, 0.0],
#             'D': [-3.5, -1.0, 111.1, 0.51, 0.0],
#             'E': [-3.5, -1.0, 138.4, 0.50, 0.0],
#             'F': [2.8, 0.0, 189.9, 0.31, 1.0],
#             'G': [-0.4, 0.0, 60.1, 0.54, 0.0],
#             'H': [-3.2, 0.5, 153.2, 0.32, 0.5],
#             'I': [4.5, 0.0, 166.7, 0.46, 0.0],
#             'K': [-3.9, 1.0, 168.6, 0.47, 0.0],
#             'L': [3.8, 0.0, 166.7, 0.37, 0.0],
#             'M': [1.9, 0.0, 162.9, 0.30, 0.0],
#             'N': [-3.5, 0.0, 114.1, 0.46, 0.0],
#             'P': [-1.6, 0.0, 112.7, 0.51, 0.0],
#             'Q': [-3.5, 0.0, 143.8, 0.49, 0.0],
#             'R': [-4.5, 1.0, 173.4, 0.53, 0.0],
#             'S': [-0.8, 0.0, 89.0, 0.51, 0.0],
#             'T': [-0.7, 0.0, 116.1, 0.44, 0.0],
#             'V': [4.2, 0.0, 140.0, 0.39, 0.0],
#             'W': [-0.9, 0.0, 227.8, 0.31, 1.0],
#             'Y': [-1.3, 0.0, 193.6, 0.42, 1.0],
#             'X': [0.0, 0.0, 120.0, 0.40, 0.0],
#         }
    
#     def _fit_scaler(self):
#         """拟合标准化器"""
#         # 收集所有氨基酸的特征
#         all_features = []
#         for aa in 'ARNDCQEGHILKMFPSTWYV':  # 20种标准氨基酸
#             if isinstance(self.aa_properties[aa], dict):
#                 # AAindex格式
#                 features = list(self.aa_properties[aa].values())
#             else:
#                 # 列表格式
#                 features = self.aa_properties[aa]
#             all_features.append(features)
        
#         all_features = np.array(all_features)
        
#         # Z-score标准化
#         scaler = StandardScaler()
#         scaler.fit(all_features)
        
#         return scaler
    
#     def _get_aa_features(self, aa: str) -> List[float]:
#         """获取单个氨基酸的特征"""
#         aa = aa.upper()
#         if aa not in self.aa_properties:
#             aa = 'X'  # Unknown
        
#         if isinstance(self.aa_properties[aa], dict):
#             # AAindex格式:字典
#             features = list(self.aa_properties[aa].values())
#         else:
#             # 基础格式:列表
#             features = self.aa_properties[aa]
        
#         return features
    
#     def forward(self, sequences: List[str]) -> torch.Tensor:
#         """
#         Args:
#             sequences: List of amino acid sequences
#         Returns:
#             [B, max_len, n_features] 标准化后的特征
#         """
#         batch_size = len(sequences)
#         max_len = max(len(seq) for seq in sequences)
        
#         # 收集特征
#         properties = []
#         for seq in sequences:
#             seq_props = []
#             for aa in seq:
#                 props = self._get_aa_features(aa)
#                 seq_props.append(props)
            
#             # Padding
#             while len(seq_props) < max_len:
#                 seq_props.append([0.0] * self.n_features)
            
#             properties.append(seq_props)
        
#         properties = np.array(properties)  # [B, L, n_features]
        
#         # 标准化(除了padding位置)
#         batch_size, seq_len, n_feat = properties.shape
#         properties_flat = properties.reshape(-1, n_feat)
        
#         # 标准化
#         properties_normalized = self.scaler.transform(properties_flat)
#         properties_normalized = properties_normalized.reshape(batch_size, seq_len, n_feat)
        
#         # 转为tensor
#         properties_tensor = torch.tensor(
#             properties_normalized, 
#             dtype=torch.float32, 
#             device=self.device
#         )
        
#         return properties_tensor  # [B, L, n_features]

import torch
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import StandardScaler
from typing import List

class PhysicochemicalEncoder(nn.Module):
    """Amino Acid Physicochemical Property Encoder (AAindex版本, 向量化优化版)"""
    
    def __init__(self, device, use_aaindex=True, selected_features=None):
        super().__init__()
        self.device = device
        self.use_aaindex = use_aaindex

        # 加载特征
        if use_aaindex:
            self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features)
            self.n_features = len(list(self.aa_properties['A'].values()))
            print(f"✓ Loaded {self.n_features} AAindex features")
        else:
            self.aa_properties = self._get_basic_properties()
            self.n_features = 5
            print(f"✓ Using {self.n_features} basic features")
        
        # 拟合标准化器
        self.scaler = self._fit_scaler()

        # ======================== 🔥 预处理部分 ======================== #
        # 1. 构建 lookup table
        aa_list = list(self.aa_properties.keys())
        aa_list.sort()  # 保证稳定顺序
        self.aa_to_idx = {aa: i for i, aa in enumerate(aa_list)}
        self.pad_idx = len(self.aa_to_idx)  # padding index

        aa_feature_table = []
        for aa in aa_list:
            feats = self._get_aa_features(aa)
            aa_feature_table.append(feats)
        aa_feature_table.append([0.0] * self.n_features)  # padding vector
        self.aa_feature_table = torch.tensor(
            np.array(aa_feature_table),
            dtype=torch.float32
        ).to(self.device)  # [n_aa+1, n_feat]

        # 2. 标准化参数预存成 GPU tensor
        self.mean_tensor = torch.tensor(self.scaler.mean_, dtype=torch.float32, device=self.device)
        self.scale_tensor = torch.tensor(self.scaler.scale_, dtype=torch.float32, device=self.device)

    # 下面这些函数和你原来的完全一致,不动
    def _load_aaindex_features(self, selected_features=None):
        try:
            from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS
            if selected_features is not None:
                filtered_props = {}
                for aa, props in AA_PROPERTIES_AAINDEX.items():
                    filtered_props[aa] = {k: v for k, v in props.items() if k in selected_features}
                return filtered_props, selected_features
            else:
                feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys())
                return AA_PROPERTIES_AAINDEX, feature_names
        except ImportError:
            print("⚠ Warning: aa_properties_aaindex.py not found!")
            return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma']

    def _get_basic_properties(self):
        # 这里同你原来的
        return {
            'A': [1.8, 0.0, 88.6, 0.36, 0.0],
            'C': [2.5, 0.0, 108.5, 0.35, 0.0],
            'D': [-3.5, -1.0, 111.1, 0.51, 0.0],
            'E': [-3.5, -1.0, 138.4, 0.50, 0.0],
            'F': [2.8, 0.0, 189.9, 0.31, 1.0],
            'G': [-0.4, 0.0, 60.1, 0.54, 0.0],
            'H': [-3.2, 0.5, 153.2, 0.32, 0.5],
            'I': [4.5, 0.0, 166.7, 0.46, 0.0],
            'K': [-3.9, 1.0, 168.6, 0.47, 0.0],
            'L': [3.8, 0.0, 166.7, 0.37, 0.0],
            'M': [1.9, 0.0, 162.9, 0.30, 0.0],
            'N': [-3.5, 0.0, 114.1, 0.46, 0.0],
            'P': [-1.6, 0.0, 112.7, 0.51, 0.0],
            'Q': [-3.5, 0.0, 143.8, 0.49, 0.0],
            'R': [-4.5, 1.0, 173.4, 0.53, 0.0],
            'S': [-0.8, 0.0, 89.0, 0.51, 0.0],
            'T': [-0.7, 0.0, 116.1, 0.44, 0.0],
            'V': [4.2, 0.0, 140.0, 0.39, 0.0],
            'W': [-0.9, 0.0, 227.8, 0.31, 1.0],
            'Y': [-1.3, 0.0, 193.6, 0.42, 1.0],
            'X': [0.0, 0.0, 120.0, 0.40, 0.0],
        }

    def _fit_scaler(self):
        all_features = []
        for aa in 'ARNDCQEGHILKMFPSTWYV':
            if isinstance(self.aa_properties[aa], dict):
                features = list(self.aa_properties[aa].values())
            else:
                features = self.aa_properties[aa]
            all_features.append(features)
        all_features = np.array(all_features)
        scaler = StandardScaler()
        scaler.fit(all_features)
        return scaler

    def _get_aa_features(self, aa: str):
        aa = aa.upper()
        if aa not in self.aa_properties:
            aa = 'X'
        if isinstance(self.aa_properties[aa], dict):
            return list(self.aa_properties[aa].values())
        else:
            return self.aa_properties[aa]

    def forward(self, sequences: List[str]) -> torch.Tensor:
        batch_size = len(sequences)
        max_len = max(len(seq) for seq in sequences)

        # 1) encode sequences to indices with padding
        idx_batch = np.full((batch_size, max_len), self.pad_idx, dtype=np.int64)
        for i, seq in enumerate(sequences):
            idx_seq = [self.aa_to_idx.get(aa.upper(), self.pad_idx) for aa in seq]
            idx_batch[i, :len(idx_seq)] = idx_seq

        idx_tensor = torch.tensor(idx_batch, dtype=torch.long, device=self.device)  # [B, L]

        # 2) lookup properties
        props = self.aa_feature_table[idx_tensor]  # [B, L, n_feat]

        props = (props - self.mean_tensor) / self.scale_tensor

        return props