| |
| import os |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Literal, Optional |
|
|
| from swift.llm import MODEL_MAPPING |
| from swift.trainers.arguments import GRPOArgumentsMixin |
| from swift.utils import get_logger, is_master, set_default_ddp_config |
| from .train_args import TrainArguments |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class RewardModelArguments: |
| reward_model: Optional[str] = None |
| reward_adapters: List[str] = field(default_factory=list) |
| reward_model_type: Optional[str] = field( |
| default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) |
| reward_model_revision: Optional[str] = None |
|
|
|
|
| @dataclass |
| class PPOArguments: |
| num_ppo_epochs: int = 4 |
| whiten_rewards: bool = False |
| kl_coef: float = 0.05 |
| cliprange: float = 0.2 |
| vf_coef: float = 0.1 |
| cliprange_value: float = 0.2 |
| gamma: float = 1.0 |
| lam: float = 0.95 |
|
|
| num_mini_batches: int = 1 |
| local_rollout_forward_batch_size: int = 64 |
| num_sample_generations: int = 10 |
| response_length: int = 512 |
| missing_eos_penalty: Optional[float] = None |
|
|
|
|
| @dataclass |
| class GRPOArguments(GRPOArgumentsMixin): |
| num_generations: int = 8 |
| max_completion_length: int = 512 |
| ds3_gather_for_generation: bool = True |
| reward_funcs: List[str] = field(default_factory=list) |
| reward_weights: List[float] = None |
| log_completions: bool = False |
|
|
| |
| use_vllm: bool = False |
|
|
| |
| num_iterations: int = 1 |
|
|
| truncation_strategy: Literal['delete', 'left', 'right', None] = None |
|
|
|
|
| @dataclass |
| class RLHFArguments(GRPOArguments, PPOArguments, RewardModelArguments, TrainArguments): |
| """ |
| RLHFArguments is a dataclass that holds arguments specific to the Reinforcement |
| Learning with Human Feedback (RLHF) training backend. |
| |
| Args: |
| rlhf_type (Literal): Specifies the type of RLHF to use. Default is 'dpo'. |
| Allowed values are 'dpo', 'orpo', 'simpo', 'kto', 'cpo'. |
| ref_model_type (Optional[str]): Type of reference model. Default is None. |
| ref_model_revision (Optional[str]): Revision of the reference model. Default is None. |
| beta (Optional[float]): Beta parameter for RLHF. Default is None. |
| label_smoothing (float): Label smoothing value. Default is 0. |
| rpo_alpha (float): Alpha parameter for RPO. Default is 1. |
| cpo_alpha (float): Alpha parameter for CPO. Default is 1. |
| simpo_gamma (float): Gamma parameter for SimPO. Default is 1. |
| desirable_weight (float): Weight for desirable outcomes in KTO. Default is 1.0. |
| undesirable_weight (float): Weight for undesirable outcomes in KTO. Default is 1.0. |
| """ |
| rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm', 'ppo', 'grpo'] = 'dpo' |
| ref_model: Optional[str] = None |
| ref_model_type: Optional[str] = field( |
| default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) |
| ref_model_revision: Optional[str] = None |
|
|
| beta: Optional[float] = None |
| label_smoothing: float = 0 |
| loss_scale: Optional[str] = None |
| |
| rpo_alpha: float = 1. |
| |
| cpo_alpha: float = 1. |
| |
| simpo_gamma: float = 1 |
| |
| desirable_weight: float = 1.0 |
| undesirable_weight: float = 1.0 |
| |
| temperature: float = 0.9 |
| |
| center_rewards_coefficient: Optional[float] = None |
|
|
| def _prepare_training_args(self, training_args: Dict[str, Any]) -> None: |
| if self.rlhf_type == 'ppo': |
| training_args['world_size'] = self.global_world_size |
|
|
| def __post_init__(self): |
| self._init_grpo() |
| self._init_rm() |
| self._init_simpo() |
| self._init_ppo() |
| self._set_default() |
| self._init_external_vllm() |
| super().__post_init__() |
| self._check_grpo() |
| self._external_vllm_warning() |
|
|
| if self.loss_scale is None: |
| if self.rlhf_type == 'orpo' and not self.model_meta.is_multimodal: |
| |
| |
| self.loss_scale = 'default' |
| else: |
| self.loss_scale = 'last_round' |
| if self.rlhf_type == 'grpo' and self.beta == 0.0: |
| self.ref_model = None |
| elif self.rlhf_type in ['dpo', 'kto', 'ppo', 'grpo'] and self.train_type == 'full': |
| self.ref_model = self.ref_model or self.model |
| self.ref_model_type = self.ref_model_type or self.model_type |
| self.ref_model_revision = self.ref_model_revision or self.model_revision |
| elif self.ref_model is not None: |
| raise ValueError('CPO/ORPO or LoRA training does not require a ref_model to be passed in.') |
|
|
| def _init_grpo(self): |
| if self.rlhf_type == 'grpo': |
| if self.use_vllm or self.use_lmdeploy: |
| os.environ['USE_FAST_INFERENCE'] = '1' |
| set_default_ddp_config() |
| if self.async_generate or not self.use_vllm: |
| self.sleep_level = 0 |
| self.remove_unused_columns = False |
| logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') |
| if self.truncation_strategy is None: |
| self.truncation_strategy = 'left' |
| assert self.truncation_strategy == 'left', \ |
| "GRPO requires `truncation_strategy='left'`," \ |
| f"Current value: `truncation_strategy='{self.truncation_strategy}'`." |
| if self.beta is None: |
| self.beta = 0.04 |
| if self.async_generate: |
| logger.info('Using async mode. This is a approximate version which ' |
| 'will use the old weights to generate responses to accelerate. ' |
| 'This will ignore the `CLIP` of advantages, if you found the training ' |
| 'is unstable, you may consider using --async_generate false.') |
| if 'soft_overlong' in self.reward_funcs: |
| assert self.soft_cache_length is not None, \ |
| 'The soft_cache_length must be set when using soft overlong rewards.' |
| if self.soft_max_length is None: |
| self.soft_max_length = self.max_completion_length |
| logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}') |
|
|
| def _init_ppo(self): |
| if self.rlhf_type == 'ppo': |
| self.padding_side = 'left' |
| |
|
|
| def _init_metric_for_best_model(self): |
| if self.rlhf_type not in {'ppo', 'grpo'}: |
| super()._init_metric_for_best_model() |
| elif self.rlhf_type == 'grpo' and self.metric_for_best_model is None: |
| self.metric_for_best_model = 'reward' |
|
|
| def _init_simpo(self): |
| if self.rlhf_type != 'simpo': |
| return |
|
|
| self.rlhf_type = 'cpo' |
| if self.loss_type is None: |
| self.loss_type = 'simpo' |
| if self.beta is None: |
| self.beta = 2. |
|
|
| def _init_rm(self): |
| if self.rlhf_type == 'rm': |
| self.task_type = 'seq_cls' |
| self.num_labels = 1 |
|
|
| def _init_external_vllm(self): |
| if self.rlhf_type != 'grpo' or self.vllm_server_host is None: |
| return |
| from swift.trainers.rlhf_trainer.vllm_client import VLLMClient |
| if is_master(): |
| self.vllm_client = VLLMClient( |
| self.vllm_server_host, self.vllm_server_port, connection_timeout=self.vllm_server_timeout) |
| self.vllm_client.init_communicator() |
|
|
| def _set_default(self): |
| if self.beta is None: |
| self.beta = 0.1 |
| if self.loss_type is None: |
| if self.rlhf_type in ['dpo', 'cpo']: |
| self.loss_type = 'sigmoid' |
| elif self.rlhf_type in ['kto']: |
| self.loss_type = 'kto' |
| elif self.rlhf_type == 'grpo': |
| self.loss_type = 'grpo' |
|
|
| def _check_grpo(self): |
| if self.rlhf_type != 'grpo': |
| return |
|
|
| from packaging import version |
| import trl |
| trl_version = version.parse(trl.__version__) |
| assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. ' |
| 'Please update it by running: pip install -U trl') |
|
|
| if self.num_generations < 2: |
| raise ValueError( |
| 'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided ' |
| f'{self.num_generations}, which is less than the minimum required.') |
| from swift.utils import get_device_count, get_dist_setting |
| device_count = get_device_count() |
| _, _, _, local_world_size = get_dist_setting() |
| num_infer_workers = self.num_infer_workers |
| fast_infer = self.use_vllm or self.use_lmdeploy |
| if fast_infer and self.vllm_server_host is None: |
| is_colocate_mode = (device_count == num_infer_workers) |
|
|
| if is_colocate_mode: |
| |
| assert device_count == local_world_size, ( |
| f'Colocate mode requires device_count({device_count}) == num_infer_workers({num_infer_workers}). ' |
| 'Please check if your device count matches NPROC_PER_NODE setting.') |
| logger.info( |
| 'You are using colocate mode because you have set num_infer_workers to be the same as ' |
| 'NPROC_PER_NODE, where model training and sampling will be performed on a single GPU. ' |
| 'If you encounter an Out-of-Memory (OOM) error, it is recommended to set the `sleep_level`, ' |
| '`offload_model`, and `offload_optimizer` parameters.') |
| assert not self.async_generate, 'async_generate requires async mode, but you are under colocate mode' |
| if self.use_lmdeploy and self.tensor_parallel_size > 1: |
| raise ValueError('Currently LMDeploy do not support tensor parallel') |
| if self.use_vllm and self.sleep_level: |
| logger.warning('It is highly recommended to use `sleep_level==1` in colocate mode,' |
| 'otherwise it may lead to an OOM (Out of Memory) error.') |
| else: |
| |
| assert device_count == (local_world_size + num_infer_workers), ( |
| f'Async mode requires total GPUs({device_count}) = training GPUs({local_world_size}) + ' |
| f'inference workers({num_infer_workers}). Please adjust your GPU allocation.') |
| logger.info( |
| 'You are using async mode, where model training and sampling will be performed on different GPUs.') |
| if self.sleep_level > 0: |
| logger.warning('You are using different GPUs for training and rollout, ' |
| 'so you do not need to use sleep_level > 0') |
|
|
| assert self.tensor_parallel_size == 1, ('async mode do not support tensor parallel right now') |
|
|
| def _external_vllm_warning(self): |
| if self.rlhf_type != 'grpo' or not self.vllm_server_host: |
| return |
|
|
| if self.vllm_device != 'auto': |
| logger.warning("Configuration conflict: External vLLM engine detected, but 'vllm_device' is set to '%s'. ", |
| self.vllm_device) |
|
|
| if self.num_infer_workers != 1: |
| logger.warning( |
| "Auto-adjustment: Changing 'num_infer_workers' from %s to 1 because external vLLM engine is detected", |
| self.num_infer_workers) |
| self.num_infer_workers = 1 |
|
|
| if self.vllm_max_model_len is not None: |
| logger.warning( |
| "Configuration conflict: 'vllm_max_model_len=%s' is ignored for external vLLM. " |
| 'Please specify it when launching the inference service: ' |
| '`swift deploy --max_model_len <value>`', self.vllm_max_model_len) |
|
|