# from transformers import PreTrainedTokenizerFast # from typing import Protocol, runtime_checkable # from tokenizers import Tokenizer # from tokenizers.models import BPE # # SEQUENCE_VOCAB = [ # # "", "", "", "", # # "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", # # "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", # # "O", ".", "-", "|", # # "", # # ] # # SEQUENCE_VOCAB = [ # "", "", "", "", # "A","U","C","G", ".", "-", "|", # "", # ] # @runtime_checkable # class EsmTokenizerBase(Protocol): # def encode(self, *args, **kwargs): ... # # def decode(self, *args, **kwargs): ... # # @property # def mask_token(self) -> str: ... # # @property # def mask_token_id(self) -> int: ... # # @property # def bos_token(self) -> str: ... # # @property # def bos_token_id(self) -> int: ... # # @property # def eos_token(self) -> str: ... # # @property # def eos_token_id(self) -> int: ... # # @property # def pad_token(self) -> str: ... # # @property # def pad_token_id(self) -> int: ... # # @property # def chain_break_token(self) -> str: ... # # @property # def chain_break_token_id(self) -> int: ... # # @property # def all_token_ids(self): ... # # @property # def special_token_ids(self): ... # class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase): # """ # Constructs an ESM tokenizer. # """ # # model_input_names = ["sequence_tokens", "attention_mask"] # # def __init__( # self, # unk_token="", # cls_token="", # pad_token="", # mask_token="", # eos_token="", # chain_break_token="|", # **kwargs, # ): # all_tokens = SEQUENCE_VOCAB # token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} # # # a character-level tokenizer is the same as BPE with no token merges # bpe = BPE(token_to_id, merges=[], unk_token=unk_token) # tokenizer = Tokenizer(bpe) # special_tokens = [ # cls_token, # pad_token, # mask_token, # eos_token, # chain_break_token, # ] # self.cb_token = chain_break_token # additional_special_tokens = [chain_break_token] # # tokenizer.add_special_tokens(special_tokens) # # # This is where we configure the automatic addition of special tokens when we call # # tokenizer(text, add_special_tokens=True). Note that you can also configure how two # # sequences are merged if you want. # tokenizer.post_processor = TemplateProcessing( # type: ignore # single=" $A ", # special_tokens=[ # ("", tokenizer.token_to_id("")), # ("", tokenizer.token_to_id("")), # ], # ) # super().__init__( # tokenizer_object=tokenizer, # unk_token=unk_token, # cls_token=cls_token, # pad_token=pad_token, # mask_token=mask_token, # eos_token=eos_token, # additional_special_tokens=additional_special_tokens, # **kwargs, # ) # # # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here. # @property # def bos_token(self): # return self.cls_token # # @property # def bos_token_id(self): # return self.cls_token_id # # @property # def chain_break_token(self): # return self.cb_token # # @property # def chain_break_token_id(self): # return self.convert_tokens_to_ids(self.chain_break_token) # # @property # def all_token_ids(self): # return list(range(self.vocab_size)) # # @property # def special_token_ids(self): # return self.all_special_ids