|
|
import json |
|
|
import os |
|
|
from typing import override, Optional |
|
|
from transformers import PreTrainedTokenizer |
|
|
from Bio.Data import IUPACData |
|
|
from Bio.SeqUtils import seq3 |
|
|
from itertools import product |
|
|
|
|
|
|
|
|
class AA2CodonTokenizer(PreTrainedTokenizer): |
|
|
def __init__(self, **kwargs): |
|
|
self._aas = list(IUPACData.protein_letters_1to3.values()) |
|
|
self._codons = ["".join(p) for p in product(list("ATGC"), repeat=3)] |
|
|
|
|
|
special_tokens = { |
|
|
'bos_token': '<s>', |
|
|
'pad_token': '<pad>', |
|
|
'eos_token': '</s>', |
|
|
'unk_token': '<unk>' |
|
|
} |
|
|
|
|
|
self._vocab = self._codons + self._aas + ['*'] + list(special_tokens.values()) |
|
|
|
|
|
self._token_to_id = {token: idx for idx, token in enumerate(self._vocab)} |
|
|
self._id_to_token = {idx: token for idx, token in enumerate(self._vocab)} |
|
|
|
|
|
kwargs.update(special_tokens) |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
@override |
|
|
def _tokenize(self, text: str, **kwargs) -> list[str]: |
|
|
text = text.strip() |
|
|
if text.endswith('*') and text[:-1].isupper(): |
|
|
return [str(seq3(token)) for token in list(text)[:-1]] + ['*'] |
|
|
else: |
|
|
return [text[i:i+3] for i in range(0, len(text), 3)] |
|
|
|
|
|
@override |
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
return self._token_to_id.get(token, self.unk_token_id) |
|
|
|
|
|
@override |
|
|
def _convert_id_to_token(self, index): |
|
|
return self._id_to_token.get(index, self.unk_token) |
|
|
|
|
|
@override |
|
|
def get_vocab(self) -> dict[str, int]: |
|
|
return self._token_to_id.copy() |
|
|
|
|
|
@property |
|
|
@override |
|
|
def vocab_size(self) -> int: |
|
|
return len(self._vocab) |
|
|
|
|
|
@override |
|
|
def convert_tokens_to_string(self, tokens: list[str]) -> str: |
|
|
return "/".join(tokens) |
|
|
|
|
|
@override |
|
|
def build_inputs_with_special_tokens( |
|
|
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None |
|
|
) -> list[int]: |
|
|
return token_ids_0 + [self.eos_token_id] |
|
|
|
|
|
@override |
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: |
|
|
filename = f"{filename_prefix}-" if filename_prefix else "" |
|
|
vocab_file = os.path.join(save_directory, f"{filename}vocab.json") |
|
|
with open(vocab_file, 'w') as f: |
|
|
json.dump(self._token_to_id, f) |
|
|
return (vocab_file, ) |
|
|
|