| | import abc |
| | from typing import Dict |
| |
|
| | from mlagents.trainers.buffer import AgentBuffer |
| |
|
| |
|
| | class Optimizer(abc.ABC): |
| | """ |
| | Creates loss functions and auxillary networks (e.g. Q or Value) needed for training. |
| | Provides methods to update the Policy. |
| | """ |
| |
|
| | def __init__(self): |
| | self.reward_signals = {} |
| |
|
| | @abc.abstractmethod |
| | def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
| | """ |
| | Update the Policy based on the batch that was passed in. |
| | :param batch: AgentBuffer that contains the minibatch of data used for this update. |
| | :param num_sequences: Number of recurrent sequences found in the minibatch. |
| | :return: A Dict containing statistics (name, value) from the update (e.g. loss) |
| | """ |
| | pass |
| |
|