| | from tokenizers import Tokenizer |
| | from tokenizers.models import BPE |
| | from tokenizers.processors import TemplateProcessing |
| | from transformers import PreTrainedTokenizerFast |
| |
|
| | from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase |
| | from src.data.esm.utils.constants import esm3 as C |
| |
|
| |
|
| | class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase): |
| | """ |
| | Constructs an ESM tokenizer. |
| | """ |
| |
|
| | model_input_names = ["sequence_tokens", "attention_mask"] |
| |
|
| | def __init__( |
| | self, |
| | unk_token="<unk>", |
| | cls_token="<cls>", |
| | pad_token="<pad>", |
| | mask_token="<mask>", |
| | eos_token="<eos>", |
| | chain_break_token="|", |
| | **kwargs, |
| | ): |
| | all_tokens = C.SEQUENCE_VOCAB |
| | token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} |
| |
|
| | |
| | bpe = BPE(token_to_id, merges=[], unk_token=unk_token) |
| | tokenizer = Tokenizer(bpe) |
| | special_tokens = [ |
| | cls_token, |
| | pad_token, |
| | mask_token, |
| | eos_token, |
| | chain_break_token, |
| | ] |
| | self.cb_token = chain_break_token |
| | additional_special_tokens = [chain_break_token] |
| |
|
| | tokenizer.add_special_tokens(special_tokens) |
| |
|
| | |
| | |
| | |
| | tokenizer.post_processor = TemplateProcessing( |
| | single="<cls> $A <eos>", |
| | special_tokens=[ |
| | ("<cls>", tokenizer.token_to_id("<cls>")), |
| | ("<eos>", tokenizer.token_to_id("<eos>")), |
| | ], |
| | ) |
| | super().__init__( |
| | tokenizer_object=tokenizer, |
| | unk_token=unk_token, |
| | cls_token=cls_token, |
| | pad_token=pad_token, |
| | mask_token=mask_token, |
| | eos_token=eos_token, |
| | additional_special_tokens=additional_special_tokens, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | @property |
| | def bos_token(self): |
| | return self.cls_token |
| |
|
| | @property |
| | def bos_token_id(self): |
| | return self.cls_token_id |
| |
|
| | @property |
| | def cls_token(self): |
| | return self._get_token("cls_token") |
| |
|
| | @property |
| | def cls_token_id(self): |
| | return self._get_token_id(self.cls_token) |
| |
|
| | @property |
| | def eos_token(self): |
| | return self._get_token("eos_token") |
| |
|
| | @property |
| | def eos_token_id(self): |
| | return self._get_token_id(self.eos_token) |
| |
|
| | @property |
| | def mask_token(self): |
| | return self._get_token("mask_token") |
| |
|
| | @property |
| | def mask_token_id(self): |
| | return self._get_token_id(self.mask_token) |
| |
|
| | @property |
| | def pad_token(self): |
| | return self._get_token("pad_token") |
| |
|
| | @property |
| | def pad_token_id(self): |
| | return self._get_token_id(self.pad_token) |
| |
|
| | @property |
| | def chain_break_token(self): |
| | return self.cb_token |
| |
|
| | @property |
| | def chain_break_token_id(self): |
| | return self._get_token_id(self.chain_break_token) |
| |
|
| | @property |
| | def all_token_ids(self): |
| | return list(range(self.vocab_size)) |
| |
|
| | @property |
| | def special_token_ids(self): |
| | return self.all_special_ids |
| |
|
| | def _get_token_id(self, token) -> int: |
| | token_id = self.convert_tokens_to_ids(token) |
| | assert isinstance(token_id, int) |
| | return token_id |
| |
|
| | def _get_token(self, token_name: str) -> str: |
| | |
| | |
| | |
| | token_str = self.__getattr__(token_name) |
| | assert isinstance(token_str, str) |
| | return token_str |
| |
|