| # from transformers import PreTrainedTokenizerFast | |
| # from typing import Protocol, runtime_checkable | |
| # from tokenizers import Tokenizer | |
| # from tokenizers.models import BPE | |
| # # SEQUENCE_VOCAB = [ | |
| # # "<cls>", "<pad>", "<eos>", "<unk>", | |
| # # "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", | |
| # # "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", | |
| # # "O", ".", "-", "|", | |
| # # "<mask>", | |
| # # ] | |
| # | |
| # SEQUENCE_VOCAB = [ | |
| # "<cls>", "<pad>", "<eos>", "<unk>", | |
| # "A","U","C","G", ".", "-", "|", | |
| # "<mask>", | |
| # ] | |
| # @runtime_checkable | |
| # class EsmTokenizerBase(Protocol): | |
| # def encode(self, *args, **kwargs): ... | |
| # | |
| # def decode(self, *args, **kwargs): ... | |
| # | |
| # @property | |
| # def mask_token(self) -> str: ... | |
| # | |
| # @property | |
| # def mask_token_id(self) -> int: ... | |
| # | |
| # @property | |
| # def bos_token(self) -> str: ... | |
| # | |
| # @property | |
| # def bos_token_id(self) -> int: ... | |
| # | |
| # @property | |
| # def eos_token(self) -> str: ... | |
| # | |
| # @property | |
| # def eos_token_id(self) -> int: ... | |
| # | |
| # @property | |
| # def pad_token(self) -> str: ... | |
| # | |
| # @property | |
| # def pad_token_id(self) -> int: ... | |
| # | |
| # @property | |
| # def chain_break_token(self) -> str: ... | |
| # | |
| # @property | |
| # def chain_break_token_id(self) -> int: ... | |
| # | |
| # @property | |
| # def all_token_ids(self): ... | |
| # | |
| # @property | |
| # def special_token_ids(self): ... | |
| # class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase): | |
| # """ | |
| # Constructs an ESM tokenizer. | |
| # """ | |
| # | |
| # model_input_names = ["sequence_tokens", "attention_mask"] | |
| # | |
| # def __init__( | |
| # self, | |
| # unk_token="<unk>", | |
| # cls_token="<cls>", | |
| # pad_token="<pad>", | |
| # mask_token="<mask>", | |
| # eos_token="<eos>", | |
| # chain_break_token="|", | |
| # **kwargs, | |
| # ): | |
| # all_tokens = SEQUENCE_VOCAB | |
| # token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} | |
| # | |
| # # a character-level tokenizer is the same as BPE with no token merges | |
| # bpe = BPE(token_to_id, merges=[], unk_token=unk_token) | |
| # tokenizer = Tokenizer(bpe) | |
| # special_tokens = [ | |
| # cls_token, | |
| # pad_token, | |
| # mask_token, | |
| # eos_token, | |
| # chain_break_token, | |
| # ] | |
| # self.cb_token = chain_break_token | |
| # additional_special_tokens = [chain_break_token] | |
| # | |
| # tokenizer.add_special_tokens(special_tokens) | |
| # | |
| # # This is where we configure the automatic addition of special tokens when we call | |
| # # tokenizer(text, add_special_tokens=True). Note that you can also configure how two | |
| # # sequences are merged if you want. | |
| # tokenizer.post_processor = TemplateProcessing( # type: ignore | |
| # single="<cls> $A <eos>", | |
| # special_tokens=[ | |
| # ("<cls>", tokenizer.token_to_id("<cls>")), | |
| # ("<eos>", tokenizer.token_to_id("<eos>")), | |
| # ], | |
| # ) | |
| # super().__init__( | |
| # tokenizer_object=tokenizer, | |
| # unk_token=unk_token, | |
| # cls_token=cls_token, | |
| # pad_token=pad_token, | |
| # mask_token=mask_token, | |
| # eos_token=eos_token, | |
| # additional_special_tokens=additional_special_tokens, | |
| # **kwargs, | |
| # ) | |
| # | |
| # # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here. | |
| # @property | |
| # def bos_token(self): | |
| # return self.cls_token | |
| # | |
| # @property | |
| # def bos_token_id(self): | |
| # return self.cls_token_id | |
| # | |
| # @property | |
| # def chain_break_token(self): | |
| # return self.cb_token | |
| # | |
| # @property | |
| # def chain_break_token_id(self): | |
| # return self.convert_tokens_to_ids(self.chain_break_token) | |
| # | |
| # @property | |
| # def all_token_ids(self): | |
| # return list(range(self.vocab_size)) | |
| # | |
| # @property | |
| # def special_token_ids(self): | |
| # return self.all_special_ids |