import json from typing import Any, Dict import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizer class SFTDataset(Dataset): def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer, template, max_length: int, ): self.data = self.load_sft_data(data_path) self.tokenizer = tokenizer self.max_length = max_length self.template = template def load_sft_data(self, file_path): with open(file_path, "r") as f: data = json.load(f) return data def __len__(self): return len(self.data) def __getitem__(self, idx: int) -> Dict[str, Any]: item = self.data[idx] rendered_text = self.template.render( messages=[ {"role": "user", "content": item["input"]}, {"role": "assistant", "content": item["output"]} ], add_generation_prompt=False ) tokens = self.tokenizer( rendered_text, max_length=self.max_length, truncation=True, return_tensors="pt" ) input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] instruction_text = self.template.render( messages=[{"role": "user", "content": item["input"]}], add_generation_prompt=True, # important ) instruction_tokens = self.tokenizer( instruction_text, max_length=self.max_length, truncation=True, return_tensors="pt" ) labels = input_ids.clone() instruction_length = instruction_tokens["input_ids"].size(1) labels[:, :instruction_length] = -100 return { "input_ids": input_ids.squeeze(), "attention_mask": attention_mask.squeeze(), "labels": labels.squeeze(), } def collate_fn(self, batch): input_ids = torch.nn.utils.rnn.pad_sequence( [item["input_ids"] for item in batch], batch_first=True, padding_value=self.tokenizer.pad_token_id ) attention_masks = torch.nn.utils.rnn.pad_sequence( [item["attention_mask"] for item in batch], batch_first=True, padding_value=0 ) labels = torch.nn.utils.rnn.pad_sequence( [item["labels"] for item in batch], batch_first=True, padding_value=-100 ) return { "input_ids": input_ids, "labels": labels, "attention_mask": attention_masks, } class RMDataset(Dataset): def __init__(self, reward_data_path, tokenizer, template, max_length=512): self.data = self.load_reward_data(reward_data_path) self.tokenizer = tokenizer self.template = template self.max_length = max_length def load_reward_data(self, file_path): with open(file_path, "r") as f: data = json.load(f) return data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] reward_value = float(item["value"]) rendered_text = self.template.render( messages=[ {"role": "user", "content": item["input"]}, {"role": "assistant", "content": item["output"]} ], add_generation_prompt=False ).strip() # important tokenized_input = self.tokenizer( rendered_text, return_tensors="pt", max_length=self.max_length, truncation=True ) # make sure there is no \n at the end of the inputs assert tokenized_input['input_ids'][0][-1] == self.tokenizer.eos_token_id return { "input_ids": tokenized_input["input_ids"].squeeze(), "attention_mask": tokenized_input["attention_mask"].squeeze(), "labels": torch.tensor(reward_value) } def collate_fn(self, batch): input_ids = torch.nn.utils.rnn.pad_sequence( [item["input_ids"] for item in batch], batch_first=True, padding_value=self.tokenizer.pad_token_id ) attention_masks = torch.nn.utils.rnn.pad_sequence( [item["attention_mask"] for item in batch], batch_first=True, padding_value=0 ) labels = torch.stack([item["labels"] for item in batch]) return { "input_ids": input_ids, "labels": labels, "attention_mask": attention_masks, } class PPODataset(Dataset): def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer, template, max_length): self.data = self.load_sft_data(data_path) self.tokenizer = tokenizer self.max_length = max_length self.template = template def load_sft_data(self, file_path): with open(file_path, "r") as f: data = json.load(f) return data def __len__(self): return len(self.data) def __getitem__(self, idx: int) -> Dict[str, Any]: item = self.data[idx] instruction_text = self.template.render( messages=[{"role": "user", "content": item["input"]}], add_generation_prompt=True, # important ) instruction_tokens = self.tokenizer( instruction_text, max_length=self.max_length, truncation=True, return_tensors="pt" ) input_ids = instruction_tokens["input_ids"] return {"input_ids": input_ids.squeeze()} def collate_fn(self, batch): input_ids = torch.nn.utils.rnn.pad_sequence( [item["input_ids"] for item in batch], batch_first=True, padding_value=self.tokenizer.pad_token_id ) return {"input_ids": input_ids} class GRPODataset(Dataset): def __init__(self, data_path: str, tokenizer, template, max_length: int): self.data = self.load_sft_data(data_path) self.tokenizer = tokenizer self.max_length = max_length self.template = template def load_sft_data(self, file_path: str): with open(file_path, "r") as f: return json.load(f) def __len__(self): return len(self.data) def __getitem__(self, idx: int) -> Dict[str, Any]: item = self.data[idx] rendered_prompt = self.template.render( messages=[{"role": "user", "content": item["input"]}], add_generation_prompt=True ) return { "prompt": rendered_prompt, "completion": item["output"] }