| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
|
|
| import pytest |
| import ray |
|
|
| from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group |
| from verl.checkpoint_engine import CheckpointEngineManager |
| from verl.single_controller.ray.base import ( |
| RayResourcePool, |
| split_resource_pool, |
| ) |
| from verl.utils.device import get_device_name |
| from verl.utils.ray_utils import auto_await |
| from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig |
|
|
|
|
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("rebuild_group", [False, True]) |
| @pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) |
| @auto_await |
| async def test_nccl_checkpoint_engine( |
| rebuild_group, |
| num_trainer, |
| num_rollout, |
| num_nodes=1, |
| num_gpus_per_node=8, |
| check_allclose=True, |
| model_path="~/models/Qwen/Qwen3-8B-Base", |
| ): |
| model_path = os.path.expanduser(model_path) |
| ray.init( |
| runtime_env={ |
| "env_vars": { |
| "UCX_TLS": "rc,tcp,cuda", |
| "UCX_MAX_RNDV_RAILS": "4", |
| "UCX_LOG_LEVEL": "INFO", |
| "VERL_LOGGING_LEVEL": "DEBUG", |
| } |
| } |
| ) |
|
|
| |
| checkpoint_engine_config = CheckpointEngineConfig( |
| backend="nccl", engine_kwargs={"nccl": {"rebuild_group": rebuild_group}} |
| ) |
| model_config = HFModelConfig(path=model_path, use_remove_padding=True) |
| rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) |
|
|
| |
| resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) |
| trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) |
| trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) |
| trainer.reset() |
| rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) |
|
|
| |
| checkpoint_manager = CheckpointEngineManager(config=checkpoint_engine_config, trainer=trainer, replicas=replicas) |
| for _ in range(3): |
| await checkpoint_manager.update_weights() |
| rollout.check_weights() |
|
|
| ray.shutdown() |
|
|
|
|
| @pytest.mark.skip(reason="temporary skip since our ci environment is not ready") |
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) |
| @pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) |
| @auto_await |
| async def test_nixl_checkpoint_engine( |
| num_trainer, |
| num_rollout, |
| device, |
| num_nodes=1, |
| num_gpus_per_node=8, |
| check_allclose=True, |
| model_path="~/models/Qwen/Qwen3-8B-Base", |
| ): |
| model_path = os.path.expanduser(model_path) |
| ray.init( |
| runtime_env={ |
| "env_vars": { |
| |
| |
| "UCX_TLS": "rc,ud,cuda", |
| |
| |
| "UCX_RC_TIMEOUT": "30s", |
| "UCX_RC_RETRY_COUNT": "7", |
| "UCX_KEEPALIVE_INTERVAL": "1s", |
| "UCX_KEEPALIVE_NUM_EPS": "10", |
| "UCX_MAX_RNDV_RAILS": "4", |
| "UCX_IB_ROCE_REACHABILITY_MODE": "all", |
| "UCX_LOG_LEVEL": "INFO", |
| "VERL_LOGGING_LEVEL": "DEBUG", |
| } |
| } |
| ) |
|
|
| |
| checkpoint_engine_config = CheckpointEngineConfig(backend="nixl", engine_kwargs={"nixl": {"device": device}}) |
| model_config = HFModelConfig(path=model_path, use_remove_padding=True) |
| rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) |
|
|
| |
| resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) |
| trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) |
| trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) |
| trainer.reset() |
| rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) |
|
|
| |
| checkpoint_manager = CheckpointEngineManager(config=checkpoint_engine_config, trainer=trainer, replicas=replicas) |
| for _ in range(3): |
| await checkpoint_manager.update_weights() |
| rollout.check_weights() |
|
|
| ray.shutdown() |
|
|
|
|
| @pytest.mark.skip(reason="temporary skip since our ci environment is not ready") |
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("rebuild_group", [False]) |
| @pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) |
| @auto_await |
| async def test_kimi_checkpoint_engine( |
| rebuild_group, |
| num_trainer, |
| num_rollout, |
| num_nodes=1, |
| num_gpus_per_node=8, |
| check_allclose=True, |
| model_path="~/models/Qwen/Qwen3-8B-Base", |
| ): |
| model_path = os.path.expanduser(model_path) |
| ray.init( |
| runtime_env={ |
| "env_vars": { |
| "NCCL_IB_HCA": "mlx5", |
| "VERL_LOGGING_LEVEL": "DEBUG", |
| } |
| } |
| ) |
|
|
| |
| checkpoint_engine_config = CheckpointEngineConfig( |
| backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}} |
| ) |
| model_config = HFModelConfig(path=model_path, use_remove_padding=True) |
| rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) |
|
|
| |
| resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) |
| resource_pool.get_placement_groups(device_name=get_device_name()) |
| trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) |
| trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) |
| trainer.reset() |
| rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) |
|
|
| |
| checkpoint_manager = CheckpointEngineManager(config=checkpoint_engine_config, trainer=trainer, replicas=replicas) |
| for _ in range(3): |
| await checkpoint_manager.update_weights() |
| rollout.check_weights() |
|
|
| ray.shutdown() |
|
|
|
|
| if __name__ == "__main__": |
| test_nccl_checkpoint_engine( |
| rebuild_group=False, |
| num_trainer=2, |
| num_rollout=30, |
| num_nodes=4, |
| num_gpus_per_node=8, |
| check_allclose=False, |
| model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base", |
| ) |
|
|