| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import pytest |
|
|
| from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards |
| from verl.utils.megatron.pipeline_parallel import make_batch_generator |
|
|
|
|
| def test_make_batch_generator_no_vpp(): |
| batches = [1, 2, 3] |
| vpp_size = 1 |
| generator = make_batch_generator(batches, vpp_size) |
| assert list(generator) == batches |
|
|
|
|
| def test_make_batch_generator_with_vpp(): |
| batches = [{"data": 1}, {"data": 2}] |
| vpp_size = 2 |
| generators = make_batch_generator(batches, vpp_size) |
| assert isinstance(generators, list) |
| assert len(generators) == vpp_size |
|
|
| |
| for gen in generators: |
| assert list(gen) == batches |
|
|
|
|
| def test_make_batch_generator_empty(): |
| batches = [] |
| vpp_size = 1 |
| generator = make_batch_generator(batches, vpp_size) |
| assert list(generator) == [] |
|
|
| vpp_size = 3 |
| generators = make_batch_generator(batches, vpp_size) |
| assert len(generators) == vpp_size |
| for gen in generators: |
| assert list(gen) == [] |
|
|
|
|
| @pytest.mark.parametrize( |
| "layer_num,pp_size,gt", |
| [ |
| (61, 8, [6, 8, 8, 8, 8, 8, 8, 7]), |
| (61, 7, [8, 9, 9, 9, 9, 9, 8]), |
| (61, 1, [61]), |
| (61, 0, ValueError), |
| (10, 16, ValueError), |
| ], |
| ) |
| def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt): |
| if isinstance(gt, list): |
| shards = get_dynamic_pipeline_shards(layer_num, pp_size) |
| assert len(shards) == len(gt) == pp_size, f"Expected {pp_size} shards, got {len(shards)}" |
| assert all([shard == gt[i] for i, shard in enumerate(shards)]), f"Expected shards {gt}, got {shards}" |
| elif issubclass(gt, Exception): |
| with pytest.raises(gt): |
| shards = get_dynamic_pipeline_shards(layer_num, pp_size) |
|
|