File size: 6,188 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SFT dataset
- We assume user pass a single parquet file.
- We load all the data into the memory.
Each parquet file contains
"""
from typing import List, Union
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.model import compute_position_id_with_mask
class SFTDataset(Dataset):
"""
This is an in-memory SFTDataset
Arguments:
config (OmegaConf): the data config
"""
def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config):
prompt_key = config.get("prompt_key", "prompt")
response_key = config.get("response_key", "response")
prompt_dict_keys = config.get('prompt_dict_keys', None)
response_dict_keys = config.get('response_dict_keys', None)
max_length = config.get("max_length", 1024)
truncation = config.get("truncation", "error")
assert truncation in ["error", "left", "right"]
self.truncation = truncation
if not isinstance(parquet_files, List):
parquet_files = [parquet_files]
self.parquet_files = parquet_files
if isinstance(tokenizer, str):
tokenizer = hf_tokenizer(tokenizer)
self.tokenizer: PreTrainedTokenizer = tokenizer
self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key]
self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key]
self.prompt_dict_keys = [] if not prompt_dict_keys else prompt_dict_keys
self.response_dict_keys = [] if not response_dict_keys else response_dict_keys
self.max_length = max_length
self._download()
self._read_files_and_tokenize()
def _download(self):
for i, parquet_file in enumerate(self.parquet_files):
self.parquet_files[i] = copy_to_local(parquet_file, verbose=True)
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.parquet_files:
# read parquet files and cache
dataframe = pd.read_parquet(parquet_file)
dataframes.append(dataframe)
self.dataframe = pd.concat(dataframes)
self.prompts = self.dataframe[self.prompt_key]
self.responses = self.dataframe[self.response_key]
def __len__(self):
return len(self.prompts)
def __getitem__(self, item):
tokenizer = self.tokenizer
prompt_chat = self.prompts.iloc[item].item()
response = self.responses.iloc[item].item()
# string
prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False)
response_chat_str = response + tokenizer.eos_token
# return
# tokenize
prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False)
prompt_ids = prompt_ids_output["input_ids"][0]
prompt_attention_mask = prompt_ids_output["attention_mask"][0]
response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False)
response_ids = response_ids_output["input_ids"][0]
response_attention_mask = response_ids_output["attention_mask"][0]
prompt_length = prompt_ids.shape[0]
response_length = response_ids.shape[0]
input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
# padding to max length
sequence_length = input_ids.shape[0]
if sequence_length < self.max_length:
padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * self.tokenizer.pad_token_id
padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype)
input_ids = torch.cat((input_ids, padded_input_ids))
attention_mask = torch.cat((attention_mask, padded_attention_mask))
elif sequence_length > self.max_length:
if self.truncation == "left":
# actually, left truncation may not be reasonable
input_ids = input_ids[-self.max_length :]
attention_mask = attention_mask[-self.max_length :]
elif self.truncation == "right":
input_ids = input_ids[: self.max_length]
attention_mask = attention_mask[: self.max_length]
elif self.truncation == "error":
raise NotImplementedError(f"{sequence_length=} is larger than {self.max_length=}")
else:
raise NotImplementedError(f"Unknown truncation method {self.truncation}")
position_ids = compute_position_id_with_mask(attention_mask)
loss_mask = attention_mask.clone()
if prompt_length > 1:
# mask out prompt for SFT.
loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0
# mask out the last token in response
loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"loss_mask": loss_mask,
}
|