from transformers.tokenization_utils import PreTrainedTokenizer from transformers.utils import logging from transformers import AutoTokenizer from transformers.tokenization_utils_base import BatchEncoding import torch import numpy as np from typing import List, Dict, Optional, Union, Tuple logger = logging.get_logger(__name__) class EsmTokenizer(PreTrainedTokenizer): """ Tokenizer for ESM models - wraps the ESM tokenizer to be compatible with HuggingFace interfaces. This tokenizer handles protein sequences (amino acid sequences). """ vocab_files_names = {} # ESM tokenizer doesn't require vocab files model_input_names = ["input_ids", "attention_mask"] # Standard amino acid alphabet used by ESM AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY" def __init__( self, esm_model_name: str = "facebook/esm2_t33_650M_UR50D", bos_token="", eos_token="", pad_token="", unk_token="", mask_token="", **kwargs ): """ Initialize the ESM Tokenizer. Args: esm_model_name: Name of the ESM model to load the tokenizer from bos_token: Beginning of sequence token (CLS token in ESM) eos_token: End of sequence token pad_token: Padding token unk_token: Unknown token mask_token: Mask token for masked language modeling """ # Load the actual ESM tokenizer try: self.esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name, trust_remote_code=True) except: # Fallback to manual tokenizer if auto loading fails self.esm_tokenizer = None self._create_manual_tokenizer() # Set special tokens self._pad_token = pad_token self._eos_token = eos_token self._bos_token = bos_token self._unk_token = unk_token self._mask_token = mask_token # Initialize with special tokens super().__init__( bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, mask_token=mask_token, **kwargs ) # Set token IDs if self.esm_tokenizer is not None: self.pad_token_id = self.esm_tokenizer.pad_token_id self.eos_token_id = self.esm_tokenizer.eos_token_id self.bos_token_id = getattr(self.esm_tokenizer, 'cls_token_id', 0) self.unk_token_id = self.esm_tokenizer.unk_token_id self.mask_token_id = getattr(self.esm_tokenizer, 'mask_token_id', 32) else: # Manual token IDs for fallback self.pad_token_id = 1 self.eos_token_id = 2 self.bos_token_id = 0 # CLS token self.unk_token_id = 3 self.mask_token_id = 32 def _create_manual_tokenizer(self): """Create a manual tokenizer mapping if ESM tokenizer loading fails.""" # Create vocabulary mapping special_tokens = ["", "", "", ""] amino_acids = list(self.AMINO_ACIDS) self.token_to_id = {} self.id_to_token = {} # Add special tokens first for i, token in enumerate(special_tokens): self.token_to_id[token] = i self.id_to_token[i] = token # Add amino acids for i, aa in enumerate(amino_acids): token_id = i + len(special_tokens) self.token_to_id[aa] = token_id self.id_to_token[token_id] = aa # Add mask token mask_id = 32 self.token_to_id[""] = mask_id self.id_to_token[mask_id] = "" self._vocab_size = max(self.id_to_token.keys()) + 1 @property def vocab_size(self) -> int: """Return the vocab size of the tokenizer.""" if self.esm_tokenizer is not None: return self.esm_tokenizer.vocab_size else: return self._vocab_size def get_vocab(self) -> Dict: """Return vocab as a dictionary.""" if self.esm_tokenizer is not None: return self.esm_tokenizer.get_vocab() else: return self.token_to_id.copy() def _tokenize(self, text: str) -> List[str]: """Tokenize a protein sequence string.""" if self.esm_tokenizer is not None: return self.esm_tokenizer.tokenize(text) else: # Manual tokenization - split into individual amino acids tokens = [] for char in text.upper(): if char in self.AMINO_ACIDS: tokens.append(char) else: tokens.append(self._unk_token) return tokens def _convert_token_to_id(self, token: str) -> int: """Convert a token to an id.""" if self.esm_tokenizer is not None: return self.esm_tokenizer.convert_tokens_to_ids(token) else: return self.token_to_id.get(token, self.unk_token_id) def _convert_id_to_token(self, index: int) -> str: """Convert an id to a token.""" if self.esm_tokenizer is not None: return self.esm_tokenizer.convert_ids_to_tokens(index) else: return self.id_to_token.get(index, self._unk_token) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert a sequence of tokens to a single string.""" # Filter out special tokens and join filtered_tokens = [] for token in tokens: if token not in [self._bos_token, self._eos_token, self._pad_token]: filtered_tokens.append(token) return "".join(filtered_tokens) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: """ESM tokenizer doesn't need vocabulary saving, return empty tuple.""" return () def __call__( self, text: Union[str, List[str]], text_pair: Optional[Union[str, List[str]]] = None, padding: Union[bool, str] = False, truncation: Union[bool, str] = False, max_length: Optional[int] = None, return_tensors: Optional[str] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = True, add_special_tokens: bool = True, **kwargs ) -> BatchEncoding: """ Main tokenization method that handles batching and converts to tensors. """ # Use ESM tokenizer if available if self.esm_tokenizer is not None: return self.esm_tokenizer( text=text, text_pair=text_pair, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, add_special_tokens=add_special_tokens, **kwargs ) # Manual tokenization fallback # Handle single string vs list of strings if isinstance(text, str): text = [text] # Tokenize all sequences input_ids_list = [] for seq in text: # Clean sequence (remove spaces, convert to uppercase) seq = seq.replace(" ", "").upper() # Tokenize sequence tokens = self._tokenize(seq) token_ids = [self._convert_token_to_id(token) for token in tokens] # Add special tokens if requested if add_special_tokens: token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id] # Truncate if needed if truncation and max_length and len(token_ids) > max_length: if add_special_tokens: # Keep BOS, truncate middle, keep EOS token_ids = [token_ids[0]] + token_ids[1:max_length-1] + [token_ids[-1]] else: token_ids = token_ids[:max_length] input_ids_list.append(token_ids) # Apply padding if needed if padding: if max_length: max_len = max_length else: max_len = max(len(ids) for ids in input_ids_list) # Create padded sequences and attention masks padded_input_ids = [] attention_mask = [] for ids in input_ids_list: # Apply right padding (pad on the right for protein sequences) padding_length = max_len - len(ids) padded_ids = ids + [self.pad_token_id] * padding_length mask = [1] * len(ids) + [0] * padding_length padded_input_ids.append(padded_ids) attention_mask.append(mask) input_ids_list = padded_input_ids else: # Create attention mask without padding attention_mask = [[1] * len(ids) for ids in input_ids_list] # Create result dictionary result = {"input_ids": input_ids_list} if return_attention_mask: result["attention_mask"] = attention_mask # Convert to tensors if requested if return_tensors == "pt": result = {k: torch.tensor(v) for k, v in result.items()} # Return a BatchEncoding object return BatchEncoding( data=result, tensor_type=return_tensors, prepend_batch_axis=False, encoding=None ) def batch_decode( self, sequences: Union[List[int], List[List[int]], torch.Tensor], skip_special_tokens: bool = True, **kwargs ) -> List[str]: """ Decode a batch of token ids to strings. """ if self.esm_tokenizer is not None: return self.esm_tokenizer.batch_decode(sequences, skip_special_tokens=skip_special_tokens, **kwargs) if isinstance(sequences, torch.Tensor): sequences = sequences.tolist() results = [] for seq in sequences: tokens = [self._convert_id_to_token(token_id) for token_id in seq] if skip_special_tokens: tokens = [token for token in tokens if token not in [ self._bos_token, self._eos_token, self._pad_token, self._unk_token ]] results.append("".join(tokens)) return results def decode( self, token_ids: Union[int, List[int], torch.Tensor], skip_special_tokens: bool = True, **kwargs ) -> str: """ Decode a single sequence of token ids to a string. """ if self.esm_tokenizer is not None: return self.esm_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens, **kwargs) if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist() # Handle both single sequence and batch with one item if not isinstance(token_ids, list) or not token_ids or not isinstance(token_ids[0], (list, torch.Tensor)): # Single sequence tokens = [self._convert_id_to_token(token_id) for token_id in token_ids] if skip_special_tokens: tokens = [token for token in tokens if token not in [ self._bos_token, self._eos_token, self._pad_token, self._unk_token ]] return "".join(tokens) # Batch with one item return self.batch_decode(token_ids, skip_special_tokens, **kwargs)[0] def register_esm_tokenizer(): """Register the EsmTokenizer with HuggingFace's AutoTokenizer.""" AutoTokenizer.register("esm", EsmTokenizer) print("EsmTokenizer registered with AutoTokenizer") if __name__ == "__main__": register_esm_tokenizer()