nas / BioReason-0813 /blip2_dna_module.py
yuccaaa's picture
Add files using upload-large-folder tool
acbfbc3 verified
from transformers import (
AutoProcessor,
AutoTokenizer,
)
from typing import Dict, Any, Union, List, Optional, Callable, Type
from trl.data_utils import maybe_apply_chat_template
import torch
from bioreason.dna_modules.dna_module import DNABaseModule
from model.blip2_stage2 import Blip2Stage2
class Blip2DNAModule(DNABaseModule):
"""
DNA module implementation for BLIP2-based models.
This module provides the interface between BLIP2 models and the GRPO training
infrastructure, handling model loading, processing setup, and reward functions.
"""
def __init__(self):
"""Initialize the Blip2DNAModule."""
super().__init__()
def get_dnallm_key(self) -> str:
"""
Get the key identifier for this DNA-LLM implementation.
Returns:
String identifier for this module type
"""
return "blip2"
def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
"""
Return the appropriate model class based on model ID.
Args:
model_id: Identifier for the model
model_init_kwargs: Initialization arguments for the model
Returns:
The model class to instantiate
Raises:
ValueError: If the model is not supported
"""
if "blip2" in model_id.lower() or "stage2" in model_id.lower():
model_cls = Blip2Stage2
else:
raise ValueError(f"Unsupported model: {model_id}")
return model_cls
def post_model_init(self, model: Any, processing_class: Any) -> None:
"""
Perform any post-initialization setup on the model.
Args:
model: The initialized model
processing_class: The processor for the model
"""
# BLIP2 models might need specific post-init setup
if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_tokenizer'):
# Ensure the tokenizer is properly configured
if not hasattr(model.blip2.llm_tokenizer, 'pad_token') or model.blip2.llm_tokenizer.pad_token is None:
model.blip2.llm_tokenizer.pad_token = model.blip2.llm_tokenizer.eos_token
def get_processing_class(self) -> Type:
"""
Get the processing class to use with this BLIP2 model.
Returns:
The processing class
"""
return Blip2Processor
def get_dnallm_modules_keywords(self) -> List[str]:
"""
Get keywords to identify DNA-specific modules in the model.
Used to exclude DNA modules from LoRA adaptation during training.
Returns:
List of keywords that identify DNA modules
"""
return ["plm", "qformer", "opt_proj"]
def get_custom_multimodal_keywords(self) -> List[str]:
"""
Get keywords for multimodal inputs that should be passed to the model.
Returns:
List of input keywords for multimodal processing
"""
return ["prot_batch", "prompt_batch"]
def get_non_generate_params(self) -> List[str]:
"""
Get parameter names that should be excluded from generation.
Returns:
List of parameter names to exclude from generation calls
"""
return ["prot_batch"]
def get_custom_processing_keywords(self) -> List[tuple]:
"""
Get custom processing keywords for the processor.
Returns:
List of (component, parameter) tuples for custom processing
"""
return [("plm_tokenizer", "max_length"), ("llm_tokenizer", "max_length")]
def prepare_prompt(
self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]]
) -> List[str]:
"""
Prepare prompts from input examples.
Args:
processing_class: The processor to use
inputs: List of input examples
Returns:
List of prepared prompts
"""
prompts_text = []
for example in inputs:
if "prompt" in example:
# Extract text content from conversational format
if isinstance(example["prompt"], list) and len(example["prompt"]) > 0:
user_content = example["prompt"][0].get("content", "")
if isinstance(user_content, list):
# Extract text from multimodal content
text_parts = [item.get("text", "") for item in user_content if item.get("type") == "text"]
prompt_text = " ".join(text_parts)
else:
prompt_text = str(user_content)
else:
prompt_text = str(example["prompt"])
else:
prompt_text = ""
prompts_text.append(prompt_text)
return prompts_text
def prepare_model_inputs(
self,
processing_class: Any,
model: Any,
prompts_text: List[str],
batch_dna_sequences: List[List[str]],
return_tensors: str = "pt",
padding: bool = True,
padding_side: str = "left",
add_special_tokens: bool = False,
) -> Dict[str, Any]:
"""
Prepare inputs for the BLIP2 model.
Args:
processing_class: The processor to use
model: The model to prepare inputs for
prompts_text: List of text prompts
batch_dna_sequences: List of lists of DNA sequences (treated as protein sequences)
return_tensors: Return format for tensors
padding: Whether to pad inputs
padding_side: Side to pad on
add_special_tokens: Whether to add special tokens
Returns:
Processed inputs for the model
"""
# Get the BLIP2 model from the wrapper
blip2_model = model.blip2 if hasattr(model, 'blip2') else model
# Prepare protein batch (using DNA sequences as protein sequences)
# Flatten all DNA sequences to treat them as individual protein sequences
all_sequences = []
for sequences in batch_dna_sequences:
all_sequences.extend(sequences)
if all_sequences:
prot_batch = blip2_model.plm_tokenizer(
all_sequences,
padding=padding,
truncation=True,
max_length=512, # Default protein sequence length
return_tensors=return_tensors,
)
else:
# Empty batch handling
prot_batch = {
'input_ids': torch.empty(0, 1, dtype=torch.long),
'attention_mask': torch.empty(0, 1, dtype=torch.long)
}
# Prepare prompt batch
prompt_batch = blip2_model.llm_tokenizer(
prompts_text,
padding=padding,
truncation=True,
max_length=256, # Default prompt length
return_tensors=return_tensors,
)
return {
"prot_batch": prot_batch,
"prompt_batch": prompt_batch,
"input_ids": prompt_batch["input_ids"], # For compatibility
"attention_mask": prompt_batch["attention_mask"], # For compatibility
}
def is_embeds_input(self) -> bool:
"""
Whether the model uses embeddings as input (instead of token IDs).
Returns:
Boolean indicating if the model takes embedding inputs
"""
return True # BLIP2 uses embeddings internally
@staticmethod
def get_question_template() -> str:
"""
Get the template for formatting questions.
Returns:
String template for questions
"""
return "{Question}"
@staticmethod
def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]:
"""
Check if the BLIP2 model output matches a specific format.
Args:
completions: List of model completions
**kwargs: Additional arguments
Returns:
List of reward scores (1.0 for match, 0.0 for no match)
"""
import re
import os
from datetime import datetime
# Pattern to match the expected output format
pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [
re.search(pattern, content, re.DOTALL) is not None
for content in completion_contents
]
# Log format results if in debug mode
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
with open(
log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8"
) as f:
f.write(f"------------- {current_time} Format reward -------------\n")
for content, match in zip(completion_contents, matches):
f.write(f"Content: {content}\n")
f.write(f"Has format: {bool(match)}\n")
return [1.0 if match else 0.0 for match in matches]
@staticmethod
def select_reward_func(func: str, task_type: str) -> Callable:
"""
Select the appropriate reward function based on function name and task type.
Args:
func: The type of reward function ('accuracy', 'format', etc.)
task_type: The type of task ('rec', etc.)
Returns:
The reward function to use
Raises:
ValueError: If the function or task type is not supported
"""
if func == "accuracy":
match task_type:
case "rec":
return Blip2DNAModule.iou_reward
case _:
raise ValueError(f"Unsupported reward function: {func}")
elif func == "format":
match task_type:
case "rec":
return Blip2DNAModule.format_reward_rec
case _:
raise ValueError(f"Unsupported reward function: {func}")
else:
raise ValueError(f"Unsupported reward function: {func}")
@staticmethod
def iou_reward(completions: List[Dict[str, Any]], **kwargs) -> List[float]:
"""
Placeholder IoU reward function.
Args:
completions: List of model completions
**kwargs: Additional arguments
Returns:
List of reward scores
"""
# Placeholder implementation
return [1.0] * len(completions)
class Blip2Processor:
"""
Simple processor wrapper for BLIP2 models to maintain compatibility
with the GRPO trainer interface.
"""
def __init__(self, plm_tokenizer=None, llm_tokenizer=None):
self.plm_tokenizer = plm_tokenizer
self.llm_tokenizer = llm_tokenizer
# Set compatibility attributes
if llm_tokenizer:
self.eos_token_id = llm_tokenizer.eos_token_id
self.pad_token_id = llm_tokenizer.pad_token_id
def __call__(self, *args, **kwargs):
"""
Process inputs for BLIP2 model.
This is a simplified version that delegates to the appropriate tokenizer.
"""
# For compatibility, return a simple tokenization result
if self.llm_tokenizer:
return self.llm_tokenizer(*args, **kwargs)
else:
# Fallback behavior
return {"input_ids": torch.tensor([[1]]), "attention_mask": torch.tensor([[1]])}
def batch_decode(self, *args, **kwargs):
"""Decode token sequences."""
if self.llm_tokenizer:
return self.llm_tokenizer.batch_decode(*args, **kwargs)
else:
return [""]