Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from peft import PeftModel
from transformers import PreTrainedModel
from trl import DPOTrainer as HFDPOTrainer
from ..mixin import DataLoaderMixin, SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
del HFDPOTrainer.__init__
class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer):
def __init__(self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
*_args,
**kwargs):
from trl.trainer import FDivergenceConstants
args = kwargs['args']
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.f_divergence_type = args.f_divergence_type
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
self.is_peft_model = isinstance(model, PeftModel)
self.ref_adapter_name = args.ref_adapter_name
self.reference_free = args.reference_free
self.use_weighting = False
super().__init__(model, ref_model, *_args, **kwargs)
def get_nll_loss(self, logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id)
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
return loss_fct(logits, labels)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch = batch.copy()
num_examples = batch['labels'].shape[0] // 2
labels = batch.pop('labels', None)
if self.is_encoder_decoder:
batch['labels'] = labels
if self.aux_loss_enabled:
batch['output_router_logits'] = True
outputs = model(**batch, use_cache=False)
batch['labels'] = labels
if outputs.logits.shape[1] != labels.shape[1]:
# for llava, the model returns logits for the entire sequence, including the image tokens
# (placed before the text tokens)
outputs.logits = outputs.logits[:, -labels.shape[1]:]
for key in ['input_ids', 'attention_mask', 'labels']:
batch[f'concatenated_{key}'] = batch.pop(key, None)
if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
batch['concatenated_input_ids'] = batch['concatenated_labels']
all_logits = outputs.logits
if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]:
# for llava, the model returns logits for the entire sequence,
# including the image tokens (placed before the text tokens)
seq_len = batch['concatenated_labels'].shape[1]
all_logits = all_logits[:, -seq_len:]
all_logps, size_completion = self.get_batch_logps(
all_logits,
batch['concatenated_labels'],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
output = {}
if self.args.rpo_alpha is not None:
labels = batch['concatenated_labels'].clone()
output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples])
if self.loss_type == 'ipo':
all_logps = all_logps / size_completion
output['chosen_logps'] = all_logps[:num_examples]
output['rejected_logps'] = all_logps[num_examples:]
output['mean_chosen_logits'] = all_logits[:num_examples].mean()
output['mean_rejected_logits'] = all_logits[num_examples:].mean()
if self.aux_loss_enabled:
output['aux_loss'] = outputs.aux_loss
return output
@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
if logits.shape[:-1] != labels.shape:
raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}'
'and labels must have the same shape {labels.shape}')
if not is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
else:
labels = labels.clone()
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)