nas / EasyR1 /verl /trainer /data_loader.py
yuccaaa's picture
Add files using upload-large-folder tool
5ea4862 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..utils.dataset import RLHFDataset, collate_fn
from .config import DataConfig
def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin]) -> None:
train_dataset = RLHFDataset(
data_path=config.train_files,
tokenizer=tokenizer,
processor=processor,
prompt_key=config.prompt_key,
answer_key=config.answer_key,
image_key=config.image_key,
image_dir=config.image_dir,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
min_pixels=config.min_pixels,
max_pixels=config.max_pixels,
filter_overlong_prompts=config.filter_overlong_prompts,
)
# use sampler for better ckpt resume
if config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(config.seed)
sampler = RandomSampler(data_source=train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=train_dataset)
train_dataloader = StatefulDataLoader(
dataset=train_dataset,
batch_size=config.rollout_batch_size,
sampler=sampler,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=True,
)
val_dataset = RLHFDataset(
data_path=config.val_files,
tokenizer=tokenizer,
processor=processor,
prompt_key=config.prompt_key,
answer_key=config.answer_key,
image_key=config.image_key,
image_dir=config.image_dir,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
min_pixels=config.min_pixels,
max_pixels=config.max_pixels,
filter_overlong_prompts=config.filter_overlong_prompts,
)
val_dataloader = StatefulDataLoader(
dataset=val_dataset,
batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
)
assert len(train_dataloader) >= 1
assert len(val_dataloader) >= 1
print(f"Size of train dataloader: {len(train_dataloader)}")
print(f"Size of val dataloader: {len(val_dataloader)}")
return train_dataloader, val_dataloader