| | import json |
| | import os |
| | from typing import override, Optional |
| | from transformers import PreTrainedTokenizer |
| | from Bio.Data import IUPACData |
| | from Bio.SeqUtils import seq3 |
| | from itertools import product |
| |
|
| |
|
| | class AA2CodonTokenizer(PreTrainedTokenizer): |
| | def __init__(self, **kwargs): |
| | self._aas = list(IUPACData.protein_letters_1to3.values()) |
| | self._codons = ["".join(p) for p in product(list("ATGC"), repeat=3)] |
| | |
| | special_tokens = { |
| | 'bos_token': '<s>', |
| | 'pad_token': '<pad>', |
| | 'eos_token': '</s>', |
| | 'unk_token': '<unk>' |
| | } |
| | |
| | self._vocab = self._codons + self._aas + ['*'] + list(special_tokens.values()) |
| | |
| | self._token_to_id = {token: idx for idx, token in enumerate(self._vocab)} |
| | self._id_to_token = {idx: token for idx, token in enumerate(self._vocab)} |
| | |
| | kwargs.update(special_tokens) |
| | super().__init__(**kwargs) |
| | |
| | @override |
| | def _tokenize(self, text: str, **kwargs) -> list[str]: |
| | text = text.strip() |
| | if text.endswith('*') and text[:-1].isupper(): |
| | return [str(seq3(token)) for token in list(text)[:-1]] + ['*'] |
| | else: |
| | return [text[i:i+3] for i in range(0, len(text), 3)] |
| | |
| | @override |
| | def _convert_token_to_id(self, token: str) -> int: |
| | return self._token_to_id.get(token, self.unk_token_id) |
| | |
| | @override |
| | def _convert_id_to_token(self, index): |
| | return self._id_to_token.get(index, self.unk_token) |
| | |
| | @override |
| | def get_vocab(self) -> dict[str, int]: |
| | return self._token_to_id.copy() |
| |
|
| | @property |
| | @override |
| | def vocab_size(self) -> int: |
| | return len(self._vocab) |
| | |
| | @override |
| | def convert_tokens_to_string(self, tokens: list[str]) -> str: |
| | return "/".join(tokens) |
| | |
| | @override |
| | def build_inputs_with_special_tokens( |
| | self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None |
| | ) -> list[int]: |
| | return token_ids_0 + [self.eos_token_id] |
| | |
| | @override |
| | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: |
| | filename = f"{filename_prefix}-" if filename_prefix else "" |
| | vocab_file = os.path.join(save_directory, f"{filename}vocab.json") |
| | with open(vocab_file, 'w') as f: |
| | json.dump(self._token_to_id, f) |
| | return (vocab_file, ) |
| |
|