Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional
from swift.llm import MODEL_MAPPING
from swift.trainers.arguments import GRPOArgumentsMixin
from swift.utils import get_logger, is_master, set_default_ddp_config
from .train_args import TrainArguments
logger = get_logger()
@dataclass
class RewardModelArguments:
reward_model: Optional[str] = None
reward_adapters: List[str] = field(default_factory=list)
reward_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
reward_model_revision: Optional[str] = None
@dataclass
class PPOArguments:
num_ppo_epochs: int = 4
whiten_rewards: bool = False
kl_coef: float = 0.05
cliprange: float = 0.2
vf_coef: float = 0.1
cliprange_value: float = 0.2
gamma: float = 1.0
lam: float = 0.95
num_mini_batches: int = 1
local_rollout_forward_batch_size: int = 64
num_sample_generations: int = 10
response_length: int = 512
missing_eos_penalty: Optional[float] = None
@dataclass
class GRPOArguments(GRPOArgumentsMixin):
num_generations: int = 8 # G in the GRPO paper
max_completion_length: int = 512
ds3_gather_for_generation: bool = True
reward_funcs: List[str] = field(default_factory=list)
reward_weights: List[float] = None
log_completions: bool = False
# vLLM in GRPO
use_vllm: bool = False
# multi step
num_iterations: int = 1
truncation_strategy: Literal['delete', 'left', 'right', None] = None
@dataclass
class RLHFArguments(GRPOArguments, PPOArguments, RewardModelArguments, TrainArguments):
"""
RLHFArguments is a dataclass that holds arguments specific to the Reinforcement
Learning with Human Feedback (RLHF) training backend.
Args:
rlhf_type (Literal): Specifies the type of RLHF to use. Default is 'dpo'.
Allowed values are 'dpo', 'orpo', 'simpo', 'kto', 'cpo'.
ref_model_type (Optional[str]): Type of reference model. Default is None.
ref_model_revision (Optional[str]): Revision of the reference model. Default is None.
beta (Optional[float]): Beta parameter for RLHF. Default is None.
label_smoothing (float): Label smoothing value. Default is 0.
rpo_alpha (float): Alpha parameter for RPO. Default is 1.
cpo_alpha (float): Alpha parameter for CPO. Default is 1.
simpo_gamma (float): Gamma parameter for SimPO. Default is 1.
desirable_weight (float): Weight for desirable outcomes in KTO. Default is 1.0.
undesirable_weight (float): Weight for undesirable outcomes in KTO. Default is 1.0.
"""
rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm', 'ppo', 'grpo'] = 'dpo'
ref_model: Optional[str] = None
ref_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
ref_model_revision: Optional[str] = None
beta: Optional[float] = None
label_smoothing: float = 0
loss_scale: Optional[str] = None # 'last_round'
# DPO
rpo_alpha: float = 1.
# CPO
cpo_alpha: float = 1.
# SimPO
simpo_gamma: float = 1
# KTO
desirable_weight: float = 1.0
undesirable_weight: float = 1.0
# PPO/GRPO
temperature: float = 0.9
# RM
center_rewards_coefficient: Optional[float] = None
def _prepare_training_args(self, training_args: Dict[str, Any]) -> None:
if self.rlhf_type == 'ppo':
training_args['world_size'] = self.global_world_size
def __post_init__(self):
self._init_grpo()
self._init_rm()
self._init_simpo()
self._init_ppo()
self._set_default()
self._init_external_vllm()
super().__post_init__()
self._check_grpo()
self._external_vllm_warning()
if self.loss_scale is None:
if self.rlhf_type == 'orpo' and not self.model_meta.is_multimodal:
# Avoid padding labels during the model's forward pass in multimodal models.
# Some multimodal models do not expand the image pad token.
self.loss_scale = 'default'
else:
self.loss_scale = 'last_round'
if self.rlhf_type == 'grpo' and self.beta == 0.0:
self.ref_model = None
elif self.rlhf_type in ['dpo', 'kto', 'ppo', 'grpo'] and self.train_type == 'full':
self.ref_model = self.ref_model or self.model
self.ref_model_type = self.ref_model_type or self.model_type
self.ref_model_revision = self.ref_model_revision or self.model_revision
elif self.ref_model is not None:
raise ValueError('CPO/ORPO or LoRA training does not require a ref_model to be passed in.')
def _init_grpo(self):
if self.rlhf_type == 'grpo':
if self.use_vllm or self.use_lmdeploy:
os.environ['USE_FAST_INFERENCE'] = '1'
set_default_ddp_config()
if self.async_generate or not self.use_vllm:
self.sleep_level = 0
self.remove_unused_columns = False
logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')
if self.truncation_strategy is None:
self.truncation_strategy = 'left'
assert self.truncation_strategy == 'left', \
"GRPO requires `truncation_strategy='left'`," \
f"Current value: `truncation_strategy='{self.truncation_strategy}'`."
if self.beta is None:
self.beta = 0.04 # https://arxiv.org/abs/2402.03300
if self.async_generate:
logger.info('Using async mode. This is a approximate version which '
'will use the old weights to generate responses to accelerate. '
'This will ignore the `CLIP` of advantages, if you found the training '
'is unstable, you may consider using --async_generate false.')
if 'soft_overlong' in self.reward_funcs:
assert self.soft_cache_length is not None, \
'The soft_cache_length must be set when using soft overlong rewards.'
if self.soft_max_length is None:
self.soft_max_length = self.max_completion_length
logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}')
def _init_ppo(self):
if self.rlhf_type == 'ppo':
self.padding_side = 'left'
# TODO: streaming, MLLM
def _init_metric_for_best_model(self):
if self.rlhf_type not in {'ppo', 'grpo'}:
super()._init_metric_for_best_model()
elif self.rlhf_type == 'grpo' and self.metric_for_best_model is None:
self.metric_for_best_model = 'reward'
def _init_simpo(self):
if self.rlhf_type != 'simpo':
return
self.rlhf_type = 'cpo'
if self.loss_type is None:
self.loss_type = 'simpo'
if self.beta is None:
self.beta = 2.
def _init_rm(self):
if self.rlhf_type == 'rm':
self.task_type = 'seq_cls'
self.num_labels = 1
def _init_external_vllm(self):
if self.rlhf_type != 'grpo' or self.vllm_server_host is None:
return
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
if is_master():
self.vllm_client = VLLMClient(
self.vllm_server_host, self.vllm_server_port, connection_timeout=self.vllm_server_timeout)
self.vllm_client.init_communicator()
def _set_default(self):
if self.beta is None:
self.beta = 0.1
if self.loss_type is None:
if self.rlhf_type in ['dpo', 'cpo']:
self.loss_type = 'sigmoid' # else None
elif self.rlhf_type in ['kto']:
self.loss_type = 'kto'
elif self.rlhf_type == 'grpo':
self.loss_type = 'grpo'
def _check_grpo(self):
if self.rlhf_type != 'grpo':
return
from packaging import version
import trl
trl_version = version.parse(trl.__version__)
assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
'Please update it by running: pip install -U trl')
if self.num_generations < 2:
raise ValueError(
'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided '
f'{self.num_generations}, which is less than the minimum required.')
from swift.utils import get_device_count, get_dist_setting
device_count = get_device_count()
_, _, _, local_world_size = get_dist_setting()
num_infer_workers = self.num_infer_workers
fast_infer = self.use_vllm or self.use_lmdeploy
if fast_infer and self.vllm_server_host is None:
is_colocate_mode = (device_count == num_infer_workers)
if is_colocate_mode:
# colocate mode
assert device_count == local_world_size, (
f'Colocate mode requires device_count({device_count}) == num_infer_workers({num_infer_workers}). '
'Please check if your device count matches NPROC_PER_NODE setting.')
logger.info(
'You are using colocate mode because you have set num_infer_workers to be the same as '
'NPROC_PER_NODE, where model training and sampling will be performed on a single GPU. '
'If you encounter an Out-of-Memory (OOM) error, it is recommended to set the `sleep_level`, '
'`offload_model`, and `offload_optimizer` parameters.')
assert not self.async_generate, 'async_generate requires async mode, but you are under colocate mode'
if self.use_lmdeploy and self.tensor_parallel_size > 1:
raise ValueError('Currently LMDeploy do not support tensor parallel')
if self.use_vllm and self.sleep_level:
logger.warning('It is highly recommended to use `sleep_level==1` in colocate mode,'
'otherwise it may lead to an OOM (Out of Memory) error.')
else:
# async mode
assert device_count == (local_world_size + num_infer_workers), (
f'Async mode requires total GPUs({device_count}) = training GPUs({local_world_size}) + '
f'inference workers({num_infer_workers}). Please adjust your GPU allocation.')
logger.info(
'You are using async mode, where model training and sampling will be performed on different GPUs.')
if self.sleep_level > 0:
logger.warning('You are using different GPUs for training and rollout, '
'so you do not need to use sleep_level > 0')
assert self.tensor_parallel_size == 1, ('async mode do not support tensor parallel right now')
def _external_vllm_warning(self):
if self.rlhf_type != 'grpo' or not self.vllm_server_host:
return
if self.vllm_device != 'auto':
logger.warning("Configuration conflict: External vLLM engine detected, but 'vllm_device' is set to '%s'. ",
self.vllm_device)
if self.num_infer_workers != 1:
logger.warning(
"Auto-adjustment: Changing 'num_infer_workers' from %s to 1 because external vLLM engine is detected",
self.num_infer_workers)
self.num_infer_workers = 1
if self.vllm_max_model_len is not None:
logger.warning(
"Configuration conflict: 'vllm_max_model_len=%s' is ignored for external vLLM. "
'Please specify it when launching the inference service: '
'`swift deploy --max_model_len <value>`', self.vllm_max_model_len)