from typing import List, Optional, Union, Dict, Any, Tuple import torch from torch import nn import torch.nn.functional as F from transformers import AutoTokenizer from transformers.processing_utils import ( CommonKwargs, ProcessingKwargs, ProcessorMixin, Unpack, ) from transformers.feature_extraction_utils import BatchFeature from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging from bioreason.utils.dna_utils import DNAInput class DLDNAKwargs(CommonKwargs): """Keyword arguments specific to DNA processing""" max_length_text: Optional[int] max_length_dna: Optional[int] class DLProcessorKwargs(ProcessingKwargs, total=False): """Processing keyword arguments for the DL processor""" dna_kwargs: DLDNAKwargs _defaults = { "text_kwargs": { "padding": False, }, } class DLProcessor(ProcessorMixin): r""" Constructs a DL processor which wraps a NucleotideTransformer DNA processor and a Qwen2_5 tokenizer into a single processor. This processor handles both text and DNA sequence processing to prepare inputs for the DNALLMModel. Args: tokenizer (PreTrainedTokenizerBase, *optional*): The text tokenizer used for processing text inputs. dna_tokenizer (PreTrainedTokenizerBase, *optional*): The DNA tokenizer used for processing DNA sequences. chat_template (`str`, *optional*): A Jinja template for chat formatting. If None, will use the tokenizer's template. """ attributes = ["tokenizer", "dna_tokenizer"] valid_kwargs = ["model", "chat_template"] tokenizer_class = ( "Qwen2Tokenizer", "Qwen2TokenizerFast", "GPT2TokenizerFast", ) dna_tokenizer_class = ("EsmTokenizer", "Evo2Tokenizer") def __init__( self, tokenizer=None, dna_tokenizer=None, chat_template=None, **kwargs ): """ Initialize the processor with text and DNA tokenizers. Args: tokenizer: Text tokenizer (usually from a language model) dna_tokenizer: DNA tokenizer (usually from a DNA model) chat_template: Template for formatting chat conversations **kwargs: Additional arguments """ self.tokenizer = tokenizer self.dna_tokenizer = dna_tokenizer self.dna_token = ( "<|dna_pad|>" if not hasattr(self.tokenizer, "dna_token") else self.tokenizer.dna_token ) # Get chat template from tokenizer if not provided if chat_template is None and hasattr(self.tokenizer, "chat_template"): chat_template = self.tokenizer.chat_template super().__init__(tokenizer, dna_tokenizer, chat_template=chat_template) # The GRPO trainer might expect this to be set if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def tokenize_dna_sequences( self, batch_dna_sequences: List[List[str]], max_length: int = 2048, return_tensors: str = "pt", device: str = "cuda", ) -> Dict[str, Any]: """ Tokenize a batch of DNA sequences. Args: batch_dna_sequences: List of lists of DNA sequences per batch item max_length: Maximum allowed length for DNA sequences return_tensors: Return format for tensors ("pt" for PyTorch) device: Device to place tensors on Returns: Dict containing: - dna_tokenized: The tokenized DNA sequences - batch_idx_map: Mapping of which sequences belong to which batch item """ # Create a mapping to track which sequences belong to which batch item batch_idx_map = [] all_sequences = [] # Flatten all sequences with batch tracking for batch_idx, dna_sequences in enumerate(batch_dna_sequences): for seq in dna_sequences: all_sequences.append(seq) batch_idx_map.append(batch_idx) # If no sequences in the entire batch, return empty dict if not all_sequences: return {"dna_tokenized": None, "batch_idx_map": []} # Tokenize all sequences at once dna_tokenized = self.dna_tokenizer( all_sequences, padding=True, truncation=True, max_length=max_length, return_tensors=return_tensors, return_attention_mask=True, ) return {"dna_tokenized": dna_tokenized, "batch_idx_map": batch_idx_map} def __call__( self, batch_dna_sequences: Optional[List[List[str]]] = None, text: Optional[ Union[ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] ] ] = None, max_length_text: int = 512, max_length_dna: int = 2048, return_tensors: str = "pt", device: str = "cuda", **kwargs: Unpack[DLProcessorKwargs], ) -> BatchFeature: """ Process text and DNA sequences for model input. Args: batch_dna_sequences: List of lists of DNA sequences per batch item text: Input text or list of texts max_length_text: Maximum length for text sequences max_length_dna: Maximum length for DNA sequences return_tensors: Return format for tensors device: Device to place tensors on **kwargs: Additional processor keyword arguments Returns: BatchFeature with tokenized inputs for the model """ output_kwargs = self._merge_kwargs( DLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) # Ensure text is a list if not isinstance(text, list): text = [text] # flattened_dna_sequences = [dna_sequence for dna_sequences in batch_dna_sequences for dna_sequence in dna_sequences] dna_inputs = {} if batch_dna_sequences is not None: # Tokenize DNA sequences dna_processing_result = self.tokenize_dna_sequences( batch_dna_sequences, max_length=max_length_dna, return_tensors=return_tensors, device=device, ) # Replace DNA tokens in text if needed index = 0 for i in range(len(text)): while self.dna_token in text[i]: num_dna_tokens = (dna_processing_result['dna_tokenized']['input_ids'][index] != 1).sum().item() text[i] = text[i].replace( self.dna_token, "<|placeholder|>" * num_dna_tokens, 1 ) index += 1 text[i] = text[i].replace("<|placeholder|>", self.dna_token) # Add batch info to the output dna_inputs = { # "batch_dna_sequences": batch_dna_sequences, "dna_tokenized": dna_processing_result["dna_tokenized"], "batch_idx_map": dna_processing_result["batch_idx_map"], } # Tokenize text text_kwargs = output_kwargs.get("text_kwargs", {}) if 'padding' in text_kwargs: del text_kwargs['padding'] # print("__call__ (processor):", text) text_inputs = self.tokenizer( text, max_length=max_length_text + 2 * max_length_dna, return_tensors=return_tensors, padding=True, truncation=True, **text_kwargs, ) # The BatchFeature should have all required fields for the model's forward pass return BatchFeature(data={**text_inputs, **dna_inputs}) def batch_decode(self, *args, **kwargs) -> List[str]: """ This method forwards all its arguments to the tokenizer's batch_decode. Returns: List of decoded strings """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs) -> str: """ This method forwards all its arguments to the tokenizer's decode. Returns: Decoded string """ return self.tokenizer.decode(*args, **kwargs) def post_process_dna_to_text( self, generated_outputs: torch.Tensor, skip_special_tokens: bool = True, **kwargs, ) -> List[str]: """ Post-process the model output to decode the text. Args: generated_outputs: The token IDs generated by the model skip_special_tokens: Whether to skip special tokens in the output **kwargs: Additional arguments for the decoder Returns: List of decoded strings """ return self.tokenizer.batch_decode( generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs, ) @property def model_input_names(self) -> List[str]: """ Get the input names expected by the model. Returns: List of input names """ tokenizer_input_names = self.tokenizer.model_input_names dna_input_names = ["dna_tokenized", "batch_idx_map"] return list(dict.fromkeys(tokenizer_input_names + dna_input_names))