Delete tokenizer.py
Browse files- tokenizer.py +0 -255
tokenizer.py
DELETED
|
@@ -1,255 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
import argparse
|
| 4 |
-
from functools import cmp_to_key
|
| 5 |
-
from itertools import permutations
|
| 6 |
-
from argparse import ArgumentParser
|
| 7 |
-
from collections import OrderedDict
|
| 8 |
-
from typing import List, Dict, OrderedDict, Union, Optional
|
| 9 |
-
|
| 10 |
-
class BioVocabGenerator():
|
| 11 |
-
|
| 12 |
-
def __init__(self,
|
| 13 |
-
gram_num: Union[int, None] = None,
|
| 14 |
-
sort: bool = True,
|
| 15 |
-
cmp_list: Union[List[str], None] = None,
|
| 16 |
-
aa_list: List[str] = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
|
| 17 |
-
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
|
| 18 |
-
'O', 'U', 'B', 'J', 'Z', 'X'],
|
| 19 |
-
# mmseqs2 aa list: (A S T) (C) (D B N) (E Q Z) (F Y) (G) (H) (I V) (K R) (L J M) (P) (W) (X)
|
| 20 |
-
special_tokens: List[str] = ['[PAD]', '[MASK]', '[CLS]', '[SEP]','[UNK]']) -> None:
|
| 21 |
-
|
| 22 |
-
# 1. Set the gram_num for tokenization.
|
| 23 |
-
# Example: gram_num = 3, 'ABCDE' -> ['ABC', 'BCD', 'CDE']
|
| 24 |
-
if gram_num is not None: assert gram_num % 2 != 0, 'gram_num must be odd!'
|
| 25 |
-
self.gram_num = gram_num
|
| 26 |
-
|
| 27 |
-
# 2. Set the amino acid list and add special_tokens for tokenization.
|
| 28 |
-
self.aa_list = aa_list
|
| 29 |
-
self.special_tokens = special_tokens
|
| 30 |
-
|
| 31 |
-
# 3. Set the bool value for sort, cmp_dict is the dict order to sort.
|
| 32 |
-
self.sort = sort
|
| 33 |
-
self.cmp_dict = self.__fill_cmp_list(self.aa_list if cmp_list is None else cmp_list)
|
| 34 |
-
|
| 35 |
-
if gram_num is not None:
|
| 36 |
-
self.vocab = self.__generate_vocab
|
| 37 |
-
self.vocab_dict = self.__generate_vocab_dict
|
| 38 |
-
|
| 39 |
-
def __fill_cmp_list(self, cmp_list: List[str]) -> Dict[str, int]:
|
| 40 |
-
"""
|
| 41 |
-
fill the start and end syntax for cmp_dict
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
return {value: index for index, value in enumerate(cmp_list + ['>', '<'])}
|
| 45 |
-
|
| 46 |
-
@property
|
| 47 |
-
def __iter_list(self) -> List[str]:
|
| 48 |
-
"""
|
| 49 |
-
generate iter_list for permutations
|
| 50 |
-
['A', 'B', 'C'] -> ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C']
|
| 51 |
-
"""
|
| 52 |
-
|
| 53 |
-
return [i for _ in range(self.gram_num) for i in self.aa_list] + ['>', '<']
|
| 54 |
-
|
| 55 |
-
def __remove_errstr(self, x: str) -> bool:
|
| 56 |
-
"""
|
| 57 |
-
remove error string from raw_vocab
|
| 58 |
-
error str example: 'A>B', '<QW'
|
| 59 |
-
"""
|
| 60 |
-
|
| 61 |
-
if x.count('<') + x.count('>') == 0:
|
| 62 |
-
return True
|
| 63 |
-
elif x.count('<') + x.count('>') == 1:
|
| 64 |
-
if x[0] == '>' or x[-1] == '<':
|
| 65 |
-
return True
|
| 66 |
-
else:
|
| 67 |
-
return False
|
| 68 |
-
|
| 69 |
-
def __vocab_cmp(self, x: str, y: str) -> int:
|
| 70 |
-
"""
|
| 71 |
-
cmp function for sort
|
| 72 |
-
"""
|
| 73 |
-
|
| 74 |
-
for i, j in zip(x, y):
|
| 75 |
-
if self.cmp_dict[i] < self.cmp_dict[j]:
|
| 76 |
-
return -1
|
| 77 |
-
elif self.cmp_dict[i] > self.cmp_dict[j]:
|
| 78 |
-
return 1
|
| 79 |
-
else:
|
| 80 |
-
continue
|
| 81 |
-
|
| 82 |
-
@property
|
| 83 |
-
def __generate_vocab(self) -> List[str]:
|
| 84 |
-
"""
|
| 85 |
-
generate n-mer amino acid vocabulary
|
| 86 |
-
"""
|
| 87 |
-
# generate raw_vocab from permutations
|
| 88 |
-
raw_vocab = permutations(self.__iter_list, r = self.gram_num)
|
| 89 |
-
|
| 90 |
-
# use set to clear duplicate values and remove the error strs
|
| 91 |
-
vocab = list(set([''.join(i) for i in raw_vocab if self.__remove_errstr(i) == True]))
|
| 92 |
-
|
| 93 |
-
# sort the vocab
|
| 94 |
-
if self.sort is True: vocab = sorted(vocab, key = cmp_to_key(self.__vocab_cmp))
|
| 95 |
-
|
| 96 |
-
return self.special_tokens + vocab
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def __generate_vocab_dict(self) -> OrderedDict:
|
| 100 |
-
"""
|
| 101 |
-
convert vocabulary from List to OrderedDict
|
| 102 |
-
"""
|
| 103 |
-
|
| 104 |
-
return OrderedDict(zip(self.vocab, [i for i in range(len(self.vocab))]))
|
| 105 |
-
|
| 106 |
-
def get_size(self) -> int:
|
| 107 |
-
return len(self.vocab)
|
| 108 |
-
|
| 109 |
-
def get_vocab_list(self) -> List[str]:
|
| 110 |
-
return self.vocab
|
| 111 |
-
|
| 112 |
-
def get_vocab_dict(self) -> OrderedDict:
|
| 113 |
-
return self.vocab_dict
|
| 114 |
-
|
| 115 |
-
def encode(self, input: str) -> int:
|
| 116 |
-
try:
|
| 117 |
-
token_id = int(self.vocab_dict[input])
|
| 118 |
-
except KeyError as e:
|
| 119 |
-
print('Can not find {} in vocabulary!'.format(e))
|
| 120 |
-
finally:
|
| 121 |
-
return token_id
|
| 122 |
-
|
| 123 |
-
def decode(self, index: int) -> str:
|
| 124 |
-
return self.vocab[index]
|
| 125 |
-
|
| 126 |
-
def save_vocabdict(self, path: Optional[str] = None) -> None:
|
| 127 |
-
|
| 128 |
-
path_name = 'vocab.txt'
|
| 129 |
-
|
| 130 |
-
if path is None:
|
| 131 |
-
path = path_name
|
| 132 |
-
elif os.path.isdir(path):
|
| 133 |
-
path += '/' + path_name
|
| 134 |
-
|
| 135 |
-
try:
|
| 136 |
-
with open(path, 'w') as f:
|
| 137 |
-
data = self.vocab_dict
|
| 138 |
-
for i, j in data.items():
|
| 139 |
-
f.write("{0:>6} {1:>5}\n".format(i, str(j)))
|
| 140 |
-
except:
|
| 141 |
-
print('Writing Error!')
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
class BioVocabLoader(BioVocabGenerator):
|
| 145 |
-
|
| 146 |
-
def __init__(self, path: str) -> None:
|
| 147 |
-
super().__init__()
|
| 148 |
-
assert os.path.exists(path), 'vocab path not exists!'
|
| 149 |
-
self.load_vocab_dict(path)
|
| 150 |
-
self.get_gram_num()
|
| 151 |
-
|
| 152 |
-
def load_vocab_dict(self, path: str) -> None:
|
| 153 |
-
"""
|
| 154 |
-
load the vocabulary dictionary from txt
|
| 155 |
-
"""
|
| 156 |
-
|
| 157 |
-
with open(path, 'r') as f:
|
| 158 |
-
data = [line.strip() for line in f.read().splitlines()]
|
| 159 |
-
self.vocab = [i.split()[0] for i in data]
|
| 160 |
-
self.vocab_dict = OrderedDict({i.split()[0] : i.split()[1] for i in data})
|
| 161 |
-
|
| 162 |
-
def get_gram_num(self) -> None:
|
| 163 |
-
"""
|
| 164 |
-
get the n-gram split from the vocabulary
|
| 165 |
-
"""
|
| 166 |
-
|
| 167 |
-
if isinstance(self.gram_num, int):
|
| 168 |
-
return self.gram_num
|
| 169 |
-
else:
|
| 170 |
-
for i in self.vocab:
|
| 171 |
-
if i not in self.special_tokens: # default 5 special_tokens
|
| 172 |
-
return len(i)
|
| 173 |
-
|
| 174 |
-
class BioTokenizer(BioVocabLoader):
|
| 175 |
-
|
| 176 |
-
@staticmethod
|
| 177 |
-
def add_argparse_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
| 178 |
-
parser = parent_parser.add_argument_group('Tokenizer hyperparameter.')
|
| 179 |
-
parser.add_argument('--vocab_path', type=str)
|
| 180 |
-
return parent_parser
|
| 181 |
-
|
| 182 |
-
def __init__(self, args = None, vocab_path: str = None) -> None:
|
| 183 |
-
|
| 184 |
-
if vocab_path is None:
|
| 185 |
-
super().__init__(args.vocab_path)
|
| 186 |
-
else:
|
| 187 |
-
super().__init__(vocab_path)
|
| 188 |
-
|
| 189 |
-
self.gram_num = self.get_gram_num()
|
| 190 |
-
|
| 191 |
-
def __cut_seq(self, seq: str) -> List[str]:
|
| 192 |
-
"""
|
| 193 |
-
cut a sequence to 3-gram/3-mer token list
|
| 194 |
-
">ABCDE<" -> '>AB', 'ABC', 'BCD', 'CDE', 'DE<'
|
| 195 |
-
"""
|
| 196 |
-
|
| 197 |
-
seq = seq.upper()
|
| 198 |
-
assert len(seq) - self.gram_num + 1 > 0, 'Protein sequence is too short to cut!'
|
| 199 |
-
return [seq[i: i + self.gram_num] for i in range(len(seq) - self.gram_num + 1)]
|
| 200 |
-
|
| 201 |
-
def __single_seq_tokenize(self, seq: str) -> List[int]:
|
| 202 |
-
"""
|
| 203 |
-
convert token to index
|
| 204 |
-
"""
|
| 205 |
-
|
| 206 |
-
# assert len(seq) > 10, 'Too short to process!'
|
| 207 |
-
token_list = self.__cut_seq(seq)
|
| 208 |
-
token_ids = [self.encode(i) for i in token_list]
|
| 209 |
-
|
| 210 |
-
return token_ids
|
| 211 |
-
|
| 212 |
-
def __append_headtail(self, seq: str) -> str:
|
| 213 |
-
"""
|
| 214 |
-
append '>' on sequence head and '<' on sequence tail
|
| 215 |
-
"""
|
| 216 |
-
|
| 217 |
-
if seq[0] != '>':
|
| 218 |
-
seq = '>' + seq
|
| 219 |
-
if seq[-1] != '<':
|
| 220 |
-
seq += '<'
|
| 221 |
-
|
| 222 |
-
return seq
|
| 223 |
-
|
| 224 |
-
def get_token_list(self, seq: str) -> List[str]:
|
| 225 |
-
"""
|
| 226 |
-
split sequence to a list contains all tokens
|
| 227 |
-
"""
|
| 228 |
-
|
| 229 |
-
seq = self.__append_headtail(seq)
|
| 230 |
-
|
| 231 |
-
assert len(seq) > 10, 'Too short to process!'
|
| 232 |
-
token_list = self.__cut_seq(seq)
|
| 233 |
-
|
| 234 |
-
return token_list
|
| 235 |
-
|
| 236 |
-
def tokenize(self, seq: str, pt: bool = False) -> List[int]:
|
| 237 |
-
"""
|
| 238 |
-
tokenize the sequence to ids
|
| 239 |
-
"""
|
| 240 |
-
|
| 241 |
-
assert seq.isalpha(), f'ERROR Seq: {seq}\nProtein Sequence has illegal char!'
|
| 242 |
-
|
| 243 |
-
seq = self.__append_headtail(seq)
|
| 244 |
-
token_ids = self.__single_seq_tokenize(seq)
|
| 245 |
-
|
| 246 |
-
return token_ids
|
| 247 |
-
|
| 248 |
-
def detokenize(self, ids: List[str]) -> str:
|
| 249 |
-
"""
|
| 250 |
-
detokenize ids to sequence
|
| 251 |
-
"""
|
| 252 |
-
|
| 253 |
-
seq = [self.decode(i) for i in ids]
|
| 254 |
-
|
| 255 |
-
return seq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|