Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List, Union
from swift.llm import safe_snapshot_download
from swift.utils import get_logger, get_model_parameter_info
from ..argument import BaseArguments, RLHFArguments
from ..model import HfConfigFactory
from .kto import prepare_kto_dataset
from .sft import SwiftSft
logger = get_logger()
class SwiftRLHF(SwiftSft):
args_class = RLHFArguments
args: args_class
def _prepare_model_tokenizer(self):
if self.args.sequence_parallel_size > 1:
# Duplicate calling is allowd to promise this function will
# be called before model initializing.
from swift.trainers.sequence_parallel import sequence_parallel
sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size)
# prepare ref/reward/value model
from swift.llm.infer.utils import prepare_adapter
args = self.args
def prepare_single_model(key, origin_key=None):
origin_key = origin_key or key
model_id_or_path = getattr(args, f'{key}_model')
if model_id_or_path is None:
return None
model_type = getattr(args, f'{key}_model_type')
model_revision = getattr(args, f'{key}_model_revision')
model_dir = safe_snapshot_download(
model_id_or_path=model_id_or_path,
revision=model_revision,
download_model=False,
use_hf=args.use_hf,
hub_token=args.hub_token,
)
task_type = None
num_labels = None
if os.path.exists(os.path.join(model_dir, 'args.json')):
model_args = BaseArguments.from_pretrained(model_dir)
if hasattr(model_args, 'task_type'):
task_type = model_args.task_type
else:
from transformers import AutoConfig
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if hasattr(model_config, 'num_labels'):
num_labels = model_config.num_labels
if task_type == 'seq_cls':
num_labels = 1
model, processor = args.get_model_processor(
model=model_id_or_path,
model_type=model_type,
model_revision=model_revision,
task_type=task_type,
num_labels=num_labels)
adapters = args.adapters if key == 'ref' else args.reward_adapters
model = prepare_adapter(args, model, adapters)
if origin_key in {'ref', 'reward'}:
if self.args.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
if hasattr(model, 'model_meta'):
is_multimodal = model.model_meta.is_multimodal
else:
is_multimodal = model.model.model_meta.is_multimodal
sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal)
model.requires_grad_(False).eval()
else:
model = self.prepare_model(args, model, task_type=task_type)
logger.info(f'value_model: {model}')
model_parameter_info = get_model_parameter_info(model)
self.train_msg['value_model_parameter_info'] = model_parameter_info
logger.info(f'value_model_parameter_info: {model_parameter_info}')
HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
return model, processor
# Handle ref and value models
for key in ['ref', 'value']:
setattr(self, f'{key}_model', None)
if key == 'value' and args.rlhf_type != 'ppo':
continue
model_key = 'reward' if key == 'value' else key
result = prepare_single_model(model_key, key)
if result is not None:
model, _ = result
setattr(self, f'{key}_model', model)
# Handle reward model(s)
self.reward_model = None
if hasattr(args, 'reward_model') and args.reward_model is not None:
reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model]
self.reward_model = []
if args.rlhf_type == 'grpo':
self.reward_template = []
for reward_model_path in reward_models:
args.reward_model = reward_model_path # Temporarily set for prepare_single_model
result = prepare_single_model('reward')
if result is not None:
model, processor = result
self.reward_model.append(model)
if args.rlhf_type == 'grpo':
reward_template = self.args.get_template(processor, processor.model_meta.template)
if reward_template.use_model:
reward_template.model = model
self.reward_template.append(reward_template)
args.reward_model = reward_models # Restore original value
super()._prepare_model_tokenizer()
def _prepare_template(self) -> None:
args = self.args
super()._prepare_template()
model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'}
self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf'))
if args.rlhf_type == 'ppo':
args.training_args.stop_token_id = self.template.template_meta.stop_token_id
def _get_dataset(self):
args = self.args
train_dataset, val_dataset = super()._get_dataset()
if args.rlhf_type == 'kto':
train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset)
return train_dataset, val_dataset
def _get_trainer_kwargs(self):
trainer_kwargs = {}
for key in ['ref', 'reward', 'value']:
key = f'{key}_model'
model = getattr(self, key, None)
if model or self.args.rlhf_type == 'ppo':
trainer_kwargs[key] = model
if hasattr(self, 'reward_template'):
trainer_kwargs['reward_template'] = self.reward_template
if self.args.rlhf_type == 'grpo':
trainer_kwargs['reward_funcs'] = self.args.reward_funcs
trainer_kwargs['vllm_client'] = self.args.vllm_client
return trainer_kwargs
def rlhf_main(args: Union[List[str], RLHFArguments, None] = None):
return SwiftRLHF(args).main()