| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| The base class for Actor |
| """ |
|
|
| from abc import ABC, abstractmethod |
|
|
| import torch |
|
|
| from verl import DataProto |
|
|
| __all__ = ["BasePPOActor"] |
|
|
|
|
| class BasePPOActor(ABC): |
| def __init__(self, config): |
| """The base class for PPO actor |
| |
| Args: |
| config (DictConfig): a config passed to the PPOActor. We expect the type to be |
| DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general. |
| """ |
| super().__init__() |
| self.config = config |
|
|
| @abstractmethod |
| def compute_log_prob(self, data: DataProto) -> torch.Tensor: |
| """Compute logits given a batch of data. |
| |
| Args: |
| data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, |
| ```attention_mask``` and ```position_ids```. |
| |
| Returns: |
| DataProto: a DataProto containing the key ```log_probs``` |
| |
| |
| """ |
| pass |
|
|
| @abstractmethod |
| def update_policy(self, data: DataProto) -> dict: |
| """Update the policy with an iterator of DataProto |
| |
| Args: |
| data (DataProto): an iterator over the DataProto that returns by |
| ```make_minibatch_iterator``` |
| |
| Returns: |
| Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model |
| such as ```loss```, ```grad_norm```, etc,. |
| |
| """ |
| pass |
|
|