| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import time |
| import textwrap |
| import pandas as pd |
| from collections import defaultdict |
| from typing import Any, Callable, Optional, Union, Sized |
|
|
| import torch |
| import torch.utils.data |
| import transformers |
| from datasets import Dataset, IterableDataset |
| from packaging import version |
| from transformers import ( |
| AriaForConditionalGeneration, |
| AriaProcessor, |
| AutoModelForCausalLM, |
| AutoModelForSequenceClassification, |
| AutoProcessor, |
| AutoTokenizer, |
| GenerationConfig, |
| PreTrainedModel, |
| PreTrainedTokenizerBase, |
| Qwen2VLForConditionalGeneration, |
| Qwen2_5_VLForConditionalGeneration, |
| Trainer, |
| TrainerCallback, |
| is_wandb_available, |
| ) |
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| from transformers.utils import is_peft_available |
|
|
| from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template |
| from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation |
| from trl.trainer.grpo_config import GRPOConfig |
| from trl.trainer.utils import generate_model_card, get_comet_experiment_url |
| |
|
|
| from accelerate.utils import is_peft_model, set_seed, gather_object |
| import PIL.Image |
|
|
| import copy |
| from torch.utils.data import Sampler |
| import warnings |
|
|
| if is_peft_available(): |
| from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
| if is_wandb_available(): |
| import wandb |
|
|
| from bioreason.dna_modules.dna_module import DNABaseModule |
| from bioreason.trainer import DNALLMGRPOConfig |
| |
| |
| RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] |
|
|
|
|
| class RepeatRandomSampler(Sampler): |
| """ |
| Sampler that repeats the indices of a dataset in a structured manner. |
| |
| Args: |
| data_source (`Sized`): |
| Dataset to sample from. |
| mini_repeat_count (`int`): |
| Number of times to repeat each index per batch. |
| batch_size (`int`, *optional*, defaults to `1`): |
| Number of unique indices per batch. |
| repeat_count (`int`, *optional*, defaults to `1`): |
| Number of times to repeat the full sampling process. |
| seed (`int` or `None`, *optional*, defaults to `None`): |
| Random seed for reproducibility. |
| """ |
|
|
| def __init__( |
| self, |
| data_source: Sized, |
| mini_repeat_count: int, |
| batch_size: int = 1, |
| repeat_count: int = 1, |
| seed: Optional[int] = None, |
| ): |
| self.data_source = data_source |
| self.mini_repeat_count = mini_repeat_count |
| self.batch_size = batch_size |
| self.repeat_count = repeat_count |
| self.num_samples = len(data_source) |
| self.seed = seed |
| self.generator = torch.Generator() |
| if seed is not None: |
| self.generator.manual_seed(seed) |
|
|
| def __iter__(self): |
| indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() |
| indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] |
| indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] |
|
|
| for chunk in indexes: |
| for _ in range(self.repeat_count): |
| for index in chunk: |
| for _ in range(self.mini_repeat_count): |
| yield index |
|
|
| def __len__(self) -> int: |
| return self.num_samples * self.mini_repeat_count * self.repeat_count |
|
|
|
|
| class DNALLMGRPOTrainer(Trainer): |
| """ |
| Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the |
| paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). |
| |
| Example: |
| |
| ```python |
| from datasets import load_dataset |
| from trl import GRPOTrainer |
| |
| dataset = load_dataset("trl-lib/tldr", split="train") |
| |
| trainer = GRPOTrainer( |
| model="Qwen/Qwen2-0.5B-Instruct", |
| reward_funcs="weqweasdas/RM-Gemma-2B", |
| train_dataset=dataset, |
| ) |
| |
| trainer.train() |
| ``` |
| |
| Args: |
| model (`Union[str, PreTrainedModel]`): |
| Model to be trained. Can be either: |
| |
| - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or |
| a path to a *directory* containing model weights saved using |
| [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is |
| loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments |
| in `args.model_init_kwargs`. |
| - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. |
| reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): |
| Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward |
| functions with the prompts and completions and sum the rewards. Can be either: |
| |
| - A single reward function, such as: |
| - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a |
| path to a *directory* containing model weights saved using |
| [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded |
| using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the |
| keyword arguments in `args.model_init_kwargs`. |
| - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. |
| - A custom reward function: The function is provided with the prompts and the generated completions, |
| plus any additional columns in the dataset. It should return a list of rewards. For more details, see |
| [Using a custom reward function](#using-a-custom-reward-function). |
| - A list of reward functions, where each item can independently be any of the above types. Mixing different |
| types within the list (e.g., a string model ID and a custom reward function) is allowed. |
| args ([`GRPOConfig`], *optional*, defaults to `None`): |
| Configuration for this trainer. If `None`, a default configuration is used. |
| train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): |
| Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is |
| ignored. The format of the samples can be either: |
| |
| - [Standard](dataset_formats#standard): Each sample contains plain text. |
| - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role |
| and content). |
| eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): |
| Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. |
| processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): |
| Processing class used to process the data. The padding side must be set to "left". If `None`, the |
| processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. |
| reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): |
| Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: |
| |
| - A single processing class: Used when `reward_funcs` contains only one reward function. |
| - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. |
| If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is |
| `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. |
| For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), |
| the corresponding entries in `reward_processing_classes` are ignored. |
| callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): |
| List of callbacks to customize the training loop. Will add those to the list of default callbacks |
| detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). |
| |
| If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] |
| method. |
| optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): |
| A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your |
| model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. |
| peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): |
| PEFT configuration used to wrap the model. If `None`, the model is not wrapped. |
| """ |
|
|
| def __init__( |
| self, |
| model: Union[str, PreTrainedModel], |
| reward_funcs: Union[RewardFunc, list[RewardFunc]], |
| args: DNALLMGRPOConfig = None, |
| dna_module: DNABaseModule = None, |
| train_dataset: Optional[Union[Dataset, IterableDataset]] = None, |
| eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, |
| processing_class: Optional[PreTrainedTokenizerBase] = None, |
| reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, |
| callbacks: Optional[list[TrainerCallback]] = None, |
| optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), |
| peft_config: Optional["PeftConfig"] = None, |
| freeze_dna_modules: Optional[bool] = False, |
| attn_implementation: str = "flash_attention_2", |
| torch_dtype: str = "bfloat16", |
| **kwargs, |
| ): |
| |
| if args is None: |
| model_name = model if isinstance(model, str) else model.config._name_or_path |
| model_name = model_name.split("/")[-1] |
| args = GRPOConfig(f"{model_name}-GRPO") |
| |
| self.dna_module = dna_module |
|
|
| |
| |
| model_init_kwargs = args.model_init_kwargs or {} |
| |
| |
| model_init_kwargs["attn_implementation"] = attn_implementation |
| if model_init_kwargs.get("torch_dtype") is None: |
| model_init_kwargs["torch_dtype"] = torch_dtype |
| |
| assert not isinstance(model, str), "model must NOT be a string in the current implementation" |
|
|
| torch_dtype = model_init_kwargs.get("torch_dtype") |
| if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: |
| pass |
| elif isinstance(torch_dtype, str): |
| torch_dtype = getattr(torch, torch_dtype) |
| else: |
| raise ValueError( |
| "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " |
| f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." |
| ) |
| |
| model_init_kwargs["use_cache"] = ( |
| False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") |
| ) |
|
|
| |
| self.dna_modules_keywords = self.dna_module.get_dnallm_modules_keywords() |
| if peft_config is not None: |
| print("Applying LoRA...") |
| def find_all_linear_names(model, multimodal_keywords): |
| cls = torch.nn.Linear |
| lora_module_names = set() |
| for name, module in model.named_modules(): |
| print('name:', name, 'module:', module) |
| |
| if any(mm_keyword in name for mm_keyword in multimodal_keywords): |
| continue |
| if isinstance(module, cls): |
| lora_module_names.add(name) |
| for m in lora_module_names: |
| if "embed_tokens" in m: |
| lora_module_names.remove(m) |
| return list(lora_module_names) |
| target_modules = find_all_linear_names(model, self.dna_modules_keywords) |
| peft_config.target_modules = target_modules |
| model = prepare_model_for_kbit_training(model) |
| model = get_peft_model(model, peft_config) |
|
|
| |
| if freeze_dna_modules: |
| print("Freezing DNA modules...") |
| for p in model.dna_model.parameters(): |
| p.requires_grad = False |
| |
| |
| for p in model.dna_projection.parameters(): |
| p.required_grad = True |
|
|
| |
| trainable_params = [p for p in model.parameters() if p.requires_grad] |
| total_params = sum(p.numel() for p in trainable_params) |
| |
| |
| |
| print(f"Total trainable parameters: {total_params}") |
|
|
| |
| if args.gradient_checkpointing: |
| model = self._enable_gradient_checkpointing(model, args) |
|
|
| |
| self.beta = args.beta |
| if self.beta == 0.0: |
| |
| self.ref_model = None |
| elif is_deepspeed_zero3_enabled(): |
| self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs) |
| elif is_peft_model(model): |
| |
| |
| self.ref_model = None |
| else: |
| |
| self.ref_model = create_reference_model(model) |
|
|
| |
| if processing_class is None: |
| processing_cls = self.dna_module.get_processing_class() |
|
|
| |
| processing_class = processing_cls(tokenizer=model.text_tokenizer, dna_tokenizer=model.dna_tokenizer) |
| |
| for component, processing_keyword in self.dna_module.get_custom_processing_keywords(): |
| if processing_keyword in kwargs: |
| |
| processing_component = getattr(processing_class, component, processing_class) |
| setattr(processing_component, processing_keyword, kwargs[processing_keyword]) |
| if getattr(processing_class, "tokenizer", None) is not None: |
| pad_token_id = processing_class.tokenizer.pad_token_id |
| processing_class.pad_token_id = pad_token_id |
| processing_class.eos_token_id = processing_class.tokenizer.eos_token_id |
| else: |
| assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute" |
| pad_token_id = processing_class.pad_token_id |
|
|
| self.dna_module.post_model_init(model, processing_class) |
| self.dna_module.post_model_init(self.ref_model, processing_class) |
|
|
| |
| if not isinstance(reward_funcs, list): |
| reward_funcs = [reward_funcs] |
| for i, reward_func in enumerate(reward_funcs): |
| if isinstance(reward_func, str): |
| reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( |
| reward_func, num_labels=1, **model_init_kwargs |
| ) |
| self.reward_funcs = reward_funcs |
|
|
| |
| if reward_processing_classes is None: |
| reward_processing_classes = [None] * len(reward_funcs) |
| elif not isinstance(reward_processing_classes, list): |
| reward_processing_classes = [reward_processing_classes] |
| else: |
| if len(reward_processing_classes) != len(reward_funcs): |
| raise ValueError("The number of reward processing classes must match the number of reward functions.") |
|
|
| for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): |
| if isinstance(reward_func, PreTrainedModel): |
| if reward_processing_class is None: |
| reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) |
| if reward_processing_class.pad_token_id is None: |
| reward_processing_class.pad_token = reward_processing_class.eos_token |
| |
| |
| reward_func.config.pad_token_id = reward_processing_class.pad_token_id |
| reward_processing_classes[i] = reward_processing_class |
| self.reward_processing_classes = reward_processing_classes |
|
|
| |
| def data_collator(features): |
| return features |
|
|
| |
| self.max_prompt_length = args.max_prompt_length |
| self.max_prompt_length = None |
| if args.max_prompt_length is not None: |
| warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None") |
|
|
| self.max_completion_length = args.max_completion_length |
| self.num_generations = args.num_generations |
| self.generation_config = GenerationConfig( |
| max_new_tokens=self.max_completion_length, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.95, |
| top_k=20, |
| pad_token_id=pad_token_id, |
| ) |
| if hasattr(self.dna_module, "get_eos_token_id"): |
| self.generation_config.eos_token_id = self.dna_module.get_eos_token_id(processing_class) |
| self.beta = args.beta |
| self.epsilon_low = args.epsilon |
| self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon |
|
|
| |
| self.num_iterations = args.num_iterations |
| |
| self._step = 0 |
| |
| self._buffered_inputs = [None] * args.gradient_accumulation_steps |
|
|
| |
| |
| |
| |
| |
| |
| model.warnings_issued["estimate_tokens"] = True |
|
|
| |
| self._metrics = defaultdict(list) |
| self.log_completions = args.log_completions |
|
|
| super().__init__( |
| model=model, |
| args=args, |
| data_collator=data_collator, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| processing_class=processing_class, |
| callbacks=callbacks, |
| optimizers=optimizers, |
| ) |
|
|
| |
| num_processes = self.accelerator.num_processes |
| global_batch_size = args.per_device_train_batch_size * num_processes |
| possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] |
| if self.num_generations not in possible_values: |
| raise ValueError( |
| f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " |
| f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " |
| f"batch size, the valid values for the number of generations are: {possible_values}." |
| ) |
| if self.args.eval_strategy != "no": |
| global_batch_size = args.per_device_eval_batch_size * num_processes |
| possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] |
| if self.num_generations not in possible_values: |
| raise ValueError( |
| f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " |
| f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " |
| f"eval batch size, the valid values for the number of generations are: {possible_values}." |
| ) |
|
|
| |
| |
| |
| set_seed(args.seed, device_specific=True) |
|
|
| |
| |
| |
| self.model_accepts_loss_kwargs = False |
|
|
| if self.ref_model is not None: |
| |
| if is_deepspeed_zero3_enabled(): |
| self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) |
| else: |
| self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) |
|
|
| for i, reward_func in enumerate(self.reward_funcs): |
| if isinstance(reward_func, PreTrainedModel): |
| self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) |
|
|
| def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel: |
| """Enables gradient checkpointing for the model.""" |
| |
| model.config.use_cache = False |
|
|
| |
| if is_peft_model(model): |
| model.base_model.gradient_checkpointing_enable() |
| |
| else: |
| if getattr(model, "language_model", None) is not None: |
| |
| model.language_model.config.use_cache = False |
| model.dna_model.gradient_checkpointing = True |
| model.dna_model.encoder.gradient_checkpointing = True |
| model.language_model._set_gradient_checkpointing() |
| |
| args.gradient_checkpointing = False |
| else: |
| model.gradient_checkpointing_enable() |
|
|
| gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} |
| use_reentrant = ( |
| "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] |
| ) |
|
|
| if use_reentrant: |
| model.enable_input_require_grads() |
|
|
| return model |
| |
| def _set_signature_columns_if_needed(self): |
| |
| |
| |
| |
| if self._signature_columns is None: |
| self._signature_columns = ["prompt"] |
|
|
|
|
| |
| def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs): |
| logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits |
| logits = logits[:, :-1, :] |
| input_ids = input_ids[:, 1:] |
| |
| per_token_logps = [] |
| for logits_row, input_ids_row in zip(logits, input_ids): |
| log_probs = logits_row.log_softmax(dim=-1) |
| token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) |
| per_token_logps.append(token_log_prob) |
| return torch.stack(per_token_logps) |
|
|
|
|
| def _prepare_inputs(self, inputs): |
| |
| return inputs |
|
|
| def _get_key_from_inputs(self, x, key): |
| ele = x.get(key, None) |
| assert ele is not None, f"The key {key} is not found in the input" |
| if isinstance(ele, list): |
| return [e for e in ele] |
| else: |
| return [ele] |
|
|
| def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]: |
| device = self.accelerator.device |
| prompts = [x["prompt"] for x in inputs] |
| prompts_text = self.dna_module.prepare_prompt(self.processing_class, inputs) |
| |
| batch_dna_sequences = [] |
| print("_generate_and_score_completions (GRPO):") |
| for x in inputs: |
| |
| |
| if 'dna_sequences' in x: |
| dnas = self._get_key_from_inputs(x, "dna_sequences") |
|
|
| for dna in dnas: |
| |
| pass |
| batch_dna_sequences.append(dnas) |
| |
| |
| |
|
|
| prompt_inputs = self.dna_module.prepare_model_inputs( |
| self.processing_class, |
| model, |
| prompts_text, |
| batch_dna_sequences, |
| return_tensors="pt", |
| padding=True, |
| padding_side="left", |
| add_special_tokens=False, |
| ) |
|
|
| prompt_inputs = super()._prepare_inputs(prompt_inputs) |
| prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| start = time.time() |
| with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: |
| kwargs = {k: v for k, v in prompt_inputs.items() if k not in self.dna_module.get_non_generate_params()} |
| generate_returned_result = unwrapped_model.generate( |
| **kwargs, |
| generation_config=self.generation_config |
| ) |
| end = time.time() |
| print(f"Generation time: {end - start:.9f} seconds") |
| prompt_length = prompt_ids.size(1) |
| if not self.dna_module.is_embeds_input(): |
| prompt_completion_ids = generate_returned_result |
| prompt_ids = prompt_completion_ids[:, :prompt_length] |
| completion_ids = prompt_completion_ids[:, prompt_length:] |
| else: |
| |
| |
| completion_ids = generate_returned_result |
| prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| is_eos = completion_ids == self.processing_class.eos_token_id |
| eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) |
| eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] |
| sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) |
| completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() |
|
|
| |
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
|
|
| |
| multimodal_keywords = self.dna_module.get_custom_multimodal_keywords() |
| multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords} |
| with torch.no_grad(): |
| |
| |
| if self.num_iterations > 1: |
| old_per_token_logps = self._get_per_token_logps( |
| model, prompt_completion_ids, attention_mask, **multimodal_inputs |
| ) |
| old_per_token_logps = old_per_token_logps[:, prompt_length - 1:] |
| else: |
| old_per_token_logps = None |
|
|
| if self.beta == 0.0: |
| ref_per_token_logps = None |
| elif self.ref_model is not None: |
| ref_per_token_logps = self._get_per_token_logps( |
| self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs |
| ) |
| else: |
| with self.accelerator.unwrap_model(model).disable_adapter(): |
| ref_per_token_logps = self._get_per_token_logps( |
| model, prompt_completion_ids, attention_mask, **multimodal_inputs |
| ) |
| if ref_per_token_logps is not None: |
| ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:] |
|
|
| |
| completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) |
| if is_conversational(inputs[0]): |
| completions = [[{"role": "assistant", "content": completion}] for completion in completions_text] |
| else: |
| completions = completions_text |
| |
| |
| print("Reward calculation...") |
| rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) |
| for i, (reward_func, reward_processing_class) in enumerate( |
| zip(self.reward_funcs, self.reward_processing_classes) |
| ): |
| if isinstance(reward_func, PreTrainedModel): |
| if is_conversational(inputs[0]): |
| messages = [{"messages": p + c} for p, c in zip(prompts, completions)] |
| texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] |
| else: |
| texts = [p + c for p, c in zip(prompts, completions)] |
| reward_inputs = reward_processing_class( |
| texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False |
| ) |
| reward_inputs = super()._prepare_inputs(reward_inputs) |
| with torch.inference_mode(): |
| rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] |
| else: |
| |
| reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} |
| for key in reward_kwargs: |
| for example in inputs: |
| |
| |
| reward_kwargs[key].extend([example[key]]) |
| output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) |
| rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) |
|
|
| |
| rewards_per_func = self.accelerator.gather(rewards_per_func) |
| |
| |
| rewards = rewards_per_func.sum(dim=1) |
| |
| |
| |
| mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) |
| std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) |
| |
| |
| mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) |
| std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) |
| advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) |
| |
| |
| process_slice = slice( |
| self.accelerator.process_index * len(prompts), |
| (self.accelerator.process_index + 1) * len(prompts), |
| ) |
| advantages = advantages[process_slice] |
|
|
| |
| print("Logging metrics...") |
| completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() |
| self._metrics["completion_length"].append(completion_length) |
|
|
| reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) |
| for i, reward_func in enumerate(self.reward_funcs): |
| if isinstance(reward_func, PreTrainedModel): |
| reward_func_name = reward_func.config._name_or_path.split("/")[-1] |
| else: |
| reward_func_name = reward_func.__name__ |
| self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) |
|
|
| self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) |
|
|
| self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) |
|
|
| print(self.log_completions, self.state.global_step, self.args.logging_steps, self.args.report_to) |
| if ( |
| self.log_completions |
| and self.state.global_step % self.args.logging_steps == 0 |
| and "wandb" in self.args.report_to |
| ): |
| timestamp = time.time() |
|
|
| |
| num_items = len(gather_object(prompts_text)) |
|
|
| table = { |
| "step": [f"{self.state.global_step}_{timestamp}"] * num_items, |
| "prompt": gather_object(prompts_text), |
| "completion": gather_object(completions_text), |
| "reward": rewards.tolist(), |
| } |
| df = pd.DataFrame(table) |
|
|
| if wandb.run is not None and self.accelerator.is_main_process: |
| wandb.log({f"completions_{self.state.global_step}_{timestamp}": wandb.Table(dataframe=df)}) |
|
|
| return { |
| "prompt_ids": prompt_ids, |
| "prompt_mask": prompt_mask, |
| "completion_ids": completion_ids, |
| "completion_mask": completion_mask, |
| "old_per_token_logps": old_per_token_logps, |
| "ref_per_token_logps": ref_per_token_logps, |
| "advantages": advantages, |
| "multimodal_inputs": multimodal_inputs |
| } |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| if return_outputs: |
| raise ValueError("The GRPOTrainer does not support returning outputs") |
| |
| |
| print("index 1") |
| if self.state.global_step % self.num_iterations == 0: |
| inputs = self._generate_and_score_completions(inputs, model) |
| self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs |
| else: |
| inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] |
| self._step += 1 |
| |
| print("index 2") |
| |
| prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] |
| completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] |
| multimodal_inputs = inputs["multimodal_inputs"] |
| |
| |
| input_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
| print("index 3") |
| |
| |
| print("index 4") |
| per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs) |
| |
| per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:] |
|
|
| |
| advantages = inputs["advantages"] |
| print("index 5") |
| |
| |
| old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach() |
|
|
| |
| coef_1 = torch.exp(per_token_logps - old_per_token_logps) |
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) |
| per_token_loss1 = coef_1 * advantages.unsqueeze(1) |
| per_token_loss2 = coef_2 * advantages.unsqueeze(1) |
| per_token_loss = -torch.min(per_token_loss1, per_token_loss2) |
| print("index 6") |
| |
| if self.beta > 0: |
| ref_per_token_logps = inputs["ref_per_token_logps"] |
| per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 |
| per_token_loss = per_token_loss + self.beta * per_token_kl |
|
|
| |
| mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() |
| self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) |
|
|
| |
| print("Computing final loss...") |
| loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() |
|
|
| |
| is_clipped = (per_token_loss1 < per_token_loss2).float() |
| clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() |
| self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item()) |
|
|
| return loss |
|
|
| def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: |
| metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} |
| logs = {**logs, **metrics} |
| if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): |
| super().log(logs, start_time) |
| else: |
| super().log(logs) |
| self._metrics.clear() |
|
|
| def create_model_card( |
| self, |
| model_name: Optional[str] = None, |
| dataset_name: Optional[str] = None, |
| tags: Union[str, list[str], None] = None, |
| ): |
| """ |
| Creates a draft of a model card using the information available to the `Trainer`. |
| |
| Args: |
| model_name (`str` or `None`, *optional*, defaults to `None`): |
| Name of the model. |
| dataset_name (`str` or `None`, *optional*, defaults to `None`): |
| Name of the dataset used for training. |
| tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): |
| Tags to be associated with the model card. |
| """ |
| if not self.is_world_process_zero(): |
| return |
|
|
| if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): |
| base_model = self.model.config._name_or_path |
| else: |
| base_model = None |
|
|
| tags = tags or [] |
| if isinstance(tags, str): |
| tags = [tags] |
|
|
| if hasattr(self.model.config, "unsloth_version"): |
| tags.append("unsloth") |
|
|
| citation = textwrap.dedent( |
| """\ |
| @article{zhihong2024deepseekmath, |
| title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, |
| author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, |
| year = 2024, |
| eprint = {arXiv:2402.03300}, |
| """ |
| ) |
|
|
| model_card = generate_model_card( |
| base_model=base_model, |
| model_name=model_name, |
| hub_model_id=self.hub_model_id, |
| dataset_name=dataset_name, |
| tags=tags, |
| wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, |
| comet_url=get_comet_experiment_url(), |
| trainer_name="GRPO", |
| trainer_citation=citation, |
| paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", |
| paper_id="2402.03300", |
| ) |
|
|
| model_card.save(os.path.join(self.args.output_dir, "README.md")) |
|
|
| def _get_train_sampler(self) -> Sampler: |
| """Returns a sampler that ensures proper data sampling for GRPO training.""" |
| effective_batch_size = ( |
| self.args.per_device_train_batch_size |
| * self.accelerator.num_processes |
| * self.args.gradient_accumulation_steps |
| ) |
| |
| return RepeatRandomSampler( |
| data_source=self.train_dataset, |
| mini_repeat_count=self.num_generations, |
| batch_size=effective_batch_size // self.num_generations, |
| repeat_count=self.num_iterations, |
| seed=self.args.seed, |
| ) |
|
|
| def _get_eval_sampler(self, eval_dataset) -> Sampler: |
| """Returns a sampler for evaluation.""" |
| return RepeatRandomSampler( |
| data_source=eval_dataset, |
| mini_repeat_count=self.num_generations, |
| seed=self.args.seed, |
| ) |