Spaces:
Sleeping
Sleeping
File size: 11,903 Bytes
78f28d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 |
import torch
import torch.nn as nn
from typing import List, Optional
import numpy as np
from sklearn.preprocessing import StandardScaler
# class PhysicochemicalEncoder(nn.Module):
# """Amino Acid Physicochemical Property Encoder (AAindex版本)"""
# def __init__(self, device, use_aaindex=True, selected_features=None):
# """
# Args:
# device: torch device
# use_aaindex: 是否使用AAindex特征(True)还是简单的5特征(False)
# selected_features: 选择使用哪些AAindex特征(None=使用全部)
# """
# super().__init__()
# self.device = device
# self.use_aaindex = use_aaindex
# if use_aaindex:
# # 从AAindex加载特征
# self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features)
# self.n_features = len(list(self.aa_properties['A'].values()))
# print(f"✓ Loaded {self.n_features} AAindex features")
# else:
# # 使用简单的5特征
# self.aa_properties = self._get_basic_properties()
# self.n_features = 5
# print(f"✓ Using {self.n_features} basic features")
# # 标准化(重要!不同特征范围差异大)
# self.scaler = self._fit_scaler()
# def _load_aaindex_features(self, selected_features=None):
# """从AAindex加载特征"""
# try:
# # 尝试导入生成的文件
# from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS
# if selected_features is not None:
# # 只选择指定的特征
# filtered_props = {}
# for aa, props in AA_PROPERTIES_AAINDEX.items():
# filtered_props[aa] = {k: v for k, v in props.items()
# if k in selected_features}
# return filtered_props, selected_features
# else:
# # 使用所有特征
# feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys())
# return AA_PROPERTIES_AAINDEX, feature_names
# except ImportError:
# print("⚠ Warning: aa_properties_aaindex.py not found!")
# print(" Falling back to basic 5 features")
# print(" Run 'python aaindex_downloader.py' to download AAindex features")
# return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma']
# def _get_basic_properties(self):
# """基础的5特征(作为fallback)"""
# return {
# 'A': [1.8, 0.0, 88.6, 0.36, 0.0],
# 'C': [2.5, 0.0, 108.5, 0.35, 0.0],
# 'D': [-3.5, -1.0, 111.1, 0.51, 0.0],
# 'E': [-3.5, -1.0, 138.4, 0.50, 0.0],
# 'F': [2.8, 0.0, 189.9, 0.31, 1.0],
# 'G': [-0.4, 0.0, 60.1, 0.54, 0.0],
# 'H': [-3.2, 0.5, 153.2, 0.32, 0.5],
# 'I': [4.5, 0.0, 166.7, 0.46, 0.0],
# 'K': [-3.9, 1.0, 168.6, 0.47, 0.0],
# 'L': [3.8, 0.0, 166.7, 0.37, 0.0],
# 'M': [1.9, 0.0, 162.9, 0.30, 0.0],
# 'N': [-3.5, 0.0, 114.1, 0.46, 0.0],
# 'P': [-1.6, 0.0, 112.7, 0.51, 0.0],
# 'Q': [-3.5, 0.0, 143.8, 0.49, 0.0],
# 'R': [-4.5, 1.0, 173.4, 0.53, 0.0],
# 'S': [-0.8, 0.0, 89.0, 0.51, 0.0],
# 'T': [-0.7, 0.0, 116.1, 0.44, 0.0],
# 'V': [4.2, 0.0, 140.0, 0.39, 0.0],
# 'W': [-0.9, 0.0, 227.8, 0.31, 1.0],
# 'Y': [-1.3, 0.0, 193.6, 0.42, 1.0],
# 'X': [0.0, 0.0, 120.0, 0.40, 0.0],
# }
# def _fit_scaler(self):
# """拟合标准化器"""
# # 收集所有氨基酸的特征
# all_features = []
# for aa in 'ARNDCQEGHILKMFPSTWYV': # 20种标准氨基酸
# if isinstance(self.aa_properties[aa], dict):
# # AAindex格式
# features = list(self.aa_properties[aa].values())
# else:
# # 列表格式
# features = self.aa_properties[aa]
# all_features.append(features)
# all_features = np.array(all_features)
# # Z-score标准化
# scaler = StandardScaler()
# scaler.fit(all_features)
# return scaler
# def _get_aa_features(self, aa: str) -> List[float]:
# """获取单个氨基酸的特征"""
# aa = aa.upper()
# if aa not in self.aa_properties:
# aa = 'X' # Unknown
# if isinstance(self.aa_properties[aa], dict):
# # AAindex格式:字典
# features = list(self.aa_properties[aa].values())
# else:
# # 基础格式:列表
# features = self.aa_properties[aa]
# return features
# def forward(self, sequences: List[str]) -> torch.Tensor:
# """
# Args:
# sequences: List of amino acid sequences
# Returns:
# [B, max_len, n_features] 标准化后的特征
# """
# batch_size = len(sequences)
# max_len = max(len(seq) for seq in sequences)
# # 收集特征
# properties = []
# for seq in sequences:
# seq_props = []
# for aa in seq:
# props = self._get_aa_features(aa)
# seq_props.append(props)
# # Padding
# while len(seq_props) < max_len:
# seq_props.append([0.0] * self.n_features)
# properties.append(seq_props)
# properties = np.array(properties) # [B, L, n_features]
# # 标准化(除了padding位置)
# batch_size, seq_len, n_feat = properties.shape
# properties_flat = properties.reshape(-1, n_feat)
# # 标准化
# properties_normalized = self.scaler.transform(properties_flat)
# properties_normalized = properties_normalized.reshape(batch_size, seq_len, n_feat)
# # 转为tensor
# properties_tensor = torch.tensor(
# properties_normalized,
# dtype=torch.float32,
# device=self.device
# )
# return properties_tensor # [B, L, n_features]
import torch
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import StandardScaler
from typing import List
class PhysicochemicalEncoder(nn.Module):
"""Amino Acid Physicochemical Property Encoder (AAindex版本, 向量化优化版)"""
def __init__(self, device, use_aaindex=True, selected_features=None):
super().__init__()
self.device = device
self.use_aaindex = use_aaindex
# 加载特征
if use_aaindex:
self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features)
self.n_features = len(list(self.aa_properties['A'].values()))
print(f"✓ Loaded {self.n_features} AAindex features")
else:
self.aa_properties = self._get_basic_properties()
self.n_features = 5
print(f"✓ Using {self.n_features} basic features")
# 拟合标准化器
self.scaler = self._fit_scaler()
# ======================== 🔥 预处理部分 ======================== #
# 1. 构建 lookup table
aa_list = list(self.aa_properties.keys())
aa_list.sort() # 保证稳定顺序
self.aa_to_idx = {aa: i for i, aa in enumerate(aa_list)}
self.pad_idx = len(self.aa_to_idx) # padding index
aa_feature_table = []
for aa in aa_list:
feats = self._get_aa_features(aa)
aa_feature_table.append(feats)
aa_feature_table.append([0.0] * self.n_features) # padding vector
self.aa_feature_table = torch.tensor(
np.array(aa_feature_table),
dtype=torch.float32
).to(self.device) # [n_aa+1, n_feat]
# 2. 标准化参数预存成 GPU tensor
self.mean_tensor = torch.tensor(self.scaler.mean_, dtype=torch.float32, device=self.device)
self.scale_tensor = torch.tensor(self.scaler.scale_, dtype=torch.float32, device=self.device)
# 下面这些函数和你原来的完全一致,不动
def _load_aaindex_features(self, selected_features=None):
try:
from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS
if selected_features is not None:
filtered_props = {}
for aa, props in AA_PROPERTIES_AAINDEX.items():
filtered_props[aa] = {k: v for k, v in props.items() if k in selected_features}
return filtered_props, selected_features
else:
feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys())
return AA_PROPERTIES_AAINDEX, feature_names
except ImportError:
print("⚠ Warning: aa_properties_aaindex.py not found!")
return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma']
def _get_basic_properties(self):
# 这里同你原来的
return {
'A': [1.8, 0.0, 88.6, 0.36, 0.0],
'C': [2.5, 0.0, 108.5, 0.35, 0.0],
'D': [-3.5, -1.0, 111.1, 0.51, 0.0],
'E': [-3.5, -1.0, 138.4, 0.50, 0.0],
'F': [2.8, 0.0, 189.9, 0.31, 1.0],
'G': [-0.4, 0.0, 60.1, 0.54, 0.0],
'H': [-3.2, 0.5, 153.2, 0.32, 0.5],
'I': [4.5, 0.0, 166.7, 0.46, 0.0],
'K': [-3.9, 1.0, 168.6, 0.47, 0.0],
'L': [3.8, 0.0, 166.7, 0.37, 0.0],
'M': [1.9, 0.0, 162.9, 0.30, 0.0],
'N': [-3.5, 0.0, 114.1, 0.46, 0.0],
'P': [-1.6, 0.0, 112.7, 0.51, 0.0],
'Q': [-3.5, 0.0, 143.8, 0.49, 0.0],
'R': [-4.5, 1.0, 173.4, 0.53, 0.0],
'S': [-0.8, 0.0, 89.0, 0.51, 0.0],
'T': [-0.7, 0.0, 116.1, 0.44, 0.0],
'V': [4.2, 0.0, 140.0, 0.39, 0.0],
'W': [-0.9, 0.0, 227.8, 0.31, 1.0],
'Y': [-1.3, 0.0, 193.6, 0.42, 1.0],
'X': [0.0, 0.0, 120.0, 0.40, 0.0],
}
def _fit_scaler(self):
all_features = []
for aa in 'ARNDCQEGHILKMFPSTWYV':
if isinstance(self.aa_properties[aa], dict):
features = list(self.aa_properties[aa].values())
else:
features = self.aa_properties[aa]
all_features.append(features)
all_features = np.array(all_features)
scaler = StandardScaler()
scaler.fit(all_features)
return scaler
def _get_aa_features(self, aa: str):
aa = aa.upper()
if aa not in self.aa_properties:
aa = 'X'
if isinstance(self.aa_properties[aa], dict):
return list(self.aa_properties[aa].values())
else:
return self.aa_properties[aa]
def forward(self, sequences: List[str]) -> torch.Tensor:
batch_size = len(sequences)
max_len = max(len(seq) for seq in sequences)
# 1) encode sequences to indices with padding
idx_batch = np.full((batch_size, max_len), self.pad_idx, dtype=np.int64)
for i, seq in enumerate(sequences):
idx_seq = [self.aa_to_idx.get(aa.upper(), self.pad_idx) for aa in seq]
idx_batch[i, :len(idx_seq)] = idx_seq
idx_tensor = torch.tensor(idx_batch, dtype=torch.long, device=self.device) # [B, L]
# 2) lookup properties
props = self.aa_feature_table[idx_tensor] # [B, L, n_feat]
props = (props - self.mean_tensor) / self.scale_tensor
return props
|