| from transformers import ( |
| AutoProcessor, |
| AutoTokenizer, |
| ) |
| from typing import Dict, Any, Union, List, Optional, Callable, Type |
| from trl.data_utils import maybe_apply_chat_template |
| import torch |
|
|
| from bioreason.dna_modules.dna_module import DNABaseModule |
| from model.blip2_stage2 import Blip2Stage2 |
|
|
|
|
| class Blip2DNAModule(DNABaseModule): |
| """ |
| DNA module implementation for BLIP2-based models. |
| |
| This module provides the interface between BLIP2 models and the GRPO training |
| infrastructure, handling model loading, processing setup, and reward functions. |
| """ |
|
|
| def __init__(self): |
| """Initialize the Blip2DNAModule.""" |
| super().__init__() |
|
|
| def get_dnallm_key(self) -> str: |
| """ |
| Get the key identifier for this DNA-LLM implementation. |
| |
| Returns: |
| String identifier for this module type |
| """ |
| return "blip2" |
|
|
| def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type: |
| """ |
| Return the appropriate model class based on model ID. |
| |
| Args: |
| model_id: Identifier for the model |
| model_init_kwargs: Initialization arguments for the model |
| |
| Returns: |
| The model class to instantiate |
| |
| Raises: |
| ValueError: If the model is not supported |
| """ |
| if "blip2" in model_id.lower() or "stage2" in model_id.lower(): |
| model_cls = Blip2Stage2 |
| else: |
| raise ValueError(f"Unsupported model: {model_id}") |
| return model_cls |
|
|
| def post_model_init(self, model: Any, processing_class: Any) -> None: |
| """ |
| Perform any post-initialization setup on the model. |
| |
| Args: |
| model: The initialized model |
| processing_class: The processor for the model |
| """ |
| |
| if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_tokenizer'): |
| |
| if not hasattr(model.blip2.llm_tokenizer, 'pad_token') or model.blip2.llm_tokenizer.pad_token is None: |
| model.blip2.llm_tokenizer.pad_token = model.blip2.llm_tokenizer.eos_token |
|
|
| def get_processing_class(self) -> Type: |
| """ |
| Get the processing class to use with this BLIP2 model. |
| |
| Returns: |
| The processing class |
| """ |
| return Blip2Processor |
|
|
| def get_dnallm_modules_keywords(self) -> List[str]: |
| """ |
| Get keywords to identify DNA-specific modules in the model. |
| |
| Used to exclude DNA modules from LoRA adaptation during training. |
| |
| Returns: |
| List of keywords that identify DNA modules |
| """ |
| return ["plm", "qformer", "opt_proj"] |
|
|
| def get_custom_multimodal_keywords(self) -> List[str]: |
| """ |
| Get keywords for multimodal inputs that should be passed to the model. |
| |
| Returns: |
| List of input keywords for multimodal processing |
| """ |
| return ["prot_batch", "prompt_batch"] |
|
|
| def get_non_generate_params(self) -> List[str]: |
| """ |
| Get parameter names that should be excluded from generation. |
| |
| Returns: |
| List of parameter names to exclude from generation calls |
| """ |
| return ["prot_batch"] |
|
|
| def get_custom_processing_keywords(self) -> List[tuple]: |
| """ |
| Get custom processing keywords for the processor. |
| |
| Returns: |
| List of (component, parameter) tuples for custom processing |
| """ |
| return [("plm_tokenizer", "max_length"), ("llm_tokenizer", "max_length")] |
|
|
| def prepare_prompt( |
| self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]] |
| ) -> List[str]: |
| """ |
| Prepare prompts from input examples. |
| |
| Args: |
| processing_class: The processor to use |
| inputs: List of input examples |
| |
| Returns: |
| List of prepared prompts |
| """ |
| prompts_text = [] |
| for example in inputs: |
| if "prompt" in example: |
| |
| if isinstance(example["prompt"], list) and len(example["prompt"]) > 0: |
| user_content = example["prompt"][0].get("content", "") |
| if isinstance(user_content, list): |
| |
| text_parts = [item.get("text", "") for item in user_content if item.get("type") == "text"] |
| prompt_text = " ".join(text_parts) |
| else: |
| prompt_text = str(user_content) |
| else: |
| prompt_text = str(example["prompt"]) |
| else: |
| prompt_text = "" |
| prompts_text.append(prompt_text) |
| return prompts_text |
|
|
| def prepare_model_inputs( |
| self, |
| processing_class: Any, |
| model: Any, |
| prompts_text: List[str], |
| batch_dna_sequences: List[List[str]], |
| return_tensors: str = "pt", |
| padding: bool = True, |
| padding_side: str = "left", |
| add_special_tokens: bool = False, |
| ) -> Dict[str, Any]: |
| """ |
| Prepare inputs for the BLIP2 model. |
| |
| Args: |
| processing_class: The processor to use |
| model: The model to prepare inputs for |
| prompts_text: List of text prompts |
| batch_dna_sequences: List of lists of DNA sequences (treated as protein sequences) |
| return_tensors: Return format for tensors |
| padding: Whether to pad inputs |
| padding_side: Side to pad on |
| add_special_tokens: Whether to add special tokens |
| |
| Returns: |
| Processed inputs for the model |
| """ |
| |
| blip2_model = model.blip2 if hasattr(model, 'blip2') else model |
| |
| |
| |
| all_sequences = [] |
| for sequences in batch_dna_sequences: |
| all_sequences.extend(sequences) |
| |
| if all_sequences: |
| prot_batch = blip2_model.plm_tokenizer( |
| all_sequences, |
| padding=padding, |
| truncation=True, |
| max_length=512, |
| return_tensors=return_tensors, |
| ) |
| else: |
| |
| prot_batch = { |
| 'input_ids': torch.empty(0, 1, dtype=torch.long), |
| 'attention_mask': torch.empty(0, 1, dtype=torch.long) |
| } |
|
|
| |
| prompt_batch = blip2_model.llm_tokenizer( |
| prompts_text, |
| padding=padding, |
| truncation=True, |
| max_length=256, |
| return_tensors=return_tensors, |
| ) |
|
|
| return { |
| "prot_batch": prot_batch, |
| "prompt_batch": prompt_batch, |
| "input_ids": prompt_batch["input_ids"], |
| "attention_mask": prompt_batch["attention_mask"], |
| } |
|
|
| def is_embeds_input(self) -> bool: |
| """ |
| Whether the model uses embeddings as input (instead of token IDs). |
| |
| Returns: |
| Boolean indicating if the model takes embedding inputs |
| """ |
| return True |
|
|
| @staticmethod |
| def get_question_template() -> str: |
| """ |
| Get the template for formatting questions. |
| |
| Returns: |
| String template for questions |
| """ |
| return "{Question}" |
|
|
| @staticmethod |
| def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]: |
| """ |
| Check if the BLIP2 model output matches a specific format. |
| |
| Args: |
| completions: List of model completions |
| **kwargs: Additional arguments |
| |
| Returns: |
| List of reward scores (1.0 for match, 0.0 for no match) |
| """ |
| import re |
| import os |
| from datetime import datetime |
|
|
| |
| pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>" |
| completion_contents = [completion[0]["content"] for completion in completions] |
| matches = [ |
| re.search(pattern, content, re.DOTALL) is not None |
| for content in completion_contents |
| ] |
|
|
| |
| current_time = datetime.now().strftime("%d-%H-%M-%S-%f") |
| if os.getenv("DEBUG_MODE") == "true": |
| log_path = os.getenv("LOG_PATH") |
| with open( |
| log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8" |
| ) as f: |
| f.write(f"------------- {current_time} Format reward -------------\n") |
| for content, match in zip(completion_contents, matches): |
| f.write(f"Content: {content}\n") |
| f.write(f"Has format: {bool(match)}\n") |
|
|
| return [1.0 if match else 0.0 for match in matches] |
|
|
| @staticmethod |
| def select_reward_func(func: str, task_type: str) -> Callable: |
| """ |
| Select the appropriate reward function based on function name and task type. |
| |
| Args: |
| func: The type of reward function ('accuracy', 'format', etc.) |
| task_type: The type of task ('rec', etc.) |
| |
| Returns: |
| The reward function to use |
| |
| Raises: |
| ValueError: If the function or task type is not supported |
| """ |
| if func == "accuracy": |
| match task_type: |
| case "rec": |
| return Blip2DNAModule.iou_reward |
| case _: |
| raise ValueError(f"Unsupported reward function: {func}") |
| elif func == "format": |
| match task_type: |
| case "rec": |
| return Blip2DNAModule.format_reward_rec |
| case _: |
| raise ValueError(f"Unsupported reward function: {func}") |
| else: |
| raise ValueError(f"Unsupported reward function: {func}") |
|
|
| @staticmethod |
| def iou_reward(completions: List[Dict[str, Any]], **kwargs) -> List[float]: |
| """ |
| Placeholder IoU reward function. |
| |
| Args: |
| completions: List of model completions |
| **kwargs: Additional arguments |
| |
| Returns: |
| List of reward scores |
| """ |
| |
| return [1.0] * len(completions) |
|
|
|
|
| class Blip2Processor: |
| """ |
| Simple processor wrapper for BLIP2 models to maintain compatibility |
| with the GRPO trainer interface. |
| """ |
| |
| def __init__(self, plm_tokenizer=None, llm_tokenizer=None): |
| self.plm_tokenizer = plm_tokenizer |
| self.llm_tokenizer = llm_tokenizer |
| |
| |
| if llm_tokenizer: |
| self.eos_token_id = llm_tokenizer.eos_token_id |
| self.pad_token_id = llm_tokenizer.pad_token_id |
| |
| def __call__(self, *args, **kwargs): |
| """ |
| Process inputs for BLIP2 model. |
| This is a simplified version that delegates to the appropriate tokenizer. |
| """ |
| |
| if self.llm_tokenizer: |
| return self.llm_tokenizer(*args, **kwargs) |
| else: |
| |
| return {"input_ids": torch.tensor([[1]]), "attention_mask": torch.tensor([[1]])} |
| |
| def batch_decode(self, *args, **kwargs): |
| """Decode token sequences.""" |
| if self.llm_tokenizer: |
| return self.llm_tokenizer.batch_decode(*args, **kwargs) |
| else: |
| return [""] |