|
|
|
|
|
import os |
|
|
from typing import List, Union |
|
|
|
|
|
from swift.llm import safe_snapshot_download |
|
|
from swift.utils import get_logger, get_model_parameter_info |
|
|
from ..argument import BaseArguments, RLHFArguments |
|
|
from ..model import HfConfigFactory |
|
|
from .kto import prepare_kto_dataset |
|
|
from .sft import SwiftSft |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class SwiftRLHF(SwiftSft): |
|
|
args_class = RLHFArguments |
|
|
args: args_class |
|
|
|
|
|
def _prepare_model_tokenizer(self): |
|
|
if self.args.sequence_parallel_size > 1: |
|
|
|
|
|
|
|
|
from swift.trainers.sequence_parallel import sequence_parallel |
|
|
sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size) |
|
|
|
|
|
from swift.llm.infer.utils import prepare_adapter |
|
|
args = self.args |
|
|
|
|
|
def prepare_single_model(key, origin_key=None): |
|
|
origin_key = origin_key or key |
|
|
model_id_or_path = getattr(args, f'{key}_model') |
|
|
if model_id_or_path is None: |
|
|
return None |
|
|
|
|
|
model_type = getattr(args, f'{key}_model_type') |
|
|
model_revision = getattr(args, f'{key}_model_revision') |
|
|
model_dir = safe_snapshot_download( |
|
|
model_id_or_path=model_id_or_path, |
|
|
revision=model_revision, |
|
|
download_model=False, |
|
|
use_hf=args.use_hf, |
|
|
hub_token=args.hub_token, |
|
|
) |
|
|
task_type = None |
|
|
num_labels = None |
|
|
if os.path.exists(os.path.join(model_dir, 'args.json')): |
|
|
model_args = BaseArguments.from_pretrained(model_dir) |
|
|
if hasattr(model_args, 'task_type'): |
|
|
task_type = model_args.task_type |
|
|
else: |
|
|
from transformers import AutoConfig |
|
|
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
|
|
if hasattr(model_config, 'num_labels'): |
|
|
num_labels = model_config.num_labels |
|
|
if task_type == 'seq_cls': |
|
|
num_labels = 1 |
|
|
|
|
|
model, processor = args.get_model_processor( |
|
|
model=model_id_or_path, |
|
|
model_type=model_type, |
|
|
model_revision=model_revision, |
|
|
task_type=task_type, |
|
|
num_labels=num_labels) |
|
|
|
|
|
adapters = args.adapters if key == 'ref' else args.reward_adapters |
|
|
model = prepare_adapter(args, model, adapters) |
|
|
if origin_key in {'ref', 'reward'}: |
|
|
if self.args.sequence_parallel_size > 1: |
|
|
from swift.trainers.sequence_parallel import sequence_parallel |
|
|
if hasattr(model, 'model_meta'): |
|
|
is_multimodal = model.model_meta.is_multimodal |
|
|
else: |
|
|
is_multimodal = model.model.model_meta.is_multimodal |
|
|
sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal) |
|
|
model.requires_grad_(False).eval() |
|
|
else: |
|
|
model = self.prepare_model(args, model, task_type=task_type) |
|
|
logger.info(f'value_model: {model}') |
|
|
model_parameter_info = get_model_parameter_info(model) |
|
|
self.train_msg['value_model_parameter_info'] = model_parameter_info |
|
|
logger.info(f'value_model_parameter_info: {model_parameter_info}') |
|
|
|
|
|
HfConfigFactory.set_model_config_attr(model, 'use_cache', False) |
|
|
return model, processor |
|
|
|
|
|
|
|
|
for key in ['ref', 'value']: |
|
|
setattr(self, f'{key}_model', None) |
|
|
if key == 'value' and args.rlhf_type != 'ppo': |
|
|
continue |
|
|
|
|
|
model_key = 'reward' if key == 'value' else key |
|
|
result = prepare_single_model(model_key, key) |
|
|
if result is not None: |
|
|
model, _ = result |
|
|
setattr(self, f'{key}_model', model) |
|
|
|
|
|
|
|
|
self.reward_model = None |
|
|
if hasattr(args, 'reward_model') and args.reward_model is not None: |
|
|
reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model] |
|
|
self.reward_model = [] |
|
|
if args.rlhf_type == 'grpo': |
|
|
self.reward_template = [] |
|
|
|
|
|
for reward_model_path in reward_models: |
|
|
args.reward_model = reward_model_path |
|
|
result = prepare_single_model('reward') |
|
|
if result is not None: |
|
|
model, processor = result |
|
|
self.reward_model.append(model) |
|
|
|
|
|
if args.rlhf_type == 'grpo': |
|
|
reward_template = self.args.get_template(processor, processor.model_meta.template) |
|
|
if reward_template.use_model: |
|
|
reward_template.model = model |
|
|
self.reward_template.append(reward_template) |
|
|
args.reward_model = reward_models |
|
|
|
|
|
super()._prepare_model_tokenizer() |
|
|
|
|
|
def _prepare_template(self) -> None: |
|
|
args = self.args |
|
|
super()._prepare_template() |
|
|
model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'} |
|
|
self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf')) |
|
|
|
|
|
if args.rlhf_type == 'ppo': |
|
|
args.training_args.stop_token_id = self.template.template_meta.stop_token_id |
|
|
|
|
|
def _get_dataset(self): |
|
|
args = self.args |
|
|
train_dataset, val_dataset = super()._get_dataset() |
|
|
if args.rlhf_type == 'kto': |
|
|
train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset) |
|
|
return train_dataset, val_dataset |
|
|
|
|
|
def _get_trainer_kwargs(self): |
|
|
trainer_kwargs = {} |
|
|
for key in ['ref', 'reward', 'value']: |
|
|
key = f'{key}_model' |
|
|
model = getattr(self, key, None) |
|
|
if model or self.args.rlhf_type == 'ppo': |
|
|
trainer_kwargs[key] = model |
|
|
if hasattr(self, 'reward_template'): |
|
|
trainer_kwargs['reward_template'] = self.reward_template |
|
|
if self.args.rlhf_type == 'grpo': |
|
|
trainer_kwargs['reward_funcs'] = self.args.reward_funcs |
|
|
trainer_kwargs['vllm_client'] = self.args.vllm_client |
|
|
return trainer_kwargs |
|
|
|
|
|
|
|
|
def rlhf_main(args: Union[List[str], RLHFArguments, None] = None): |
|
|
return SwiftRLHF(args).main() |
|
|
|