| from typing import Any, Dict, List, Optional |
|
|
| from trl.trainer.utils import DPODataCollatorWithPadding |
|
|
|
|
| class ReplayDataCollator(DPODataCollatorWithPadding): |
| """Data collator that merges DPO and replay samples.""" |
|
|
| def __init__( |
| self, |
| tokenizer, |
| padding: bool | str = True, |
| max_length: Optional[int] = None, |
| pad_to_multiple_of: Optional[int] = None, |
| return_tensors: str = "pt", |
| label_pad_token_id=-100, |
| **kwargs, |
| ): |
| """ |
| Accept a tokenizer but initialise the base collator using the |
| parameters expected by the current TRL version. |
| """ |
|
|
| super().__init__( |
| pad_token_id=tokenizer.pad_token_id, |
| label_pad_token_id=label_pad_token_id, |
| is_encoder_decoder=False, |
| ) |
| |
| self.tokenizer = tokenizer |
| self.padding = padding |
| self.max_length = max_length |
| self.pad_to_multiple_of = pad_to_multiple_of |
| self.return_tensors = return_tensors |
|
|
| def _collate_replay(self, items: List[Dict[str, Any]]) -> Dict[str, Any]: |
| input_ids = [i["input_ids"] for i in items] |
| attention_mask = [i["attention_mask"] for i in items] |
| labels = [i["labels"] for i in items] |
| batch = self.tokenizer.pad( |
| {"input_ids": input_ids, "attention_mask": attention_mask}, |
| padding=self.padding, |
| max_length=self.max_length, |
| pad_to_multiple_of=self.pad_to_multiple_of, |
| return_tensors=self.return_tensors, |
| ) |
| label_batch = self.tokenizer.pad( |
| {"input_ids": labels}, |
| padding=self.padding, |
| max_length=self.max_length, |
| pad_to_multiple_of=self.pad_to_multiple_of, |
| return_tensors=self.return_tensors, |
| ) |
| batch["labels"] = label_batch["input_ids"] |
| return batch |
|
|
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
| originals: List[Dict[str, Any]] = [] |
| replays: List[Dict[str, Any]] = [] |
| for feat in features: |
| orig = feat.get("original") |
| rep = feat.get("replay") |
| assert orig is not None |
| originals.append(orig) |
| if rep is not None: |
| replays.append(rep) |
|
|
| batch: Dict[str, Optional[Dict[str, Any]]] = {} |
| |
| batch["original"] = super().__call__(originals) |
| |
| |
| batch["replay"] = self._collate_replay(replays) if replays else None |
| return batch |
|
|
|
|