webshop-hsl-seed2026 / hsl_code_snapshot /replay_sft_data_collator.py
heendung's picture
Upload folder using huggingface_hub
753b40f verified
from typing import Any, Dict, List, Optional
class ReplaySFTDataCollator:
"""Pad SFT batches for original and replay samples from ``ReplayDataset``."""
def __init__(
self,
tokenizer,
padding: bool | str = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_tensors: str = "pt",
) -> None:
self.tokenizer = tokenizer
self.padding = padding
self.max_length = max_length
self.pad_to_multiple_of = pad_to_multiple_of
self.return_tensors = return_tensors
def _collate(self, items: List[Dict[str, Any]]) -> Dict[str, Any]:
input_ids = [i["input_ids"] for i in items]
attention_mask = [i["attention_mask"] for i in items]
labels = [i["labels"] for i in items]
batch = self.tokenizer.pad(
{"input_ids": input_ids, "attention_mask": attention_mask},
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
label_batch = self.tokenizer.pad(
{"input_ids": labels},
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = label_batch["input_ids"]
return batch
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
originals: List[Dict[str, Any]] = []
replays: List[Dict[str, Any]] = []
#print("features", features)
for feat in features:
orig = feat.get("original")
rep = feat.get("replay")
assert orig is not None
originals.append(orig)
#print("rep", rep)
if rep is not None:
replays.append(rep)
batch: Dict[str, Optional[Dict[str, Any]]] = {}
batch["original"] = self._collate(originals)
batch["replay"] = self._collate(replays) if replays else None
#print("batch_replay", batch["replay"])
return batch