File size: 633 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Optional, Union
import torch.nn as nn
from transformers import PreTrainedModel
from trl import ORPOTrainer as HFORPOTrainer
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
del HFORPOTrainer.__init__
class ORPOTrainer(RLHFTrainerMixin, SwiftMixin, HFORPOTrainer):
def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs):
ref_model = kwargs.get('ref_model')
assert ref_model is None, 'ORPO does not require a ref_model.'
super().__init__(model, *_args, **kwargs)
|