| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | The base class for Actor |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import Any, Dict |
| |
|
| | import torch |
| |
|
| | from ...protocol import DataProto |
| | from .config import ActorConfig |
| |
|
| |
|
| | __all__ = ["BasePPOActor"] |
| |
|
| |
|
| | class BasePPOActor(ABC): |
| | def __init__(self, config: ActorConfig): |
| | """The base class for PPO actor |
| | |
| | Args: |
| | config (ActorConfig): a config passed to the PPOActor. |
| | """ |
| | self.config = config |
| |
|
| | @abstractmethod |
| | def compute_log_prob(self, data: DataProto) -> torch.Tensor: |
| | """Compute logits given a batch of data. |
| | |
| | Args: |
| | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, |
| | ```attention_mask``` and ```position_ids```. |
| | |
| | Returns: |
| | DataProto: a DataProto containing the key ```log_probs``` |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def update_policy(self, data: DataProto) -> Dict[str, Any]: |
| | """Update the policy with an iterator of DataProto |
| | |
| | Args: |
| | data (DataProto): an iterator over the DataProto that returns by |
| | ```make_minibatch_iterator``` |
| | |
| | Returns: |
| | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model |
| | such as ```loss```, ```grad_norm```, etc,. |
| | """ |
| | pass |
| |
|