Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional | |
| import numpy as np | |
| from sklearn.preprocessing import StandardScaler | |
| # class PhysicochemicalEncoder(nn.Module): | |
| # """Amino Acid Physicochemical Property Encoder (AAindex版本)""" | |
| # def __init__(self, device, use_aaindex=True, selected_features=None): | |
| # """ | |
| # Args: | |
| # device: torch device | |
| # use_aaindex: 是否使用AAindex特征(True)还是简单的5特征(False) | |
| # selected_features: 选择使用哪些AAindex特征(None=使用全部) | |
| # """ | |
| # super().__init__() | |
| # self.device = device | |
| # self.use_aaindex = use_aaindex | |
| # if use_aaindex: | |
| # # 从AAindex加载特征 | |
| # self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features) | |
| # self.n_features = len(list(self.aa_properties['A'].values())) | |
| # print(f"✓ Loaded {self.n_features} AAindex features") | |
| # else: | |
| # # 使用简单的5特征 | |
| # self.aa_properties = self._get_basic_properties() | |
| # self.n_features = 5 | |
| # print(f"✓ Using {self.n_features} basic features") | |
| # # 标准化(重要!不同特征范围差异大) | |
| # self.scaler = self._fit_scaler() | |
| # def _load_aaindex_features(self, selected_features=None): | |
| # """从AAindex加载特征""" | |
| # try: | |
| # # 尝试导入生成的文件 | |
| # from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS | |
| # if selected_features is not None: | |
| # # 只选择指定的特征 | |
| # filtered_props = {} | |
| # for aa, props in AA_PROPERTIES_AAINDEX.items(): | |
| # filtered_props[aa] = {k: v for k, v in props.items() | |
| # if k in selected_features} | |
| # return filtered_props, selected_features | |
| # else: | |
| # # 使用所有特征 | |
| # feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys()) | |
| # return AA_PROPERTIES_AAINDEX, feature_names | |
| # except ImportError: | |
| # print("⚠ Warning: aa_properties_aaindex.py not found!") | |
| # print(" Falling back to basic 5 features") | |
| # print(" Run 'python aaindex_downloader.py' to download AAindex features") | |
| # return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma'] | |
| # def _get_basic_properties(self): | |
| # """基础的5特征(作为fallback)""" | |
| # return { | |
| # 'A': [1.8, 0.0, 88.6, 0.36, 0.0], | |
| # 'C': [2.5, 0.0, 108.5, 0.35, 0.0], | |
| # 'D': [-3.5, -1.0, 111.1, 0.51, 0.0], | |
| # 'E': [-3.5, -1.0, 138.4, 0.50, 0.0], | |
| # 'F': [2.8, 0.0, 189.9, 0.31, 1.0], | |
| # 'G': [-0.4, 0.0, 60.1, 0.54, 0.0], | |
| # 'H': [-3.2, 0.5, 153.2, 0.32, 0.5], | |
| # 'I': [4.5, 0.0, 166.7, 0.46, 0.0], | |
| # 'K': [-3.9, 1.0, 168.6, 0.47, 0.0], | |
| # 'L': [3.8, 0.0, 166.7, 0.37, 0.0], | |
| # 'M': [1.9, 0.0, 162.9, 0.30, 0.0], | |
| # 'N': [-3.5, 0.0, 114.1, 0.46, 0.0], | |
| # 'P': [-1.6, 0.0, 112.7, 0.51, 0.0], | |
| # 'Q': [-3.5, 0.0, 143.8, 0.49, 0.0], | |
| # 'R': [-4.5, 1.0, 173.4, 0.53, 0.0], | |
| # 'S': [-0.8, 0.0, 89.0, 0.51, 0.0], | |
| # 'T': [-0.7, 0.0, 116.1, 0.44, 0.0], | |
| # 'V': [4.2, 0.0, 140.0, 0.39, 0.0], | |
| # 'W': [-0.9, 0.0, 227.8, 0.31, 1.0], | |
| # 'Y': [-1.3, 0.0, 193.6, 0.42, 1.0], | |
| # 'X': [0.0, 0.0, 120.0, 0.40, 0.0], | |
| # } | |
| # def _fit_scaler(self): | |
| # """拟合标准化器""" | |
| # # 收集所有氨基酸的特征 | |
| # all_features = [] | |
| # for aa in 'ARNDCQEGHILKMFPSTWYV': # 20种标准氨基酸 | |
| # if isinstance(self.aa_properties[aa], dict): | |
| # # AAindex格式 | |
| # features = list(self.aa_properties[aa].values()) | |
| # else: | |
| # # 列表格式 | |
| # features = self.aa_properties[aa] | |
| # all_features.append(features) | |
| # all_features = np.array(all_features) | |
| # # Z-score标准化 | |
| # scaler = StandardScaler() | |
| # scaler.fit(all_features) | |
| # return scaler | |
| # def _get_aa_features(self, aa: str) -> List[float]: | |
| # """获取单个氨基酸的特征""" | |
| # aa = aa.upper() | |
| # if aa not in self.aa_properties: | |
| # aa = 'X' # Unknown | |
| # if isinstance(self.aa_properties[aa], dict): | |
| # # AAindex格式:字典 | |
| # features = list(self.aa_properties[aa].values()) | |
| # else: | |
| # # 基础格式:列表 | |
| # features = self.aa_properties[aa] | |
| # return features | |
| # def forward(self, sequences: List[str]) -> torch.Tensor: | |
| # """ | |
| # Args: | |
| # sequences: List of amino acid sequences | |
| # Returns: | |
| # [B, max_len, n_features] 标准化后的特征 | |
| # """ | |
| # batch_size = len(sequences) | |
| # max_len = max(len(seq) for seq in sequences) | |
| # # 收集特征 | |
| # properties = [] | |
| # for seq in sequences: | |
| # seq_props = [] | |
| # for aa in seq: | |
| # props = self._get_aa_features(aa) | |
| # seq_props.append(props) | |
| # # Padding | |
| # while len(seq_props) < max_len: | |
| # seq_props.append([0.0] * self.n_features) | |
| # properties.append(seq_props) | |
| # properties = np.array(properties) # [B, L, n_features] | |
| # # 标准化(除了padding位置) | |
| # batch_size, seq_len, n_feat = properties.shape | |
| # properties_flat = properties.reshape(-1, n_feat) | |
| # # 标准化 | |
| # properties_normalized = self.scaler.transform(properties_flat) | |
| # properties_normalized = properties_normalized.reshape(batch_size, seq_len, n_feat) | |
| # # 转为tensor | |
| # properties_tensor = torch.tensor( | |
| # properties_normalized, | |
| # dtype=torch.float32, | |
| # device=self.device | |
| # ) | |
| # return properties_tensor # [B, L, n_features] | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from sklearn.preprocessing import StandardScaler | |
| from typing import List | |
| class PhysicochemicalEncoder(nn.Module): | |
| """Amino Acid Physicochemical Property Encoder (AAindex版本, 向量化优化版)""" | |
| def __init__(self, device, use_aaindex=True, selected_features=None): | |
| super().__init__() | |
| self.device = device | |
| self.use_aaindex = use_aaindex | |
| # 加载特征 | |
| if use_aaindex: | |
| self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features) | |
| self.n_features = len(list(self.aa_properties['A'].values())) | |
| print(f"✓ Loaded {self.n_features} AAindex features") | |
| else: | |
| self.aa_properties = self._get_basic_properties() | |
| self.n_features = 5 | |
| print(f"✓ Using {self.n_features} basic features") | |
| # 拟合标准化器 | |
| self.scaler = self._fit_scaler() | |
| # ======================== 🔥 预处理部分 ======================== # | |
| # 1. 构建 lookup table | |
| aa_list = list(self.aa_properties.keys()) | |
| aa_list.sort() # 保证稳定顺序 | |
| self.aa_to_idx = {aa: i for i, aa in enumerate(aa_list)} | |
| self.pad_idx = len(self.aa_to_idx) # padding index | |
| aa_feature_table = [] | |
| for aa in aa_list: | |
| feats = self._get_aa_features(aa) | |
| aa_feature_table.append(feats) | |
| aa_feature_table.append([0.0] * self.n_features) # padding vector | |
| self.aa_feature_table = torch.tensor( | |
| np.array(aa_feature_table), | |
| dtype=torch.float32 | |
| ).to(self.device) # [n_aa+1, n_feat] | |
| # 2. 标准化参数预存成 GPU tensor | |
| self.mean_tensor = torch.tensor(self.scaler.mean_, dtype=torch.float32, device=self.device) | |
| self.scale_tensor = torch.tensor(self.scaler.scale_, dtype=torch.float32, device=self.device) | |
| # 下面这些函数和你原来的完全一致,不动 | |
| def _load_aaindex_features(self, selected_features=None): | |
| try: | |
| from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS | |
| if selected_features is not None: | |
| filtered_props = {} | |
| for aa, props in AA_PROPERTIES_AAINDEX.items(): | |
| filtered_props[aa] = {k: v for k, v in props.items() if k in selected_features} | |
| return filtered_props, selected_features | |
| else: | |
| feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys()) | |
| return AA_PROPERTIES_AAINDEX, feature_names | |
| except ImportError: | |
| print("⚠ Warning: aa_properties_aaindex.py not found!") | |
| return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma'] | |
| def _get_basic_properties(self): | |
| # 这里同你原来的 | |
| return { | |
| 'A': [1.8, 0.0, 88.6, 0.36, 0.0], | |
| 'C': [2.5, 0.0, 108.5, 0.35, 0.0], | |
| 'D': [-3.5, -1.0, 111.1, 0.51, 0.0], | |
| 'E': [-3.5, -1.0, 138.4, 0.50, 0.0], | |
| 'F': [2.8, 0.0, 189.9, 0.31, 1.0], | |
| 'G': [-0.4, 0.0, 60.1, 0.54, 0.0], | |
| 'H': [-3.2, 0.5, 153.2, 0.32, 0.5], | |
| 'I': [4.5, 0.0, 166.7, 0.46, 0.0], | |
| 'K': [-3.9, 1.0, 168.6, 0.47, 0.0], | |
| 'L': [3.8, 0.0, 166.7, 0.37, 0.0], | |
| 'M': [1.9, 0.0, 162.9, 0.30, 0.0], | |
| 'N': [-3.5, 0.0, 114.1, 0.46, 0.0], | |
| 'P': [-1.6, 0.0, 112.7, 0.51, 0.0], | |
| 'Q': [-3.5, 0.0, 143.8, 0.49, 0.0], | |
| 'R': [-4.5, 1.0, 173.4, 0.53, 0.0], | |
| 'S': [-0.8, 0.0, 89.0, 0.51, 0.0], | |
| 'T': [-0.7, 0.0, 116.1, 0.44, 0.0], | |
| 'V': [4.2, 0.0, 140.0, 0.39, 0.0], | |
| 'W': [-0.9, 0.0, 227.8, 0.31, 1.0], | |
| 'Y': [-1.3, 0.0, 193.6, 0.42, 1.0], | |
| 'X': [0.0, 0.0, 120.0, 0.40, 0.0], | |
| } | |
| def _fit_scaler(self): | |
| all_features = [] | |
| for aa in 'ARNDCQEGHILKMFPSTWYV': | |
| if isinstance(self.aa_properties[aa], dict): | |
| features = list(self.aa_properties[aa].values()) | |
| else: | |
| features = self.aa_properties[aa] | |
| all_features.append(features) | |
| all_features = np.array(all_features) | |
| scaler = StandardScaler() | |
| scaler.fit(all_features) | |
| return scaler | |
| def _get_aa_features(self, aa: str): | |
| aa = aa.upper() | |
| if aa not in self.aa_properties: | |
| aa = 'X' | |
| if isinstance(self.aa_properties[aa], dict): | |
| return list(self.aa_properties[aa].values()) | |
| else: | |
| return self.aa_properties[aa] | |
| def forward(self, sequences: List[str]) -> torch.Tensor: | |
| batch_size = len(sequences) | |
| max_len = max(len(seq) for seq in sequences) | |
| # 1) encode sequences to indices with padding | |
| idx_batch = np.full((batch_size, max_len), self.pad_idx, dtype=np.int64) | |
| for i, seq in enumerate(sequences): | |
| idx_seq = [self.aa_to_idx.get(aa.upper(), self.pad_idx) for aa in seq] | |
| idx_batch[i, :len(idx_seq)] = idx_seq | |
| idx_tensor = torch.tensor(idx_batch, dtype=torch.long, device=self.device) # [B, L] | |
| # 2) lookup properties | |
| props = self.aa_feature_table[idx_tensor] # [B, L, n_feat] | |
| props = (props - self.mean_tensor) / self.scale_tensor | |
| return props | |