File size: 7,246 Bytes
2d48951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import numpy as np
import torch
import torch.nn.utils.rnn as rnn_utils

def Data2EqlTensor(lines,max_len:int=51,AminoAcid_vocab=None):
    '''
    Args:
        flie:文件路径 \n
        max_len:设定转换后的氨基酸序列最大长度 \n
        vocab_dict:esm or protbert ,默认为按顺序映射的词典
    '''
    # 只保留20种氨基酸和填充数,其余几种非常规氨基酸均用填充数代替
    # 使用 esm和portbert字典时,nn.embedding()的vocab_size = 25
    if AminoAcid_vocab =='esm':
        aa_dict = {'[PAD]': 1, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 
                   'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 1, 'B': 1, 'U': 1, 'Z': 1, 'O': 1}
    elif AminoAcid_vocab == 'protbert':
        aa_dict = {'[PAD]':0,'L': 5, 'A': 6, 'G': 7, 'V': 8, 'E': 9, 'S': 10, 'I': 11, 'K': 12, 'R': 13, 'D': 14, 'T': 15, 
               'P': 16, 'N': 17, 'Q': 18, 'F': 19, 'Y': 20, 'M': 21, 'H': 22, 'C': 23, 'W': 24, 'X': 0, 'U': 0, 'B': 0, 'Z': 0, 'O': 0}
    else:
        aa_dict = {'[PAD]':0,'A':1,'C':2,'D':3,'E':4,'F':5,'G':6,'H':7,'I':8,'K':9,'L':10,'M':11,'N':12,'P':13,'Q':14,'R':15,
               'S':16,'T':17,'V':18,'W':19,'Y':20,'U':0,'X':0,'J':0}
    ## Esm vocab
    ## protbert vocab
    
    padding_key = '[PAD]'
    default_padding_value = 0
    if padding_key in aa_dict:
        dict_padding_value = aa_dict.get('[PAD]')
    else:
        dict_padding_value = default_padding_value
        print(f"No padding value in the implicit dictionary, set to {default_padding_value} by default")

    # assert len(lines) % 2 == 0, "Invalid file format. Number of lines should be even."
    
    long_pep_counter = 0
    pep_codes = []
    labels = []
    ids = []
    pad_flag = 1
    for id,pep in lines:
        ids.append(id)
        x = len(pep)
    
        if  x < max_len:
            current_pep=[]
            for aa in pep:
                if aa.upper() in aa_dict.keys():
                    current_pep.append(aa_dict[aa.upper()])
            # 将第一个长度<max_len的序列填充到40,确保当输入序列均<max_len时,所有序列仍然能够填充到max_len
            if pad_flag:
                current_pep.extend([dict_padding_value] * (max_len - len(current_pep)))
                pad_flag = 0       
            pep_codes.append(torch.tensor(current_pep)) # torch.tensor(current_pep)
        else:
            pep_head = pep[0:int(max_len/2)]
            pep_tail = pep[int(x-int(max_len/2)):int(x)]
            new_pep = pep_head+pep_tail
            current_pep=[]
            for aa in new_pep:
                current_pep.append(aa_dict[aa])
            pep_codes.append(torch.tensor(current_pep))
            long_pep_counter += 1

    print("length > {}:{}".format(max_len,long_pep_counter))
    data = rnn_utils.pad_sequence(pep_codes,batch_first=True,padding_value=dict_padding_value)
    return data,torch.tensor(labels)

def SeqsData2EqlTensor(file_path:str,max_len:int,AminoAcid_vocab=None):
    '''
    Args:
        flie:文件路径 \n
        max_len:设定转换后的氨基酸序列最大长度 \n
        vocab_dict:esm or protbert ,默认为按顺序映射的词典
    '''
    # 只保留20种氨基酸和填充数,其余几种非常规氨基酸均用填充数代替
    # 使用 esm和portbert字典时,nn.embedding()的vocab_size = 25
    if AminoAcid_vocab =='esm':
        aa_dict = {'[PAD]': 1, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 
                   'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 1, 'B': 1, 'U': 1, 'Z': 1, 'O': 1}
    elif AminoAcid_vocab == 'protbert':
        aa_dict = {'[PAD]':0,'L': 5, 'A': 6, 'G': 7, 'V': 8, 'E': 9, 'S': 10, 'I': 11, 'K': 12, 'R': 13, 'D': 14, 'T': 15, 
               'P': 16, 'N': 17, 'Q': 18, 'F': 19, 'Y': 20, 'M': 21, 'H': 22, 'C': 23, 'W': 24, 'X': 0, 'U': 0, 'B': 0, 'Z': 0, 'O': 0}
    else:
        aa_dict = {'[PAD]':0,'A':1,'C':2,'D':3,'E':4,'F':5,'G':6,'H':7,'I':8,'K':9,'L':10,'M':11,'N':12,'P':13,'Q':14,'R':15,
               'S':16,'T':17,'V':18,'W':19,'Y':20,'U':0,'X':0,'J':0}
    ## Esm vocab
    ## protbert vocab
    
    padding_key = '[PAD]'
    default_padding_value = 0
    if padding_key in aa_dict:
        dict_padding_value = aa_dict.get('[PAD]')
    else:
        dict_padding_value = default_padding_value
        print(f"No padding value in the implicit dictionary, set to {default_padding_value} by default")

    with open(file_path, 'r') as inf:
        lines = inf.read().splitlines()
    assert len(lines) % 2 == 0, "Invalid file format. Number of lines should be even."
    
    long_pep_counter=0
    pep_codes=[]
    labels=[]
    for line in lines:
        if line[0] == '>':
            labels.append([int(i) for i in line[1:]])
        else:
            x = len(line)
        
            if  x < max_len:
                current_pep=[]
                for aa in line:
                    if aa.upper() in aa_dict.keys():
                        current_pep.append(aa_dict[aa.upper()])
                pep_codes.append(torch.tensor(current_pep)) #torch.tensor(current_pep)
            else:
                pep_head = line[0:int(max_len/2)]
                pep_tail = line[int(x-int(max_len/2)):int(x)]
                new_pep = pep_head+pep_tail
                current_pep=[]
                for aa in new_pep:
                    current_pep.append(aa_dict[aa])
                pep_codes.append(torch.tensor(current_pep))
                long_pep_counter += 1

    print("length > {}:{}".format(max_len,long_pep_counter))
    data = rnn_utils.pad_sequence(pep_codes,batch_first=True,padding_value=dict_padding_value)
    return data,torch.tensor(labels)

def index_alignment(batch,condition_num=0,subtraction_num1=4,subtraction_num2=1):
    '''将其他蛋白质语言模型的字典索引和默认字典索引进行对齐,保持氨基酸索引只有20个数构成,且范围在[1,20],[PAD]=0或者1 \n
    "esm"模型,condition_num=1,subtraction_num1=3,subtraction_num2=1; \n
    "protbert"模型,condition_num=0,subtraction_num1=4

    Args:               
        batch:形状为[batch_size,seq_len]的二维张量 \n
        condition_num:字典中的[PAD]值 \n 
        subtraction_num1:对齐非[PAD]元素所需减掉的差值 \n
        subtraction_num2:对齐[PAD]元素所需减掉的差值
    
    return:
        shape:[batch_size,seq_len],dtype=tensor.
    '''
    condition = batch == condition_num
    # 创建一个张量,形状和batch相同,表示非[PAD]元素要减去的值
    subtraction = torch.full_like(batch, subtraction_num1)
    if condition_num==0:
        # 使用torch.where()函数来选择batch中为0的元素或者batch减去subtraction中的元素
        output = torch.where(condition, batch, batch - subtraction)
    elif condition_num==1:
        # 创建一个张量,形状和batch相同,表示[PAD]元素要减去的值
        subtraction_2 = torch.full_like(batch, subtraction_num2)
        output = torch.where(condition, batch-subtraction_2, batch - subtraction)
    
    return output