didididadada commited on
Commit
1ee15bf
·
verified ·
1 Parent(s): c196cad

Delete tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +0 -255
tokenizer.py DELETED
@@ -1,255 +0,0 @@
1
- import os
2
- import time
3
- import argparse
4
- from functools import cmp_to_key
5
- from itertools import permutations
6
- from argparse import ArgumentParser
7
- from collections import OrderedDict
8
- from typing import List, Dict, OrderedDict, Union, Optional
9
-
10
- class BioVocabGenerator():
11
-
12
- def __init__(self,
13
- gram_num: Union[int, None] = None,
14
- sort: bool = True,
15
- cmp_list: Union[List[str], None] = None,
16
- aa_list: List[str] = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
17
- 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
18
- 'O', 'U', 'B', 'J', 'Z', 'X'],
19
- # mmseqs2 aa list: (A S T) (C) (D B N) (E Q Z) (F Y) (G) (H) (I V) (K R) (L J M) (P) (W) (X)
20
- special_tokens: List[str] = ['[PAD]', '[MASK]', '[CLS]', '[SEP]','[UNK]']) -> None:
21
-
22
- # 1. Set the gram_num for tokenization.
23
- # Example: gram_num = 3, 'ABCDE' -> ['ABC', 'BCD', 'CDE']
24
- if gram_num is not None: assert gram_num % 2 != 0, 'gram_num must be odd!'
25
- self.gram_num = gram_num
26
-
27
- # 2. Set the amino acid list and add special_tokens for tokenization.
28
- self.aa_list = aa_list
29
- self.special_tokens = special_tokens
30
-
31
- # 3. Set the bool value for sort, cmp_dict is the dict order to sort.
32
- self.sort = sort
33
- self.cmp_dict = self.__fill_cmp_list(self.aa_list if cmp_list is None else cmp_list)
34
-
35
- if gram_num is not None:
36
- self.vocab = self.__generate_vocab
37
- self.vocab_dict = self.__generate_vocab_dict
38
-
39
- def __fill_cmp_list(self, cmp_list: List[str]) -> Dict[str, int]:
40
- """
41
- fill the start and end syntax for cmp_dict
42
- """
43
-
44
- return {value: index for index, value in enumerate(cmp_list + ['>', '<'])}
45
-
46
- @property
47
- def __iter_list(self) -> List[str]:
48
- """
49
- generate iter_list for permutations
50
- ['A', 'B', 'C'] -> ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C']
51
- """
52
-
53
- return [i for _ in range(self.gram_num) for i in self.aa_list] + ['>', '<']
54
-
55
- def __remove_errstr(self, x: str) -> bool:
56
- """
57
- remove error string from raw_vocab
58
- error str example: 'A>B', '<QW'
59
- """
60
-
61
- if x.count('<') + x.count('>') == 0:
62
- return True
63
- elif x.count('<') + x.count('>') == 1:
64
- if x[0] == '>' or x[-1] == '<':
65
- return True
66
- else:
67
- return False
68
-
69
- def __vocab_cmp(self, x: str, y: str) -> int:
70
- """
71
- cmp function for sort
72
- """
73
-
74
- for i, j in zip(x, y):
75
- if self.cmp_dict[i] < self.cmp_dict[j]:
76
- return -1
77
- elif self.cmp_dict[i] > self.cmp_dict[j]:
78
- return 1
79
- else:
80
- continue
81
-
82
- @property
83
- def __generate_vocab(self) -> List[str]:
84
- """
85
- generate n-mer amino acid vocabulary
86
- """
87
- # generate raw_vocab from permutations
88
- raw_vocab = permutations(self.__iter_list, r = self.gram_num)
89
-
90
- # use set to clear duplicate values and remove the error strs
91
- vocab = list(set([''.join(i) for i in raw_vocab if self.__remove_errstr(i) == True]))
92
-
93
- # sort the vocab
94
- if self.sort is True: vocab = sorted(vocab, key = cmp_to_key(self.__vocab_cmp))
95
-
96
- return self.special_tokens + vocab
97
-
98
- @property
99
- def __generate_vocab_dict(self) -> OrderedDict:
100
- """
101
- convert vocabulary from List to OrderedDict
102
- """
103
-
104
- return OrderedDict(zip(self.vocab, [i for i in range(len(self.vocab))]))
105
-
106
- def get_size(self) -> int:
107
- return len(self.vocab)
108
-
109
- def get_vocab_list(self) -> List[str]:
110
- return self.vocab
111
-
112
- def get_vocab_dict(self) -> OrderedDict:
113
- return self.vocab_dict
114
-
115
- def encode(self, input: str) -> int:
116
- try:
117
- token_id = int(self.vocab_dict[input])
118
- except KeyError as e:
119
- print('Can not find {} in vocabulary!'.format(e))
120
- finally:
121
- return token_id
122
-
123
- def decode(self, index: int) -> str:
124
- return self.vocab[index]
125
-
126
- def save_vocabdict(self, path: Optional[str] = None) -> None:
127
-
128
- path_name = 'vocab.txt'
129
-
130
- if path is None:
131
- path = path_name
132
- elif os.path.isdir(path):
133
- path += '/' + path_name
134
-
135
- try:
136
- with open(path, 'w') as f:
137
- data = self.vocab_dict
138
- for i, j in data.items():
139
- f.write("{0:>6} {1:>5}\n".format(i, str(j)))
140
- except:
141
- print('Writing Error!')
142
-
143
-
144
- class BioVocabLoader(BioVocabGenerator):
145
-
146
- def __init__(self, path: str) -> None:
147
- super().__init__()
148
- assert os.path.exists(path), 'vocab path not exists!'
149
- self.load_vocab_dict(path)
150
- self.get_gram_num()
151
-
152
- def load_vocab_dict(self, path: str) -> None:
153
- """
154
- load the vocabulary dictionary from txt
155
- """
156
-
157
- with open(path, 'r') as f:
158
- data = [line.strip() for line in f.read().splitlines()]
159
- self.vocab = [i.split()[0] for i in data]
160
- self.vocab_dict = OrderedDict({i.split()[0] : i.split()[1] for i in data})
161
-
162
- def get_gram_num(self) -> None:
163
- """
164
- get the n-gram split from the vocabulary
165
- """
166
-
167
- if isinstance(self.gram_num, int):
168
- return self.gram_num
169
- else:
170
- for i in self.vocab:
171
- if i not in self.special_tokens: # default 5 special_tokens
172
- return len(i)
173
-
174
- class BioTokenizer(BioVocabLoader):
175
-
176
- @staticmethod
177
- def add_argparse_args(parent_parser: ArgumentParser) -> ArgumentParser:
178
- parser = parent_parser.add_argument_group('Tokenizer hyperparameter.')
179
- parser.add_argument('--vocab_path', type=str)
180
- return parent_parser
181
-
182
- def __init__(self, args = None, vocab_path: str = None) -> None:
183
-
184
- if vocab_path is None:
185
- super().__init__(args.vocab_path)
186
- else:
187
- super().__init__(vocab_path)
188
-
189
- self.gram_num = self.get_gram_num()
190
-
191
- def __cut_seq(self, seq: str) -> List[str]:
192
- """
193
- cut a sequence to 3-gram/3-mer token list
194
- ">ABCDE<" -> '>AB', 'ABC', 'BCD', 'CDE', 'DE<'
195
- """
196
-
197
- seq = seq.upper()
198
- assert len(seq) - self.gram_num + 1 > 0, 'Protein sequence is too short to cut!'
199
- return [seq[i: i + self.gram_num] for i in range(len(seq) - self.gram_num + 1)]
200
-
201
- def __single_seq_tokenize(self, seq: str) -> List[int]:
202
- """
203
- convert token to index
204
- """
205
-
206
- # assert len(seq) > 10, 'Too short to process!'
207
- token_list = self.__cut_seq(seq)
208
- token_ids = [self.encode(i) for i in token_list]
209
-
210
- return token_ids
211
-
212
- def __append_headtail(self, seq: str) -> str:
213
- """
214
- append '>' on sequence head and '<' on sequence tail
215
- """
216
-
217
- if seq[0] != '>':
218
- seq = '>' + seq
219
- if seq[-1] != '<':
220
- seq += '<'
221
-
222
- return seq
223
-
224
- def get_token_list(self, seq: str) -> List[str]:
225
- """
226
- split sequence to a list contains all tokens
227
- """
228
-
229
- seq = self.__append_headtail(seq)
230
-
231
- assert len(seq) > 10, 'Too short to process!'
232
- token_list = self.__cut_seq(seq)
233
-
234
- return token_list
235
-
236
- def tokenize(self, seq: str, pt: bool = False) -> List[int]:
237
- """
238
- tokenize the sequence to ids
239
- """
240
-
241
- assert seq.isalpha(), f'ERROR Seq: {seq}\nProtein Sequence has illegal char!'
242
-
243
- seq = self.__append_headtail(seq)
244
- token_ids = self.__single_seq_tokenize(seq)
245
-
246
- return token_ids
247
-
248
- def detokenize(self, ids: List[str]) -> str:
249
- """
250
- detokenize ids to sequence
251
- """
252
-
253
- seq = [self.decode(i) for i in ids]
254
-
255
- return seq