| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import time |
| import textwrap |
| import pandas as pd |
| from collections import defaultdict |
| from typing import Any, Callable, Optional, Union, Sized |
|
|
| import torch |
| import torch.utils.data |
| import transformers |
| from datasets import Dataset, IterableDataset |
| from packaging import version |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoModelForSequenceClassification, |
| AutoProcessor, |
| AutoTokenizer, |
| GenerationConfig, |
| PreTrainedModel, |
| PreTrainedTokenizerBase, |
| Trainer, |
| TrainerCallback, |
| is_wandb_available, |
| ) |
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| from transformers.utils import is_peft_available |
|
|
| from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template |
| from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation |
| from trl.trainer.grpo_config import GRPOConfig |
| from trl.trainer.utils import generate_model_card, get_comet_experiment_url |
|
|
| from accelerate.utils import is_peft_model, set_seed, gather_object |
| import PIL.Image |
|
|
| import copy |
| from torch.utils.data import Sampler |
| import warnings |
|
|
| if is_peft_available(): |
| from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
| if is_wandb_available(): |
| import wandb |
|
|
| from bioreason.dna_modules.dna_module import DNABaseModule |
| from bioreason.trainer import DNALLMGRPOConfig |
|
|
| |
| from bioreason.trainer.grpo_trainer import RepeatRandomSampler |
|
|
| |
| |
| RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] |
|
|
|
|
| class Blip2GRPOTrainer(Trainer): |
| """ |
| Modified GRPO Trainer for BLIP2 models. |
| |
| This trainer adapts the original GRPO trainer to work with BLIP2 architecture, |
| handling the different input formats and forward pass requirements. |
| """ |
|
|
| def __init__( |
| self, |
| model: Union[str, PreTrainedModel], |
| reward_funcs: Union[RewardFunc, list[RewardFunc]], |
| args: DNALLMGRPOConfig = None, |
| dna_module: DNABaseModule = None, |
| train_dataset: Optional[Union[Dataset, IterableDataset]] = None, |
| eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, |
| processing_class: Optional[PreTrainedTokenizerBase] = None, |
| reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, |
| callbacks: Optional[list[TrainerCallback]] = None, |
| optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), |
| peft_config: Optional["PeftConfig"] = None, |
| freeze_dna_modules: Optional[bool] = False, |
| attn_implementation: str = "flash_attention_2", |
| torch_dtype: str = "bfloat16", |
| **kwargs, |
| ): |
| |
| if args is None: |
| model_name = model if isinstance(model, str) else "blip2-model" |
| args = GRPOConfig(f"{model_name}-GRPO") |
| |
| self.dna_module = dna_module |
|
|
| |
| model_init_kwargs = args.model_init_kwargs or {} |
| model_init_kwargs["attn_implementation"] = attn_implementation |
| if model_init_kwargs.get("torch_dtype") is None: |
| model_init_kwargs["torch_dtype"] = torch_dtype |
| |
| assert not isinstance(model, str), "model must NOT be a string in the current implementation" |
|
|
| torch_dtype = model_init_kwargs.get("torch_dtype") |
| if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: |
| pass |
| elif isinstance(torch_dtype, str): |
| torch_dtype = getattr(torch, torch_dtype) |
| else: |
| raise ValueError( |
| "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " |
| f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." |
| ) |
|
|
| |
| if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'): |
| model.blip2.llm_model.config.use_cache = ( |
| False if args.gradient_checkpointing else model.blip2.llm_model.config.use_cache |
| ) |
|
|
| |
| self.dna_modules_keywords = self.dna_module.get_dnallm_modules_keywords() |
| if peft_config is not None: |
| print("Applying LoRA...") |
| def find_all_linear_names(model, multimodal_keywords): |
| cls = torch.nn.Linear |
| lora_module_names = set() |
| |
| |
| if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'): |
| llm_model = model.blip2.llm_model |
| for name, module in llm_model.named_modules(): |
| |
| if any(mm_keyword in name for mm_keyword in multimodal_keywords): |
| continue |
| if isinstance(module, cls): |
| lora_module_names.add(name) |
| |
| |
| for m in list(lora_module_names): |
| if "embed_tokens" in m or "embedding" in m: |
| lora_module_names.remove(m) |
| |
| return list(lora_module_names) |
| |
| target_modules = find_all_linear_names(model, self.dna_modules_keywords) |
| peft_config.target_modules = target_modules |
| |
| |
| if hasattr(model, 'blip2') and hasattr(model.blip2, 'llm_model'): |
| model.blip2.llm_model = prepare_model_for_kbit_training(model.blip2.llm_model) |
| model.blip2.llm_model = get_peft_model(model.blip2.llm_model, peft_config) |
|
|
| |
| if freeze_dna_modules: |
| print("Freezing protein/DNA modules...") |
| if hasattr(model, 'blip2'): |
| |
| if hasattr(model.blip2, 'plm'): |
| for p in model.blip2.plm.parameters(): |
| p.requires_grad = False |
| |
| |
| if hasattr(model.blip2, 'Qformer'): |
| for p in model.blip2.Qformer.parameters(): |
| p.requires_grad = False |
|
|
| |
| trainable_params = [p for p in model.parameters() if p.requires_grad] |
| total_params = sum(p.numel() for p in trainable_params) |
| print(f"Total trainable parameters: {total_params}") |
|
|
| |
| if args.gradient_checkpointing: |
| model = self._enable_gradient_checkpointing(model, args) |
|
|
| |
| self.beta = args.beta |
| if self.beta == 0.0: |
| self.ref_model = None |
| elif is_deepspeed_zero3_enabled(): |
| |
| self.ref_model = type(model)(model.args) |
| elif is_peft_model(model.blip2.llm_model if hasattr(model, 'blip2') else model): |
| self.ref_model = None |
| else: |
| self.ref_model = create_reference_model(model) |
|
|
| |
| if processing_class is None: |
| processing_cls = self.dna_module.get_processing_class() |
| |
| |
| if hasattr(model, 'blip2'): |
| plm_tokenizer = getattr(model.blip2, 'plm_tokenizer', None) |
| llm_tokenizer = getattr(model.blip2, 'llm_tokenizer', None) |
| processing_class = processing_cls(plm_tokenizer=plm_tokenizer, llm_tokenizer=llm_tokenizer) |
| else: |
| processing_class = processing_cls() |
| |
| |
| if hasattr(processing_class, 'llm_tokenizer') and processing_class.llm_tokenizer: |
| processing_class.pad_token_id = processing_class.llm_tokenizer.pad_token_id |
| processing_class.eos_token_id = processing_class.llm_tokenizer.eos_token_id |
| else: |
| |
| processing_class.pad_token_id = 0 |
| processing_class.eos_token_id = 1 |
|
|
| self.dna_module.post_model_init(model, processing_class) |
| self.dna_module.post_model_init(self.ref_model, processing_class) |
|
|
| |
| if not isinstance(reward_funcs, list): |
| reward_funcs = [reward_funcs] |
| for i, reward_func in enumerate(reward_funcs): |
| if isinstance(reward_func, str): |
| reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( |
| reward_func, num_labels=1, **model_init_kwargs |
| ) |
| self.reward_funcs = reward_funcs |
|
|
| |
| if reward_processing_classes is None: |
| reward_processing_classes = [None] * len(reward_funcs) |
| elif not isinstance(reward_processing_classes, list): |
| reward_processing_classes = [reward_processing_classes] |
|
|
| for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): |
| if isinstance(reward_func, PreTrainedModel): |
| if reward_processing_class is None: |
| reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) |
| if reward_processing_class.pad_token_id is None: |
| reward_processing_class.pad_token = reward_processing_class.eos_token |
| reward_func.config.pad_token_id = reward_processing_class.pad_token_id |
| reward_processing_classes[i] = reward_processing_class |
| self.reward_processing_classes = reward_processing_classes |
|
|
| |
| def data_collator(features): |
| return features |
|
|
| |
| self.max_prompt_length = args.max_prompt_length |
| self.max_prompt_length = None |
| if args.max_prompt_length is not None: |
| warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None") |
|
|
| self.max_completion_length = args.max_completion_length |
| self.num_generations = args.num_generations |
| |
| |
| self.generation_config = GenerationConfig( |
| max_new_tokens=self.max_completion_length, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.95, |
| top_k=20, |
| pad_token_id=processing_class.pad_token_id, |
| eos_token_id=processing_class.eos_token_id, |
| ) |
|
|
| self.beta = args.beta |
| self.epsilon_low = args.epsilon |
| self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon |
|
|
| |
| self.num_iterations = args.num_iterations |
| self._step = 0 |
| self._buffered_inputs = [None] * args.gradient_accumulation_steps |
|
|
| |
| self._metrics = defaultdict(list) |
| self.log_completions = args.log_completions |
|
|
| super().__init__( |
| model=model, |
| args=args, |
| data_collator=data_collator, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| processing_class=processing_class, |
| callbacks=callbacks, |
| optimizers=optimizers, |
| ) |
|
|
| |
| num_processes = self.accelerator.num_processes |
| global_batch_size = args.per_device_train_batch_size * num_processes |
| possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] |
| if self.num_generations not in possible_values: |
| raise ValueError( |
| f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " |
| f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " |
| f"batch size, the valid values for the number of generations are: {possible_values}." |
| ) |
|
|
| |
| set_seed(args.seed, device_specific=True) |
|
|
| |
| self.model_accepts_loss_kwargs = False |
|
|
| |
| if self.ref_model is not None: |
| if is_deepspeed_zero3_enabled(): |
| self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) |
| else: |
| self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) |
|
|
| for i, reward_func in enumerate(self.reward_funcs): |
| if isinstance(reward_func, PreTrainedModel): |
| self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) |
|
|
| def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: DNALLMGRPOConfig) -> PreTrainedModel: |
| """Enables gradient checkpointing for BLIP2 model.""" |
| if hasattr(model, 'blip2'): |
| |
| if hasattr(model.blip2, 'llm_model'): |
| model.blip2.llm_model.config.use_cache = False |
| if hasattr(model.blip2.llm_model, 'gradient_checkpointing_enable'): |
| model.blip2.llm_model.gradient_checkpointing_enable() |
| |
| |
| if hasattr(model.blip2, 'plm') and hasattr(model.blip2.plm, 'gradient_checkpointing_enable'): |
| model.blip2.plm.gradient_checkpointing_enable() |
| |
| return model |
|
|
| def _set_signature_columns_if_needed(self): |
| if self._signature_columns is None: |
| self._signature_columns = ["prompt"] |
|
|
| def _get_key_from_inputs(self, x, key): |
| ele = x.get(key, None) |
| assert ele is not None, f"The key {key} is not found in the input" |
| if isinstance(ele, list): |
| return [e for e in ele] |
| else: |
| return [ele] |
|
|
| def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]: |
| device = self.accelerator.device |
| prompts = [x["prompt"] for x in inputs] |
| prompts_text = self.dna_module.prepare_prompt(self.processing_class, inputs) |
| |
| |
| batch_dna_sequences = [] |
| print("_generate_and_score_completions (BLIP2 GRPO):") |
| for x in inputs: |
| if 'dna_sequences' in x: |
| dnas = self._get_key_from_inputs(x, "dna_sequences") |
| batch_dna_sequences.append(dnas) |
| else: |
| batch_dna_sequences.append([]) |
|
|
| |
| prompt_inputs = self.dna_module.prepare_model_inputs( |
| self.processing_class, |
| model, |
| prompts_text, |
| batch_dna_sequences, |
| return_tensors="pt", |
| padding=True, |
| padding_side="left", |
| add_special_tokens=False, |
| ) |
|
|
| prompt_inputs = super()._prepare_inputs(prompt_inputs) |
| |
| |
| prot_batch = prompt_inputs.get("prot_batch") |
| prompt_batch = prompt_inputs.get("prompt_batch") |
| |
| |
| start = time.time() |
| with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: |
| |
| samples = { |
| 'prot_batch': prot_batch, |
| 'prompt_batch': prompt_batch |
| } |
| |
| |
| if hasattr(unwrapped_model, 'blip2'): |
| completions_text = unwrapped_model.blip2.generate( |
| samples, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.95, |
| num_beams=1, |
| max_length=self.max_completion_length, |
| min_length=1, |
| ) |
| else: |
| |
| completions_text = ["Generated text"] * len(prompts_text) |
| |
| end = time.time() |
| print(f"Generation time: {end - start:.9f} seconds") |
|
|
| |
| if is_conversational(inputs[0]): |
| completions = [[{"role": "assistant", "content": completion}] for completion in completions_text] |
| else: |
| completions = completions_text |
|
|
| |
| print("Reward calculation...") |
| rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) |
| for i, (reward_func, reward_processing_class) in enumerate( |
| zip(self.reward_funcs, self.reward_processing_classes) |
| ): |
| if isinstance(reward_func, PreTrainedModel): |
| if is_conversational(inputs[0]): |
| messages = [{"messages": p + c} for p, c in zip(prompts, completions)] |
| texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] |
| else: |
| texts = [p + c for p, c in zip(prompts, completions)] |
| reward_inputs = reward_processing_class( |
| texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False |
| ) |
| reward_inputs = super()._prepare_inputs(reward_inputs) |
| with torch.inference_mode(): |
| rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] |
| else: |
| |
| reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} |
| for key in reward_kwargs: |
| for example in inputs: |
| reward_kwargs[key].extend([example[key]]) |
| output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) |
| rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) |
|
|
| |
| rewards_per_func = self.accelerator.gather(rewards_per_func) |
| rewards = rewards_per_func.sum(dim=1) |
| |
| |
| mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) |
| std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) |
| |
| |
| mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) |
| std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) |
| advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) |
| |
| |
| process_slice = slice( |
| self.accelerator.process_index * len(prompts), |
| (self.accelerator.process_index + 1) * len(prompts), |
| ) |
| advantages = advantages[process_slice] |
|
|
| |
| print("Logging metrics...") |
| completion_length = len(completions_text[0].split()) if completions_text else 0 |
| self._metrics["completion_length"].append(completion_length) |
|
|
| reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) |
| for i, reward_func in enumerate(self.reward_funcs): |
| if isinstance(reward_func, PreTrainedModel): |
| reward_func_name = reward_func.config._name_or_path.split("/")[-1] |
| else: |
| reward_func_name = reward_func.__name__ |
| self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) |
|
|
| self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) |
| self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) |
|
|
| |
| if ( |
| self.log_completions |
| and self.state.global_step % self.args.logging_steps == 0 |
| and "wandb" in self.args.report_to |
| ): |
| timestamp = time.time() |
| num_items = len(gather_object(prompts_text)) |
|
|
| table = { |
| "step": [f"{self.state.global_step}_{timestamp}"] * num_items, |
| "prompt": gather_object(prompts_text), |
| "completion": gather_object(completions_text), |
| "reward": rewards.tolist(), |
| } |
| df = pd.DataFrame(table) |
|
|
| if wandb.run is not None and self.accelerator.is_main_process: |
| wandb.log({f"completions_{self.state.global_step}_{timestamp}": wandb.Table(dataframe=df)}) |
|
|
| return { |
| "prot_batch": prot_batch, |
| "prompt_batch": prompt_batch, |
| "completions_text": completions_text, |
| "old_per_token_logps": None, |
| "ref_per_token_logps": None, |
| "advantages": advantages, |
| "multimodal_inputs": {"prot_batch": prot_batch, "prompt_batch": prompt_batch} |
| } |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| if return_outputs: |
| raise ValueError("The BLIP2 GRPO Trainer does not support returning outputs") |
| |
| print("compute_loss - index 1") |
| if self.state.global_step % self.num_iterations == 0: |
| inputs = self._generate_and_score_completions(inputs, model) |
| self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs |
| else: |
| inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] |
| self._step += 1 |
| |
| print("compute_loss - index 2") |
| |
| |
| |
| |
| |
| prot_batch = inputs.get("prot_batch") |
| prompt_batch = inputs.get("prompt_batch") |
| advantages = inputs.get("advantages") |
| |
| print("compute_loss - index 3") |
| |
| |
| |
| text_dict = {"targets": inputs.get("completions_text", [])} |
| batch = (prot_batch, prompt_batch, text_dict) |
| |
| print("compute_loss - index 4") |
| |
| |
| if hasattr(model, 'blip2'): |
| loss = model.blip2(batch) |
| else: |
| loss = model(batch) |
| |
| print("compute_loss - index 5") |
| |
| |
| |
| |
| |
| if advantages is not None: |
| |
| advantage_weight = advantages.mean().item() |
| loss = loss * (1.0 + advantage_weight) |
| |
| print("Computing final loss...") |
| return loss |
|
|
| def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: |
| metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} |
| logs = {**logs, **metrics} |
| if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): |
| super().log(logs, start_time) |
| else: |
| super().log(logs) |
| self._metrics.clear() |
|
|
| def _get_train_sampler(self) -> Sampler: |
| """Returns a sampler that ensures proper data sampling for GRPO training.""" |
| effective_batch_size = ( |
| self.args.per_device_train_batch_size |
| * self.accelerator.num_processes |
| * self.args.gradient_accumulation_steps |
| ) |
| |
| return RepeatRandomSampler( |
| data_source=self.train_dataset, |
| mini_repeat_count=self.num_generations, |
| batch_size=effective_batch_size // self.num_generations, |
| repeat_count=self.num_iterations, |
| seed=self.args.seed, |
| ) |
|
|
| def _get_eval_sampler(self, eval_dataset) -> Sampler: |
| """Returns a sampler for evaluation.""" |
| return RepeatRandomSampler( |
| data_source=eval_dataset, |
| mini_repeat_count=self.num_generations, |
| seed=self.args.seed, |
| ) |