Upload 3 files
#1
by
SY-Bai
- opened
- collate.py +105 -0
- tokenizer.py +255 -0
- tokens.txt +0 -0
collate.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from argparse import ArgumentParser
|
| 5 |
+
from typing import List, Tuple, Dict
|
| 6 |
+
|
| 7 |
+
class Simple_Collator:
|
| 8 |
+
|
| 9 |
+
@staticmethod
|
| 10 |
+
def add_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
| 11 |
+
parser = parent_parser.add_argument_group('Data Collator Config & Hyperparameter.')
|
| 12 |
+
|
| 13 |
+
parser.add_argument('--max_len', default = 256, type = int) # max length of sequence
|
| 14 |
+
parser.add_argument('--ignore_label', default = -100, type = int) # pytorch standard ignore_label: -100
|
| 15 |
+
parser.add_argument('--split_aa_num', default = 3, type = int) # new tokenizer split amino acid number
|
| 16 |
+
|
| 17 |
+
parser.add_argument('--truncation', default = True, type = bool)
|
| 18 |
+
parser.add_argument('--truncation_mode', default = 'cut', type = str, choices=['window', 'cut'])
|
| 19 |
+
|
| 20 |
+
parser.add_argument('--padding', default = True)
|
| 21 |
+
parser.add_argument('--padding_token', default = '[PAD]', type = str)
|
| 22 |
+
|
| 23 |
+
return parent_parser
|
| 24 |
+
|
| 25 |
+
def __init__(self, tokenizer, args) -> None:
|
| 26 |
+
|
| 27 |
+
self.tokenizer = tokenizer # get the tokenizer
|
| 28 |
+
self.max_len = args.max_len
|
| 29 |
+
self.ignore_label = args.ignore_label
|
| 30 |
+
self.split_aa_num = args.split_aa_num
|
| 31 |
+
|
| 32 |
+
# truncation, padding, mask
|
| 33 |
+
assert args.truncation_mode in ['window', 'cut'], "truncate mode must be 'window' or 'cut'."
|
| 34 |
+
self.trunc = args.truncation
|
| 35 |
+
self.trunc_mode = args.truncation_mode
|
| 36 |
+
|
| 37 |
+
self.padding = args.padding
|
| 38 |
+
self.padding_token = args.padding_token
|
| 39 |
+
|
| 40 |
+
def process_tokens(self, tokens_ids: List[int]) -> Tuple[List[int], List[int]]:
|
| 41 |
+
tokens_labels = [self.ignore_label] * len(tokens_ids)
|
| 42 |
+
return tokens_ids, tokens_labels
|
| 43 |
+
|
| 44 |
+
def pad_tokens(self,
|
| 45 |
+
tokens_ids: List[int],
|
| 46 |
+
tokens_labels: List[str]) -> Tuple[List[int], List[int], List[int]]:
|
| 47 |
+
|
| 48 |
+
raw_len = len(tokens_ids)
|
| 49 |
+
|
| 50 |
+
len_diff = self.max_len - (raw_len % self.max_len)
|
| 51 |
+
tokens_ids += [self.tokenizer.encode(self.padding_token)] * len_diff
|
| 52 |
+
tokens_labels += [self.ignore_label] * len_diff
|
| 53 |
+
tokens_attn_mask = [1] * raw_len + [0] * len_diff
|
| 54 |
+
|
| 55 |
+
return tokens_ids, tokens_labels, tokens_attn_mask
|
| 56 |
+
|
| 57 |
+
def trunc_tokens(self, data: list) -> List[list]:
|
| 58 |
+
|
| 59 |
+
res = []
|
| 60 |
+
tokens_len = len(data)
|
| 61 |
+
|
| 62 |
+
if tokens_len <= self.max_len: return [data]
|
| 63 |
+
|
| 64 |
+
if self.trunc_mode == 'window':
|
| 65 |
+
for i in range(tokens_len - self.max_len + 1):
|
| 66 |
+
res.append(deepcopy(data[i: i + self.max_len]))
|
| 67 |
+
elif self.trunc_mode == 'cut':
|
| 68 |
+
for i in range(0, tokens_len, self.max_len):
|
| 69 |
+
res.append(deepcopy(data[i: i + self.max_len]))
|
| 70 |
+
|
| 71 |
+
return res
|
| 72 |
+
|
| 73 |
+
def seq2data(self, seq: str) -> Tuple[List[int], List[int], List[int]]:
|
| 74 |
+
tokens_ids = self.tokenizer.tokenize(seq) # 1. tokenize the sequence
|
| 75 |
+
|
| 76 |
+
tokens_ids, tokens_labels = self.process_tokens(tokens_ids) # 2. joint mask and change tokens and generate labels
|
| 77 |
+
|
| 78 |
+
if self.padding is True:
|
| 79 |
+
tokens_ids, tokens_labels, tokens_attn_mask = self.pad_tokens(tokens_ids, tokens_labels) # 3. padding seqs
|
| 80 |
+
|
| 81 |
+
if self.trunc is True:
|
| 82 |
+
tokens_ids, tokens_labels, tokens_attn_mask = [self.trunc_tokens(i)
|
| 83 |
+
for i in [tokens_ids, tokens_labels, tokens_attn_mask]] # 4. truncate data
|
| 84 |
+
|
| 85 |
+
return tokens_ids, tokens_labels, tokens_attn_mask
|
| 86 |
+
|
| 87 |
+
def __call__(self, data, HF_dataset: bool = False) -> Dict:
|
| 88 |
+
|
| 89 |
+
input_ids, labels, attn_mask = [], [], []
|
| 90 |
+
|
| 91 |
+
if HF_dataset is False:
|
| 92 |
+
if isinstance(data, str): data = [data] # process single protein sequence for testing
|
| 93 |
+
|
| 94 |
+
for i in data:
|
| 95 |
+
seq = i['seq'] if HF_dataset else i
|
| 96 |
+
tokens_ids, tokens_labels, tokens_attn_mask = self.seq2data(seq)
|
| 97 |
+
|
| 98 |
+
input_ids.extend(deepcopy(tokens_ids))
|
| 99 |
+
labels.extend(deepcopy(tokens_labels))
|
| 100 |
+
attn_mask.extend(deepcopy(tokens_attn_mask))
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
'input_ids': torch.tensor(input_ids),
|
| 104 |
+
'labels': torch.tensor(labels),
|
| 105 |
+
'attention_mask': torch.tensor(attn_mask)}
|
tokenizer.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import argparse
|
| 4 |
+
from functools import cmp_to_key
|
| 5 |
+
from itertools import permutations
|
| 6 |
+
from argparse import ArgumentParser
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from typing import List, Dict, OrderedDict, Union, Optional
|
| 9 |
+
|
| 10 |
+
class BioVocabGenerator():
|
| 11 |
+
|
| 12 |
+
def __init__(self,
|
| 13 |
+
gram_num: Union[int, None] = None,
|
| 14 |
+
sort: bool = True,
|
| 15 |
+
cmp_list: Union[List[str], None] = None,
|
| 16 |
+
aa_list: List[str] = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
|
| 17 |
+
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
|
| 18 |
+
'O', 'U', 'B', 'J', 'Z', 'X'],
|
| 19 |
+
# mmseqs2 aa list: (A S T) (C) (D B N) (E Q Z) (F Y) (G) (H) (I V) (K R) (L J M) (P) (W) (X)
|
| 20 |
+
special_tokens: List[str] = ['[PAD]', '[MASK]', '[CLS]', '[SEP]','[UNK]']) -> None:
|
| 21 |
+
|
| 22 |
+
# 1. Set the gram_num for tokenization.
|
| 23 |
+
# Example: gram_num = 3, 'ABCDE' -> ['ABC', 'BCD', 'CDE']
|
| 24 |
+
if gram_num is not None: assert gram_num % 2 != 0, 'gram_num must be odd!'
|
| 25 |
+
self.gram_num = gram_num
|
| 26 |
+
|
| 27 |
+
# 2. Set the amino acid list and add special_tokens for tokenization.
|
| 28 |
+
self.aa_list = aa_list
|
| 29 |
+
self.special_tokens = special_tokens
|
| 30 |
+
|
| 31 |
+
# 3. Set the bool value for sort, cmp_dict is the dict order to sort.
|
| 32 |
+
self.sort = sort
|
| 33 |
+
self.cmp_dict = self.__fill_cmp_list(self.aa_list if cmp_list is None else cmp_list)
|
| 34 |
+
|
| 35 |
+
if gram_num is not None:
|
| 36 |
+
self.vocab = self.__generate_vocab
|
| 37 |
+
self.vocab_dict = self.__generate_vocab_dict
|
| 38 |
+
|
| 39 |
+
def __fill_cmp_list(self, cmp_list: List[str]) -> Dict[str, int]:
|
| 40 |
+
"""
|
| 41 |
+
fill the start and end syntax for cmp_dict
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
return {value: index for index, value in enumerate(cmp_list + ['>', '<'])}
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def __iter_list(self) -> List[str]:
|
| 48 |
+
"""
|
| 49 |
+
generate iter_list for permutations
|
| 50 |
+
['A', 'B', 'C'] -> ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C']
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
return [i for _ in range(self.gram_num) for i in self.aa_list] + ['>', '<']
|
| 54 |
+
|
| 55 |
+
def __remove_errstr(self, x: str) -> bool:
|
| 56 |
+
"""
|
| 57 |
+
remove error string from raw_vocab
|
| 58 |
+
error str example: 'A>B', '<QW'
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
if x.count('<') + x.count('>') == 0:
|
| 62 |
+
return True
|
| 63 |
+
elif x.count('<') + x.count('>') == 1:
|
| 64 |
+
if x[0] == '>' or x[-1] == '<':
|
| 65 |
+
return True
|
| 66 |
+
else:
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
def __vocab_cmp(self, x: str, y: str) -> int:
|
| 70 |
+
"""
|
| 71 |
+
cmp function for sort
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
for i, j in zip(x, y):
|
| 75 |
+
if self.cmp_dict[i] < self.cmp_dict[j]:
|
| 76 |
+
return -1
|
| 77 |
+
elif self.cmp_dict[i] > self.cmp_dict[j]:
|
| 78 |
+
return 1
|
| 79 |
+
else:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def __generate_vocab(self) -> List[str]:
|
| 84 |
+
"""
|
| 85 |
+
generate n-mer amino acid vocabulary
|
| 86 |
+
"""
|
| 87 |
+
# generate raw_vocab from permutations
|
| 88 |
+
raw_vocab = permutations(self.__iter_list, r = self.gram_num)
|
| 89 |
+
|
| 90 |
+
# use set to clear duplicate values and remove the error strs
|
| 91 |
+
vocab = list(set([''.join(i) for i in raw_vocab if self.__remove_errstr(i) == True]))
|
| 92 |
+
|
| 93 |
+
# sort the vocab
|
| 94 |
+
if self.sort is True: vocab = sorted(vocab, key = cmp_to_key(self.__vocab_cmp))
|
| 95 |
+
|
| 96 |
+
return self.special_tokens + vocab
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def __generate_vocab_dict(self) -> OrderedDict:
|
| 100 |
+
"""
|
| 101 |
+
convert vocabulary from List to OrderedDict
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
return OrderedDict(zip(self.vocab, [i for i in range(len(self.vocab))]))
|
| 105 |
+
|
| 106 |
+
def get_size(self) -> int:
|
| 107 |
+
return len(self.vocab)
|
| 108 |
+
|
| 109 |
+
def get_vocab_list(self) -> List[str]:
|
| 110 |
+
return self.vocab
|
| 111 |
+
|
| 112 |
+
def get_vocab_dict(self) -> OrderedDict:
|
| 113 |
+
return self.vocab_dict
|
| 114 |
+
|
| 115 |
+
def encode(self, input: str) -> int:
|
| 116 |
+
try:
|
| 117 |
+
token_id = int(self.vocab_dict[input])
|
| 118 |
+
except KeyError as e:
|
| 119 |
+
print('Can not find {} in vocabulary!'.format(e))
|
| 120 |
+
finally:
|
| 121 |
+
return token_id
|
| 122 |
+
|
| 123 |
+
def decode(self, index: int) -> str:
|
| 124 |
+
return self.vocab[index]
|
| 125 |
+
|
| 126 |
+
def save_vocabdict(self, path: Optional[str] = None) -> None:
|
| 127 |
+
|
| 128 |
+
path_name = 'vocab.txt'
|
| 129 |
+
|
| 130 |
+
if path is None:
|
| 131 |
+
path = path_name
|
| 132 |
+
elif os.path.isdir(path):
|
| 133 |
+
path += '/' + path_name
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
with open(path, 'w') as f:
|
| 137 |
+
data = self.vocab_dict
|
| 138 |
+
for i, j in data.items():
|
| 139 |
+
f.write("{0:>6} {1:>5}\n".format(i, str(j)))
|
| 140 |
+
except:
|
| 141 |
+
print('Writing Error!')
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class BioVocabLoader(BioVocabGenerator):
|
| 145 |
+
|
| 146 |
+
def __init__(self, path: str) -> None:
|
| 147 |
+
super().__init__()
|
| 148 |
+
assert os.path.exists(path), 'vocab path not exists!'
|
| 149 |
+
self.load_vocab_dict(path)
|
| 150 |
+
self.get_gram_num()
|
| 151 |
+
|
| 152 |
+
def load_vocab_dict(self, path: str) -> None:
|
| 153 |
+
"""
|
| 154 |
+
load the vocabulary dictionary from txt
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
with open(path, 'r') as f:
|
| 158 |
+
data = [line.strip() for line in f.read().splitlines()]
|
| 159 |
+
self.vocab = [i.split()[0] for i in data]
|
| 160 |
+
self.vocab_dict = OrderedDict({i.split()[0] : i.split()[1] for i in data})
|
| 161 |
+
|
| 162 |
+
def get_gram_num(self) -> None:
|
| 163 |
+
"""
|
| 164 |
+
get the n-gram split from the vocabulary
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
if isinstance(self.gram_num, int):
|
| 168 |
+
return self.gram_num
|
| 169 |
+
else:
|
| 170 |
+
for i in self.vocab:
|
| 171 |
+
if i not in self.special_tokens: # default 5 special_tokens
|
| 172 |
+
return len(i)
|
| 173 |
+
|
| 174 |
+
class BioTokenizer(BioVocabLoader):
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def add_argparse_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
| 178 |
+
parser = parent_parser.add_argument_group('Tokenizer hyperparameter.')
|
| 179 |
+
parser.add_argument('--vocab_path', type=str)
|
| 180 |
+
return parent_parser
|
| 181 |
+
|
| 182 |
+
def __init__(self, args = None, vocab_path: str = None) -> None:
|
| 183 |
+
|
| 184 |
+
if vocab_path is None:
|
| 185 |
+
super().__init__(args.vocab_path)
|
| 186 |
+
else:
|
| 187 |
+
super().__init__(vocab_path)
|
| 188 |
+
|
| 189 |
+
self.gram_num = self.get_gram_num()
|
| 190 |
+
|
| 191 |
+
def __cut_seq(self, seq: str) -> List[str]:
|
| 192 |
+
"""
|
| 193 |
+
cut a sequence to 3-gram/3-mer token list
|
| 194 |
+
">ABCDE<" -> '>AB', 'ABC', 'BCD', 'CDE', 'DE<'
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
seq = seq.upper()
|
| 198 |
+
assert len(seq) - self.gram_num + 1 > 0, 'Protein sequence is too short to cut!'
|
| 199 |
+
return [seq[i: i + self.gram_num] for i in range(len(seq) - self.gram_num + 1)]
|
| 200 |
+
|
| 201 |
+
def __single_seq_tokenize(self, seq: str) -> List[int]:
|
| 202 |
+
"""
|
| 203 |
+
convert token to index
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
# assert len(seq) > 10, 'Too short to process!'
|
| 207 |
+
token_list = self.__cut_seq(seq)
|
| 208 |
+
token_ids = [self.encode(i) for i in token_list]
|
| 209 |
+
|
| 210 |
+
return token_ids
|
| 211 |
+
|
| 212 |
+
def __append_headtail(self, seq: str) -> str:
|
| 213 |
+
"""
|
| 214 |
+
append '>' on sequence head and '<' on sequence tail
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
if seq[0] != '>':
|
| 218 |
+
seq = '>' + seq
|
| 219 |
+
if seq[-1] != '<':
|
| 220 |
+
seq += '<'
|
| 221 |
+
|
| 222 |
+
return seq
|
| 223 |
+
|
| 224 |
+
def get_token_list(self, seq: str) -> List[str]:
|
| 225 |
+
"""
|
| 226 |
+
split sequence to a list contains all tokens
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
seq = self.__append_headtail(seq)
|
| 230 |
+
|
| 231 |
+
assert len(seq) > 10, 'Too short to process!'
|
| 232 |
+
token_list = self.__cut_seq(seq)
|
| 233 |
+
|
| 234 |
+
return token_list
|
| 235 |
+
|
| 236 |
+
def tokenize(self, seq: str, pt: bool = False) -> List[int]:
|
| 237 |
+
"""
|
| 238 |
+
tokenize the sequence to ids
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
assert seq.isalpha(), f'ERROR Seq: {seq}\nProtein Sequence has illegal char!'
|
| 242 |
+
|
| 243 |
+
seq = self.__append_headtail(seq)
|
| 244 |
+
token_ids = self.__single_seq_tokenize(seq)
|
| 245 |
+
|
| 246 |
+
return token_ids
|
| 247 |
+
|
| 248 |
+
def detokenize(self, ids: List[str]) -> str:
|
| 249 |
+
"""
|
| 250 |
+
detokenize ids to sequence
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
seq = [self.decode(i) for i in ids]
|
| 254 |
+
|
| 255 |
+
return seq
|
tokens.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|