| | from transformers.tokenization_utils import PreTrainedTokenizer |
| | from transformers.utils import logging |
| | from transformers import AutoTokenizer |
| | from transformers.tokenization_utils_base import BatchEncoding |
| | import torch |
| | import numpy as np |
| | from typing import List, Dict, Optional, Union, Tuple |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | class EsmTokenizer(PreTrainedTokenizer): |
| | """ |
| | Tokenizer for ESM models - wraps the ESM tokenizer to be compatible with HuggingFace interfaces. |
| | This tokenizer handles protein sequences (amino acid sequences). |
| | """ |
| | vocab_files_names = {} |
| | model_input_names = ["input_ids", "attention_mask"] |
| | |
| | |
| | AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY" |
| | |
| | def __init__( |
| | self, |
| | esm_model_name: str = "facebook/esm2_t33_650M_UR50D", |
| | bos_token="<cls>", |
| | eos_token="<eos>", |
| | pad_token="<pad>", |
| | unk_token="<unk>", |
| | mask_token="<mask>", |
| | **kwargs |
| | ): |
| | """ |
| | Initialize the ESM Tokenizer. |
| | |
| | Args: |
| | esm_model_name: Name of the ESM model to load the tokenizer from |
| | bos_token: Beginning of sequence token (CLS token in ESM) |
| | eos_token: End of sequence token |
| | pad_token: Padding token |
| | unk_token: Unknown token |
| | mask_token: Mask token for masked language modeling |
| | """ |
| | |
| | try: |
| | self.esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name, trust_remote_code=True) |
| | except: |
| | |
| | self.esm_tokenizer = None |
| | self._create_manual_tokenizer() |
| | |
| | |
| | self._pad_token = pad_token |
| | self._eos_token = eos_token |
| | self._bos_token = bos_token |
| | self._unk_token = unk_token |
| | self._mask_token = mask_token |
| | |
| | |
| | super().__init__( |
| | bos_token=bos_token, |
| | eos_token=eos_token, |
| | pad_token=pad_token, |
| | unk_token=unk_token, |
| | mask_token=mask_token, |
| | **kwargs |
| | ) |
| | |
| | |
| | if self.esm_tokenizer is not None: |
| | self.pad_token_id = self.esm_tokenizer.pad_token_id |
| | self.eos_token_id = self.esm_tokenizer.eos_token_id |
| | self.bos_token_id = getattr(self.esm_tokenizer, 'cls_token_id', 0) |
| | self.unk_token_id = self.esm_tokenizer.unk_token_id |
| | self.mask_token_id = getattr(self.esm_tokenizer, 'mask_token_id', 32) |
| | else: |
| | |
| | self.pad_token_id = 1 |
| | self.eos_token_id = 2 |
| | self.bos_token_id = 0 |
| | self.unk_token_id = 3 |
| | self.mask_token_id = 32 |
| | |
| | def _create_manual_tokenizer(self): |
| | """Create a manual tokenizer mapping if ESM tokenizer loading fails.""" |
| | |
| | special_tokens = ["<cls>", "<pad>", "<eos>", "<unk>"] |
| | amino_acids = list(self.AMINO_ACIDS) |
| | |
| | self.token_to_id = {} |
| | self.id_to_token = {} |
| | |
| | |
| | for i, token in enumerate(special_tokens): |
| | self.token_to_id[token] = i |
| | self.id_to_token[i] = token |
| | |
| | |
| | for i, aa in enumerate(amino_acids): |
| | token_id = i + len(special_tokens) |
| | self.token_to_id[aa] = token_id |
| | self.id_to_token[token_id] = aa |
| | |
| | |
| | mask_id = 32 |
| | self.token_to_id["<mask>"] = mask_id |
| | self.id_to_token[mask_id] = "<mask>" |
| | |
| | self._vocab_size = max(self.id_to_token.keys()) + 1 |
| | |
| | @property |
| | def vocab_size(self) -> int: |
| | """Return the vocab size of the tokenizer.""" |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer.vocab_size |
| | else: |
| | return self._vocab_size |
| | |
| | def get_vocab(self) -> Dict: |
| | """Return vocab as a dictionary.""" |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer.get_vocab() |
| | else: |
| | return self.token_to_id.copy() |
| | |
| | def _tokenize(self, text: str) -> List[str]: |
| | """Tokenize a protein sequence string.""" |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer.tokenize(text) |
| | else: |
| | |
| | tokens = [] |
| | for char in text.upper(): |
| | if char in self.AMINO_ACIDS: |
| | tokens.append(char) |
| | else: |
| | tokens.append(self._unk_token) |
| | return tokens |
| | |
| | def _convert_token_to_id(self, token: str) -> int: |
| | """Convert a token to an id.""" |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer.convert_tokens_to_ids(token) |
| | else: |
| | return self.token_to_id.get(token, self.unk_token_id) |
| | |
| | def _convert_id_to_token(self, index: int) -> str: |
| | """Convert an id to a token.""" |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer.convert_ids_to_tokens(index) |
| | else: |
| | return self.id_to_token.get(index, self._unk_token) |
| | |
| | def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| | """Convert a sequence of tokens to a single string.""" |
| | |
| | filtered_tokens = [] |
| | for token in tokens: |
| | if token not in [self._bos_token, self._eos_token, self._pad_token]: |
| | filtered_tokens.append(token) |
| | return "".join(filtered_tokens) |
| | |
| | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| | """ESM tokenizer doesn't need vocabulary saving, return empty tuple.""" |
| | return () |
| | |
| | def __call__( |
| | self, |
| | text: Union[str, List[str]], |
| | text_pair: Optional[Union[str, List[str]]] = None, |
| | padding: Union[bool, str] = False, |
| | truncation: Union[bool, str] = False, |
| | max_length: Optional[int] = None, |
| | return_tensors: Optional[str] = None, |
| | return_token_type_ids: Optional[bool] = None, |
| | return_attention_mask: Optional[bool] = True, |
| | add_special_tokens: bool = True, |
| | **kwargs |
| | ) -> BatchEncoding: |
| | """ |
| | Main tokenization method that handles batching and converts to tensors. |
| | """ |
| | |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer( |
| | text=text, |
| | text_pair=text_pair, |
| | padding=padding, |
| | truncation=truncation, |
| | max_length=max_length, |
| | return_tensors=return_tensors, |
| | return_token_type_ids=return_token_type_ids, |
| | return_attention_mask=return_attention_mask, |
| | add_special_tokens=add_special_tokens, |
| | **kwargs |
| | ) |
| | |
| | |
| | |
| | if isinstance(text, str): |
| | text = [text] |
| | |
| | |
| | input_ids_list = [] |
| | for seq in text: |
| | |
| | seq = seq.replace(" ", "").upper() |
| | |
| | |
| | tokens = self._tokenize(seq) |
| | token_ids = [self._convert_token_to_id(token) for token in tokens] |
| | |
| | |
| | if add_special_tokens: |
| | token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id] |
| | |
| | |
| | if truncation and max_length and len(token_ids) > max_length: |
| | if add_special_tokens: |
| | |
| | token_ids = [token_ids[0]] + token_ids[1:max_length-1] + [token_ids[-1]] |
| | else: |
| | token_ids = token_ids[:max_length] |
| | |
| | input_ids_list.append(token_ids) |
| | |
| | |
| | if padding: |
| | if max_length: |
| | max_len = max_length |
| | else: |
| | max_len = max(len(ids) for ids in input_ids_list) |
| | |
| | |
| | padded_input_ids = [] |
| | attention_mask = [] |
| | |
| | for ids in input_ids_list: |
| | |
| | padding_length = max_len - len(ids) |
| | padded_ids = ids + [self.pad_token_id] * padding_length |
| | mask = [1] * len(ids) + [0] * padding_length |
| | |
| | padded_input_ids.append(padded_ids) |
| | attention_mask.append(mask) |
| | |
| | input_ids_list = padded_input_ids |
| | else: |
| | |
| | attention_mask = [[1] * len(ids) for ids in input_ids_list] |
| | |
| | |
| | result = {"input_ids": input_ids_list} |
| | if return_attention_mask: |
| | result["attention_mask"] = attention_mask |
| | |
| | |
| | if return_tensors == "pt": |
| | result = {k: torch.tensor(v) for k, v in result.items()} |
| | |
| | |
| | return BatchEncoding( |
| | data=result, |
| | tensor_type=return_tensors, |
| | prepend_batch_axis=False, |
| | encoding=None |
| | ) |
| | |
| | def batch_decode( |
| | self, |
| | sequences: Union[List[int], List[List[int]], torch.Tensor], |
| | skip_special_tokens: bool = True, |
| | **kwargs |
| | ) -> List[str]: |
| | """ |
| | Decode a batch of token ids to strings. |
| | """ |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer.batch_decode(sequences, skip_special_tokens=skip_special_tokens, **kwargs) |
| | |
| | if isinstance(sequences, torch.Tensor): |
| | sequences = sequences.tolist() |
| | |
| | results = [] |
| | for seq in sequences: |
| | tokens = [self._convert_id_to_token(token_id) for token_id in seq] |
| | if skip_special_tokens: |
| | tokens = [token for token in tokens if token not in [ |
| | self._bos_token, self._eos_token, self._pad_token, self._unk_token |
| | ]] |
| | results.append("".join(tokens)) |
| | |
| | return results |
| | |
| | def decode( |
| | self, |
| | token_ids: Union[int, List[int], torch.Tensor], |
| | skip_special_tokens: bool = True, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Decode a single sequence of token ids to a string. |
| | """ |
| | if self.esm_tokenizer is not None: |
| | return self.esm_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens, **kwargs) |
| | |
| | if isinstance(token_ids, torch.Tensor): |
| | token_ids = token_ids.tolist() |
| | |
| | |
| | if not isinstance(token_ids, list) or not token_ids or not isinstance(token_ids[0], (list, torch.Tensor)): |
| | |
| | tokens = [self._convert_id_to_token(token_id) for token_id in token_ids] |
| | if skip_special_tokens: |
| | tokens = [token for token in tokens if token not in [ |
| | self._bos_token, self._eos_token, self._pad_token, self._unk_token |
| | ]] |
| | return "".join(tokens) |
| | |
| | |
| | return self.batch_decode(token_ids, skip_special_tokens, **kwargs)[0] |
| |
|
| |
|
| | def register_esm_tokenizer(): |
| | """Register the EsmTokenizer with HuggingFace's AutoTokenizer.""" |
| | AutoTokenizer.register("esm", EsmTokenizer) |
| | print("EsmTokenizer registered with AutoTokenizer") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | register_esm_tokenizer() |