StriMap / src /physicochemical.py
cao
Add model and predictor files
78f28d5
import torch
import torch.nn as nn
from typing import List, Optional
import numpy as np
from sklearn.preprocessing import StandardScaler
# class PhysicochemicalEncoder(nn.Module):
# """Amino Acid Physicochemical Property Encoder (AAindex版本)"""
# def __init__(self, device, use_aaindex=True, selected_features=None):
# """
# Args:
# device: torch device
# use_aaindex: 是否使用AAindex特征(True)还是简单的5特征(False)
# selected_features: 选择使用哪些AAindex特征(None=使用全部)
# """
# super().__init__()
# self.device = device
# self.use_aaindex = use_aaindex
# if use_aaindex:
# # 从AAindex加载特征
# self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features)
# self.n_features = len(list(self.aa_properties['A'].values()))
# print(f"✓ Loaded {self.n_features} AAindex features")
# else:
# # 使用简单的5特征
# self.aa_properties = self._get_basic_properties()
# self.n_features = 5
# print(f"✓ Using {self.n_features} basic features")
# # 标准化(重要!不同特征范围差异大)
# self.scaler = self._fit_scaler()
# def _load_aaindex_features(self, selected_features=None):
# """从AAindex加载特征"""
# try:
# # 尝试导入生成的文件
# from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS
# if selected_features is not None:
# # 只选择指定的特征
# filtered_props = {}
# for aa, props in AA_PROPERTIES_AAINDEX.items():
# filtered_props[aa] = {k: v for k, v in props.items()
# if k in selected_features}
# return filtered_props, selected_features
# else:
# # 使用所有特征
# feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys())
# return AA_PROPERTIES_AAINDEX, feature_names
# except ImportError:
# print("⚠ Warning: aa_properties_aaindex.py not found!")
# print(" Falling back to basic 5 features")
# print(" Run 'python aaindex_downloader.py' to download AAindex features")
# return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma']
# def _get_basic_properties(self):
# """基础的5特征(作为fallback)"""
# return {
# 'A': [1.8, 0.0, 88.6, 0.36, 0.0],
# 'C': [2.5, 0.0, 108.5, 0.35, 0.0],
# 'D': [-3.5, -1.0, 111.1, 0.51, 0.0],
# 'E': [-3.5, -1.0, 138.4, 0.50, 0.0],
# 'F': [2.8, 0.0, 189.9, 0.31, 1.0],
# 'G': [-0.4, 0.0, 60.1, 0.54, 0.0],
# 'H': [-3.2, 0.5, 153.2, 0.32, 0.5],
# 'I': [4.5, 0.0, 166.7, 0.46, 0.0],
# 'K': [-3.9, 1.0, 168.6, 0.47, 0.0],
# 'L': [3.8, 0.0, 166.7, 0.37, 0.0],
# 'M': [1.9, 0.0, 162.9, 0.30, 0.0],
# 'N': [-3.5, 0.0, 114.1, 0.46, 0.0],
# 'P': [-1.6, 0.0, 112.7, 0.51, 0.0],
# 'Q': [-3.5, 0.0, 143.8, 0.49, 0.0],
# 'R': [-4.5, 1.0, 173.4, 0.53, 0.0],
# 'S': [-0.8, 0.0, 89.0, 0.51, 0.0],
# 'T': [-0.7, 0.0, 116.1, 0.44, 0.0],
# 'V': [4.2, 0.0, 140.0, 0.39, 0.0],
# 'W': [-0.9, 0.0, 227.8, 0.31, 1.0],
# 'Y': [-1.3, 0.0, 193.6, 0.42, 1.0],
# 'X': [0.0, 0.0, 120.0, 0.40, 0.0],
# }
# def _fit_scaler(self):
# """拟合标准化器"""
# # 收集所有氨基酸的特征
# all_features = []
# for aa in 'ARNDCQEGHILKMFPSTWYV': # 20种标准氨基酸
# if isinstance(self.aa_properties[aa], dict):
# # AAindex格式
# features = list(self.aa_properties[aa].values())
# else:
# # 列表格式
# features = self.aa_properties[aa]
# all_features.append(features)
# all_features = np.array(all_features)
# # Z-score标准化
# scaler = StandardScaler()
# scaler.fit(all_features)
# return scaler
# def _get_aa_features(self, aa: str) -> List[float]:
# """获取单个氨基酸的特征"""
# aa = aa.upper()
# if aa not in self.aa_properties:
# aa = 'X' # Unknown
# if isinstance(self.aa_properties[aa], dict):
# # AAindex格式:字典
# features = list(self.aa_properties[aa].values())
# else:
# # 基础格式:列表
# features = self.aa_properties[aa]
# return features
# def forward(self, sequences: List[str]) -> torch.Tensor:
# """
# Args:
# sequences: List of amino acid sequences
# Returns:
# [B, max_len, n_features] 标准化后的特征
# """
# batch_size = len(sequences)
# max_len = max(len(seq) for seq in sequences)
# # 收集特征
# properties = []
# for seq in sequences:
# seq_props = []
# for aa in seq:
# props = self._get_aa_features(aa)
# seq_props.append(props)
# # Padding
# while len(seq_props) < max_len:
# seq_props.append([0.0] * self.n_features)
# properties.append(seq_props)
# properties = np.array(properties) # [B, L, n_features]
# # 标准化(除了padding位置)
# batch_size, seq_len, n_feat = properties.shape
# properties_flat = properties.reshape(-1, n_feat)
# # 标准化
# properties_normalized = self.scaler.transform(properties_flat)
# properties_normalized = properties_normalized.reshape(batch_size, seq_len, n_feat)
# # 转为tensor
# properties_tensor = torch.tensor(
# properties_normalized,
# dtype=torch.float32,
# device=self.device
# )
# return properties_tensor # [B, L, n_features]
import torch
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import StandardScaler
from typing import List
class PhysicochemicalEncoder(nn.Module):
"""Amino Acid Physicochemical Property Encoder (AAindex版本, 向量化优化版)"""
def __init__(self, device, use_aaindex=True, selected_features=None):
super().__init__()
self.device = device
self.use_aaindex = use_aaindex
# 加载特征
if use_aaindex:
self.aa_properties, self.feature_names = self._load_aaindex_features(selected_features)
self.n_features = len(list(self.aa_properties['A'].values()))
print(f"✓ Loaded {self.n_features} AAindex features")
else:
self.aa_properties = self._get_basic_properties()
self.n_features = 5
print(f"✓ Using {self.n_features} basic features")
# 拟合标准化器
self.scaler = self._fit_scaler()
# ======================== 🔥 预处理部分 ======================== #
# 1. 构建 lookup table
aa_list = list(self.aa_properties.keys())
aa_list.sort() # 保证稳定顺序
self.aa_to_idx = {aa: i for i, aa in enumerate(aa_list)}
self.pad_idx = len(self.aa_to_idx) # padding index
aa_feature_table = []
for aa in aa_list:
feats = self._get_aa_features(aa)
aa_feature_table.append(feats)
aa_feature_table.append([0.0] * self.n_features) # padding vector
self.aa_feature_table = torch.tensor(
np.array(aa_feature_table),
dtype=torch.float32
).to(self.device) # [n_aa+1, n_feat]
# 2. 标准化参数预存成 GPU tensor
self.mean_tensor = torch.tensor(self.scaler.mean_, dtype=torch.float32, device=self.device)
self.scale_tensor = torch.tensor(self.scaler.scale_, dtype=torch.float32, device=self.device)
# 下面这些函数和你原来的完全一致,不动
def _load_aaindex_features(self, selected_features=None):
try:
from aa_properties_aaindex import AA_PROPERTIES_AAINDEX, FEATURE_DESCRIPTIONS
if selected_features is not None:
filtered_props = {}
for aa, props in AA_PROPERTIES_AAINDEX.items():
filtered_props[aa] = {k: v for k, v in props.items() if k in selected_features}
return filtered_props, selected_features
else:
feature_names = list(AA_PROPERTIES_AAINDEX['A'].keys())
return AA_PROPERTIES_AAINDEX, feature_names
except ImportError:
print("⚠ Warning: aa_properties_aaindex.py not found!")
return self._get_basic_properties(), ['hydro', 'charge', 'volume', 'flex', 'aroma']
def _get_basic_properties(self):
# 这里同你原来的
return {
'A': [1.8, 0.0, 88.6, 0.36, 0.0],
'C': [2.5, 0.0, 108.5, 0.35, 0.0],
'D': [-3.5, -1.0, 111.1, 0.51, 0.0],
'E': [-3.5, -1.0, 138.4, 0.50, 0.0],
'F': [2.8, 0.0, 189.9, 0.31, 1.0],
'G': [-0.4, 0.0, 60.1, 0.54, 0.0],
'H': [-3.2, 0.5, 153.2, 0.32, 0.5],
'I': [4.5, 0.0, 166.7, 0.46, 0.0],
'K': [-3.9, 1.0, 168.6, 0.47, 0.0],
'L': [3.8, 0.0, 166.7, 0.37, 0.0],
'M': [1.9, 0.0, 162.9, 0.30, 0.0],
'N': [-3.5, 0.0, 114.1, 0.46, 0.0],
'P': [-1.6, 0.0, 112.7, 0.51, 0.0],
'Q': [-3.5, 0.0, 143.8, 0.49, 0.0],
'R': [-4.5, 1.0, 173.4, 0.53, 0.0],
'S': [-0.8, 0.0, 89.0, 0.51, 0.0],
'T': [-0.7, 0.0, 116.1, 0.44, 0.0],
'V': [4.2, 0.0, 140.0, 0.39, 0.0],
'W': [-0.9, 0.0, 227.8, 0.31, 1.0],
'Y': [-1.3, 0.0, 193.6, 0.42, 1.0],
'X': [0.0, 0.0, 120.0, 0.40, 0.0],
}
def _fit_scaler(self):
all_features = []
for aa in 'ARNDCQEGHILKMFPSTWYV':
if isinstance(self.aa_properties[aa], dict):
features = list(self.aa_properties[aa].values())
else:
features = self.aa_properties[aa]
all_features.append(features)
all_features = np.array(all_features)
scaler = StandardScaler()
scaler.fit(all_features)
return scaler
def _get_aa_features(self, aa: str):
aa = aa.upper()
if aa not in self.aa_properties:
aa = 'X'
if isinstance(self.aa_properties[aa], dict):
return list(self.aa_properties[aa].values())
else:
return self.aa_properties[aa]
def forward(self, sequences: List[str]) -> torch.Tensor:
batch_size = len(sequences)
max_len = max(len(seq) for seq in sequences)
# 1) encode sequences to indices with padding
idx_batch = np.full((batch_size, max_len), self.pad_idx, dtype=np.int64)
for i, seq in enumerate(sequences):
idx_seq = [self.aa_to_idx.get(aa.upper(), self.pad_idx) for aa in seq]
idx_batch[i, :len(idx_seq)] = idx_seq
idx_tensor = torch.tensor(idx_batch, dtype=torch.long, device=self.device) # [B, L]
# 2) lookup properties
props = self.aa_feature_table[idx_tensor] # [B, L, n_feat]
props = (props - self.mean_tensor) / self.scale_tensor
return props