| | from typing import List, Optional, Union, Dict, Any, Tuple |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import AutoTokenizer |
| | from transformers.processing_utils import ( |
| | CommonKwargs, |
| | ProcessingKwargs, |
| | ProcessorMixin, |
| | Unpack, |
| | ) |
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
| | from transformers.utils import logging |
| |
|
| | from bioreason.utils.dna_utils import DNAInput |
| |
|
| | class DLDNAKwargs(CommonKwargs): |
| | """Keyword arguments specific to DNA processing""" |
| | max_length_text: Optional[int] |
| | max_length_dna: Optional[int] |
| |
|
| |
|
| | class DLProcessorKwargs(ProcessingKwargs, total=False): |
| | """Processing keyword arguments for the DL processor""" |
| | dna_kwargs: DLDNAKwargs |
| | _defaults = { |
| | "text_kwargs": { |
| | "padding": False, |
| | }, |
| | } |
| |
|
| | class DLProcessor(ProcessorMixin): |
| | r""" |
| | Constructs a DL processor which wraps a NucleotideTransformer DNA processor and a Qwen2_5 tokenizer into a single processor. |
| | This processor handles both text and DNA sequence processing to prepare inputs for the DNALLMModel. |
| | |
| | Args: |
| | tokenizer (PreTrainedTokenizerBase, *optional*): |
| | The text tokenizer used for processing text inputs. |
| | dna_tokenizer (PreTrainedTokenizerBase, *optional*): |
| | The DNA tokenizer used for processing DNA sequences. |
| | chat_template (`str`, *optional*): |
| | A Jinja template for chat formatting. If None, will use the tokenizer's template. |
| | """ |
| |
|
| | attributes = ["tokenizer", "dna_tokenizer"] |
| | valid_kwargs = ["model", "chat_template"] |
| | tokenizer_class = ( |
| | "Qwen2Tokenizer", "Qwen2TokenizerFast", |
| | "GPT2TokenizerFast", |
| | ) |
| | dna_tokenizer_class = ("EsmTokenizer", "Evo2Tokenizer") |
| |
|
| | def __init__( |
| | self, tokenizer=None, dna_tokenizer=None, chat_template=None, **kwargs |
| | ): |
| | """ |
| | Initialize the processor with text and DNA tokenizers. |
| | |
| | Args: |
| | tokenizer: Text tokenizer (usually from a language model) |
| | dna_tokenizer: DNA tokenizer (usually from a DNA model) |
| | chat_template: Template for formatting chat conversations |
| | **kwargs: Additional arguments |
| | """ |
| | self.tokenizer = tokenizer |
| | self.dna_tokenizer = dna_tokenizer |
| |
|
| | self.dna_token = ( |
| | "<|dna_pad|>" |
| | if not hasattr(self.tokenizer, "dna_token") |
| | else self.tokenizer.dna_token |
| | ) |
| | |
| | |
| | if chat_template is None and hasattr(self.tokenizer, "chat_template"): |
| | chat_template = self.tokenizer.chat_template |
| | super().__init__(tokenizer, dna_tokenizer, chat_template=chat_template) |
| | |
| | |
| | if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
|
| | def tokenize_dna_sequences( |
| | self, |
| | batch_dna_sequences: List[List[str]], |
| | max_length: int = 2048, |
| | return_tensors: str = "pt", |
| | device: str = "cuda", |
| | ) -> Dict[str, Any]: |
| | """ |
| | Tokenize a batch of DNA sequences. |
| | |
| | Args: |
| | batch_dna_sequences: List of lists of DNA sequences per batch item |
| | max_length: Maximum allowed length for DNA sequences |
| | return_tensors: Return format for tensors ("pt" for PyTorch) |
| | device: Device to place tensors on |
| | |
| | Returns: |
| | Dict containing: |
| | - dna_tokenized: The tokenized DNA sequences |
| | - batch_idx_map: Mapping of which sequences belong to which batch item |
| | """ |
| | |
| | batch_idx_map = [] |
| | all_sequences = [] |
| |
|
| | |
| | for batch_idx, dna_sequences in enumerate(batch_dna_sequences): |
| | for seq in dna_sequences: |
| | all_sequences.append(seq) |
| | batch_idx_map.append(batch_idx) |
| |
|
| | |
| | if not all_sequences: |
| | return {"dna_tokenized": None, "batch_idx_map": []} |
| |
|
| | |
| | dna_tokenized = self.dna_tokenizer( |
| | all_sequences, |
| | padding=True, |
| | truncation=True, |
| | max_length=max_length, |
| | return_tensors=return_tensors, |
| | return_attention_mask=True, |
| | ) |
| | |
| | return {"dna_tokenized": dna_tokenized, "batch_idx_map": batch_idx_map} |
| |
|
| | def __call__( |
| | self, |
| | batch_dna_sequences: Optional[List[List[str]]] = None, |
| | text: Optional[ |
| | Union[ |
| | TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] |
| | ] |
| | ] = None, |
| | max_length_text: int = 512, |
| | max_length_dna: int = 2048, |
| | return_tensors: str = "pt", |
| | device: str = "cuda", |
| | **kwargs: Unpack[DLProcessorKwargs], |
| | ) -> BatchFeature: |
| | """ |
| | Process text and DNA sequences for model input. |
| | |
| | Args: |
| | batch_dna_sequences: List of lists of DNA sequences per batch item |
| | text: Input text or list of texts |
| | max_length_text: Maximum length for text sequences |
| | max_length_dna: Maximum length for DNA sequences |
| | return_tensors: Return format for tensors |
| | device: Device to place tensors on |
| | **kwargs: Additional processor keyword arguments |
| | |
| | Returns: |
| | BatchFeature with tokenized inputs for the model |
| | """ |
| | output_kwargs = self._merge_kwargs( |
| | DLProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | if not isinstance(text, list): |
| | text = [text] |
| |
|
| | |
| | dna_inputs = {} |
| | if batch_dna_sequences is not None: |
| | |
| | dna_processing_result = self.tokenize_dna_sequences( |
| | batch_dna_sequences, |
| | max_length=max_length_dna, |
| | return_tensors=return_tensors, |
| | device=device, |
| | ) |
| | |
| | |
| | index = 0 |
| | for i in range(len(text)): |
| | while self.dna_token in text[i]: |
| | num_dna_tokens = (dna_processing_result['dna_tokenized']['input_ids'][index] != 1).sum().item() |
| | text[i] = text[i].replace( |
| | self.dna_token, "<|placeholder|>" * num_dna_tokens, 1 |
| | ) |
| | index += 1 |
| | text[i] = text[i].replace("<|placeholder|>", self.dna_token) |
| | |
| | |
| | |
| | |
| | dna_inputs = { |
| | |
| | "dna_tokenized": dna_processing_result["dna_tokenized"], |
| | "batch_idx_map": dna_processing_result["batch_idx_map"], |
| | } |
| |
|
| | |
| | text_kwargs = output_kwargs.get("text_kwargs", {}) |
| | |
| | if 'padding' in text_kwargs: |
| | del text_kwargs['padding'] |
| | |
| | |
| | text_inputs = self.tokenizer( |
| | text, |
| | max_length=max_length_text + 2 * max_length_dna, |
| | return_tensors=return_tensors, |
| | padding=True, |
| | truncation=True, |
| | **text_kwargs, |
| | ) |
| | |
| | |
| | return BatchFeature(data={**text_inputs, **dna_inputs}) |
| |
|
| | def batch_decode(self, *args, **kwargs) -> List[str]: |
| | """ |
| | This method forwards all its arguments to the tokenizer's batch_decode. |
| | |
| | Returns: |
| | List of decoded strings |
| | """ |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| | def decode(self, *args, **kwargs) -> str: |
| | """ |
| | This method forwards all its arguments to the tokenizer's decode. |
| | |
| | Returns: |
| | Decoded string |
| | """ |
| | return self.tokenizer.decode(*args, **kwargs) |
| |
|
| | def post_process_dna_to_text( |
| | self, |
| | generated_outputs: torch.Tensor, |
| | skip_special_tokens: bool = True, |
| | **kwargs, |
| | ) -> List[str]: |
| | """ |
| | Post-process the model output to decode the text. |
| | |
| | Args: |
| | generated_outputs: The token IDs generated by the model |
| | skip_special_tokens: Whether to skip special tokens in the output |
| | **kwargs: Additional arguments for the decoder |
| | |
| | Returns: |
| | List of decoded strings |
| | """ |
| | return self.tokenizer.batch_decode( |
| | generated_outputs, |
| | skip_special_tokens=skip_special_tokens, |
| | **kwargs, |
| | ) |
| |
|
| | @property |
| | def model_input_names(self) -> List[str]: |
| | """ |
| | Get the input names expected by the model. |
| | |
| | Returns: |
| | List of input names |
| | """ |
| | tokenizer_input_names = self.tokenizer.model_input_names |
| | dna_input_names = ["dna_tokenized", "batch_idx_map"] |
| | |
| | return list(dict.fromkeys(tokenizer_input_names + dna_input_names)) |
| |
|