| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import ray |
|
|
| from verl.trainer.ppo.ray_trainer import ResourcePoolManager |
| from verl.trainer.ppo.utils import Role, need_reference_policy |
|
|
|
|
| def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: |
| """ |
| Create resource pool manager |
| |
| Args: |
| config: Configuration object |
| roles: List of roles that need to create resource pools |
| |
| Returns: |
| ResourcePoolManager: Resource pool manager |
| """ |
| resource_pool_spec = {} |
| mapping = {} |
|
|
| |
| if any(role in roles for role in [Role.Actor, Role.ActorRollout, Role.Critic, Role.RefPolicy, Role.RewardModel]): |
| assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" |
| assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" |
|
|
| trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes |
| resource_pool_spec["trainer_pool"] = trainer_pool |
|
|
| |
| for role in [Role.Actor, Role.ActorRollout, Role.Critic, Role.RefPolicy, Role.RewardModel]: |
| if role in roles: |
| mapping[role] = "trainer_pool" |
|
|
| |
| if Role.Rollout in roles: |
| assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" |
| assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" |
|
|
| return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) |
|
|
|
|
| def create_role_worker_mapping(config): |
| """ |
| Create mapping from roles to worker classes |
| |
| Args: |
| config: Configuration object |
| |
| Returns: |
| dict: Mapping from roles to worker classes |
| """ |
| |
| if config.trainer.get("use_legacy_worker_impl", "auto") != "disable": |
| raise NotImplementedError( |
| "Fully async policy or One step off policy does not support legacy worker implementation" |
| ) |
|
|
| from verl.experimental.separation.engine_workers import DetachActorWorker |
| from verl.single_controller.ray import RayWorkerGroup |
| from verl.workers.engine_workers import TrainingWorker |
|
|
| ray_worker_group_cls = RayWorkerGroup |
|
|
| train_role = Role.Actor |
| if config.get("async_training", {}).get("use_trainer_do_validate", False): |
| train_role = Role.ActorRollout |
|
|
| role_worker_mapping = { |
| train_role: ray.remote(DetachActorWorker), |
| Role.Critic: ray.remote(TrainingWorker), |
| } |
|
|
| |
| if need_reference_policy(config): |
| role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker) |
|
|
| return role_worker_mapping, ray_worker_group_cls |
|
|