|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, Sequence |
|
|
|
|
|
import torch |
|
|
from transformers import DataCollatorForSeq2Seq |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): |
|
|
r""" |
|
|
Data collator for pairwise data. |
|
|
""" |
|
|
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
r""" |
|
|
Pads batched data to the longest sequence in the batch. |
|
|
|
|
|
We generate 2 * n examples where the first n examples represent chosen examples and |
|
|
the last n examples represent rejected examples. |
|
|
""" |
|
|
concatenated_features = [] |
|
|
for key in ("chosen", "rejected"): |
|
|
for feature in features: |
|
|
target_feature = { |
|
|
"input_ids": feature["{}_input_ids".format(key)], |
|
|
"attention_mask": feature["{}_attention_mask".format(key)], |
|
|
"labels": feature["{}_labels".format(key)], |
|
|
} |
|
|
if "pixel_values" in feature: |
|
|
target_feature["pixel_values"] = feature["pixel_values"] |
|
|
|
|
|
if "{}_token_type_ids".format(key) in feature: |
|
|
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] |
|
|
|
|
|
concatenated_features.append(target_feature) |
|
|
|
|
|
return super().__call__(concatenated_features) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): |
|
|
r""" |
|
|
Data collator for KTO data. |
|
|
""" |
|
|
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
target_features = [] |
|
|
kl_features = [] |
|
|
kto_tags = [] |
|
|
for feature in features: |
|
|
target_feature = { |
|
|
"input_ids": feature["input_ids"], |
|
|
"attention_mask": feature["attention_mask"], |
|
|
"labels": feature["labels"], |
|
|
} |
|
|
kl_feature = { |
|
|
"input_ids": feature["kl_input_ids"], |
|
|
"attention_mask": feature["kl_attention_mask"], |
|
|
"labels": feature["kl_labels"], |
|
|
} |
|
|
if "pixel_values" in feature: |
|
|
target_feature["pixel_values"] = feature["pixel_values"] |
|
|
|
|
|
if "token_type_ids" in feature: |
|
|
target_feature["token_type_ids"] = feature["token_type_ids"] |
|
|
kl_feature["token_type_ids"] = feature["kl_token_type_ids"] |
|
|
|
|
|
target_features.append(target_feature) |
|
|
kl_features.append(kl_feature) |
|
|
kto_tags.append(feature["kto_tags"]) |
|
|
|
|
|
batch = super().__call__(target_features) |
|
|
kl_batch = super().__call__(kl_features) |
|
|
batch["kl_input_ids"] = kl_batch["input_ids"] |
|
|
batch["kl_attention_mask"] = kl_batch["attention_mask"] |
|
|
batch["kl_labels"] = kl_batch["labels"] |
|
|
if "token_type_ids" in batch: |
|
|
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] |
|
|
|
|
|
batch["kto_tags"] = torch.tensor(kto_tags) |
|
|
return batch |
|
|
|