# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Optional, Union import torch.nn as nn from transformers import PreTrainedModel from trl import ORPOTrainer as HFORPOTrainer from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin del HFORPOTrainer.__init__ class ORPOTrainer(RLHFTrainerMixin, SwiftMixin, HFORPOTrainer): def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): ref_model = kwargs.get('ref_model') assert ref_model is None, 'ORPO does not require a ref_model.' super().__init__(model, *_args, **kwargs)