File size: 12,296 Bytes
cb2428f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional
from swift.llm import MODEL_MAPPING
from swift.trainers.arguments import GRPOArgumentsMixin
from swift.utils import get_logger, is_master, set_default_ddp_config
from .train_args import TrainArguments
logger = get_logger()
@dataclass
class RewardModelArguments:
reward_model: Optional[str] = None
reward_adapters: List[str] = field(default_factory=list)
reward_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
reward_model_revision: Optional[str] = None
@dataclass
class PPOArguments:
num_ppo_epochs: int = 4
whiten_rewards: bool = False
kl_coef: float = 0.05
cliprange: float = 0.2
vf_coef: float = 0.1
cliprange_value: float = 0.2
gamma: float = 1.0
lam: float = 0.95
num_mini_batches: int = 1
local_rollout_forward_batch_size: int = 64
num_sample_generations: int = 10
response_length: int = 512
missing_eos_penalty: Optional[float] = None
@dataclass
class GRPOArguments(GRPOArgumentsMixin):
num_generations: int = 8 # G in the GRPO paper
max_completion_length: int = 512
ds3_gather_for_generation: bool = True
reward_funcs: List[str] = field(default_factory=list)
reward_weights: List[float] = None
log_completions: bool = False
# vLLM in GRPO
use_vllm: bool = False
# multi step
num_iterations: int = 1
truncation_strategy: Literal['delete', 'left', 'right', None] = None
@dataclass
class RLHFArguments(GRPOArguments, PPOArguments, RewardModelArguments, TrainArguments):
"""
RLHFArguments is a dataclass that holds arguments specific to the Reinforcement
Learning with Human Feedback (RLHF) training backend.
Args:
rlhf_type (Literal): Specifies the type of RLHF to use. Default is 'dpo'.
Allowed values are 'dpo', 'orpo', 'simpo', 'kto', 'cpo'.
ref_model_type (Optional[str]): Type of reference model. Default is None.
ref_model_revision (Optional[str]): Revision of the reference model. Default is None.
beta (Optional[float]): Beta parameter for RLHF. Default is None.
label_smoothing (float): Label smoothing value. Default is 0.
rpo_alpha (float): Alpha parameter for RPO. Default is 1.
cpo_alpha (float): Alpha parameter for CPO. Default is 1.
simpo_gamma (float): Gamma parameter for SimPO. Default is 1.
desirable_weight (float): Weight for desirable outcomes in KTO. Default is 1.0.
undesirable_weight (float): Weight for undesirable outcomes in KTO. Default is 1.0.
"""
rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm', 'ppo', 'grpo'] = 'dpo'
ref_model: Optional[str] = None
ref_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
ref_model_revision: Optional[str] = None
beta: Optional[float] = None
label_smoothing: float = 0
loss_scale: Optional[str] = None # 'last_round'
# DPO
rpo_alpha: float = 1.
# CPO
cpo_alpha: float = 1.
# SimPO
simpo_gamma: float = 1
# KTO
desirable_weight: float = 1.0
undesirable_weight: float = 1.0
# PPO/GRPO
temperature: float = 0.9
# RM
center_rewards_coefficient: Optional[float] = None
def _prepare_training_args(self, training_args: Dict[str, Any]) -> None:
if self.rlhf_type == 'ppo':
training_args['world_size'] = self.global_world_size
def __post_init__(self):
self._init_grpo()
self._init_rm()
self._init_simpo()
self._init_ppo()
self._set_default()
self._init_external_vllm()
super().__post_init__()
self._check_grpo()
self._external_vllm_warning()
if self.loss_scale is None:
if self.rlhf_type == 'orpo' and not self.model_meta.is_multimodal:
# Avoid padding labels during the model's forward pass in multimodal models.
# Some multimodal models do not expand the image pad token.
self.loss_scale = 'default'
else:
self.loss_scale = 'last_round'
if self.rlhf_type == 'grpo' and self.beta == 0.0:
self.ref_model = None
elif self.rlhf_type in ['dpo', 'kto', 'ppo', 'grpo'] and self.train_type == 'full':
self.ref_model = self.ref_model or self.model
self.ref_model_type = self.ref_model_type or self.model_type
self.ref_model_revision = self.ref_model_revision or self.model_revision
elif self.ref_model is not None:
raise ValueError('CPO/ORPO or LoRA training does not require a ref_model to be passed in.')
def _init_grpo(self):
if self.rlhf_type == 'grpo':
if self.use_vllm or self.use_lmdeploy:
os.environ['USE_FAST_INFERENCE'] = '1'
set_default_ddp_config()
if self.async_generate or not self.use_vllm:
self.sleep_level = 0
self.remove_unused_columns = False
logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')
if self.truncation_strategy is None:
self.truncation_strategy = 'left'
assert self.truncation_strategy == 'left', \
"GRPO requires `truncation_strategy='left'`," \
f"Current value: `truncation_strategy='{self.truncation_strategy}'`."
if self.beta is None:
self.beta = 0.04 # https://arxiv.org/abs/2402.03300
if self.async_generate:
logger.info('Using async mode. This is a approximate version which '
'will use the old weights to generate responses to accelerate. '
'This will ignore the `CLIP` of advantages, if you found the training '
'is unstable, you may consider using --async_generate false.')
if 'soft_overlong' in self.reward_funcs:
assert self.soft_cache_length is not None, \
'The soft_cache_length must be set when using soft overlong rewards.'
if self.soft_max_length is None:
self.soft_max_length = self.max_completion_length
logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}')
def _init_ppo(self):
if self.rlhf_type == 'ppo':
self.padding_side = 'left'
# TODO: streaming, MLLM
def _init_metric_for_best_model(self):
if self.rlhf_type not in {'ppo', 'grpo'}:
super()._init_metric_for_best_model()
elif self.rlhf_type == 'grpo' and self.metric_for_best_model is None:
self.metric_for_best_model = 'reward'
def _init_simpo(self):
if self.rlhf_type != 'simpo':
return
self.rlhf_type = 'cpo'
if self.loss_type is None:
self.loss_type = 'simpo'
if self.beta is None:
self.beta = 2.
def _init_rm(self):
if self.rlhf_type == 'rm':
self.task_type = 'seq_cls'
self.num_labels = 1
def _init_external_vllm(self):
if self.rlhf_type != 'grpo' or self.vllm_server_host is None:
return
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
if is_master():
self.vllm_client = VLLMClient(
self.vllm_server_host, self.vllm_server_port, connection_timeout=self.vllm_server_timeout)
self.vllm_client.init_communicator()
def _set_default(self):
if self.beta is None:
self.beta = 0.1
if self.loss_type is None:
if self.rlhf_type in ['dpo', 'cpo']:
self.loss_type = 'sigmoid' # else None
elif self.rlhf_type in ['kto']:
self.loss_type = 'kto'
elif self.rlhf_type == 'grpo':
self.loss_type = 'grpo'
def _check_grpo(self):
if self.rlhf_type != 'grpo':
return
from packaging import version
import trl
trl_version = version.parse(trl.__version__)
assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
'Please update it by running: pip install -U trl')
if self.num_generations < 2:
raise ValueError(
'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided '
f'{self.num_generations}, which is less than the minimum required.')
from swift.utils import get_device_count, get_dist_setting
device_count = get_device_count()
_, _, _, local_world_size = get_dist_setting()
num_infer_workers = self.num_infer_workers
fast_infer = self.use_vllm or self.use_lmdeploy
if fast_infer and self.vllm_server_host is None:
is_colocate_mode = (device_count == num_infer_workers)
if is_colocate_mode:
# colocate mode
assert device_count == local_world_size, (
f'Colocate mode requires device_count({device_count}) == num_infer_workers({num_infer_workers}). '
'Please check if your device count matches NPROC_PER_NODE setting.')
logger.info(
'You are using colocate mode because you have set num_infer_workers to be the same as '
'NPROC_PER_NODE, where model training and sampling will be performed on a single GPU. '
'If you encounter an Out-of-Memory (OOM) error, it is recommended to set the `sleep_level`, '
'`offload_model`, and `offload_optimizer` parameters.')
assert not self.async_generate, 'async_generate requires async mode, but you are under colocate mode'
if self.use_lmdeploy and self.tensor_parallel_size > 1:
raise ValueError('Currently LMDeploy do not support tensor parallel')
if self.use_vllm and self.sleep_level:
logger.warning('It is highly recommended to use `sleep_level==1` in colocate mode,'
'otherwise it may lead to an OOM (Out of Memory) error.')
else:
# async mode
assert device_count == (local_world_size + num_infer_workers), (
f'Async mode requires total GPUs({device_count}) = training GPUs({local_world_size}) + '
f'inference workers({num_infer_workers}). Please adjust your GPU allocation.')
logger.info(
'You are using async mode, where model training and sampling will be performed on different GPUs.')
if self.sleep_level > 0:
logger.warning('You are using different GPUs for training and rollout, '
'so you do not need to use sleep_level > 0')
assert self.tensor_parallel_size == 1, ('async mode do not support tensor parallel right now')
def _external_vllm_warning(self):
if self.rlhf_type != 'grpo' or not self.vllm_server_host:
return
if self.vllm_device != 'auto':
logger.warning("Configuration conflict: External vLLM engine detected, but 'vllm_device' is set to '%s'. ",
self.vllm_device)
if self.num_infer_workers != 1:
logger.warning(
"Auto-adjustment: Changing 'num_infer_workers' from %s to 1 because external vLLM engine is detected",
self.num_infer_workers)
self.num_infer_workers = 1
if self.vllm_max_model_len is not None:
logger.warning(
"Configuration conflict: 'vllm_max_model_len=%s' is ignored for external vLLM. "
'Please specify it when launching the inference service: '
'`swift deploy --max_model_len <value>`', self.vllm_max_model_len)
|