| | from typing import List, Optional, Union, Dict, Any, Tuple |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import AutoTokenizer |
| | from transformers.processing_utils import ( |
| | CommonKwargs, |
| | ProcessingKwargs, |
| | ProcessorMixin, |
| | Unpack, |
| | ) |
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
| | from transformers.utils import logging |
| |
|
| | from bioreason.utils.protein_utils import ProteinInput |
| |
|
| | class ProteinLLMKwargs(CommonKwargs): |
| | """Keyword arguments specific to protein processing""" |
| | max_length_text: Optional[int] |
| | max_length_protein: Optional[int] |
| |
|
| |
|
| | class ProteinLLMProcessorKwargs(ProcessingKwargs, total=False): |
| | """Processing keyword arguments for the ProteinLLM processor""" |
| | protein_kwargs: ProteinLLMKwargs |
| | _defaults = { |
| | "text_kwargs": { |
| | "padding": False, |
| | }, |
| | } |
| |
|
| | class ProteinLLMProcessor(ProcessorMixin): |
| | r""" |
| | Constructs a ProteinLLM processor which wraps a ESM2 protein processor and a Qwen tokenizer into a single processor. |
| | This processor handles both text and protein sequence processing to prepare inputs for the ProteinLLMModel. |
| | |
| | Args: |
| | tokenizer (PreTrainedTokenizerBase, *optional*): |
| | The text tokenizer used for processing text inputs. |
| | protein_tokenizer (PreTrainedTokenizerBase, *optional*): |
| | The protein tokenizer used for processing protein sequences. |
| | chat_template (`str`, *optional*): |
| | A Jinja template for chat formatting. If None, will use the tokenizer's template. |
| | """ |
| |
|
| | attributes = ["tokenizer", "protein_tokenizer"] |
| | valid_kwargs = ["model", "chat_template"] |
| | tokenizer_class = ( |
| | "Qwen2Tokenizer", "Qwen2TokenizerFast", |
| | "GPT2TokenizerFast", |
| | ) |
| | protein_tokenizer_class = ("EsmTokenizer",) |
| |
|
| | def __init__( |
| | self, tokenizer=None, protein_tokenizer=None, chat_template=None, **kwargs |
| | ): |
| | """ |
| | Initialize the processor with text and protein tokenizers. |
| | |
| | Args: |
| | tokenizer: Text tokenizer (usually from a language model) |
| | protein_tokenizer: Protein tokenizer (usually from ESM2) |
| | chat_template: Template for formatting chat conversations |
| | **kwargs: Additional arguments |
| | """ |
| | self.tokenizer = tokenizer |
| | self.protein_tokenizer = protein_tokenizer |
| |
|
| | self.protein_token = ( |
| | "<|protein_pad|>" |
| | if not hasattr(self.tokenizer, "protein_token") |
| | else self.tokenizer.protein_token |
| | ) |
| | |
| | |
| | if chat_template is None and hasattr(self.tokenizer, "chat_template"): |
| | chat_template = self.tokenizer.chat_template |
| | super().__init__(tokenizer, protein_tokenizer, chat_template=chat_template) |
| | |
| | |
| | if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
|
| | def tokenize_protein_sequences( |
| | self, |
| | batch_protein_sequences: List[List[str]], |
| | max_length: int = 1024, |
| | return_tensors: str = "pt", |
| | device: str = "cuda", |
| | ) -> Dict[str, Any]: |
| | """ |
| | Tokenize a batch of protein sequences. |
| | |
| | Args: |
| | batch_protein_sequences: List of lists of protein sequences per batch item |
| | max_length: Maximum allowed length for protein sequences |
| | return_tensors: Return format for tensors ("pt" for PyTorch) |
| | device: Device to place tensors on |
| | |
| | Returns: |
| | Dict containing: |
| | - protein_tokenized: The tokenized protein sequences |
| | - batch_idx_map: Mapping of which sequences belong to which batch item |
| | """ |
| | |
| | batch_idx_map = [] |
| | all_sequences = [] |
| |
|
| | |
| | for batch_idx, protein_sequences in enumerate(batch_protein_sequences): |
| | for seq in protein_sequences: |
| | all_sequences.append(seq) |
| | batch_idx_map.append(batch_idx) |
| |
|
| | |
| | if not all_sequences: |
| | return {"protein_tokenized": None, "batch_idx_map": []} |
| |
|
| | |
| | protein_tokenized = self.protein_tokenizer( |
| | all_sequences, |
| | padding=True, |
| | truncation=True, |
| | max_length=max_length, |
| | return_tensors=return_tensors, |
| | return_attention_mask=True, |
| | ) |
| | |
| | |
| | if return_tensors == "pt" and device is not None: |
| | protein_tokenized = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| | for k, v in protein_tokenized.items()} |
| | |
| | return {"protein_tokenized": protein_tokenized, "batch_idx_map": batch_idx_map} |
| |
|
| | def __call__( |
| | self, |
| | batch_protein_sequences: Optional[List[List[str]]] = None, |
| | text: Optional[ |
| | Union[ |
| | TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] |
| | ] |
| | ] = None, |
| | max_length_text: int = 512, |
| | max_length_protein: int = 1024, |
| | return_tensors: str = "pt", |
| | device: str = "cuda", |
| | **kwargs: Unpack[ProteinLLMProcessorKwargs], |
| | ) -> BatchFeature: |
| | """ |
| | Process text and protein sequences for model input. |
| | |
| | Args: |
| | batch_protein_sequences: List of lists of protein sequences per batch item |
| | text: Input text or list of texts |
| | max_length_text: Maximum length for text sequences |
| | max_length_protein: Maximum length for protein sequences |
| | return_tensors: Return format for tensors |
| | device: Device to place tensors on |
| | **kwargs: Additional processor keyword arguments |
| | |
| | Returns: |
| | BatchFeature with tokenized inputs for the model |
| | """ |
| | output_kwargs = self._merge_kwargs( |
| | ProteinLLMProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | if not isinstance(text, list): |
| | text = [text] |
| |
|
| | protein_inputs = {} |
| | if batch_protein_sequences is not None: |
| | |
| | protein_processing_result = self.tokenize_protein_sequences( |
| | batch_protein_sequences, |
| | max_length=max_length_protein, |
| | return_tensors=return_tensors, |
| | device=device, |
| | ) |
| | |
| | |
| | index = 0 |
| | for i in range(len(text)): |
| | while self.protein_token in text[i]: |
| | num_protein_tokens = (protein_processing_result['protein_tokenized']['input_ids'][index] != self.protein_tokenizer.pad_token_id).sum().item() |
| | text[i] = text[i].replace( |
| | self.protein_token, "<|placeholder|>" * num_protein_tokens, 1 |
| | ) |
| | index += 1 |
| | text[i] = text[i].replace("<|placeholder|>", self.protein_token) |
| | |
| | |
| | protein_inputs = { |
| | "protein_tokenized": protein_processing_result["protein_tokenized"], |
| | "batch_idx_map": protein_processing_result["batch_idx_map"], |
| | } |
| |
|
| | |
| | text_kwargs = output_kwargs.get("text_kwargs", {}) |
| | |
| | if 'padding' in text_kwargs: |
| | del text_kwargs['padding'] |
| | |
| | text_inputs = self.tokenizer( |
| | text, |
| | max_length=max_length_text + 2 * max_length_protein, |
| | return_tensors=return_tensors, |
| | padding=True, |
| | truncation=True, |
| | **text_kwargs, |
| | ) |
| | |
| | if return_tensors == "pt" and device is not None: |
| | text_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| | for k, v in text_inputs.items()} |
| | |
| | |
| | return BatchFeature(data={**text_inputs, **protein_inputs}) |
| |
|
| | def batch_decode(self, *args, **kwargs) -> List[str]: |
| | """ |
| | This method forwards all its arguments to the tokenizer's batch_decode. |
| | |
| | Returns: |
| | List of decoded strings |
| | """ |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| | def decode(self, *args, **kwargs) -> str: |
| | """ |
| | This method forwards all its arguments to the tokenizer's decode. |
| | |
| | Returns: |
| | Decoded string |
| | """ |
| | return self.tokenizer.decode(*args, **kwargs) |
| |
|
| | def post_process_protein_to_text( |
| | self, |
| | generated_outputs: torch.Tensor, |
| | skip_special_tokens: bool = True, |
| | **kwargs, |
| | ) -> List[str]: |
| | """ |
| | Post-process the model output to decode the text. |
| | |
| | Args: |
| | generated_outputs: The token IDs generated by the model |
| | skip_special_tokens: Whether to skip special tokens in the output |
| | **kwargs: Additional arguments for the decoder |
| | |
| | Returns: |
| | List of decoded strings |
| | """ |
| | return self.tokenizer.batch_decode( |
| | generated_outputs, |
| | skip_special_tokens=skip_special_tokens, |
| | **kwargs, |
| | ) |
| |
|
| | @property |
| | def model_input_names(self) -> List[str]: |
| | """ |
| | Get the input names expected by the model. |
| | |
| | Returns: |
| | List of input names |
| | """ |
| | tokenizer_input_names = self.tokenizer.model_input_names |
| | protein_input_names = ["protein_tokenized", "batch_idx_map"] |
| | |
| | return list(dict.fromkeys(tokenizer_input_names + protein_input_names)) |