File size: 6,686 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List, Union
from swift.llm import safe_snapshot_download
from swift.utils import get_logger, get_model_parameter_info
from ..argument import BaseArguments, RLHFArguments
from ..model import HfConfigFactory
from .kto import prepare_kto_dataset
from .sft import SwiftSft
logger = get_logger()
class SwiftRLHF(SwiftSft):
args_class = RLHFArguments
args: args_class
def _prepare_model_tokenizer(self):
if self.args.sequence_parallel_size > 1:
# Duplicate calling is allowd to promise this function will
# be called before model initializing.
from swift.trainers.sequence_parallel import sequence_parallel
sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size)
# prepare ref/reward/value model
from swift.llm.infer.utils import prepare_adapter
args = self.args
def prepare_single_model(key, origin_key=None):
origin_key = origin_key or key
model_id_or_path = getattr(args, f'{key}_model')
if model_id_or_path is None:
return None
model_type = getattr(args, f'{key}_model_type')
model_revision = getattr(args, f'{key}_model_revision')
model_dir = safe_snapshot_download(
model_id_or_path=model_id_or_path,
revision=model_revision,
download_model=False,
use_hf=args.use_hf,
hub_token=args.hub_token,
)
task_type = None
num_labels = None
if os.path.exists(os.path.join(model_dir, 'args.json')):
model_args = BaseArguments.from_pretrained(model_dir)
if hasattr(model_args, 'task_type'):
task_type = model_args.task_type
else:
from transformers import AutoConfig
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if hasattr(model_config, 'num_labels'):
num_labels = model_config.num_labels
if task_type == 'seq_cls':
num_labels = 1
model, processor = args.get_model_processor(
model=model_id_or_path,
model_type=model_type,
model_revision=model_revision,
task_type=task_type,
num_labels=num_labels)
adapters = args.adapters if key == 'ref' else args.reward_adapters
model = prepare_adapter(args, model, adapters)
if origin_key in {'ref', 'reward'}:
if self.args.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
if hasattr(model, 'model_meta'):
is_multimodal = model.model_meta.is_multimodal
else:
is_multimodal = model.model.model_meta.is_multimodal
sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal)
model.requires_grad_(False).eval()
else:
model = self.prepare_model(args, model, task_type=task_type)
logger.info(f'value_model: {model}')
model_parameter_info = get_model_parameter_info(model)
self.train_msg['value_model_parameter_info'] = model_parameter_info
logger.info(f'value_model_parameter_info: {model_parameter_info}')
HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
return model, processor
# Handle ref and value models
for key in ['ref', 'value']:
setattr(self, f'{key}_model', None)
if key == 'value' and args.rlhf_type != 'ppo':
continue
model_key = 'reward' if key == 'value' else key
result = prepare_single_model(model_key, key)
if result is not None:
model, _ = result
setattr(self, f'{key}_model', model)
# Handle reward model(s)
self.reward_model = None
if hasattr(args, 'reward_model') and args.reward_model is not None:
reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model]
self.reward_model = []
if args.rlhf_type == 'grpo':
self.reward_template = []
for reward_model_path in reward_models:
args.reward_model = reward_model_path # Temporarily set for prepare_single_model
result = prepare_single_model('reward')
if result is not None:
model, processor = result
self.reward_model.append(model)
if args.rlhf_type == 'grpo':
reward_template = self.args.get_template(processor, processor.model_meta.template)
if reward_template.use_model:
reward_template.model = model
self.reward_template.append(reward_template)
args.reward_model = reward_models # Restore original value
super()._prepare_model_tokenizer()
def _prepare_template(self) -> None:
args = self.args
super()._prepare_template()
model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'}
self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf'))
if args.rlhf_type == 'ppo':
args.training_args.stop_token_id = self.template.template_meta.stop_token_id
def _get_dataset(self):
args = self.args
train_dataset, val_dataset = super()._get_dataset()
if args.rlhf_type == 'kto':
train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset)
return train_dataset, val_dataset
def _get_trainer_kwargs(self):
trainer_kwargs = {}
for key in ['ref', 'reward', 'value']:
key = f'{key}_model'
model = getattr(self, key, None)
if model or self.args.rlhf_type == 'ppo':
trainer_kwargs[key] = model
if hasattr(self, 'reward_template'):
trainer_kwargs['reward_template'] = self.reward_template
if self.args.rlhf_type == 'grpo':
trainer_kwargs['reward_funcs'] = self.args.reward_funcs
trainer_kwargs['vllm_client'] = self.args.vllm_client
return trainer_kwargs
def rlhf_main(args: Union[List[str], RLHFArguments, None] = None):
return SwiftRLHF(args).main()
|