# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import List, Union from swift.llm import safe_snapshot_download from swift.utils import get_logger, get_model_parameter_info from ..argument import BaseArguments, RLHFArguments from ..model import HfConfigFactory from .kto import prepare_kto_dataset from .sft import SwiftSft logger = get_logger() class SwiftRLHF(SwiftSft): args_class = RLHFArguments args: args_class def _prepare_model_tokenizer(self): if self.args.sequence_parallel_size > 1: # Duplicate calling is allowd to promise this function will # be called before model initializing. from swift.trainers.sequence_parallel import sequence_parallel sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size) # prepare ref/reward/value model from swift.llm.infer.utils import prepare_adapter args = self.args def prepare_single_model(key, origin_key=None): origin_key = origin_key or key model_id_or_path = getattr(args, f'{key}_model') if model_id_or_path is None: return None model_type = getattr(args, f'{key}_model_type') model_revision = getattr(args, f'{key}_model_revision') model_dir = safe_snapshot_download( model_id_or_path=model_id_or_path, revision=model_revision, download_model=False, use_hf=args.use_hf, hub_token=args.hub_token, ) task_type = None num_labels = None if os.path.exists(os.path.join(model_dir, 'args.json')): model_args = BaseArguments.from_pretrained(model_dir) if hasattr(model_args, 'task_type'): task_type = model_args.task_type else: from transformers import AutoConfig model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) if hasattr(model_config, 'num_labels'): num_labels = model_config.num_labels if task_type == 'seq_cls': num_labels = 1 model, processor = args.get_model_processor( model=model_id_or_path, model_type=model_type, model_revision=model_revision, task_type=task_type, num_labels=num_labels) adapters = args.adapters if key == 'ref' else args.reward_adapters model = prepare_adapter(args, model, adapters) if origin_key in {'ref', 'reward'}: if self.args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel if hasattr(model, 'model_meta'): is_multimodal = model.model_meta.is_multimodal else: is_multimodal = model.model.model_meta.is_multimodal sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal) model.requires_grad_(False).eval() else: model = self.prepare_model(args, model, task_type=task_type) logger.info(f'value_model: {model}') model_parameter_info = get_model_parameter_info(model) self.train_msg['value_model_parameter_info'] = model_parameter_info logger.info(f'value_model_parameter_info: {model_parameter_info}') HfConfigFactory.set_model_config_attr(model, 'use_cache', False) return model, processor # Handle ref and value models for key in ['ref', 'value']: setattr(self, f'{key}_model', None) if key == 'value' and args.rlhf_type != 'ppo': continue model_key = 'reward' if key == 'value' else key result = prepare_single_model(model_key, key) if result is not None: model, _ = result setattr(self, f'{key}_model', model) # Handle reward model(s) self.reward_model = None if hasattr(args, 'reward_model') and args.reward_model is not None: reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model] self.reward_model = [] if args.rlhf_type == 'grpo': self.reward_template = [] for reward_model_path in reward_models: args.reward_model = reward_model_path # Temporarily set for prepare_single_model result = prepare_single_model('reward') if result is not None: model, processor = result self.reward_model.append(model) if args.rlhf_type == 'grpo': reward_template = self.args.get_template(processor, processor.model_meta.template) if reward_template.use_model: reward_template.model = model self.reward_template.append(reward_template) args.reward_model = reward_models # Restore original value super()._prepare_model_tokenizer() def _prepare_template(self) -> None: args = self.args super()._prepare_template() model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'} self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf')) if args.rlhf_type == 'ppo': args.training_args.stop_token_id = self.template.template_meta.stop_token_id def _get_dataset(self): args = self.args train_dataset, val_dataset = super()._get_dataset() if args.rlhf_type == 'kto': train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset) return train_dataset, val_dataset def _get_trainer_kwargs(self): trainer_kwargs = {} for key in ['ref', 'reward', 'value']: key = f'{key}_model' model = getattr(self, key, None) if model or self.args.rlhf_type == 'ppo': trainer_kwargs[key] = model if hasattr(self, 'reward_template'): trainer_kwargs['reward_template'] = self.reward_template if self.args.rlhf_type == 'grpo': trainer_kwargs['reward_funcs'] = self.args.reward_funcs trainer_kwargs['vllm_client'] = self.args.vllm_client return trainer_kwargs def rlhf_main(args: Union[List[str], RLHFArguments, None] = None): return SwiftRLHF(args).main()