| from typing import Any, Dict, List, Optional |
|
|
|
|
| class ReplaySFTDataCollator: |
| """Pad SFT batches for original and replay samples from ``ReplayDataset``.""" |
|
|
| def __init__( |
| self, |
| tokenizer, |
| padding: bool | str = True, |
| max_length: Optional[int] = None, |
| pad_to_multiple_of: Optional[int] = None, |
| return_tensors: str = "pt", |
| ) -> None: |
| self.tokenizer = tokenizer |
| self.padding = padding |
| self.max_length = max_length |
| self.pad_to_multiple_of = pad_to_multiple_of |
| self.return_tensors = return_tensors |
|
|
| def _collate(self, items: List[Dict[str, Any]]) -> Dict[str, Any]: |
| input_ids = [i["input_ids"] for i in items] |
| attention_mask = [i["attention_mask"] for i in items] |
| labels = [i["labels"] for i in items] |
| batch = self.tokenizer.pad( |
| {"input_ids": input_ids, "attention_mask": attention_mask}, |
| padding=self.padding, |
| max_length=self.max_length, |
| pad_to_multiple_of=self.pad_to_multiple_of, |
| return_tensors=self.return_tensors, |
| ) |
| label_batch = self.tokenizer.pad( |
| {"input_ids": labels}, |
| padding=self.padding, |
| max_length=self.max_length, |
| pad_to_multiple_of=self.pad_to_multiple_of, |
| return_tensors=self.return_tensors, |
| ) |
| batch["labels"] = label_batch["input_ids"] |
| return batch |
|
|
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
| originals: List[Dict[str, Any]] = [] |
| replays: List[Dict[str, Any]] = [] |
| |
| for feat in features: |
| orig = feat.get("original") |
| rep = feat.get("replay") |
| assert orig is not None |
| originals.append(orig) |
| |
| if rep is not None: |
| replays.append(rep) |
| batch: Dict[str, Optional[Dict[str, Any]]] = {} |
| batch["original"] = self._collate(originals) |
| batch["replay"] = self._collate(replays) if replays else None |
| |
| return batch |
|
|
|
|