|
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from peft import PeftModel |
|
|
from transformers import PreTrainedModel |
|
|
from trl import DPOTrainer as HFDPOTrainer |
|
|
|
|
|
from ..mixin import DataLoaderMixin, SwiftMixin |
|
|
from .rlhf_mixin import RLHFTrainerMixin |
|
|
|
|
|
del HFDPOTrainer.__init__ |
|
|
|
|
|
|
|
|
class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): |
|
|
|
|
|
def __init__(self, |
|
|
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
|
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
|
*_args, |
|
|
**kwargs): |
|
|
from trl.trainer import FDivergenceConstants |
|
|
args = kwargs['args'] |
|
|
self.label_smoothing = args.label_smoothing |
|
|
self.loss_type = args.loss_type |
|
|
self.precompute_ref_log_probs = args.precompute_ref_log_probs |
|
|
self.f_divergence_type = args.f_divergence_type |
|
|
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} |
|
|
self.is_peft_model = isinstance(model, PeftModel) |
|
|
|
|
|
self.ref_adapter_name = args.ref_adapter_name |
|
|
self.reference_free = args.reference_free |
|
|
self.use_weighting = False |
|
|
|
|
|
super().__init__(model, ref_model, *_args, **kwargs) |
|
|
|
|
|
def get_nll_loss(self, logits, labels): |
|
|
if not self.is_encoder_decoder: |
|
|
|
|
|
logits = logits[..., :-1, :].contiguous() |
|
|
labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id) |
|
|
logits = logits.view(-1, logits.shape[-1]) |
|
|
labels = labels.view(-1) |
|
|
|
|
|
labels = labels.to(logits.device) |
|
|
return loss_fct(logits, labels) |
|
|
|
|
|
def concatenated_forward( |
|
|
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] |
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: |
|
|
batch = batch.copy() |
|
|
num_examples = batch['labels'].shape[0] // 2 |
|
|
labels = batch.pop('labels', None) |
|
|
if self.is_encoder_decoder: |
|
|
batch['labels'] = labels |
|
|
|
|
|
if self.aux_loss_enabled: |
|
|
batch['output_router_logits'] = True |
|
|
outputs = model(**batch, use_cache=False) |
|
|
batch['labels'] = labels |
|
|
if outputs.logits.shape[1] != labels.shape[1]: |
|
|
|
|
|
|
|
|
outputs.logits = outputs.logits[:, -labels.shape[1]:] |
|
|
for key in ['input_ids', 'attention_mask', 'labels']: |
|
|
batch[f'concatenated_{key}'] = batch.pop(key, None) |
|
|
if self.__class__.__name__ == 'ORPOTrainer': |
|
|
batch['concatenated_input_ids'] = batch['concatenated_labels'] |
|
|
|
|
|
all_logits = outputs.logits |
|
|
|
|
|
if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]: |
|
|
|
|
|
|
|
|
seq_len = batch['concatenated_labels'].shape[1] |
|
|
all_logits = all_logits[:, -seq_len:] |
|
|
|
|
|
all_logps, size_completion = self.get_batch_logps( |
|
|
all_logits, |
|
|
batch['concatenated_labels'], |
|
|
is_encoder_decoder=self.is_encoder_decoder, |
|
|
label_pad_token_id=self.label_pad_token_id, |
|
|
) |
|
|
|
|
|
output = {} |
|
|
|
|
|
if self.args.rpo_alpha is not None: |
|
|
labels = batch['concatenated_labels'].clone() |
|
|
output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples]) |
|
|
|
|
|
if self.loss_type == 'ipo': |
|
|
all_logps = all_logps / size_completion |
|
|
|
|
|
output['chosen_logps'] = all_logps[:num_examples] |
|
|
output['rejected_logps'] = all_logps[num_examples:] |
|
|
output['mean_chosen_logits'] = all_logits[:num_examples].mean() |
|
|
output['mean_rejected_logits'] = all_logits[num_examples:].mean() |
|
|
|
|
|
if self.aux_loss_enabled: |
|
|
output['aux_loss'] = outputs.aux_loss |
|
|
|
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def get_batch_logps( |
|
|
logits: torch.FloatTensor, |
|
|
labels: torch.LongTensor, |
|
|
label_pad_token_id: int = -100, |
|
|
is_encoder_decoder: bool = False, |
|
|
) -> Tuple[torch.FloatTensor, torch.LongTensor]: |
|
|
if logits.shape[:-1] != labels.shape: |
|
|
raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}' |
|
|
'and labels must have the same shape {labels.shape}') |
|
|
if not is_encoder_decoder: |
|
|
labels = labels[:, 1:].clone() |
|
|
logits = logits[:, :-1, :] |
|
|
else: |
|
|
labels = labels.clone() |
|
|
|
|
|
loss_mask = labels != label_pad_token_id |
|
|
|
|
|
labels[labels == label_pad_token_id] = 0 |
|
|
|
|
|
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) |
|
|
|
|
|
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) |
|
|
|