from transformers import ( AutoProcessor, AutoTokenizer, ) from typing import Dict, Any, Union, List, Optional, Callable, Type from trl.data_utils import maybe_apply_chat_template import torch from bioreason.dna_modules.dna_module import DNABaseModule from model.blip2_stage2 import Blip2Stage2 class Blip2DNAModule(DNABaseModule): """ DNA module implementation for BLIP2-based models. This module provides the interface between BLIP2 models and the GRPO training infrastructure, handling model loading, processing setup, and reward functions. """ def __init__(self): """Initialize the Blip2DNAModule.""" super().__init__() def get_dnallm_key(self) -> str: """ Get the key identifier for this DNA-LLM implementation. Returns: String identifier for this module type """ return "blip2" def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type: """ Return the appropriate model class based on model ID. Args: model_id: Identifier for the model model_init_kwargs: Initialization arguments for the model Returns: The model class to instantiate Raises: ValueError: If the model is not supported """ if "blip2" in model_id.lower() or "stage2" in model_id.lower(): model_cls = Blip2Stage2 else: raise ValueError(f"Unsupported model: {model_id}") return model_cls def post_model_init(self, model: Any, processing_class: Any) -> None: """ Perform any post-initialization setup on the model. Args: model: The initialized model processing_class: The processor for the model """ # BLIP2 models might need specific post-init setup if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_tokenizer'): # Ensure the tokenizer is properly configured if not hasattr(model.blip2.llm_tokenizer, 'pad_token') or model.blip2.llm_tokenizer.pad_token is None: model.blip2.llm_tokenizer.pad_token = model.blip2.llm_tokenizer.eos_token def get_processing_class(self) -> Type: """ Get the processing class to use with this BLIP2 model. Returns: The processing class """ return Blip2Processor def get_dnallm_modules_keywords(self) -> List[str]: """ Get keywords to identify DNA-specific modules in the model. Used to exclude DNA modules from LoRA adaptation during training. Returns: List of keywords that identify DNA modules """ return ["plm", "qformer", "opt_proj"] def get_custom_multimodal_keywords(self) -> List[str]: """ Get keywords for multimodal inputs that should be passed to the model. Returns: List of input keywords for multimodal processing """ return ["prot_batch", "prompt_batch"] def get_non_generate_params(self) -> List[str]: """ Get parameter names that should be excluded from generation. Returns: List of parameter names to exclude from generation calls """ return ["prot_batch"] def get_custom_processing_keywords(self) -> List[tuple]: """ Get custom processing keywords for the processor. Returns: List of (component, parameter) tuples for custom processing """ return [("plm_tokenizer", "max_length"), ("llm_tokenizer", "max_length")] def prepare_prompt( self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]] ) -> List[str]: """ Prepare prompts from input examples. Args: processing_class: The processor to use inputs: List of input examples Returns: List of prepared prompts """ prompts_text = [] for example in inputs: if "prompt" in example: # Extract text content from conversational format if isinstance(example["prompt"], list) and len(example["prompt"]) > 0: user_content = example["prompt"][0].get("content", "") if isinstance(user_content, list): # Extract text from multimodal content text_parts = [item.get("text", "") for item in user_content if item.get("type") == "text"] prompt_text = " ".join(text_parts) else: prompt_text = str(user_content) else: prompt_text = str(example["prompt"]) else: prompt_text = "" prompts_text.append(prompt_text) return prompts_text def prepare_model_inputs( self, processing_class: Any, model: Any, prompts_text: List[str], batch_dna_sequences: List[List[str]], return_tensors: str = "pt", padding: bool = True, padding_side: str = "left", add_special_tokens: bool = False, ) -> Dict[str, Any]: """ Prepare inputs for the BLIP2 model. Args: processing_class: The processor to use model: The model to prepare inputs for prompts_text: List of text prompts batch_dna_sequences: List of lists of DNA sequences (treated as protein sequences) return_tensors: Return format for tensors padding: Whether to pad inputs padding_side: Side to pad on add_special_tokens: Whether to add special tokens Returns: Processed inputs for the model """ # Get the BLIP2 model from the wrapper blip2_model = model.blip2 if hasattr(model, 'blip2') else model # Prepare protein batch (using DNA sequences as protein sequences) # Flatten all DNA sequences to treat them as individual protein sequences all_sequences = [] for sequences in batch_dna_sequences: all_sequences.extend(sequences) if all_sequences: prot_batch = blip2_model.plm_tokenizer( all_sequences, padding=padding, truncation=True, max_length=512, # Default protein sequence length return_tensors=return_tensors, ) else: # Empty batch handling prot_batch = { 'input_ids': torch.empty(0, 1, dtype=torch.long), 'attention_mask': torch.empty(0, 1, dtype=torch.long) } # Prepare prompt batch prompt_batch = blip2_model.llm_tokenizer( prompts_text, padding=padding, truncation=True, max_length=256, # Default prompt length return_tensors=return_tensors, ) return { "prot_batch": prot_batch, "prompt_batch": prompt_batch, "input_ids": prompt_batch["input_ids"], # For compatibility "attention_mask": prompt_batch["attention_mask"], # For compatibility } def is_embeds_input(self) -> bool: """ Whether the model uses embeddings as input (instead of token IDs). Returns: Boolean indicating if the model takes embedding inputs """ return True # BLIP2 uses embeddings internally @staticmethod def get_question_template() -> str: """ Get the template for formatting questions. Returns: String template for questions """ return "{Question}" @staticmethod def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]: """ Check if the BLIP2 model output matches a specific format. Args: completions: List of model completions **kwargs: Additional arguments Returns: List of reward scores (1.0 for match, 0.0 for no match) """ import re import os from datetime import datetime # Pattern to match the expected output format pattern = r".*?\s*.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [ re.search(pattern, content, re.DOTALL) is not None for content in completion_contents ] # Log format results if in debug mode current_time = datetime.now().strftime("%d-%H-%M-%S-%f") if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") with open( log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8" ) as f: f.write(f"------------- {current_time} Format reward -------------\n") for content, match in zip(completion_contents, matches): f.write(f"Content: {content}\n") f.write(f"Has format: {bool(match)}\n") return [1.0 if match else 0.0 for match in matches] @staticmethod def select_reward_func(func: str, task_type: str) -> Callable: """ Select the appropriate reward function based on function name and task type. Args: func: The type of reward function ('accuracy', 'format', etc.) task_type: The type of task ('rec', etc.) Returns: The reward function to use Raises: ValueError: If the function or task type is not supported """ if func == "accuracy": match task_type: case "rec": return Blip2DNAModule.iou_reward case _: raise ValueError(f"Unsupported reward function: {func}") elif func == "format": match task_type: case "rec": return Blip2DNAModule.format_reward_rec case _: raise ValueError(f"Unsupported reward function: {func}") else: raise ValueError(f"Unsupported reward function: {func}") @staticmethod def iou_reward(completions: List[Dict[str, Any]], **kwargs) -> List[float]: """ Placeholder IoU reward function. Args: completions: List of model completions **kwargs: Additional arguments Returns: List of reward scores """ # Placeholder implementation return [1.0] * len(completions) class Blip2Processor: """ Simple processor wrapper for BLIP2 models to maintain compatibility with the GRPO trainer interface. """ def __init__(self, plm_tokenizer=None, llm_tokenizer=None): self.plm_tokenizer = plm_tokenizer self.llm_tokenizer = llm_tokenizer # Set compatibility attributes if llm_tokenizer: self.eos_token_id = llm_tokenizer.eos_token_id self.pad_token_id = llm_tokenizer.pad_token_id def __call__(self, *args, **kwargs): """ Process inputs for BLIP2 model. This is a simplified version that delegates to the appropriate tokenizer. """ # For compatibility, return a simple tokenization result if self.llm_tokenizer: return self.llm_tokenizer(*args, **kwargs) else: # Fallback behavior return {"input_ids": torch.tensor([[1]]), "attention_mask": torch.tensor([[1]])} def batch_decode(self, *args, **kwargs): """Decode token sequences.""" if self.llm_tokenizer: return self.llm_tokenizer.batch_decode(*args, **kwargs) else: return [""]