| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| test create_rl_sampler |
| """ |
|
|
| from collections.abc import Sized |
|
|
| import pytest |
| import torch |
| from omegaconf import DictConfig, OmegaConf |
| from torch.utils.data import Dataset, RandomSampler |
|
|
| from verl.experimental.dataset.sampler import AbstractCurriculumSampler |
| from verl.trainer.main_ppo import create_rl_sampler |
|
|
|
|
| class RandomCurriculumSampler(AbstractCurriculumSampler): |
| def __init__( |
| self, |
| data_source: Sized, |
| data_config: DictConfig, |
| ): |
| train_dataloader_generator = torch.Generator() |
| train_dataloader_generator.manual_seed(1) |
| sampler = RandomSampler(data_source=data_source) |
| self.sampler = sampler |
|
|
| def __iter__(self): |
| return self.sampler.__iter__() |
|
|
| def __len__(self) -> int: |
| return len(self.sampler) |
|
|
| def update(self, batch) -> None: |
| return |
|
|
|
|
| class MockIncorrectSampler: |
| """A fake sampler class that does not adhere to the AbstractCurriculumSampler interface.""" |
|
|
| def __init__(self, data_source, data_config): |
| pass |
|
|
|
|
| class MockChatDataset(Dataset): |
| def __init__(self): |
| self.data = [ |
| {"prompt": "What's your name?", "response": "My name is Assistant."}, |
| {"prompt": "How are you?", "response": "I'm doing well, thank you."}, |
| {"prompt": "What is the capital of France?", "response": "Paris."}, |
| { |
| "prompt": "Tell me a joke.", |
| "response": "Why did the chicken cross the road? To get to the other side!", |
| }, |
| {"prompt": "What is 2+2?", "response": "4"}, |
| ] |
|
|
| def __getitem__(self, index): |
| return self.data[index] |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
|
|
| def test_create_custom_curriculum_samper(): |
| data_config = OmegaConf.create( |
| { |
| "dataloader_num_workers": 0, |
| "sampler": { |
| "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", |
| "class_name": "RandomCurriculumSampler", |
| }, |
| } |
| ) |
|
|
| dataset = MockChatDataset() |
|
|
| |
| create_rl_sampler(data_config, dataset) |
|
|
|
|
| def test_create_custom_curriculum_samper_wrong_class(): |
| data_config = OmegaConf.create( |
| { |
| "sampler": { |
| "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", |
| "class_name": "MockIncorrectSampler", |
| } |
| } |
| ) |
|
|
| dataset = MockChatDataset() |
|
|
| |
| with pytest.raises(AssertionError): |
| create_rl_sampler(data_config, dataset) |
|
|