|
|
import itertools |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
from typing import List, Optional, Tuple |
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
class DNAKmerTokenizer(PreTrainedTokenizer): |
|
|
def __init__(self, k, **kwargs): |
|
|
self.k = k |
|
|
self.special_tokens = [ |
|
|
"<oov>", |
|
|
"<s>", |
|
|
"</s>", |
|
|
"<pad>", |
|
|
"<mask>", |
|
|
"<bog>", |
|
|
"<eog>", |
|
|
"<bok>", |
|
|
"<eok>", |
|
|
"<+>", |
|
|
"<->", |
|
|
"<cds>", |
|
|
"<pseudo>", |
|
|
"<tRNA>", |
|
|
"<rRNA>", |
|
|
"<ncRNA>", |
|
|
"<miscRNA>", |
|
|
"<mam>", |
|
|
"<vrt>", |
|
|
"<inv>", |
|
|
"<pln>", |
|
|
"<fng>", |
|
|
"<prt>", |
|
|
"<arc>", |
|
|
"<bct>", |
|
|
"<mit>", |
|
|
"<plt>", |
|
|
"<plm>", |
|
|
"<vir>", |
|
|
"<sp0>", |
|
|
"<sp1>", |
|
|
"<sp2>", |
|
|
"<crispr_spacer>", |
|
|
"<crispr_repeat>", |
|
|
"<cas1>", |
|
|
"<cas2>", |
|
|
"<tracrrna>", |
|
|
"<cas5>", |
|
|
"<cas3>", |
|
|
"<cas4>", |
|
|
"<cas9>", |
|
|
"<cas7>", |
|
|
"<cas8c>", |
|
|
"<cas6>", |
|
|
"<csm3gr7>", |
|
|
"<csn2>", |
|
|
"<cas10>", |
|
|
"<cas7b>", |
|
|
"<cas6e>", |
|
|
"<cas8e>", |
|
|
"<cas12f>", |
|
|
"<cse2gr11>", |
|
|
"<csx1>", |
|
|
"<csm2gr11>", |
|
|
"<csm4>", |
|
|
"<wyl>", |
|
|
"<cas12a>", |
|
|
"<csm6>", |
|
|
"<deddh>", |
|
|
"<csm5>", |
|
|
"<casr>", |
|
|
"<cas8b1>", |
|
|
"<csx19>", |
|
|
"<csx20>", |
|
|
"<csm5gr7>", |
|
|
"<cas6f>", |
|
|
"<cas8b2>", |
|
|
"<cas5f>", |
|
|
"<rt>", |
|
|
"<cas7f>", |
|
|
"<cas3-cas2>", |
|
|
"<primpol>", |
|
|
"<cas8f>", |
|
|
"<cysh>", |
|
|
"<cas3hd>", |
|
|
"<tnib>", |
|
|
"<csx10gr5>", |
|
|
"<cas8a1>", |
|
|
"<csa3>", |
|
|
"<recd>", |
|
|
"<cmr1gr7>", |
|
|
"<cmr4>", |
|
|
"<cmr6gr7>", |
|
|
"<cmr3gr5>", |
|
|
"<cmr5gr11>", |
|
|
"<cas8b6>", |
|
|
"<csb2>", |
|
|
"<cora>", |
|
|
"<csm4gr5>", |
|
|
"<abieii>", |
|
|
"<can2>", |
|
|
"<cas13d>", |
|
|
"<csb1gr7>", |
|
|
"<iscb-hnh>", |
|
|
"<pd>", |
|
|
"<tnpa>", |
|
|
"<cse2>", |
|
|
"<csb3>", |
|
|
"<csm3>", |
|
|
"<cas13b>", |
|
|
"<unk>", |
|
|
"<csx16>", |
|
|
"<tpr>", |
|
|
"<dhh>", |
|
|
"<2og>", |
|
|
"<cas12m>", |
|
|
"<mem>", |
|
|
"<csf4>", |
|
|
"<hearo>", |
|
|
"<tn7>", |
|
|
"<tniq>", |
|
|
"<csf2>", |
|
|
"<csf3>", |
|
|
"<csf1>", |
|
|
"<cas8b4>", |
|
|
"<tnsd>", |
|
|
"<heat>", |
|
|
"<csx17>", |
|
|
"<cas8u1>", |
|
|
"<csx3>", |
|
|
"<htpx>", |
|
|
"<cas12b>", |
|
|
"<csm2>", |
|
|
"<cas10d>", |
|
|
"<csc2>", |
|
|
"<cmr3>", |
|
|
"<cmr5>", |
|
|
"<csc1gr5>", |
|
|
"<gramp>", |
|
|
"<cmr6>", |
|
|
"<cas8b12>", |
|
|
"<cas11b>", |
|
|
"<cas12c>", |
|
|
"<cas8a4>", |
|
|
"<tnsb>", |
|
|
"<nyn>", |
|
|
"<iscb-nterm>", |
|
|
"<cas8b3>", |
|
|
"<cas8a2>", |
|
|
"<cas5u>", |
|
|
"<csx27>", |
|
|
"<csx21>", |
|
|
"<csx23>", |
|
|
"<tm>", |
|
|
"<cas3d>", |
|
|
"<cas12lambda>", |
|
|
"<tnsc>", |
|
|
"<cas8b5>", |
|
|
"<stand>", |
|
|
"<st>", |
|
|
"<iscb-ruvciii-cterm>", |
|
|
"<cas11>", |
|
|
"<cas11d2>", |
|
|
"<cas12j>", |
|
|
"<cas12d>", |
|
|
"<cas8b8>", |
|
|
"<cmr1>", |
|
|
"<cas12k>", |
|
|
"<cas12g>", |
|
|
"<cas13f>", |
|
|
"<cas8b10>", |
|
|
"<cas13i>", |
|
|
"<toprim>", |
|
|
"<cas12e>", |
|
|
] |
|
|
self.kmers = [ |
|
|
"".join(kmer) for kmer in itertools.product("ATCG", repeat=self.k) |
|
|
] |
|
|
self.vocab = { |
|
|
token: i for i, token in enumerate(self.special_tokens + self.kmers) |
|
|
} |
|
|
self.ids_to_tokens = {v: k for k, v in self.vocab.items()} |
|
|
self.special_token_pattern = re.compile( |
|
|
"|".join(re.escape(token) for token in self.special_tokens) |
|
|
) |
|
|
self.dna_pattern = re.compile(f"[A-Z]{{{self.k}}}|[A-Z]+") |
|
|
self.bos_token = "<s>" |
|
|
self.eos_token = "</s>" |
|
|
self.bos_token_id = self._convert_token_to_id(self.bos_token) |
|
|
self.eos_token_id = self._convert_token_to_id(self.eos_token) |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
return len(self.vocab) |
|
|
|
|
|
def get_vocab(self): |
|
|
return dict(self.vocab) |
|
|
|
|
|
def _tokenize(self, text, **kwargs) -> List[str]: |
|
|
tokens = [] |
|
|
pos = 0 |
|
|
while pos < len(text): |
|
|
special_match = self.special_token_pattern.match(text, pos) |
|
|
if special_match: |
|
|
tokens.append(special_match.group()) |
|
|
pos = special_match.end() |
|
|
else: |
|
|
dna_match = self.dna_pattern.match(text, pos) |
|
|
if dna_match: |
|
|
dna_seq = dna_match.group() |
|
|
tokens.append(dna_seq) |
|
|
pos = dna_match.end() |
|
|
else: |
|
|
tokens.append(text[pos]) |
|
|
pos += 1 |
|
|
return tokens |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
return self.vocab.get(token, self.vocab["<oov>"]) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
return self.ids_to_tokens.get(index, "<oov>") |
|
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
|
return "".join(tokens) |
|
|
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
|
|
if token_ids_1 is None: |
|
|
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
|
|
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] |
|
|
|
|
|
def get_special_tokens_mask( |
|
|
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False |
|
|
): |
|
|
if already_has_special_tokens: |
|
|
return super().get_special_tokens_mask( |
|
|
token_ids_0, token_ids_1, already_has_special_tokens=True |
|
|
) |
|
|
if token_ids_1 is None: |
|
|
return [1] + ([0] * len(token_ids_0)) + [1] |
|
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
|
|
|
|
|
def prepare_for_model(self, *args, **kwargs): |
|
|
encoding = super().prepare_for_model(*args, **kwargs) |
|
|
if "token_type_ids" in encoding: |
|
|
del encoding["token_type_ids"] |
|
|
return encoding |
|
|
|
|
|
def save_vocabulary( |
|
|
self, save_directory: str, filename_prefix: Optional[str] = None |
|
|
) -> Tuple[str]: |
|
|
import os |
|
|
|
|
|
vocab_file = os.path.join( |
|
|
save_directory, |
|
|
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt", |
|
|
) |
|
|
with open(vocab_file, "w", encoding="utf-8") as writer: |
|
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): |
|
|
writer.write(token + "\n") |
|
|
return (vocab_file,) |
|
|
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
|
vocab_files = super().save_pretrained(save_directory, **kwargs) |
|
|
tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json") |
|
|
|
|
|
|
|
|
if os.path.exists(tokenizer_config_path): |
|
|
with open(tokenizer_config_path, "r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
else: |
|
|
config = {} |
|
|
|
|
|
|
|
|
config.update({ |
|
|
"auto_map": { |
|
|
"AutoTokenizer": [ |
|
|
"tokenizer.DNAKmerTokenizer", |
|
|
None |
|
|
] |
|
|
}, |
|
|
}) |
|
|
|
|
|
|
|
|
config.update({ |
|
|
"k": self.k |
|
|
}) |
|
|
|
|
|
|
|
|
with open(tokenizer_config_path, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return vocab_files |
|
|
|