webshop-hsl-seed123 / hsl_code_snapshot /replay_data_collator.py
heendung's picture
Upload folder using huggingface_hub
d1c897a verified
from typing import Any, Dict, List, Optional
from trl.trainer.utils import DPODataCollatorWithPadding
class ReplayDataCollator(DPODataCollatorWithPadding):
"""Data collator that merges DPO and replay samples."""
def __init__(
self,
tokenizer,
padding: bool | str = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_tensors: str = "pt",
label_pad_token_id=-100,
**kwargs,
):
"""
Accept a tokenizer but initialise the base collator using the
parameters expected by the current TRL version.
"""
super().__init__(
pad_token_id=tokenizer.pad_token_id,
label_pad_token_id=label_pad_token_id,
is_encoder_decoder=False,
)
# store parameters used in our own replay padding
self.tokenizer = tokenizer
self.padding = padding
self.max_length = max_length
self.pad_to_multiple_of = pad_to_multiple_of
self.return_tensors = return_tensors
def _collate_replay(self, items: List[Dict[str, Any]]) -> Dict[str, Any]:
input_ids = [i["input_ids"] for i in items]
attention_mask = [i["attention_mask"] for i in items]
labels = [i["labels"] for i in items]
batch = self.tokenizer.pad(
{"input_ids": input_ids, "attention_mask": attention_mask},
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
label_batch = self.tokenizer.pad(
{"input_ids": labels},
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = label_batch["input_ids"]
return batch
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
originals: List[Dict[str, Any]] = []
replays: List[Dict[str, Any]] = []
for feat in features:
orig = feat.get("original")
rep = feat.get("replay")
assert orig is not None
originals.append(orig)
if rep is not None:
replays.append(rep)
batch: Dict[str, Optional[Dict[str, Any]]] = {}
#print("originals", originals)
batch["original"] = super().__call__(originals)
#print("batch_originals", batch["original"])
#print("replays", replays)
batch["replay"] = self._collate_replay(replays) if replays else None
return batch