|
|
import os |
|
|
import json |
|
|
import re |
|
|
from typing import List, Optional, Tuple, Dict |
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
class SingleNucleotideTokenizer(PreTrainedTokenizer): |
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
self.vocab_list = [ |
|
|
"<oov>", "<s>", "</s>", "<pad>", "<mask>", |
|
|
"<bog>", "<eog>", "<bok>", "<eok>", "<+>", "<->", |
|
|
"<mam>", "<vrt>", "<inv>", "<pln>", "<fng>", "<prt>", |
|
|
"<arc>", "<bct>", "<mit>", "<plt>", "<plm>", "<vir>", |
|
|
"<cds>", "<pseudo>", "<tRNA>", "<rRNA>", "<ncRNA>", |
|
|
"<sp0>", "<sp1>", "<sp2>", "<sp3>", |
|
|
"A", "C", "G", "<K>", "<M>", "N", "<R>", "<S>", "T", "<W>", "<Y>" |
|
|
] |
|
|
|
|
|
|
|
|
self.vocab = {token: idx for idx, token in enumerate(self.vocab_list)} |
|
|
self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} |
|
|
self.tokens_to_ids = {token: idx for token, idx in self.vocab.items()} |
|
|
|
|
|
|
|
|
self.unk_token = "N" |
|
|
self.bos_token = "<s>" |
|
|
self.eos_token = "</s>" |
|
|
self.pad_token = "<pad>" |
|
|
self.mask_token = "<mask>" |
|
|
|
|
|
|
|
|
special_tokens_pattern = "|".join(re.escape(token) for token in self.vocab_list if token.startswith("<") and token.endswith(">")) |
|
|
self.special_token_re = re.compile(f"({special_tokens_pattern})") |
|
|
|
|
|
|
|
|
self.normal_token_re = re.compile(r"[ACGTN]") |
|
|
|
|
|
|
|
|
self.unk_token_id = self.vocab[self.unk_token] |
|
|
self.bos_token_id = self.vocab[self.bos_token] |
|
|
self.eos_token_id = self.vocab[self.eos_token] |
|
|
self.pad_token_id = self.vocab[self.pad_token] |
|
|
self.mask_token_id = self.vocab[self.mask_token] |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
unk_token=self.unk_token, |
|
|
bos_token=self.bos_token, |
|
|
eos_token=self.eos_token, |
|
|
pad_token=self.pad_token, |
|
|
mask_token=self.mask_token, |
|
|
**kwargs |
|
|
) |
|
|
self.clean_up_tokenization_spaces = True |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return len(self.vocab) |
|
|
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
|
return self.vocab |
|
|
|
|
|
def _tokenize(self, text: str, **kwargs) -> List[str]: |
|
|
tokens = [] |
|
|
pos = 0 |
|
|
text_length = len(text) |
|
|
|
|
|
while pos < text_length: |
|
|
|
|
|
special_match = self.special_token_re.match(text, pos) |
|
|
if special_match: |
|
|
token = special_match.group() |
|
|
tokens.append(token) |
|
|
pos = special_match.end() |
|
|
continue |
|
|
|
|
|
|
|
|
normal_match = self.normal_token_re.match(text, pos) |
|
|
if normal_match: |
|
|
token = normal_match.group() |
|
|
|
|
|
if token in self.vocab: |
|
|
tokens.append(token) |
|
|
else: |
|
|
tokens.append(self.unk_token) |
|
|
pos = normal_match.end() |
|
|
continue |
|
|
|
|
|
|
|
|
tokens.append(self.unk_token) |
|
|
pos += 1 |
|
|
|
|
|
return tokens |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
return self.vocab.get(token, self.unk_token_id) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
return self.ids_to_tokens.get(index, self.unk_token) |
|
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
|
|
|
|
return "".join(tokens) |
|
|
|
|
|
def build_inputs_with_special_tokens( |
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|
|
) -> List[int]: |
|
|
if token_ids_1 is None: |
|
|
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
|
|
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] |
|
|
|
|
|
def get_special_tokens_mask( |
|
|
self, |
|
|
token_ids_0: List[int], |
|
|
token_ids_1: Optional[List[int]] = None, |
|
|
already_has_special_tokens: bool = False |
|
|
) -> List[int]: |
|
|
if already_has_special_tokens: |
|
|
return super().get_special_tokens_mask( |
|
|
token_ids_0=token_ids_0, |
|
|
token_ids_1=token_ids_1, |
|
|
already_has_special_tokens=True |
|
|
) |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return [1] + ([0] * len(token_ids_0)) + [1] |
|
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
|
|
|
|
|
def create_token_type_ids_from_sequences( |
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|
|
) -> List[int]: |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return [0] * (len(token_ids_0) + 2) |
|
|
return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1) |
|
|
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
|
"""重写save_pretrained以包含auto_map配置""" |
|
|
|
|
|
vocab_files = super().save_pretrained(save_directory, **kwargs) |
|
|
|
|
|
|
|
|
tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json") |
|
|
|
|
|
|
|
|
if os.path.exists(tokenizer_config_path): |
|
|
with open(tokenizer_config_path, "r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
else: |
|
|
config = {} |
|
|
|
|
|
|
|
|
config.update({ |
|
|
"auto_map": { |
|
|
"AutoTokenizer": [ |
|
|
"tokenizer.SingleNucleotideTokenizer", |
|
|
None |
|
|
] |
|
|
}, |
|
|
}) |
|
|
|
|
|
|
|
|
with open(tokenizer_config_path, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return vocab_files |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
|
|
import os |
|
|
|
|
|
|
|
|
if not os.path.exists(save_directory): |
|
|
os.makedirs(save_directory) |
|
|
|
|
|
|
|
|
vocab_file = os.path.join( |
|
|
save_directory, |
|
|
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt" |
|
|
) |
|
|
|
|
|
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
|
for token, idx in sorted(self.vocab.items(), key=lambda x: x[1]): |
|
|
f.write(f"{token} {idx}\n") |
|
|
|
|
|
return (vocab_file,) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): |
|
|
|
|
|
return cls(**kwargs) |