Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Any, Dict, Sequence | |
| import torch | |
| from transformers import DataCollatorWithPadding | |
| class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): | |
| r""" | |
| Data collator for pairwise data. | |
| """ | |
| def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
| r""" | |
| Pads batched data to the longest sequence in the batch. | |
| We generate 2 * n examples where the first n examples represent chosen examples and | |
| the last n examples represent rejected examples. | |
| """ | |
| features = [ | |
| { | |
| "input_ids": feature["prompt_ids"] + feature[key], | |
| "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])), | |
| } | |
| for key in ("chosen_ids", "rejected_ids") | |
| for feature in features | |
| ] | |
| return super().__call__(features) | |