Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
from contextlib import contextmanager
import transformers
from packaging import version
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from trl import PPOTrainer as HFPPOTrainer
from swift.utils import patch_getattr
from ..mixin import SwiftMixin
ppo_trainer_init = HFPPOTrainer.__init__
del HFPPOTrainer.__init__
class PPOTrainer(SwiftMixin, HFPPOTrainer):
@staticmethod
@contextmanager
def _patch_dataloader(collate_fn):
__init__ = DataLoader.__init__
def __new_init__(self, *args, **kwargs):
kwargs['collate_fn'] = collate_fn
__init__(self, *args, **kwargs)
DataLoader.__init__ = __new_init__
try:
yield
finally:
DataLoader.__init__ = __init__
def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, **kwargs):
super().__init__(model, *_args, **{k: v for k, v in kwargs.items() if k not in {'reward_model', 'value_model'}})
with self._patch_dataloader(kwargs['data_collator']):
new_kwargs = {
k: v
for k, v in kwargs.items()
if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset']
}
parameters = inspect.signature(ppo_trainer_init).parameters
if 'config' in parameters:
new_kwargs['config'] = kwargs['args']
else:
new_kwargs['args'] = kwargs['args']
if 'processing_class' in parameters:
new_kwargs['processing_class'] = self.tokenizer
else:
new_kwargs['tokenizer'] = self.tokenizer
ppo_trainer_init(self, model=model, ref_model=ref_model, **new_kwargs)
unwrap_model = self.accelerator.unwrap_model(self.model)
patch_getattr(unwrap_model.__class__, 'policy')
def train(self, *args, **kwargs):
# remove args that are not needed for the HFPPOTrainer
super().train()
def _save_checkpoint(self, *args, **kwargs):
if version.parse(transformers.__version__) >= version.parse('4.47'):
metrics = kwargs.pop('metrics', None)
trial = kwargs.get('trial')
self._determine_best_metric(metrics=metrics, trial=trial)
return super()._save_checkpoint(*args, **kwargs)