| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| from typing import Optional, Union | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from trl import ORPOTrainer as HFORPOTrainer | |
| from ..mixin import SwiftMixin | |
| from .rlhf_mixin import RLHFTrainerMixin | |
| del HFORPOTrainer.__init__ | |
| class ORPOTrainer(RLHFTrainerMixin, SwiftMixin, HFORPOTrainer): | |
| def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): | |
| ref_model = kwargs.get('ref_model') | |
| assert ref_model is None, 'ORPO does not require a ref_model.' | |
| super().__init__(model, *_args, **kwargs) | |