import json import os from typing import override, Optional from transformers import PreTrainedTokenizer from Bio.Data import IUPACData from Bio.SeqUtils import seq3 from itertools import product class AA2CodonTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): self._aas = list(IUPACData.protein_letters_1to3.values()) self._codons = ["".join(p) for p in product(list("ATGC"), repeat=3)] special_tokens = { 'bos_token': '', 'pad_token': '', 'eos_token': '', 'unk_token': '' } self._vocab = self._codons + self._aas + ['*'] + list(special_tokens.values()) self._token_to_id = {token: idx for idx, token in enumerate(self._vocab)} self._id_to_token = {idx: token for idx, token in enumerate(self._vocab)} kwargs.update(special_tokens) super().__init__(**kwargs) @override def _tokenize(self, text: str, **kwargs) -> list[str]: text = text.strip() if text.endswith('*') and text[:-1].isupper(): return [str(seq3(token)) for token in list(text)[:-1]] + ['*'] else: return [text[i:i+3] for i in range(0, len(text), 3)] @override def _convert_token_to_id(self, token: str) -> int: return self._token_to_id.get(token, self.unk_token_id) @override def _convert_id_to_token(self, index): return self._id_to_token.get(index, self.unk_token) @override def get_vocab(self) -> dict[str, int]: return self._token_to_id.copy() @property @override def vocab_size(self) -> int: return len(self._vocab) @override def convert_tokens_to_string(self, tokens: list[str]) -> str: return "/".join(tokens) @override def build_inputs_with_special_tokens( self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None ) -> list[int]: return token_ids_0 + [self.eos_token_id] @override def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: filename = f"{filename_prefix}-" if filename_prefix else "" vocab_file = os.path.join(save_directory, f"{filename}vocab.json") with open(vocab_file, 'w') as f: json.dump(self._token_to_id, f) return (vocab_file, )