File size: 8,215 Bytes
77ccc16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import os
import time
import argparse
from functools import cmp_to_key
from itertools import permutations
from argparse import ArgumentParser
from collections import OrderedDict
from typing import List, Dict, OrderedDict, Union, Optional
class BioVocabGenerator():
def __init__(self,
gram_num: Union[int, None] = None,
sort: bool = True,
cmp_list: Union[List[str], None] = None,
aa_list: List[str] = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
'O', 'U', 'B', 'J', 'Z', 'X'],
# mmseqs2 aa list: (A S T) (C) (D B N) (E Q Z) (F Y) (G) (H) (I V) (K R) (L J M) (P) (W) (X)
special_tokens: List[str] = ['[PAD]', '[MASK]', '[CLS]', '[SEP]','[UNK]']) -> None:
# 1. Set the gram_num for tokenization.
# Example: gram_num = 3, 'ABCDE' -> ['ABC', 'BCD', 'CDE']
if gram_num is not None: assert gram_num % 2 != 0, 'gram_num must be odd!'
self.gram_num = gram_num
# 2. Set the amino acid list and add special_tokens for tokenization.
self.aa_list = aa_list
self.special_tokens = special_tokens
# 3. Set the bool value for sort, cmp_dict is the dict order to sort.
self.sort = sort
self.cmp_dict = self.__fill_cmp_list(self.aa_list if cmp_list is None else cmp_list)
if gram_num is not None:
self.vocab = self.__generate_vocab
self.vocab_dict = self.__generate_vocab_dict
def __fill_cmp_list(self, cmp_list: List[str]) -> Dict[str, int]:
"""
fill the start and end syntax for cmp_dict
"""
return {value: index for index, value in enumerate(cmp_list + ['>', '<'])}
@property
def __iter_list(self) -> List[str]:
"""
generate iter_list for permutations
['A', 'B', 'C'] -> ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C']
"""
return [i for _ in range(self.gram_num) for i in self.aa_list] + ['>', '<']
def __remove_errstr(self, x: str) -> bool:
"""
remove error string from raw_vocab
error str example: 'A>B', '<QW'
"""
if x.count('<') + x.count('>') == 0:
return True
elif x.count('<') + x.count('>') == 1:
if x[0] == '>' or x[-1] == '<':
return True
else:
return False
def __vocab_cmp(self, x: str, y: str) -> int:
"""
cmp function for sort
"""
for i, j in zip(x, y):
if self.cmp_dict[i] < self.cmp_dict[j]:
return -1
elif self.cmp_dict[i] > self.cmp_dict[j]:
return 1
else:
continue
@property
def __generate_vocab(self) -> List[str]:
"""
generate n-mer amino acid vocabulary
"""
# generate raw_vocab from permutations
raw_vocab = permutations(self.__iter_list, r = self.gram_num)
# use set to clear duplicate values and remove the error strs
vocab = list(set([''.join(i) for i in raw_vocab if self.__remove_errstr(i) == True]))
# sort the vocab
if self.sort is True: vocab = sorted(vocab, key = cmp_to_key(self.__vocab_cmp))
return self.special_tokens + vocab
@property
def __generate_vocab_dict(self) -> OrderedDict:
"""
convert vocabulary from List to OrderedDict
"""
return OrderedDict(zip(self.vocab, [i for i in range(len(self.vocab))]))
def get_size(self) -> int:
return len(self.vocab)
def get_vocab_list(self) -> List[str]:
return self.vocab
def get_vocab_dict(self) -> OrderedDict:
return self.vocab_dict
def encode(self, input: str) -> int:
try:
token_id = int(self.vocab_dict[input])
except KeyError as e:
print('Can not find {} in vocabulary!'.format(e))
finally:
return token_id
def decode(self, index: int) -> str:
return self.vocab[index]
def save_vocabdict(self, path: Optional[str] = None) -> None:
path_name = 'vocab.txt'
if path is None:
path = path_name
elif os.path.isdir(path):
path += '/' + path_name
try:
with open(path, 'w') as f:
data = self.vocab_dict
for i, j in data.items():
f.write("{0:>6} {1:>5}\n".format(i, str(j)))
except:
print('Writing Error!')
class BioVocabLoader(BioVocabGenerator):
def __init__(self, path: str) -> None:
super().__init__()
assert os.path.exists(path), 'vocab path not exists!'
self.load_vocab_dict(path)
self.get_gram_num()
def load_vocab_dict(self, path: str) -> None:
"""
load the vocabulary dictionary from txt
"""
with open(path, 'r') as f:
data = [line.strip() for line in f.read().splitlines()]
self.vocab = [i.split()[0] for i in data]
self.vocab_dict = OrderedDict({i.split()[0] : i.split()[1] for i in data})
def get_gram_num(self) -> None:
"""
get the n-gram split from the vocabulary
"""
if isinstance(self.gram_num, int):
return self.gram_num
else:
for i in self.vocab:
if i not in self.special_tokens: # default 5 special_tokens
return len(i)
class BioTokenizer(BioVocabLoader):
@staticmethod
def add_argparse_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = parent_parser.add_argument_group('Tokenizer hyperparameter.')
parser.add_argument('--vocab_path', type=str)
return parent_parser
def __init__(self, args = None, vocab_path: str = None) -> None:
if vocab_path is None:
super().__init__(args.vocab_path)
else:
super().__init__(vocab_path)
self.gram_num = self.get_gram_num()
def __cut_seq(self, seq: str) -> List[str]:
"""
cut a sequence to 3-gram/3-mer token list
">ABCDE<" -> '>AB', 'ABC', 'BCD', 'CDE', 'DE<'
"""
seq = seq.upper()
assert len(seq) - self.gram_num + 1 > 0, 'Protein sequence is too short to cut!'
return [seq[i: i + self.gram_num] for i in range(len(seq) - self.gram_num + 1)]
def __single_seq_tokenize(self, seq: str) -> List[int]:
"""
convert token to index
"""
# assert len(seq) > 10, 'Too short to process!'
token_list = self.__cut_seq(seq)
token_ids = [self.encode(i) for i in token_list]
return token_ids
def __append_headtail(self, seq: str) -> str:
"""
append '>' on sequence head and '<' on sequence tail
"""
if seq[0] != '>':
seq = '>' + seq
if seq[-1] != '<':
seq += '<'
return seq
def get_token_list(self, seq: str) -> List[str]:
"""
split sequence to a list contains all tokens
"""
seq = self.__append_headtail(seq)
assert len(seq) > 10, 'Too short to process!'
token_list = self.__cut_seq(seq)
return token_list
def tokenize(self, seq: str, pt: bool = False) -> List[int]:
"""
tokenize the sequence to ids
"""
assert seq.isalpha(), f'ERROR Seq: {seq}\nProtein Sequence has illegal char!'
seq = self.__append_headtail(seq)
token_ids = self.__single_seq_tokenize(seq)
return token_ids
def detokenize(self, ids: List[str]) -> str:
"""
detokenize ids to sequence
"""
seq = [self.decode(i) for i in ids]
return seq |