File size: 12,296 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional

from swift.llm import MODEL_MAPPING
from swift.trainers.arguments import GRPOArgumentsMixin
from swift.utils import get_logger, is_master, set_default_ddp_config
from .train_args import TrainArguments

logger = get_logger()


@dataclass
class RewardModelArguments:
    reward_model: Optional[str] = None
    reward_adapters: List[str] = field(default_factory=list)
    reward_model_type: Optional[str] = field(
        default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
    reward_model_revision: Optional[str] = None


@dataclass
class PPOArguments:
    num_ppo_epochs: int = 4
    whiten_rewards: bool = False
    kl_coef: float = 0.05
    cliprange: float = 0.2
    vf_coef: float = 0.1
    cliprange_value: float = 0.2
    gamma: float = 1.0
    lam: float = 0.95

    num_mini_batches: int = 1
    local_rollout_forward_batch_size: int = 64
    num_sample_generations: int = 10
    response_length: int = 512
    missing_eos_penalty: Optional[float] = None


@dataclass
class GRPOArguments(GRPOArgumentsMixin):
    num_generations: int = 8  # G in the GRPO paper
    max_completion_length: int = 512
    ds3_gather_for_generation: bool = True
    reward_funcs: List[str] = field(default_factory=list)
    reward_weights: List[float] = None
    log_completions: bool = False

    # vLLM in GRPO
    use_vllm: bool = False

    # multi step
    num_iterations: int = 1

    truncation_strategy: Literal['delete', 'left', 'right', None] = None


@dataclass
class RLHFArguments(GRPOArguments, PPOArguments, RewardModelArguments, TrainArguments):
    """
    RLHFArguments is a dataclass that holds arguments specific to the Reinforcement
        Learning with Human Feedback (RLHF) training backend.

    Args:
        rlhf_type (Literal): Specifies the type of RLHF to use. Default is 'dpo'.
            Allowed values are 'dpo', 'orpo', 'simpo', 'kto', 'cpo'.
        ref_model_type (Optional[str]): Type of reference model. Default is None.
        ref_model_revision (Optional[str]): Revision of the reference model. Default is None.
        beta (Optional[float]): Beta parameter for RLHF. Default is None.
        label_smoothing (float): Label smoothing value. Default is 0.
        rpo_alpha (float): Alpha parameter for RPO. Default is 1.
        cpo_alpha (float): Alpha parameter for CPO. Default is 1.
        simpo_gamma (float): Gamma parameter for SimPO. Default is 1.
        desirable_weight (float): Weight for desirable outcomes in KTO. Default is 1.0.
        undesirable_weight (float): Weight for undesirable outcomes in KTO. Default is 1.0.
    """
    rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm', 'ppo', 'grpo'] = 'dpo'
    ref_model: Optional[str] = None
    ref_model_type: Optional[str] = field(
        default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
    ref_model_revision: Optional[str] = None

    beta: Optional[float] = None
    label_smoothing: float = 0
    loss_scale: Optional[str] = None  # 'last_round'
    # DPO
    rpo_alpha: float = 1.
    # CPO
    cpo_alpha: float = 1.
    # SimPO
    simpo_gamma: float = 1
    # KTO
    desirable_weight: float = 1.0
    undesirable_weight: float = 1.0
    # PPO/GRPO
    temperature: float = 0.9
    # RM
    center_rewards_coefficient: Optional[float] = None

    def _prepare_training_args(self, training_args: Dict[str, Any]) -> None:
        if self.rlhf_type == 'ppo':
            training_args['world_size'] = self.global_world_size

    def __post_init__(self):
        self._init_grpo()
        self._init_rm()
        self._init_simpo()
        self._init_ppo()
        self._set_default()
        self._init_external_vllm()
        super().__post_init__()
        self._check_grpo()
        self._external_vllm_warning()

        if self.loss_scale is None:
            if self.rlhf_type == 'orpo' and not self.model_meta.is_multimodal:
                # Avoid padding labels during the model's forward pass in multimodal models.
                # Some multimodal models do not expand the image pad token.
                self.loss_scale = 'default'
            else:
                self.loss_scale = 'last_round'
        if self.rlhf_type == 'grpo' and self.beta == 0.0:
            self.ref_model = None
        elif self.rlhf_type in ['dpo', 'kto', 'ppo', 'grpo'] and self.train_type == 'full':
            self.ref_model = self.ref_model or self.model
            self.ref_model_type = self.ref_model_type or self.model_type
            self.ref_model_revision = self.ref_model_revision or self.model_revision
        elif self.ref_model is not None:
            raise ValueError('CPO/ORPO or LoRA training does not require a ref_model to be passed in.')

    def _init_grpo(self):
        if self.rlhf_type == 'grpo':
            if self.use_vllm or self.use_lmdeploy:
                os.environ['USE_FAST_INFERENCE'] = '1'
                set_default_ddp_config()
            if self.async_generate or not self.use_vllm:
                self.sleep_level = 0
            self.remove_unused_columns = False
            logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')
            if self.truncation_strategy is None:
                self.truncation_strategy = 'left'
            assert self.truncation_strategy == 'left', \
                "GRPO requires `truncation_strategy='left'`," \
                f"Current value: `truncation_strategy='{self.truncation_strategy}'`."
            if self.beta is None:
                self.beta = 0.04  # https://arxiv.org/abs/2402.03300
            if self.async_generate:
                logger.info('Using async mode. This is a approximate version which '
                            'will use the old weights to generate responses to accelerate. '
                            'This will ignore the `CLIP` of advantages, if you found the training '
                            'is unstable, you may consider using --async_generate false.')
            if 'soft_overlong' in self.reward_funcs:
                assert self.soft_cache_length is not None, \
                    'The soft_cache_length must be set when using soft overlong rewards.'
                if self.soft_max_length is None:
                    self.soft_max_length = self.max_completion_length
                    logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}')

    def _init_ppo(self):
        if self.rlhf_type == 'ppo':
            self.padding_side = 'left'
            # TODO: streaming, MLLM

    def _init_metric_for_best_model(self):
        if self.rlhf_type not in {'ppo', 'grpo'}:
            super()._init_metric_for_best_model()
        elif self.rlhf_type == 'grpo' and self.metric_for_best_model is None:
            self.metric_for_best_model = 'reward'

    def _init_simpo(self):
        if self.rlhf_type != 'simpo':
            return

        self.rlhf_type = 'cpo'
        if self.loss_type is None:
            self.loss_type = 'simpo'
        if self.beta is None:
            self.beta = 2.

    def _init_rm(self):
        if self.rlhf_type == 'rm':
            self.task_type = 'seq_cls'
            self.num_labels = 1

    def _init_external_vllm(self):
        if self.rlhf_type != 'grpo' or self.vllm_server_host is None:
            return
        from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
        if is_master():
            self.vllm_client = VLLMClient(
                self.vllm_server_host, self.vllm_server_port, connection_timeout=self.vllm_server_timeout)
            self.vllm_client.init_communicator()

    def _set_default(self):
        if self.beta is None:
            self.beta = 0.1
        if self.loss_type is None:
            if self.rlhf_type in ['dpo', 'cpo']:
                self.loss_type = 'sigmoid'  # else None
            elif self.rlhf_type in ['kto']:
                self.loss_type = 'kto'
            elif self.rlhf_type == 'grpo':
                self.loss_type = 'grpo'

    def _check_grpo(self):
        if self.rlhf_type != 'grpo':
            return

        from packaging import version
        import trl
        trl_version = version.parse(trl.__version__)
        assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
                                                      'Please update it by running: pip install -U trl')

        if self.num_generations < 2:
            raise ValueError(
                'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided '
                f'{self.num_generations}, which is less than the minimum required.')
        from swift.utils import get_device_count, get_dist_setting
        device_count = get_device_count()
        _, _, _, local_world_size = get_dist_setting()
        num_infer_workers = self.num_infer_workers
        fast_infer = self.use_vllm or self.use_lmdeploy
        if fast_infer and self.vllm_server_host is None:
            is_colocate_mode = (device_count == num_infer_workers)

            if is_colocate_mode:
                # colocate mode
                assert device_count == local_world_size, (
                    f'Colocate mode requires device_count({device_count}) == num_infer_workers({num_infer_workers}). '
                    'Please check if your device count matches NPROC_PER_NODE setting.')
                logger.info(
                    'You are using colocate mode because you have set num_infer_workers to be the same as '
                    'NPROC_PER_NODE, where model training and sampling will be performed on a single GPU. '
                    'If you encounter an Out-of-Memory (OOM) error, it is recommended to set the `sleep_level`, '
                    '`offload_model`, and `offload_optimizer` parameters.')
                assert not self.async_generate, 'async_generate requires async mode, but you are under colocate mode'
                if self.use_lmdeploy and self.tensor_parallel_size > 1:
                    raise ValueError('Currently LMDeploy do not support tensor parallel')
                if self.use_vllm and self.sleep_level:
                    logger.warning('It is highly recommended to use `sleep_level==1` in colocate mode,'
                                   'otherwise it may lead to an OOM (Out of Memory) error.')
            else:
                # async mode
                assert device_count == (local_world_size + num_infer_workers), (
                    f'Async mode requires total GPUs({device_count}) = training GPUs({local_world_size}) + '
                    f'inference workers({num_infer_workers}). Please adjust your GPU allocation.')
                logger.info(
                    'You are using async mode, where model training and sampling will be performed on different GPUs.')
                if self.sleep_level > 0:
                    logger.warning('You are using different GPUs for training and rollout, '
                                   'so you do not need to use sleep_level > 0')

                assert self.tensor_parallel_size == 1, ('async mode do not support tensor parallel right now')

    def _external_vllm_warning(self):
        if self.rlhf_type != 'grpo' or not self.vllm_server_host:
            return

        if self.vllm_device != 'auto':
            logger.warning("Configuration conflict: External vLLM engine detected, but 'vllm_device' is set to '%s'. ",
                           self.vllm_device)

        if self.num_infer_workers != 1:
            logger.warning(
                "Auto-adjustment: Changing 'num_infer_workers' from %s to 1 because external vLLM engine is detected",
                self.num_infer_workers)
            self.num_infer_workers = 1

        if self.vllm_max_model_len is not None:
            logger.warning(
                "Configuration conflict: 'vllm_max_model_len=%s' is ignored for external vLLM. "
                'Please specify it when launching the inference service: '
                '`swift deploy --max_model_len <value>`', self.vllm_max_model_len)