File size: 5,268 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from peft import PeftModel
from transformers import PreTrainedModel
from trl import DPOTrainer as HFDPOTrainer
from ..mixin import DataLoaderMixin, SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
del HFDPOTrainer.__init__
class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer):
def __init__(self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
*_args,
**kwargs):
from trl.trainer import FDivergenceConstants
args = kwargs['args']
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.f_divergence_type = args.f_divergence_type
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
self.is_peft_model = isinstance(model, PeftModel)
self.ref_adapter_name = args.ref_adapter_name
self.reference_free = args.reference_free
self.use_weighting = False
super().__init__(model, ref_model, *_args, **kwargs)
def get_nll_loss(self, logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id)
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
return loss_fct(logits, labels)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch = batch.copy()
num_examples = batch['labels'].shape[0] // 2
labels = batch.pop('labels', None)
if self.is_encoder_decoder:
batch['labels'] = labels
if self.aux_loss_enabled:
batch['output_router_logits'] = True
outputs = model(**batch, use_cache=False)
batch['labels'] = labels
if outputs.logits.shape[1] != labels.shape[1]:
# for llava, the model returns logits for the entire sequence, including the image tokens
# (placed before the text tokens)
outputs.logits = outputs.logits[:, -labels.shape[1]:]
for key in ['input_ids', 'attention_mask', 'labels']:
batch[f'concatenated_{key}'] = batch.pop(key, None)
if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
batch['concatenated_input_ids'] = batch['concatenated_labels']
all_logits = outputs.logits
if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]:
# for llava, the model returns logits for the entire sequence,
# including the image tokens (placed before the text tokens)
seq_len = batch['concatenated_labels'].shape[1]
all_logits = all_logits[:, -seq_len:]
all_logps, size_completion = self.get_batch_logps(
all_logits,
batch['concatenated_labels'],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
output = {}
if self.args.rpo_alpha is not None:
labels = batch['concatenated_labels'].clone()
output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples])
if self.loss_type == 'ipo':
all_logps = all_logps / size_completion
output['chosen_logps'] = all_logps[:num_examples]
output['rejected_logps'] = all_logps[num_examples:]
output['mean_chosen_logits'] = all_logits[:num_examples].mean()
output['mean_rejected_logits'] = all_logits[num_examples:].mean()
if self.aux_loss_enabled:
output['aux_loss'] = outputs.aux_loss
return output
@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
if logits.shape[:-1] != labels.shape:
raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}'
'and labels must have the same shape {labels.shape}')
if not is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
else:
labels = labels.clone()
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|