nas / BioReason-0813 /blip2_grpo_trainer.py
yuccaaa's picture
Add files using upload-large-folder tool
acbfbc3 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import textwrap
import pandas as pd
from collections import defaultdict
from typing import Any, Callable, Optional, Union, Sized
import torch
import torch.utils.data
import transformers
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
from accelerate.utils import is_peft_model, set_seed, gather_object
import PIL.Image
import copy
from torch.utils.data import Sampler
import warnings
if is_peft_available():
from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
from bioreason.dna_modules.dna_module import DNABaseModule
from bioreason.trainer import DNALLMGRPOConfig
# Import the RepeatRandomSampler from the original trainer
from bioreason.trainer.grpo_trainer import RepeatRandomSampler
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class Blip2GRPOTrainer(Trainer):
"""
Modified GRPO Trainer for BLIP2 models.
This trainer adapts the original GRPO trainer to work with BLIP2 architecture,
handling the different input formats and forward pass requirements.
"""
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: DNALLMGRPOConfig = None,
dna_module: DNABaseModule = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
freeze_dna_modules: Optional[bool] = False,
attn_implementation: str = "flash_attention_2",
torch_dtype: str = "bfloat16",
**kwargs,
):
# Args
if args is None:
model_name = model if isinstance(model, str) else "blip2-model"
args = GRPOConfig(f"{model_name}-GRPO")
self.dna_module = dna_module
# Models
model_init_kwargs = args.model_init_kwargs or {}
model_init_kwargs["attn_implementation"] = attn_implementation
if model_init_kwargs.get("torch_dtype") is None:
model_init_kwargs["torch_dtype"] = torch_dtype
assert not isinstance(model, str), "model must NOT be a string in the current implementation"
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'):
model.blip2.llm_model.config.use_cache = (
False if args.gradient_checkpointing else model.blip2.llm_model.config.use_cache
)
# LoRA setup for BLIP2
self.dna_modules_keywords = self.dna_module.get_dnallm_modules_keywords()
if peft_config is not None:
print("Applying LoRA...")
def find_all_linear_names(model, multimodal_keywords):
cls = torch.nn.Linear
lora_module_names = set()
# Focus on the LLM part of BLIP2
if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'):
llm_model = model.blip2.llm_model
for name, module in llm_model.named_modules():
# Skip DNA/multimodal modules
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
lora_module_names.add(name)
# Remove embedding layers
for m in list(lora_module_names):
if "embed_tokens" in m or "embedding" in m:
lora_module_names.remove(m)
return list(lora_module_names)
target_modules = find_all_linear_names(model, self.dna_modules_keywords)
peft_config.target_modules = target_modules
# Apply LoRA to the LLM part
if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'):
model.blip2.llm_model = prepare_model_for_kbit_training(model.blip2.llm_model)
model.blip2.llm_model = get_peft_model(model.blip2.llm_model, peft_config)
# Freeze DNA/protein modules if requested
if freeze_dna_modules:
print("Freezing protein/DNA modules...")
if hasattr(model, 'blip2'):
# Freeze protein language model
if hasattr(model.blip2, 'plm'):
for p in model.blip2.plm.parameters():
p.requires_grad = False
# Freeze Q-former if specified
if hasattr(model.blip2, 'Qformer'):
for p in model.blip2.Qformer.parameters():
p.requires_grad = False
# Count trainable parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
total_params = sum(p.numel() for p in trainable_params)
print(f"Total trainable parameters: {total_params}")
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)
# Reference model
self.beta = args.beta
if self.beta == 0.0:
self.ref_model = None
elif is_deepspeed_zero3_enabled():
# Create reference model for DeepSpeed
self.ref_model = type(model)(model.args) # Create same type of model
elif is_peft_model(model.blip2.llm_model if hasattr(model, 'blip2') else model):
self.ref_model = None
else:
self.ref_model = create_reference_model(model)
# Processing class setup
if processing_class is None:
processing_cls = self.dna_module.get_processing_class()
# Get tokenizers from BLIP2 model
if hasattr(model, 'blip2'):
plm_tokenizer = getattr(model.blip2, 'plm_tokenizer', None)
llm_tokenizer = getattr(model.blip2, 'llm_tokenizer', None)
processing_class = processing_cls(plm_tokenizer=plm_tokenizer, llm_tokenizer=llm_tokenizer)
else:
processing_class = processing_cls()
# Set up tokenizer attributes
if hasattr(processing_class, 'llm_tokenizer') and processing_class.llm_tokenizer:
processing_class.pad_token_id = processing_class.llm_tokenizer.pad_token_id
processing_class.eos_token_id = processing_class.llm_tokenizer.eos_token_id
else:
# Fallback
processing_class.pad_token_id = 0
processing_class.eos_token_id = 1
self.dna_module.post_model_init(model, processing_class)
self.dna_module.post_model_init(self.ref_model, processing_class)
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
self.reward_funcs = reward_funcs
# Reward processing classes
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
# Data collator
def data_collator(features):
return features
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_prompt_length = None
if args.max_prompt_length is not None:
warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
self.max_completion_length = args.max_completion_length
self.num_generations = args.num_generations
# Generation config for BLIP2
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=0.6,
top_p=0.95,
top_k=20,
pad_token_id=processing_class.pad_token_id,
eos_token_id=processing_class.eos_token_id,
)
self.beta = args.beta
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
# Multi-step
self.num_iterations = args.num_iterations
self._step = 0
self._buffered_inputs = [None] * args.gradient_accumulation_steps
# Initialize metrics
self._metrics = defaultdict(list)
self.log_completions = args.log_completions
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
)
# Validate batch sizes
num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
f"batch size, the valid values for the number of generations are: {possible_values}."
)
# Set unique seed per process
set_seed(args.seed, device_specific=True)
# Gradient accumulation settings
self.model_accepts_loss_kwargs = False
# Prepare reference model and reward functions
if self.ref_model is not None:
if is_deepspeed_zero3_enabled():
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: DNALLMGRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for BLIP2 model."""
if hasattr(model, 'blip2'):
# Enable for the LLM component
if hasattr(model.blip2, 'llm_model'):
model.blip2.llm_model.config.use_cache = False
if hasattr(model.blip2.llm_model, 'gradient_checkpointing_enable'):
model.blip2.llm_model.gradient_checkpointing_enable()
# Enable for protein model if needed
if hasattr(model.blip2, 'plm') and hasattr(model.blip2.plm, 'gradient_checkpointing_enable'):
model.blip2.plm.gradient_checkpointing_enable()
return model
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
self._signature_columns = ["prompt"]
def _get_key_from_inputs(self, x, key):
ele = x.get(key, None)
assert ele is not None, f"The key {key} is not found in the input"
if isinstance(ele, list):
return [e for e in ele]
else:
return [ele]
def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = self.dna_module.prepare_prompt(self.processing_class, inputs)
# Handle DNA sequences (treat as protein sequences for BLIP2)
batch_dna_sequences = []
print("_generate_and_score_completions (BLIP2 GRPO):")
for x in inputs:
if 'dna_sequences' in x:
dnas = self._get_key_from_inputs(x, "dna_sequences")
batch_dna_sequences.append(dnas)
else:
batch_dna_sequences.append([])
# Prepare model inputs for BLIP2
prompt_inputs = self.dna_module.prepare_model_inputs(
self.processing_class,
model,
prompts_text,
batch_dna_sequences,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
# Extract BLIP2-specific inputs
prot_batch = prompt_inputs.get("prot_batch")
prompt_batch = prompt_inputs.get("prompt_batch")
# Generate completions using BLIP2
start = time.time()
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
# Prepare samples for BLIP2 generation
samples = {
'prot_batch': prot_batch,
'prompt_batch': prompt_batch
}
# Use BLIP2's generate method
if hasattr(unwrapped_model, 'blip2'):
completions_text = unwrapped_model.blip2.generate(
samples,
do_sample=True,
temperature=0.6,
top_p=0.95,
num_beams=1,
max_length=self.max_completion_length,
min_length=1,
)
else:
# Fallback if not BLIP2 structure
completions_text = ["Generated text"] * len(prompts_text)
end = time.time()
print(f"Generation time: {end - start:.9f} seconds")
# Convert completions to expected format
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
else:
completions = completions_text
# Compute rewards
print("Reward calculation...")
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]
else:
# Custom reward function
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
reward_kwargs[key].extend([example[key]])
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Gather rewards across processes
rewards_per_func = self.accelerator.gather(rewards_per_func)
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize rewards to compute advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Get local slice of advantages
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
# Log metrics
print("Logging metrics...")
completion_length = len(completions_text[0].split()) if completions_text else 0
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
# Log completions if enabled
if (
self.log_completions
and self.state.global_step % self.args.logging_steps == 0
and "wandb" in self.args.report_to
):
timestamp = time.time()
num_items = len(gather_object(prompts_text))
table = {
"step": [f"{self.state.global_step}_{timestamp}"] * num_items,
"prompt": gather_object(prompts_text),
"completion": gather_object(completions_text),
"reward": rewards.tolist(),
}
df = pd.DataFrame(table)
if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({f"completions_{self.state.global_step}_{timestamp}": wandb.Table(dataframe=df)})
return {
"prot_batch": prot_batch,
"prompt_batch": prompt_batch,
"completions_text": completions_text,
"old_per_token_logps": None, # BLIP2 doesn't need this for current implementation
"ref_per_token_logps": None, # BLIP2 doesn't need this for current implementation
"advantages": advantages,
"multimodal_inputs": {"prot_batch": prot_batch, "prompt_batch": prompt_batch}
}
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The BLIP2 GRPO Trainer does not support returning outputs")
print("compute_loss - index 1")
if self.state.global_step % self.num_iterations == 0:
inputs = self._generate_and_score_completions(inputs, model)
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
else:
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
self._step += 1
print("compute_loss - index 2")
# For BLIP2, we need to compute loss differently
# This is a simplified version - you may need to adapt based on your specific BLIP2 implementation
# Extract the necessary components
prot_batch = inputs.get("prot_batch")
prompt_batch = inputs.get("prompt_batch")
advantages = inputs.get("advantages")
print("compute_loss - index 3")
# Create a batch for BLIP2 forward pass
# This assumes your BLIP2 model expects (prot_batch, prompt_batch, text_dict) format
text_dict = {"targets": inputs.get("completions_text", [])}
batch = (prot_batch, prompt_batch, text_dict)
print("compute_loss - index 4")
# Forward pass through BLIP2
if hasattr(model, 'blip2'):
loss = model.blip2(batch)
else:
loss = model(batch)
print("compute_loss - index 5")
# For now, return the basic loss
# You may want to incorporate the advantages into the loss calculation
# based on your specific GRPO implementation needs
if advantages is not None:
# Apply advantages weighting (simplified)
advantage_weight = advantages.mean().item()
loss = loss * (1.0 + advantage_weight)
print("Computing final loss...")
return loss
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else:
super().log(logs)
self._metrics.clear()
def _get_train_sampler(self) -> Sampler:
"""Returns a sampler that ensures proper data sampling for GRPO training."""
effective_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
return RepeatRandomSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size // self.num_generations,
repeat_count=self.num_iterations,
seed=self.args.seed,
)
def _get_eval_sampler(self, eval_dataset) -> Sampler:
"""Returns a sampler for evaluation."""
return RepeatRandomSampler(
data_source=eval_dataset,
mini_repeat_count=self.num_generations,
seed=self.args.seed,
)