from typing import List, Optional, Union, Dict, Any, Tuple import torch from torch import nn import torch.nn.functional as F from transformers import AutoTokenizer from transformers.processing_utils import ( CommonKwargs, ProcessingKwargs, ProcessorMixin, Unpack, ) from transformers.feature_extraction_utils import BatchFeature from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging from bioreason.utils.protein_utils import ProteinInput class ProteinLLMProteinKwargs(CommonKwargs): """Keyword arguments specific to protein sequence processing""" max_length_text: Optional[int] max_length_protein: Optional[int] class ProteinLLMProcessorKwargs(ProcessingKwargs, total=False): """Processing keyword arguments for the ProteinLLM processor""" protein_kwargs: ProteinLLMProteinKwargs _defaults = { "text_kwargs": { "padding": False, }, } class ProteinLLMProcessor(ProcessorMixin): r""" Constructs a ProteinLLM processor which wraps an ESM protein tokenizer and a language model tokenizer into a single processor. This processor handles both text and protein sequence processing to prepare inputs for the ProteinLLMModel. Args: tokenizer (PreTrainedTokenizerBase, *optional*): The text tokenizer used for processing text inputs. protein_tokenizer (PreTrainedTokenizerBase, *optional*): The protein tokenizer (ESM) used for processing protein sequences. chat_template (`str`, *optional*): A Jinja template for chat formatting. If None, will use the tokenizer's template. """ attributes = ["tokenizer", "protein_tokenizer"] valid_kwargs = ["model", "chat_template"] tokenizer_class = ( "Qwen2Tokenizer", "Qwen2TokenizerFast", "GPT2TokenizerFast", "LlamaTokenizer", "LlamaTokenizerFast", ) protein_tokenizer_class = ("EsmTokenizer",) def __init__( self, tokenizer=None, protein_tokenizer=None, chat_template=None, **kwargs ): """ Initialize the processor with text and protein tokenizers. Args: tokenizer: Text tokenizer (usually from a language model) protein_tokenizer: Protein tokenizer (usually ESM tokenizer) chat_template: Template for formatting chat conversations **kwargs: Additional arguments """ self.tokenizer = tokenizer self.protein_tokenizer = protein_tokenizer self.protein_token = ( "<|protein_pad|>" if not hasattr(self.tokenizer, "protein_token") else self.tokenizer.protein_token ) # Get chat template from tokenizer if not provided if chat_template is None and hasattr(self.tokenizer, "chat_template"): chat_template = self.tokenizer.chat_template super().__init__(tokenizer, protein_tokenizer, chat_template=chat_template) # The GRPO trainer might expect this to be set if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def tokenize_protein_sequences( self, batch_protein_sequences: List[List[str]], max_length: int = 1024, return_tensors: str = "pt", device: str = "cuda", ) -> Dict[str, Any]: """ Tokenize a batch of protein sequences. Args: batch_protein_sequences: List of lists of protein sequences per batch item max_length: Maximum allowed length for protein sequences return_tensors: Return format for tensors ("pt" for PyTorch) device: Device to place tensors on Returns: Dict containing: - protein_tokenized: The tokenized protein sequences - batch_idx_map: Mapping of which sequences belong to which batch item """ # Create a mapping to track which sequences belong to which batch item batch_idx_map = [] all_sequences = [] # Flatten all sequences with batch tracking for batch_idx, protein_sequences in enumerate(batch_protein_sequences): for seq in protein_sequences: all_sequences.append(seq) batch_idx_map.append(batch_idx) # If no sequences in the entire batch, return empty dict if not all_sequences: return {"protein_tokenized": None, "batch_idx_map": []} # Tokenize all sequences at once protein_tokenized = self.protein_tokenizer( all_sequences, padding=True, truncation=True, max_length=max_length, return_tensors=return_tensors, return_attention_mask=True, add_special_tokens=True, ) return {"protein_tokenized": protein_tokenized, "batch_idx_map": batch_idx_map} def __call__( self, batch_protein_sequences: Optional[List[List[str]]] = None, text: Optional[ Union[ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] ] ] = None, max_length_text: int = 512, max_length_protein: int = 1024, return_tensors: str = "pt", device: str = "cuda", **kwargs: Unpack[ProteinLLMProcessorKwargs], ) -> BatchFeature: """ Process text and protein sequences for model input. Args: batch_protein_sequences: List of lists of protein sequences per batch item text: Input text or list of texts max_length_text: Maximum length for text sequences max_length_protein: Maximum length for protein sequences return_tensors: Return format for tensors device: Device to place tensors on **kwargs: Additional processor keyword arguments Returns: BatchFeature with tokenized inputs for the model """ output_kwargs = self._merge_kwargs( ProteinLLMProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) # Ensure text is a list if not isinstance(text, list): text = [text] protein_inputs = {} if batch_protein_sequences is not None: # Tokenize protein sequences protein_processing_result = self.tokenize_protein_sequences( batch_protein_sequences, max_length=max_length_protein, return_tensors=return_tensors, device=device, ) # Replace protein tokens in text if needed index = 0 for i in range(len(text)): while self.protein_token in text[i]: # For ESM tokenizer, calculate actual tokens excluding special tokens protein_token_ids = protein_processing_result['protein_tokenized']['input_ids'][index] # Exclude BOS and EOS tokens (typically first and last tokens in ESM) num_protein_tokens = (protein_token_ids != self.protein_tokenizer.pad_token_id).sum().item() - 2 num_protein_tokens = max(1, num_protein_tokens) # Ensure at least 1 token text[i] = text[i].replace( self.protein_token, "<|placeholder|>" * num_protein_tokens, 1 ) index += 1 text[i] = text[i].replace("<|placeholder|>", self.protein_token) # Add batch info to the output protein_inputs = { "protein_tokenized": protein_processing_result["protein_tokenized"], "batch_idx_map": protein_processing_result["batch_idx_map"], } # Tokenize text text_kwargs = output_kwargs.get("text_kwargs", {}) if 'padding' in text_kwargs: del text_kwargs['padding'] text_inputs = self.tokenizer( text, max_length=max_length_text + 2 * max_length_protein, return_tensors=return_tensors, padding=True, truncation=True, **text_kwargs, ) # The BatchFeature should have all required fields for the model's forward pass return BatchFeature(data={**text_inputs, **protein_inputs}) def batch_decode(self, *args, **kwargs) -> List[str]: """ This method forwards all its arguments to the tokenizer's batch_decode. Returns: List of decoded strings """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs) -> str: """ This method forwards all its arguments to the tokenizer's decode. Returns: Decoded string """ return self.tokenizer.decode(*args, **kwargs) def post_process_protein_to_text( self, generated_outputs: torch.Tensor, skip_special_tokens: bool = True, **kwargs, ) -> List[str]: """ Post-process the model output to decode the text. Args: generated_outputs: The token IDs generated by the model skip_special_tokens: Whether to skip special tokens in the output **kwargs: Additional arguments for the decoder Returns: List of decoded strings """ return self.tokenizer.batch_decode( generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs, ) @property def model_input_names(self) -> List[str]: """ Get the input names expected by the model. Returns: List of input names """ tokenizer_input_names = self.tokenizer.model_input_names protein_input_names = ["protein_tokenized", "batch_idx_map"] return list(dict.fromkeys(tokenizer_input_names + protein_input_names))