| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | from torch.utils.data import RandomSampler, SequentialSampler |
| | from torchdata.stateful_dataloader import StatefulDataLoader |
| | from transformers import PreTrainedTokenizer, ProcessorMixin |
| |
|
| | from ..utils.dataset import RLHFDataset, collate_fn |
| | from .config import DataConfig |
| |
|
| |
|
| | def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin]) -> None: |
| | train_dataset = RLHFDataset( |
| | data_path=config.train_files, |
| | tokenizer=tokenizer, |
| | processor=processor, |
| | prompt_key=config.prompt_key, |
| | answer_key=config.answer_key, |
| | image_key=config.image_key, |
| | image_dir=config.image_dir, |
| | max_prompt_length=config.max_prompt_length, |
| | truncation="right", |
| | format_prompt=config.format_prompt, |
| | min_pixels=config.min_pixels, |
| | max_pixels=config.max_pixels, |
| | filter_overlong_prompts=config.filter_overlong_prompts, |
| | ) |
| | |
| | if config.shuffle: |
| | train_dataloader_generator = torch.Generator() |
| | train_dataloader_generator.manual_seed(config.seed) |
| | sampler = RandomSampler(data_source=train_dataset, generator=train_dataloader_generator) |
| | else: |
| | sampler = SequentialSampler(data_source=train_dataset) |
| |
|
| | train_dataloader = StatefulDataLoader( |
| | dataset=train_dataset, |
| | batch_size=config.rollout_batch_size, |
| | sampler=sampler, |
| | num_workers=8, |
| | collate_fn=collate_fn, |
| | pin_memory=False, |
| | drop_last=True, |
| | ) |
| |
|
| | val_dataset = RLHFDataset( |
| | data_path=config.val_files, |
| | tokenizer=tokenizer, |
| | processor=processor, |
| | prompt_key=config.prompt_key, |
| | answer_key=config.answer_key, |
| | image_key=config.image_key, |
| | image_dir=config.image_dir, |
| | max_prompt_length=config.max_prompt_length, |
| | truncation="right", |
| | format_prompt=config.format_prompt, |
| | min_pixels=config.min_pixels, |
| | max_pixels=config.max_pixels, |
| | filter_overlong_prompts=config.filter_overlong_prompts, |
| | ) |
| | val_dataloader = StatefulDataLoader( |
| | dataset=val_dataset, |
| | batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size, |
| | shuffle=False, |
| | num_workers=8, |
| | collate_fn=collate_fn, |
| | pin_memory=False, |
| | drop_last=False, |
| | ) |
| |
|
| | assert len(train_dataloader) >= 1 |
| | assert len(val_dataloader) >= 1 |
| | print(f"Size of train dataloader: {len(train_dataloader)}") |
| | print(f"Size of val dataloader: {len(val_dataloader)}") |
| | return train_dataloader, val_dataloader |
| |
|