|
|
|
|
|
import inspect |
|
|
from contextlib import contextmanager |
|
|
|
|
|
import transformers |
|
|
from packaging import version |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import PreTrainedModel |
|
|
from trl import PPOTrainer as HFPPOTrainer |
|
|
|
|
|
from swift.utils import patch_getattr |
|
|
from ..mixin import SwiftMixin |
|
|
|
|
|
ppo_trainer_init = HFPPOTrainer.__init__ |
|
|
del HFPPOTrainer.__init__ |
|
|
|
|
|
|
|
|
class PPOTrainer(SwiftMixin, HFPPOTrainer): |
|
|
|
|
|
@staticmethod |
|
|
@contextmanager |
|
|
def _patch_dataloader(collate_fn): |
|
|
__init__ = DataLoader.__init__ |
|
|
|
|
|
def __new_init__(self, *args, **kwargs): |
|
|
kwargs['collate_fn'] = collate_fn |
|
|
__init__(self, *args, **kwargs) |
|
|
|
|
|
DataLoader.__init__ = __new_init__ |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
DataLoader.__init__ = __init__ |
|
|
|
|
|
def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, **kwargs): |
|
|
super().__init__(model, *_args, **{k: v for k, v in kwargs.items() if k not in {'reward_model', 'value_model'}}) |
|
|
with self._patch_dataloader(kwargs['data_collator']): |
|
|
new_kwargs = { |
|
|
k: v |
|
|
for k, v in kwargs.items() |
|
|
if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset'] |
|
|
} |
|
|
parameters = inspect.signature(ppo_trainer_init).parameters |
|
|
if 'config' in parameters: |
|
|
new_kwargs['config'] = kwargs['args'] |
|
|
else: |
|
|
new_kwargs['args'] = kwargs['args'] |
|
|
if 'processing_class' in parameters: |
|
|
new_kwargs['processing_class'] = self.tokenizer |
|
|
else: |
|
|
new_kwargs['tokenizer'] = self.tokenizer |
|
|
ppo_trainer_init(self, model=model, ref_model=ref_model, **new_kwargs) |
|
|
unwrap_model = self.accelerator.unwrap_model(self.model) |
|
|
patch_getattr(unwrap_model.__class__, 'policy') |
|
|
|
|
|
def train(self, *args, **kwargs): |
|
|
|
|
|
super().train() |
|
|
|
|
|
def _save_checkpoint(self, *args, **kwargs): |
|
|
if version.parse(transformers.__version__) >= version.parse('4.47'): |
|
|
metrics = kwargs.pop('metrics', None) |
|
|
trial = kwargs.get('trial') |
|
|
self._determine_best_metric(metrics=metrics, trial=trial) |
|
|
return super()._save_checkpoint(*args, **kwargs) |
|
|
|