nas / BioReason /bioreason /models /esm_tokenizer.py
yuccaaa's picture
Add files using upload-large-folder tool
ffcfc75 verified
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import BatchEncoding
import torch
import numpy as np
from typing import List, Dict, Optional, Union, Tuple
logger = logging.get_logger(__name__)
class EsmTokenizer(PreTrainedTokenizer):
"""
Tokenizer for ESM models - wraps the ESM tokenizer to be compatible with HuggingFace interfaces.
This tokenizer handles protein sequences (amino acid sequences).
"""
vocab_files_names = {} # ESM tokenizer doesn't require vocab files
model_input_names = ["input_ids", "attention_mask"]
# Standard amino acid alphabet used by ESM
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
def __init__(
self,
esm_model_name: str = "facebook/esm2_t33_650M_UR50D",
bos_token="<cls>",
eos_token="<eos>",
pad_token="<pad>",
unk_token="<unk>",
mask_token="<mask>",
**kwargs
):
"""
Initialize the ESM Tokenizer.
Args:
esm_model_name: Name of the ESM model to load the tokenizer from
bos_token: Beginning of sequence token (CLS token in ESM)
eos_token: End of sequence token
pad_token: Padding token
unk_token: Unknown token
mask_token: Mask token for masked language modeling
"""
# Load the actual ESM tokenizer
try:
self.esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name, trust_remote_code=True)
except:
# Fallback to manual tokenizer if auto loading fails
self.esm_tokenizer = None
self._create_manual_tokenizer()
# Set special tokens
self._pad_token = pad_token
self._eos_token = eos_token
self._bos_token = bos_token
self._unk_token = unk_token
self._mask_token = mask_token
# Initialize with special tokens
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
mask_token=mask_token,
**kwargs
)
# Set token IDs
if self.esm_tokenizer is not None:
self.pad_token_id = self.esm_tokenizer.pad_token_id
self.eos_token_id = self.esm_tokenizer.eos_token_id
self.bos_token_id = getattr(self.esm_tokenizer, 'cls_token_id', 0)
self.unk_token_id = self.esm_tokenizer.unk_token_id
self.mask_token_id = getattr(self.esm_tokenizer, 'mask_token_id', 32)
else:
# Manual token IDs for fallback
self.pad_token_id = 1
self.eos_token_id = 2
self.bos_token_id = 0 # CLS token
self.unk_token_id = 3
self.mask_token_id = 32
def _create_manual_tokenizer(self):
"""Create a manual tokenizer mapping if ESM tokenizer loading fails."""
# Create vocabulary mapping
special_tokens = ["<cls>", "<pad>", "<eos>", "<unk>"]
amino_acids = list(self.AMINO_ACIDS)
self.token_to_id = {}
self.id_to_token = {}
# Add special tokens first
for i, token in enumerate(special_tokens):
self.token_to_id[token] = i
self.id_to_token[i] = token
# Add amino acids
for i, aa in enumerate(amino_acids):
token_id = i + len(special_tokens)
self.token_to_id[aa] = token_id
self.id_to_token[token_id] = aa
# Add mask token
mask_id = 32
self.token_to_id["<mask>"] = mask_id
self.id_to_token[mask_id] = "<mask>"
self._vocab_size = max(self.id_to_token.keys()) + 1
@property
def vocab_size(self) -> int:
"""Return the vocab size of the tokenizer."""
if self.esm_tokenizer is not None:
return self.esm_tokenizer.vocab_size
else:
return self._vocab_size
def get_vocab(self) -> Dict:
"""Return vocab as a dictionary."""
if self.esm_tokenizer is not None:
return self.esm_tokenizer.get_vocab()
else:
return self.token_to_id.copy()
def _tokenize(self, text: str) -> List[str]:
"""Tokenize a protein sequence string."""
if self.esm_tokenizer is not None:
return self.esm_tokenizer.tokenize(text)
else:
# Manual tokenization - split into individual amino acids
tokens = []
for char in text.upper():
if char in self.AMINO_ACIDS:
tokens.append(char)
else:
tokens.append(self._unk_token)
return tokens
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to an id."""
if self.esm_tokenizer is not None:
return self.esm_tokenizer.convert_tokens_to_ids(token)
else:
return self.token_to_id.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
"""Convert an id to a token."""
if self.esm_tokenizer is not None:
return self.esm_tokenizer.convert_ids_to_tokens(index)
else:
return self.id_to_token.get(index, self._unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert a sequence of tokens to a single string."""
# Filter out special tokens and join
filtered_tokens = []
for token in tokens:
if token not in [self._bos_token, self._eos_token, self._pad_token]:
filtered_tokens.append(token)
return "".join(filtered_tokens)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""ESM tokenizer doesn't need vocabulary saving, return empty tuple."""
return ()
def __call__(
self,
text: Union[str, List[str]],
text_pair: Optional[Union[str, List[str]]] = None,
padding: Union[bool, str] = False,
truncation: Union[bool, str] = False,
max_length: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = True,
add_special_tokens: bool = True,
**kwargs
) -> BatchEncoding:
"""
Main tokenization method that handles batching and converts to tensors.
"""
# Use ESM tokenizer if available
if self.esm_tokenizer is not None:
return self.esm_tokenizer(
text=text,
text_pair=text_pair,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
add_special_tokens=add_special_tokens,
**kwargs
)
# Manual tokenization fallback
# Handle single string vs list of strings
if isinstance(text, str):
text = [text]
# Tokenize all sequences
input_ids_list = []
for seq in text:
# Clean sequence (remove spaces, convert to uppercase)
seq = seq.replace(" ", "").upper()
# Tokenize sequence
tokens = self._tokenize(seq)
token_ids = [self._convert_token_to_id(token) for token in tokens]
# Add special tokens if requested
if add_special_tokens:
token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
# Truncate if needed
if truncation and max_length and len(token_ids) > max_length:
if add_special_tokens:
# Keep BOS, truncate middle, keep EOS
token_ids = [token_ids[0]] + token_ids[1:max_length-1] + [token_ids[-1]]
else:
token_ids = token_ids[:max_length]
input_ids_list.append(token_ids)
# Apply padding if needed
if padding:
if max_length:
max_len = max_length
else:
max_len = max(len(ids) for ids in input_ids_list)
# Create padded sequences and attention masks
padded_input_ids = []
attention_mask = []
for ids in input_ids_list:
# Apply right padding (pad on the right for protein sequences)
padding_length = max_len - len(ids)
padded_ids = ids + [self.pad_token_id] * padding_length
mask = [1] * len(ids) + [0] * padding_length
padded_input_ids.append(padded_ids)
attention_mask.append(mask)
input_ids_list = padded_input_ids
else:
# Create attention mask without padding
attention_mask = [[1] * len(ids) for ids in input_ids_list]
# Create result dictionary
result = {"input_ids": input_ids_list}
if return_attention_mask:
result["attention_mask"] = attention_mask
# Convert to tensors if requested
if return_tensors == "pt":
result = {k: torch.tensor(v) for k, v in result.items()}
# Return a BatchEncoding object
return BatchEncoding(
data=result,
tensor_type=return_tensors,
prepend_batch_axis=False,
encoding=None
)
def batch_decode(
self,
sequences: Union[List[int], List[List[int]], torch.Tensor],
skip_special_tokens: bool = True,
**kwargs
) -> List[str]:
"""
Decode a batch of token ids to strings.
"""
if self.esm_tokenizer is not None:
return self.esm_tokenizer.batch_decode(sequences, skip_special_tokens=skip_special_tokens, **kwargs)
if isinstance(sequences, torch.Tensor):
sequences = sequences.tolist()
results = []
for seq in sequences:
tokens = [self._convert_id_to_token(token_id) for token_id in seq]
if skip_special_tokens:
tokens = [token for token in tokens if token not in [
self._bos_token, self._eos_token, self._pad_token, self._unk_token
]]
results.append("".join(tokens))
return results
def decode(
self,
token_ids: Union[int, List[int], torch.Tensor],
skip_special_tokens: bool = True,
**kwargs
) -> str:
"""
Decode a single sequence of token ids to a string.
"""
if self.esm_tokenizer is not None:
return self.esm_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens, **kwargs)
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
# Handle both single sequence and batch with one item
if not isinstance(token_ids, list) or not token_ids or not isinstance(token_ids[0], (list, torch.Tensor)):
# Single sequence
tokens = [self._convert_id_to_token(token_id) for token_id in token_ids]
if skip_special_tokens:
tokens = [token for token in tokens if token not in [
self._bos_token, self._eos_token, self._pad_token, self._unk_token
]]
return "".join(tokens)
# Batch with one item
return self.batch_decode(token_ids, skip_special_tokens, **kwargs)[0]
def register_esm_tokenizer():
"""Register the EsmTokenizer with HuggingFace's AutoTokenizer."""
AutoTokenizer.register("esm", EsmTokenizer)
print("EsmTokenizer registered with AutoTokenizer")
if __name__ == "__main__":
register_esm_tokenizer()