| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import warnings |
| from enum import Enum |
|
|
| from omegaconf import DictConfig |
|
|
| from verl.single_controller.base import Worker |
| from verl.trainer.ppo.core_algos import AdvantageEstimator |
|
|
| WorkerType = type[Worker] |
|
|
|
|
| class Role(Enum): |
| """ |
| To create more roles dynamically, you can subclass Role and add new members |
| """ |
|
|
| Actor = 0 |
| Rollout = 1 |
| ActorRollout = 2 |
| Critic = 3 |
| RefPolicy = 4 |
| RewardModel = 5 |
| ActorRolloutRef = 6 |
| Env = 7 |
|
|
| def __str__(self): |
| return self._get_role_string() |
|
|
| def _get_role_string(self): |
| role_mapping = { |
| Role.Actor: "actor", |
| Role.Rollout: "rollout", |
| Role.ActorRollout: "actor_rollout", |
| Role.Critic: "critic", |
| Role.RefPolicy: "ref", |
| Role.RewardModel: "rm", |
| Role.ActorRolloutRef: "actor_rollout_ref", |
| } |
| return role_mapping.get(self, self.name.lower()) |
|
|
| @classmethod |
| def from_string(cls, name: str): |
| string_mapping = { |
| "actor": cls.Actor, |
| "rollout": cls.Rollout, |
| "actor_rollout": cls.ActorRollout, |
| "critic": cls.Critic, |
| "ref": cls.RefPolicy, |
| "rm": cls.RewardModel, |
| "actor_rollout_ref": cls.ActorRolloutRef, |
| } |
| role = string_mapping.get(name.lower()) |
| if role is None: |
| raise ValueError(f"No Role found for string: {name}") |
| return role |
|
|
|
|
| def need_reference_policy( |
| config: DictConfig, |
| ) -> bool: |
| """Given the config, do we need ref policy.""" |
| return config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss |
|
|
|
|
| def need_reward_model( |
| config: DictConfig, |
| ) -> bool: |
| """Given the config, do we need reward model.""" |
| return config.reward.reward_model.enable |
|
|
|
|
| def need_critic(config: DictConfig) -> bool: |
| """Given a config, do we need critic.""" |
| if config.critic.enable is not None: |
| return bool(config.critic.enable) |
| elif config.algorithm.adv_estimator == AdvantageEstimator.GAE: |
| return True |
| else: |
| warnings.warn( |
| "Disabled critic as algorithm.adv_estimator != gae. If it is not intended, please set critic.enable=True", |
| stacklevel=2, |
| ) |
| return False |
|
|