File size: 2,432 Bytes
589b1bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import json
import os
from typing import override, Optional
from transformers import PreTrainedTokenizer
from Bio.Data import IUPACData
from Bio.SeqUtils import seq3
from itertools import product
class AA2CodonTokenizer(PreTrainedTokenizer):
def __init__(self, **kwargs):
self._aas = list(IUPACData.protein_letters_1to3.values())
self._codons = ["".join(p) for p in product(list("ATGC"), repeat=3)]
special_tokens = {
'bos_token': '<s>',
'pad_token': '<pad>',
'eos_token': '</s>',
'unk_token': '<unk>'
}
self._vocab = self._codons + self._aas + ['*'] + list(special_tokens.values())
self._token_to_id = {token: idx for idx, token in enumerate(self._vocab)}
self._id_to_token = {idx: token for idx, token in enumerate(self._vocab)}
kwargs.update(special_tokens)
super().__init__(**kwargs)
@override
def _tokenize(self, text: str, **kwargs) -> list[str]:
text = text.strip()
if text.endswith('*') and text[:-1].isupper():
return [str(seq3(token)) for token in list(text)[:-1]] + ['*']
else:
return [text[i:i+3] for i in range(0, len(text), 3)]
@override
def _convert_token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self.unk_token_id)
@override
def _convert_id_to_token(self, index):
return self._id_to_token.get(index, self.unk_token)
@override
def get_vocab(self) -> dict[str, int]:
return self._token_to_id.copy()
@property
@override
def vocab_size(self) -> int:
return len(self._vocab)
@override
def convert_tokens_to_string(self, tokens: list[str]) -> str:
return "/".join(tokens)
@override
def build_inputs_with_special_tokens(
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
) -> list[int]:
return token_ids_0 + [self.eos_token_id]
@override
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
filename = f"{filename_prefix}-" if filename_prefix else ""
vocab_file = os.path.join(save_directory, f"{filename}vocab.json")
with open(vocab_file, 'w') as f:
json.dump(self._token_to_id, f)
return (vocab_file, )
|