File size: 2,440 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
from contextlib import contextmanager
import transformers
from packaging import version
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from trl import PPOTrainer as HFPPOTrainer
from swift.utils import patch_getattr
from ..mixin import SwiftMixin
ppo_trainer_init = HFPPOTrainer.__init__
del HFPPOTrainer.__init__
class PPOTrainer(SwiftMixin, HFPPOTrainer):
@staticmethod
@contextmanager
def _patch_dataloader(collate_fn):
__init__ = DataLoader.__init__
def __new_init__(self, *args, **kwargs):
kwargs['collate_fn'] = collate_fn
__init__(self, *args, **kwargs)
DataLoader.__init__ = __new_init__
try:
yield
finally:
DataLoader.__init__ = __init__
def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, **kwargs):
super().__init__(model, *_args, **{k: v for k, v in kwargs.items() if k not in {'reward_model', 'value_model'}})
with self._patch_dataloader(kwargs['data_collator']):
new_kwargs = {
k: v
for k, v in kwargs.items()
if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset']
}
parameters = inspect.signature(ppo_trainer_init).parameters
if 'config' in parameters:
new_kwargs['config'] = kwargs['args']
else:
new_kwargs['args'] = kwargs['args']
if 'processing_class' in parameters:
new_kwargs['processing_class'] = self.tokenizer
else:
new_kwargs['tokenizer'] = self.tokenizer
ppo_trainer_init(self, model=model, ref_model=ref_model, **new_kwargs)
unwrap_model = self.accelerator.unwrap_model(self.model)
patch_getattr(unwrap_model.__class__, 'policy')
def train(self, *args, **kwargs):
# remove args that are not needed for the HFPPOTrainer
super().train()
def _save_checkpoint(self, *args, **kwargs):
if version.parse(transformers.__version__) >= version.parse('4.47'):
metrics = kwargs.pop('metrics', None)
trial = kwargs.get('trial')
self._determine_best_metric(metrics=metrics, trial=trial)
return super()._save_checkpoint(*args, **kwargs)
|