File size: 6,686 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List, Union

from swift.llm import safe_snapshot_download
from swift.utils import get_logger, get_model_parameter_info
from ..argument import BaseArguments, RLHFArguments
from ..model import HfConfigFactory
from .kto import prepare_kto_dataset
from .sft import SwiftSft

logger = get_logger()


class SwiftRLHF(SwiftSft):
    args_class = RLHFArguments
    args: args_class

    def _prepare_model_tokenizer(self):
        if self.args.sequence_parallel_size > 1:
            # Duplicate calling is allowd to promise this function will
            # be called before model initializing.
            from swift.trainers.sequence_parallel import sequence_parallel
            sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size)
        # prepare ref/reward/value model
        from swift.llm.infer.utils import prepare_adapter
        args = self.args

        def prepare_single_model(key, origin_key=None):
            origin_key = origin_key or key
            model_id_or_path = getattr(args, f'{key}_model')
            if model_id_or_path is None:
                return None

            model_type = getattr(args, f'{key}_model_type')
            model_revision = getattr(args, f'{key}_model_revision')
            model_dir = safe_snapshot_download(
                model_id_or_path=model_id_or_path,
                revision=model_revision,
                download_model=False,
                use_hf=args.use_hf,
                hub_token=args.hub_token,
            )
            task_type = None
            num_labels = None
            if os.path.exists(os.path.join(model_dir, 'args.json')):
                model_args = BaseArguments.from_pretrained(model_dir)
                if hasattr(model_args, 'task_type'):
                    task_type = model_args.task_type
            else:
                from transformers import AutoConfig
                model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
                if hasattr(model_config, 'num_labels'):
                    num_labels = model_config.num_labels
            if task_type == 'seq_cls':
                num_labels = 1

            model, processor = args.get_model_processor(
                model=model_id_or_path,
                model_type=model_type,
                model_revision=model_revision,
                task_type=task_type,
                num_labels=num_labels)

            adapters = args.adapters if key == 'ref' else args.reward_adapters
            model = prepare_adapter(args, model, adapters)
            if origin_key in {'ref', 'reward'}:
                if self.args.sequence_parallel_size > 1:
                    from swift.trainers.sequence_parallel import sequence_parallel
                    if hasattr(model, 'model_meta'):
                        is_multimodal = model.model_meta.is_multimodal
                    else:
                        is_multimodal = model.model.model_meta.is_multimodal
                    sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal)
                model.requires_grad_(False).eval()
            else:
                model = self.prepare_model(args, model, task_type=task_type)
                logger.info(f'value_model: {model}')
                model_parameter_info = get_model_parameter_info(model)
                self.train_msg['value_model_parameter_info'] = model_parameter_info
                logger.info(f'value_model_parameter_info: {model_parameter_info}')

            HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
            return model, processor

        # Handle ref and value models
        for key in ['ref', 'value']:
            setattr(self, f'{key}_model', None)
            if key == 'value' and args.rlhf_type != 'ppo':
                continue

            model_key = 'reward' if key == 'value' else key
            result = prepare_single_model(model_key, key)
            if result is not None:
                model, _ = result
                setattr(self, f'{key}_model', model)

        # Handle reward model(s)
        self.reward_model = None
        if hasattr(args, 'reward_model') and args.reward_model is not None:
            reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model]
            self.reward_model = []
            if args.rlhf_type == 'grpo':
                self.reward_template = []

            for reward_model_path in reward_models:
                args.reward_model = reward_model_path  # Temporarily set for prepare_single_model
                result = prepare_single_model('reward')
                if result is not None:
                    model, processor = result
                    self.reward_model.append(model)

                    if args.rlhf_type == 'grpo':
                        reward_template = self.args.get_template(processor, processor.model_meta.template)
                        if reward_template.use_model:
                            reward_template.model = model
                        self.reward_template.append(reward_template)
                args.reward_model = reward_models  # Restore original value

        super()._prepare_model_tokenizer()

    def _prepare_template(self) -> None:
        args = self.args
        super()._prepare_template()
        model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'}
        self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf'))

        if args.rlhf_type == 'ppo':
            args.training_args.stop_token_id = self.template.template_meta.stop_token_id

    def _get_dataset(self):
        args = self.args
        train_dataset, val_dataset = super()._get_dataset()
        if args.rlhf_type == 'kto':
            train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset)
        return train_dataset, val_dataset

    def _get_trainer_kwargs(self):
        trainer_kwargs = {}
        for key in ['ref', 'reward', 'value']:
            key = f'{key}_model'
            model = getattr(self, key, None)
            if model or self.args.rlhf_type == 'ppo':
                trainer_kwargs[key] = model
        if hasattr(self, 'reward_template'):
            trainer_kwargs['reward_template'] = self.reward_template
        if self.args.rlhf_type == 'grpo':
            trainer_kwargs['reward_funcs'] = self.args.reward_funcs
            trainer_kwargs['vllm_client'] = self.args.vllm_client
        return trainer_kwargs


def rlhf_main(args: Union[List[str], RLHFArguments, None] = None):
    return SwiftRLHF(args).main()