|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
The base class for Actor |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, Dict |
|
|
|
|
|
import torch |
|
|
|
|
|
from ...protocol import DataProto |
|
|
from .config import ActorConfig |
|
|
|
|
|
|
|
|
__all__ = ["BasePPOActor"] |
|
|
|
|
|
|
|
|
class BasePPOActor(ABC): |
|
|
def __init__(self, config: ActorConfig): |
|
|
"""The base class for PPO actor |
|
|
|
|
|
Args: |
|
|
config (ActorConfig): a config passed to the PPOActor. |
|
|
""" |
|
|
self.config = config |
|
|
|
|
|
@abstractmethod |
|
|
def compute_log_prob(self, data: DataProto) -> torch.Tensor: |
|
|
"""Compute logits given a batch of data. |
|
|
|
|
|
Args: |
|
|
data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, |
|
|
```attention_mask``` and ```position_ids```. |
|
|
|
|
|
Returns: |
|
|
DataProto: a DataProto containing the key ```log_probs``` |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def update_policy(self, data: DataProto) -> Dict[str, Any]: |
|
|
"""Update the policy with an iterator of DataProto |
|
|
|
|
|
Args: |
|
|
data (DataProto): an iterator over the DataProto that returns by |
|
|
```make_minibatch_iterator``` |
|
|
|
|
|
Returns: |
|
|
Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model |
|
|
such as ```loss```, ```grad_norm```, etc,. |
|
|
""" |
|
|
pass |
|
|
|