PDeepPP_ACE / processing_pdeeppp.py
fondress's picture
Create processing_pdeeppp.py
a959f5c verified
raw
history blame
5.17 kB
import os
import torch
import torch.nn as nn
import numpy as np
from processing_pdeeppp import PDeepPPProcessor
from sklearn.model_selection import train_test_split
import esm
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 设置超参数
batch_size = 16
embedding_dim = 1280
esm_ratio = 0.95
target_length = 33 # PDeepPPProcessor 的目标序列长度
ptm_type = "Hydroxyproline_P"
save_dir = f"./pretrained_weights/{ptm_type}/"
os.makedirs(save_dir, exist_ok=True)
# 加载数据集
data_path = "/path/to/your/dataset.xlsx" # 替换为你的数据集路径
data = pd.read_excel(data_path)
labels = data["label"].values
sequences = data["sequence"].fillna("").values
# 数据集划分
train_sequences, test_sequences, train_labels, test_labels = train_test_split(
sequences, labels, test_size=0.2, random_state=42
)
# 初始化 PDeepPPProcessor
processor = PDeepPPProcessor(pad_char="X", target_length=target_length)
# 处理训练和测试数据
train_inputs = processor(sequences=train_sequences, ptm_mode=True)
test_inputs = processor(sequences=test_sequences, ptm_mode=True)
# 加载 ESM 模型
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = esm_alphabet.get_batch_converter()
esm_model = esm_model.to(device)
esm_model.eval()
def extract_esm_representations(sequences, batch_size=16):
"""从 ESM 模型中提取序列表示"""
sequence_representations = []
for i in range(0, len(sequences), batch_size):
batch_data = sequences[i : i + batch_size]
batch_labels = [0] * len(batch_data) # 占位符标签
batch = list(zip(batch_labels, batch_data))
_, _, batch_tokens = batch_converter(batch)
batch_tokens = batch_tokens.to(device)
with torch.no_grad():
results = esm_model(batch_tokens, repr_layers=[33])
token_representations = results["representations"][33]
for seq, token_repr in zip(batch_data, token_representations):
seq_len = len(seq)
seq_repr = token_repr[1 : seq_len + 1] # 去掉起始和结束标记
if seq_len < target_length:
padding = torch.zeros(target_length - seq_len, embedding_dim).to(device)
seq_repr = torch.cat((seq_repr, padding), dim=0)
sequence_representations.append(seq_repr)
return torch.stack(sequence_representations)
# 提取 ESM 表示
print("Extracting ESM representations for training data...")
train_esm_representations = extract_esm_representations(train_sequences, batch_size=batch_size)
print("Extracting ESM representations for testing data...")
test_esm_representations = extract_esm_representations(test_sequences, batch_size=batch_size)
# 定义嵌入模型
class EmbeddingPretrainedModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, max_len):
super(EmbeddingPretrainedModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc = nn.Linear(embedding_dim, embedding_dim)
def forward(self, x):
x = self.embedding(x)
x = self.fc(x)
return x
# 构建词汇表
vocab = set("".join(sequences))
vocab_size = len(vocab)
vocab_dict = {char: i for i, char in enumerate(vocab)}
def seq_to_indices(seq, vocab_dict):
"""将序列转换为索引"""
return [vocab_dict[char] for char in seq]
train_indices = [seq_to_indices(seq, vocab_dict) for seq in train_sequences]
test_indices = [seq_to_indices(seq, vocab_dict) for seq in test_sequences]
def pad_sequences(sequences, max_len=None, pad_value=0):
"""将序列填充到相同的长度"""
if max_len is None:
max_len = max(len(seq) for seq in sequences)
padded_sequences = torch.zeros((len(sequences), max_len), dtype=torch.long)
for i, seq in enumerate(sequences):
padded_sequences[i, :len(seq)] = torch.tensor(seq)
return padded_sequences
# 填充序列
train_indices_padded = pad_sequences(train_indices, max_len=target_length)
test_indices_padded = pad_sequences(test_indices, max_len=target_length)
# 初始化嵌入模型
embedding_model = EmbeddingPretrainedModel(vocab_size, embedding_dim, target_length).to(device)
# 获取嵌入表示
with torch.no_grad():
train_embedding_output = embedding_model(train_indices_padded.to(device))
test_embedding_output = embedding_model(test_indices_padded.to(device))
# 合并 ESM 和嵌入表示
train_combined_representations = esm_ratio * train_esm_representations + (1 - esm_ratio) * train_embedding_output
test_combined_representations = esm_ratio * test_esm_representations + (1 - esm_ratio) * test_embedding_output
# 保存为 .npy 文件
np.save(os.path.join(save_dir, "train_combined_representations.npy"), train_combined_representations.cpu().numpy())
np.save(os.path.join(save_dir, "test_combined_representations.npy"), test_combined_representations.cpu().numpy())
np.save(os.path.join(save_dir, "train_labels.npy"), train_labels)
np.save(os.path.join(save_dir, "test_labels.npy"), test_labels)
print(f"Preprocessed data and representations saved to {save_dir}")