nas / BioReason /bioreason /models /pl /processing_pl.py
yuccaaa's picture
Add files using upload-large-folder tool
ffcfc75 verified
from typing import List, Optional, Union, Dict, Any, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from transformers.processing_utils import (
CommonKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.feature_extraction_utils import BatchFeature
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from bioreason.utils.protein_utils import ProteinInput
class ProteinLLMProteinKwargs(CommonKwargs):
"""Keyword arguments specific to protein sequence processing"""
max_length_text: Optional[int]
max_length_protein: Optional[int]
class ProteinLLMProcessorKwargs(ProcessingKwargs, total=False):
"""Processing keyword arguments for the ProteinLLM processor"""
protein_kwargs: ProteinLLMProteinKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
}
class ProteinLLMProcessor(ProcessorMixin):
r"""
Constructs a ProteinLLM processor which wraps an ESM protein tokenizer and a language model tokenizer into a single processor.
This processor handles both text and protein sequence processing to prepare inputs for the ProteinLLMModel.
Args:
tokenizer (PreTrainedTokenizerBase, *optional*):
The text tokenizer used for processing text inputs.
protein_tokenizer (PreTrainedTokenizerBase, *optional*):
The protein tokenizer (ESM) used for processing protein sequences.
chat_template (`str`, *optional*):
A Jinja template for chat formatting. If None, will use the tokenizer's template.
"""
attributes = ["tokenizer", "protein_tokenizer"]
valid_kwargs = ["model", "chat_template"]
tokenizer_class = (
"Qwen2Tokenizer", "Qwen2TokenizerFast",
"GPT2TokenizerFast", "LlamaTokenizer", "LlamaTokenizerFast",
)
protein_tokenizer_class = ("EsmTokenizer",)
def __init__(
self, tokenizer=None, protein_tokenizer=None, chat_template=None, **kwargs
):
"""
Initialize the processor with text and protein tokenizers.
Args:
tokenizer: Text tokenizer (usually from a language model)
protein_tokenizer: Protein tokenizer (usually ESM tokenizer)
chat_template: Template for formatting chat conversations
**kwargs: Additional arguments
"""
self.tokenizer = tokenizer
self.protein_tokenizer = protein_tokenizer
self.protein_token = (
"<|protein_pad|>"
if not hasattr(self.tokenizer, "protein_token")
else self.tokenizer.protein_token
)
# Get chat template from tokenizer if not provided
if chat_template is None and hasattr(self.tokenizer, "chat_template"):
chat_template = self.tokenizer.chat_template
super().__init__(tokenizer, protein_tokenizer, chat_template=chat_template)
# The GRPO trainer might expect this to be set
if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def tokenize_protein_sequences(
self,
batch_protein_sequences: List[List[str]],
max_length: int = 1024,
return_tensors: str = "pt",
device: str = "cuda",
) -> Dict[str, Any]:
"""
Tokenize a batch of protein sequences.
Args:
batch_protein_sequences: List of lists of protein sequences per batch item
max_length: Maximum allowed length for protein sequences
return_tensors: Return format for tensors ("pt" for PyTorch)
device: Device to place tensors on
Returns:
Dict containing:
- protein_tokenized: The tokenized protein sequences
- batch_idx_map: Mapping of which sequences belong to which batch item
"""
# Create a mapping to track which sequences belong to which batch item
batch_idx_map = []
all_sequences = []
# Flatten all sequences with batch tracking
for batch_idx, protein_sequences in enumerate(batch_protein_sequences):
for seq in protein_sequences:
all_sequences.append(seq)
batch_idx_map.append(batch_idx)
# If no sequences in the entire batch, return empty dict
if not all_sequences:
return {"protein_tokenized": None, "batch_idx_map": []}
# Tokenize all sequences at once
protein_tokenized = self.protein_tokenizer(
all_sequences,
padding=True,
truncation=True,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=True,
add_special_tokens=True,
)
return {"protein_tokenized": protein_tokenized, "batch_idx_map": batch_idx_map}
def __call__(
self,
batch_protein_sequences: Optional[List[List[str]]] = None,
text: Optional[
Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
]
] = None,
max_length_text: int = 512,
max_length_protein: int = 1024,
return_tensors: str = "pt",
device: str = "cuda",
**kwargs: Unpack[ProteinLLMProcessorKwargs],
) -> BatchFeature:
"""
Process text and protein sequences for model input.
Args:
batch_protein_sequences: List of lists of protein sequences per batch item
text: Input text or list of texts
max_length_text: Maximum length for text sequences
max_length_protein: Maximum length for protein sequences
return_tensors: Return format for tensors
device: Device to place tensors on
**kwargs: Additional processor keyword arguments
Returns:
BatchFeature with tokenized inputs for the model
"""
output_kwargs = self._merge_kwargs(
ProteinLLMProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Ensure text is a list
if not isinstance(text, list):
text = [text]
protein_inputs = {}
if batch_protein_sequences is not None:
# Tokenize protein sequences
protein_processing_result = self.tokenize_protein_sequences(
batch_protein_sequences,
max_length=max_length_protein,
return_tensors=return_tensors,
device=device,
)
# Replace protein tokens in text if needed
index = 0
for i in range(len(text)):
while self.protein_token in text[i]:
# For ESM tokenizer, calculate actual tokens excluding special tokens
protein_token_ids = protein_processing_result['protein_tokenized']['input_ids'][index]
# Exclude BOS and EOS tokens (typically first and last tokens in ESM)
num_protein_tokens = (protein_token_ids != self.protein_tokenizer.pad_token_id).sum().item() - 2
num_protein_tokens = max(1, num_protein_tokens) # Ensure at least 1 token
text[i] = text[i].replace(
self.protein_token, "<|placeholder|>" * num_protein_tokens, 1
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.protein_token)
# Add batch info to the output
protein_inputs = {
"protein_tokenized": protein_processing_result["protein_tokenized"],
"batch_idx_map": protein_processing_result["batch_idx_map"],
}
# Tokenize text
text_kwargs = output_kwargs.get("text_kwargs", {})
if 'padding' in text_kwargs:
del text_kwargs['padding']
text_inputs = self.tokenizer(
text,
max_length=max_length_text + 2 * max_length_protein,
return_tensors=return_tensors,
padding=True,
truncation=True,
**text_kwargs,
)
# The BatchFeature should have all required fields for the model's forward pass
return BatchFeature(data={**text_inputs, **protein_inputs})
def batch_decode(self, *args, **kwargs) -> List[str]:
"""
This method forwards all its arguments to the tokenizer's batch_decode.
Returns:
List of decoded strings
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs) -> str:
"""
This method forwards all its arguments to the tokenizer's decode.
Returns:
Decoded string
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_protein_to_text(
self,
generated_outputs: torch.Tensor,
skip_special_tokens: bool = True,
**kwargs,
) -> List[str]:
"""
Post-process the model output to decode the text.
Args:
generated_outputs: The token IDs generated by the model
skip_special_tokens: Whether to skip special tokens in the output
**kwargs: Additional arguments for the decoder
Returns:
List of decoded strings
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
**kwargs,
)
@property
def model_input_names(self) -> List[str]:
"""
Get the input names expected by the model.
Returns:
List of input names
"""
tokenizer_input_names = self.tokenizer.model_input_names
protein_input_names = ["protein_tokenized", "batch_idx_map"]
return list(dict.fromkeys(tokenizer_input_names + protein_input_names))