Files changed (3) hide show
  1. collate.py +105 -0
  2. tokenizer.py +255 -0
  3. tokens.txt +0 -0
collate.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from copy import deepcopy
4
+ from argparse import ArgumentParser
5
+ from typing import List, Tuple, Dict
6
+
7
+ class Simple_Collator:
8
+
9
+ @staticmethod
10
+ def add_args(parent_parser: ArgumentParser) -> ArgumentParser:
11
+ parser = parent_parser.add_argument_group('Data Collator Config & Hyperparameter.')
12
+
13
+ parser.add_argument('--max_len', default = 256, type = int) # max length of sequence
14
+ parser.add_argument('--ignore_label', default = -100, type = int) # pytorch standard ignore_label: -100
15
+ parser.add_argument('--split_aa_num', default = 3, type = int) # new tokenizer split amino acid number
16
+
17
+ parser.add_argument('--truncation', default = True, type = bool)
18
+ parser.add_argument('--truncation_mode', default = 'cut', type = str, choices=['window', 'cut'])
19
+
20
+ parser.add_argument('--padding', default = True)
21
+ parser.add_argument('--padding_token', default = '[PAD]', type = str)
22
+
23
+ return parent_parser
24
+
25
+ def __init__(self, tokenizer, args) -> None:
26
+
27
+ self.tokenizer = tokenizer # get the tokenizer
28
+ self.max_len = args.max_len
29
+ self.ignore_label = args.ignore_label
30
+ self.split_aa_num = args.split_aa_num
31
+
32
+ # truncation, padding, mask
33
+ assert args.truncation_mode in ['window', 'cut'], "truncate mode must be 'window' or 'cut'."
34
+ self.trunc = args.truncation
35
+ self.trunc_mode = args.truncation_mode
36
+
37
+ self.padding = args.padding
38
+ self.padding_token = args.padding_token
39
+
40
+ def process_tokens(self, tokens_ids: List[int]) -> Tuple[List[int], List[int]]:
41
+ tokens_labels = [self.ignore_label] * len(tokens_ids)
42
+ return tokens_ids, tokens_labels
43
+
44
+ def pad_tokens(self,
45
+ tokens_ids: List[int],
46
+ tokens_labels: List[str]) -> Tuple[List[int], List[int], List[int]]:
47
+
48
+ raw_len = len(tokens_ids)
49
+
50
+ len_diff = self.max_len - (raw_len % self.max_len)
51
+ tokens_ids += [self.tokenizer.encode(self.padding_token)] * len_diff
52
+ tokens_labels += [self.ignore_label] * len_diff
53
+ tokens_attn_mask = [1] * raw_len + [0] * len_diff
54
+
55
+ return tokens_ids, tokens_labels, tokens_attn_mask
56
+
57
+ def trunc_tokens(self, data: list) -> List[list]:
58
+
59
+ res = []
60
+ tokens_len = len(data)
61
+
62
+ if tokens_len <= self.max_len: return [data]
63
+
64
+ if self.trunc_mode == 'window':
65
+ for i in range(tokens_len - self.max_len + 1):
66
+ res.append(deepcopy(data[i: i + self.max_len]))
67
+ elif self.trunc_mode == 'cut':
68
+ for i in range(0, tokens_len, self.max_len):
69
+ res.append(deepcopy(data[i: i + self.max_len]))
70
+
71
+ return res
72
+
73
+ def seq2data(self, seq: str) -> Tuple[List[int], List[int], List[int]]:
74
+ tokens_ids = self.tokenizer.tokenize(seq) # 1. tokenize the sequence
75
+
76
+ tokens_ids, tokens_labels = self.process_tokens(tokens_ids) # 2. joint mask and change tokens and generate labels
77
+
78
+ if self.padding is True:
79
+ tokens_ids, tokens_labels, tokens_attn_mask = self.pad_tokens(tokens_ids, tokens_labels) # 3. padding seqs
80
+
81
+ if self.trunc is True:
82
+ tokens_ids, tokens_labels, tokens_attn_mask = [self.trunc_tokens(i)
83
+ for i in [tokens_ids, tokens_labels, tokens_attn_mask]] # 4. truncate data
84
+
85
+ return tokens_ids, tokens_labels, tokens_attn_mask
86
+
87
+ def __call__(self, data, HF_dataset: bool = False) -> Dict:
88
+
89
+ input_ids, labels, attn_mask = [], [], []
90
+
91
+ if HF_dataset is False:
92
+ if isinstance(data, str): data = [data] # process single protein sequence for testing
93
+
94
+ for i in data:
95
+ seq = i['seq'] if HF_dataset else i
96
+ tokens_ids, tokens_labels, tokens_attn_mask = self.seq2data(seq)
97
+
98
+ input_ids.extend(deepcopy(tokens_ids))
99
+ labels.extend(deepcopy(tokens_labels))
100
+ attn_mask.extend(deepcopy(tokens_attn_mask))
101
+
102
+ return {
103
+ 'input_ids': torch.tensor(input_ids),
104
+ 'labels': torch.tensor(labels),
105
+ 'attention_mask': torch.tensor(attn_mask)}
tokenizer.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+ from functools import cmp_to_key
5
+ from itertools import permutations
6
+ from argparse import ArgumentParser
7
+ from collections import OrderedDict
8
+ from typing import List, Dict, OrderedDict, Union, Optional
9
+
10
+ class BioVocabGenerator():
11
+
12
+ def __init__(self,
13
+ gram_num: Union[int, None] = None,
14
+ sort: bool = True,
15
+ cmp_list: Union[List[str], None] = None,
16
+ aa_list: List[str] = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
17
+ 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
18
+ 'O', 'U', 'B', 'J', 'Z', 'X'],
19
+ # mmseqs2 aa list: (A S T) (C) (D B N) (E Q Z) (F Y) (G) (H) (I V) (K R) (L J M) (P) (W) (X)
20
+ special_tokens: List[str] = ['[PAD]', '[MASK]', '[CLS]', '[SEP]','[UNK]']) -> None:
21
+
22
+ # 1. Set the gram_num for tokenization.
23
+ # Example: gram_num = 3, 'ABCDE' -> ['ABC', 'BCD', 'CDE']
24
+ if gram_num is not None: assert gram_num % 2 != 0, 'gram_num must be odd!'
25
+ self.gram_num = gram_num
26
+
27
+ # 2. Set the amino acid list and add special_tokens for tokenization.
28
+ self.aa_list = aa_list
29
+ self.special_tokens = special_tokens
30
+
31
+ # 3. Set the bool value for sort, cmp_dict is the dict order to sort.
32
+ self.sort = sort
33
+ self.cmp_dict = self.__fill_cmp_list(self.aa_list if cmp_list is None else cmp_list)
34
+
35
+ if gram_num is not None:
36
+ self.vocab = self.__generate_vocab
37
+ self.vocab_dict = self.__generate_vocab_dict
38
+
39
+ def __fill_cmp_list(self, cmp_list: List[str]) -> Dict[str, int]:
40
+ """
41
+ fill the start and end syntax for cmp_dict
42
+ """
43
+
44
+ return {value: index for index, value in enumerate(cmp_list + ['>', '<'])}
45
+
46
+ @property
47
+ def __iter_list(self) -> List[str]:
48
+ """
49
+ generate iter_list for permutations
50
+ ['A', 'B', 'C'] -> ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C']
51
+ """
52
+
53
+ return [i for _ in range(self.gram_num) for i in self.aa_list] + ['>', '<']
54
+
55
+ def __remove_errstr(self, x: str) -> bool:
56
+ """
57
+ remove error string from raw_vocab
58
+ error str example: 'A>B', '<QW'
59
+ """
60
+
61
+ if x.count('<') + x.count('>') == 0:
62
+ return True
63
+ elif x.count('<') + x.count('>') == 1:
64
+ if x[0] == '>' or x[-1] == '<':
65
+ return True
66
+ else:
67
+ return False
68
+
69
+ def __vocab_cmp(self, x: str, y: str) -> int:
70
+ """
71
+ cmp function for sort
72
+ """
73
+
74
+ for i, j in zip(x, y):
75
+ if self.cmp_dict[i] < self.cmp_dict[j]:
76
+ return -1
77
+ elif self.cmp_dict[i] > self.cmp_dict[j]:
78
+ return 1
79
+ else:
80
+ continue
81
+
82
+ @property
83
+ def __generate_vocab(self) -> List[str]:
84
+ """
85
+ generate n-mer amino acid vocabulary
86
+ """
87
+ # generate raw_vocab from permutations
88
+ raw_vocab = permutations(self.__iter_list, r = self.gram_num)
89
+
90
+ # use set to clear duplicate values and remove the error strs
91
+ vocab = list(set([''.join(i) for i in raw_vocab if self.__remove_errstr(i) == True]))
92
+
93
+ # sort the vocab
94
+ if self.sort is True: vocab = sorted(vocab, key = cmp_to_key(self.__vocab_cmp))
95
+
96
+ return self.special_tokens + vocab
97
+
98
+ @property
99
+ def __generate_vocab_dict(self) -> OrderedDict:
100
+ """
101
+ convert vocabulary from List to OrderedDict
102
+ """
103
+
104
+ return OrderedDict(zip(self.vocab, [i for i in range(len(self.vocab))]))
105
+
106
+ def get_size(self) -> int:
107
+ return len(self.vocab)
108
+
109
+ def get_vocab_list(self) -> List[str]:
110
+ return self.vocab
111
+
112
+ def get_vocab_dict(self) -> OrderedDict:
113
+ return self.vocab_dict
114
+
115
+ def encode(self, input: str) -> int:
116
+ try:
117
+ token_id = int(self.vocab_dict[input])
118
+ except KeyError as e:
119
+ print('Can not find {} in vocabulary!'.format(e))
120
+ finally:
121
+ return token_id
122
+
123
+ def decode(self, index: int) -> str:
124
+ return self.vocab[index]
125
+
126
+ def save_vocabdict(self, path: Optional[str] = None) -> None:
127
+
128
+ path_name = 'vocab.txt'
129
+
130
+ if path is None:
131
+ path = path_name
132
+ elif os.path.isdir(path):
133
+ path += '/' + path_name
134
+
135
+ try:
136
+ with open(path, 'w') as f:
137
+ data = self.vocab_dict
138
+ for i, j in data.items():
139
+ f.write("{0:>6} {1:>5}\n".format(i, str(j)))
140
+ except:
141
+ print('Writing Error!')
142
+
143
+
144
+ class BioVocabLoader(BioVocabGenerator):
145
+
146
+ def __init__(self, path: str) -> None:
147
+ super().__init__()
148
+ assert os.path.exists(path), 'vocab path not exists!'
149
+ self.load_vocab_dict(path)
150
+ self.get_gram_num()
151
+
152
+ def load_vocab_dict(self, path: str) -> None:
153
+ """
154
+ load the vocabulary dictionary from txt
155
+ """
156
+
157
+ with open(path, 'r') as f:
158
+ data = [line.strip() for line in f.read().splitlines()]
159
+ self.vocab = [i.split()[0] for i in data]
160
+ self.vocab_dict = OrderedDict({i.split()[0] : i.split()[1] for i in data})
161
+
162
+ def get_gram_num(self) -> None:
163
+ """
164
+ get the n-gram split from the vocabulary
165
+ """
166
+
167
+ if isinstance(self.gram_num, int):
168
+ return self.gram_num
169
+ else:
170
+ for i in self.vocab:
171
+ if i not in self.special_tokens: # default 5 special_tokens
172
+ return len(i)
173
+
174
+ class BioTokenizer(BioVocabLoader):
175
+
176
+ @staticmethod
177
+ def add_argparse_args(parent_parser: ArgumentParser) -> ArgumentParser:
178
+ parser = parent_parser.add_argument_group('Tokenizer hyperparameter.')
179
+ parser.add_argument('--vocab_path', type=str)
180
+ return parent_parser
181
+
182
+ def __init__(self, args = None, vocab_path: str = None) -> None:
183
+
184
+ if vocab_path is None:
185
+ super().__init__(args.vocab_path)
186
+ else:
187
+ super().__init__(vocab_path)
188
+
189
+ self.gram_num = self.get_gram_num()
190
+
191
+ def __cut_seq(self, seq: str) -> List[str]:
192
+ """
193
+ cut a sequence to 3-gram/3-mer token list
194
+ ">ABCDE<" -> '>AB', 'ABC', 'BCD', 'CDE', 'DE<'
195
+ """
196
+
197
+ seq = seq.upper()
198
+ assert len(seq) - self.gram_num + 1 > 0, 'Protein sequence is too short to cut!'
199
+ return [seq[i: i + self.gram_num] for i in range(len(seq) - self.gram_num + 1)]
200
+
201
+ def __single_seq_tokenize(self, seq: str) -> List[int]:
202
+ """
203
+ convert token to index
204
+ """
205
+
206
+ # assert len(seq) > 10, 'Too short to process!'
207
+ token_list = self.__cut_seq(seq)
208
+ token_ids = [self.encode(i) for i in token_list]
209
+
210
+ return token_ids
211
+
212
+ def __append_headtail(self, seq: str) -> str:
213
+ """
214
+ append '>' on sequence head and '<' on sequence tail
215
+ """
216
+
217
+ if seq[0] != '>':
218
+ seq = '>' + seq
219
+ if seq[-1] != '<':
220
+ seq += '<'
221
+
222
+ return seq
223
+
224
+ def get_token_list(self, seq: str) -> List[str]:
225
+ """
226
+ split sequence to a list contains all tokens
227
+ """
228
+
229
+ seq = self.__append_headtail(seq)
230
+
231
+ assert len(seq) > 10, 'Too short to process!'
232
+ token_list = self.__cut_seq(seq)
233
+
234
+ return token_list
235
+
236
+ def tokenize(self, seq: str, pt: bool = False) -> List[int]:
237
+ """
238
+ tokenize the sequence to ids
239
+ """
240
+
241
+ assert seq.isalpha(), f'ERROR Seq: {seq}\nProtein Sequence has illegal char!'
242
+
243
+ seq = self.__append_headtail(seq)
244
+ token_ids = self.__single_seq_tokenize(seq)
245
+
246
+ return token_ids
247
+
248
+ def detokenize(self, ids: List[str]) -> str:
249
+ """
250
+ detokenize ids to sequence
251
+ """
252
+
253
+ seq = [self.decode(i) for i in ids]
254
+
255
+ return seq
tokens.txt ADDED
The diff for this file is too large to render. See raw diff