RSGM-CDS / tokenization_aa2codon.py
wangtao2001's picture
Upload tokenization_aa2codon.py with huggingface_hub
589b1bb verified
import json
import os
from typing import override, Optional
from transformers import PreTrainedTokenizer
from Bio.Data import IUPACData
from Bio.SeqUtils import seq3
from itertools import product
class AA2CodonTokenizer(PreTrainedTokenizer):
def __init__(self, **kwargs):
self._aas = list(IUPACData.protein_letters_1to3.values())
self._codons = ["".join(p) for p in product(list("ATGC"), repeat=3)]
special_tokens = {
'bos_token': '<s>',
'pad_token': '<pad>',
'eos_token': '</s>',
'unk_token': '<unk>'
}
self._vocab = self._codons + self._aas + ['*'] + list(special_tokens.values())
self._token_to_id = {token: idx for idx, token in enumerate(self._vocab)}
self._id_to_token = {idx: token for idx, token in enumerate(self._vocab)}
kwargs.update(special_tokens)
super().__init__(**kwargs)
@override
def _tokenize(self, text: str, **kwargs) -> list[str]:
text = text.strip()
if text.endswith('*') and text[:-1].isupper():
return [str(seq3(token)) for token in list(text)[:-1]] + ['*']
else:
return [text[i:i+3] for i in range(0, len(text), 3)]
@override
def _convert_token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self.unk_token_id)
@override
def _convert_id_to_token(self, index):
return self._id_to_token.get(index, self.unk_token)
@override
def get_vocab(self) -> dict[str, int]:
return self._token_to_id.copy()
@property
@override
def vocab_size(self) -> int:
return len(self._vocab)
@override
def convert_tokens_to_string(self, tokens: list[str]) -> str:
return "/".join(tokens)
@override
def build_inputs_with_special_tokens(
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
) -> list[int]:
return token_ids_0 + [self.eos_token_id]
@override
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
filename = f"{filename_prefix}-" if filename_prefix else ""
vocab_file = os.path.join(save_directory, f"{filename}vocab.json")
with open(vocab_file, 'w') as f:
json.dump(self._token_to_id, f)
return (vocab_file, )