| import os |
| import json |
| import re |
| from typing import List, Optional, Tuple, Dict |
| from transformers import PreTrainedTokenizer |
|
|
| class SingleNucleotideTokenizer(PreTrainedTokenizer): |
| def __init__(self, **kwargs): |
| |
| self.vocab_list = [ |
| "<oov>", "<s>", "</s>", "<pad>", "<mask>", |
| "<bog>", "<eog>", "<bok>", "<eok>", "<+>", "<->", |
| "<mam>", "<vrt>", "<inv>", "<pln>", "<fng>", "<prt>", |
| "<arc>", "<bct>", "<mit>", "<plt>", "<plm>", "<vir>", |
| "<cds>", "<pseudo>", "<tRNA>", "<rRNA>", "<ncRNA>", |
| "<sp0>", "<sp1>", "<sp2>", "<sp3>", |
| "A", "C", "G", "<K>", "<M>", "N", "<R>", "<S>", "T", "<W>", "<Y>" |
| ] |
| |
| |
| self.vocab = {token: idx for idx, token in enumerate(self.vocab_list)} |
| self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} |
| self.tokens_to_ids = {token: idx for token, idx in self.vocab.items()} |
| |
| |
| self.unk_token = "N" |
| self.bos_token = "<s>" |
| self.eos_token = "</s>" |
| self.pad_token = "<pad>" |
| self.mask_token = "<mask>" |
| |
| |
| special_tokens_pattern = "|".join(re.escape(token) for token in self.vocab_list if token.startswith("<") and token.endswith(">")) |
| self.special_token_re = re.compile(f"({special_tokens_pattern})") |
| |
| |
| self.normal_token_re = re.compile(r"[ACGTN]") |
| |
| |
| self.unk_token_id = self.vocab[self.unk_token] |
| self.bos_token_id = self.vocab[self.bos_token] |
| self.eos_token_id = self.vocab[self.eos_token] |
| self.pad_token_id = self.vocab[self.pad_token] |
| self.mask_token_id = self.vocab[self.mask_token] |
| |
| |
| super().__init__( |
| unk_token=self.unk_token, |
| bos_token=self.bos_token, |
| eos_token=self.eos_token, |
| pad_token=self.pad_token, |
| mask_token=self.mask_token, |
| **kwargs |
| ) |
| self.clean_up_tokenization_spaces = True |
| |
| @property |
| def vocab_size(self) -> int: |
| return len(self.vocab) |
| |
| def get_vocab(self) -> Dict[str, int]: |
| return self.vocab |
| |
| def _tokenize(self, text: str, **kwargs) -> List[str]: |
| tokens = [] |
| pos = 0 |
| text_length = len(text) |
| |
| while pos < text_length: |
| |
| special_match = self.special_token_re.match(text, pos) |
| if special_match: |
| token = special_match.group() |
| tokens.append(token) |
| pos = special_match.end() |
| continue |
| |
| |
| normal_match = self.normal_token_re.match(text, pos) |
| if normal_match: |
| token = normal_match.group() |
| |
| if token in self.vocab: |
| tokens.append(token) |
| else: |
| tokens.append(self.unk_token) |
| pos = normal_match.end() |
| continue |
| |
| |
| tokens.append(self.unk_token) |
| pos += 1 |
| |
| return tokens |
| |
| def _convert_token_to_id(self, token: str) -> int: |
| return self.vocab.get(token, self.unk_token_id) |
| |
| def _convert_id_to_token(self, index: int) -> str: |
| return self.ids_to_tokens.get(index, self.unk_token) |
| |
| def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| |
| return "".join(tokens) |
| |
| def build_inputs_with_special_tokens( |
| self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
| ) -> List[int]: |
| if token_ids_1 is None: |
| return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
| return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] |
| |
| def get_special_tokens_mask( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| already_has_special_tokens: bool = False |
| ) -> List[int]: |
| if already_has_special_tokens: |
| return super().get_special_tokens_mask( |
| token_ids_0=token_ids_0, |
| token_ids_1=token_ids_1, |
| already_has_special_tokens=True |
| ) |
| |
| if token_ids_1 is None: |
| return [1] + ([0] * len(token_ids_0)) + [1] |
| return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
| |
| def create_token_type_ids_from_sequences( |
| self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
| ) -> List[int]: |
| |
| if token_ids_1 is None: |
| return [0] * (len(token_ids_0) + 2) |
| return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1) |
| |
| def save_pretrained(self, save_directory: str, **kwargs): |
| """重写save_pretrained以包含auto_map配置""" |
| |
| vocab_files = super().save_pretrained(save_directory, **kwargs) |
| |
| |
| tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json") |
| |
| |
| if os.path.exists(tokenizer_config_path): |
| with open(tokenizer_config_path, "r", encoding="utf-8") as f: |
| config = json.load(f) |
| else: |
| config = {} |
| |
| |
| config.update({ |
| "auto_map": { |
| "AutoTokenizer": [ |
| "tokenizer.SingleNucleotideTokenizer", |
| None |
| ] |
| }, |
| }) |
| |
| |
| with open(tokenizer_config_path, "w", encoding="utf-8") as f: |
| json.dump(config, f, ensure_ascii=False, indent=2) |
| |
| return vocab_files |
| |
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| import os |
| |
| |
| if not os.path.exists(save_directory): |
| os.makedirs(save_directory) |
| |
| |
| vocab_file = os.path.join( |
| save_directory, |
| (filename_prefix + "-" if filename_prefix else "") + "vocab.txt" |
| ) |
| |
| |
| with open(vocab_file, "w", encoding="utf-8") as f: |
| for token, idx in sorted(self.vocab.items(), key=lambda x: x[1]): |
| f.write(f"{token} {idx}\n") |
| |
| return (vocab_file,) |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): |
| |
| return cls(**kwargs) |