File size: 5,086 Bytes
e5da61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import numpy as np


class PretrainingPDeepPP:
    def __init__(self, embedding_dim=1280, target_length=33, esm_ratio=None, device=None):
        """

        初始化 PretrainingPDeepPP 类。



        Args:

            embedding_dim: 嵌入维度大小。

            target_length: 目标序列长度。

            esm_ratio: ESM 表征与嵌入表示的权重比例(由外部赋值)。

            device: 设备信息。

        """
        self.embedding_dim = embedding_dim
        self.target_length = target_length
        self.esm_ratio = esm_ratio  # 仅存储 esm_ratio,不赋默认值
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def extract_esm_representations(self, sequences, esm_model, batch_converter, batch_size=32):
        """

        提取 ESM 表征,并直接返回形状为 (batch_size, target_length, embedding_dim) 的结果。

        """
        sequence_representations = []
        print("Sequences to process:", sequences)
        print("Batch size:", batch_size)
        
        # 为每个序列添加一个“伪标签”以满足 batch_converter 要求
        labeled_sequences = [(None, seq) for seq in sequences]

        for i in range(0, len(labeled_sequences), batch_size):
            batch = labeled_sequences[i:i + batch_size]
            if len(batch) == 0:
                continue
            # 调用 batch_converter 将序列转换为 batch_tokens
            _, batch_strs, batch_tokens = batch_converter(batch)
            batch_tokens = batch_tokens.to(self.device)
            
            # 使用 ESM 模型提取表示
            with torch.no_grad():
                results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
            
            # 提取每个序列的表示
            for token_repr in results["representations"][33]:  # 获取第 33 层的表示
                sequence_representations.append(token_repr[:self.target_length])

        if len(sequence_representations) == 0:
            raise ValueError("No ESM representations were generated. Check your input sequences and batch processing logic.")

        # 将所有序列的表示堆叠起来,形状为 (batch_size, 33, 1280)
        return torch.stack(sequence_representations)

    def pad_sequences(self, sequences, max_len=None, pad_value=0):
        if max_len is None:
            max_len = max(len(seq) for seq in sequences)
        padded_sequences = torch.zeros((len(sequences), max_len), dtype=torch.long)
        for i, seq in enumerate(sequences):
            padded_sequences[i, :len(seq)] = torch.tensor(seq)
        return padded_sequences

    def seq_to_indices(self, seq, vocab_dict):
        return [vocab_dict.get(char, 0) for char in seq]

    def create_embeddings(self, sequences, vocab, esm_model, esm_alphabet, batch_size=16):
        """

        创建嵌入向量,使用类的 esm_ratio 属性动态控制权重分配。



        Args:

            sequences: 输入序列列表。

            vocab: 字符词汇表。

            esm_model: 预训练的 ESM 模型。

            esm_alphabet: ESM 模型的字母表。

            batch_size: 批量大小。



        Returns:

            结合 ESM 表征与嵌入表示的嵌入结果。

        """
        if self.esm_ratio is None:
            raise ValueError("esm_ratio is not set. Please assign a value before creating embeddings.")

        # 构建词汇表字典
        vocab_dict = {char: i for i, char in enumerate(vocab)}

        # 将序列转为索引
        indices = [self.seq_to_indices(seq, vocab_dict) for seq in sequences]
        indices_padded = self.pad_sequences(indices, max_len=self.target_length)

        # 定义嵌入模型
        class EmbeddingPretrainedModel(nn.Module):
            def __init__(self, vocab_size, embedding_dim, max_len):
                super(EmbeddingPretrainedModel, self).__init__()
                self.embedding = nn.Embedding(vocab_size, embedding_dim)
                self.fc = nn.Linear(embedding_dim, embedding_dim)
            
            def forward(self, x):
                x = self.embedding(x)
                x = self.fc(x)
                return x

        embedding_model = EmbeddingPretrainedModel(len(vocab), self.embedding_dim, self.target_length).to(self.device)

        # 提取 ESM 表示
        esm_representations = self.extract_esm_representations(
            sequences,
            esm_model,
            esm_alphabet.get_batch_converter(),
            batch_size=batch_size
        )

        # 获取嵌入表示
        with torch.no_grad():
            embedding_output = embedding_model(indices_padded.to(self.device))

        # 合并 ESM 和嵌入表示,动态使用 esm_ratio
        combined_representations = self.esm_ratio * esm_representations + (1 - self.esm_ratio) * embedding_output

        return combined_representations