Spaces:
Running
Running
| from typing import TYPE_CHECKING, Callable, List, Tuple, Union, Dict, Optional | |
| from easydict import EasyDict | |
| from collections import deque | |
| from ding.framework import task | |
| from ding.data import Buffer | |
| from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer | |
| if TYPE_CHECKING: | |
| from ding.framework import Context, OnlineRLContext | |
| from ding.policy import Policy | |
| from ding.reward_model import BaseRewardModel | |
| class OffPolicyLearner: | |
| """ | |
| Overview: | |
| The class of the off-policy learner, including data fetching and model training. Use \ | |
| the `__call__` method to execute the whole learning process. | |
| """ | |
| def __new__(cls, *args, **kwargs): | |
| if task.router.is_active and not task.has_role(task.role.LEARNER): | |
| return task.void() | |
| return super(OffPolicyLearner, cls).__new__(cls) | |
| def __init__( | |
| self, | |
| cfg: EasyDict, | |
| policy: 'Policy', | |
| buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], | |
| reward_model: Optional['BaseRewardModel'] = None, | |
| log_freq: int = 100, | |
| ) -> None: | |
| """ | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Config. | |
| - policy (:obj:`Policy`): The policy to be trained. | |
| - buffer (:obj:`Buffer`): The replay buffer to store the data for training. | |
| - reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \ | |
| default to None. | |
| - log_freq (:obj:`int`): The frequency (iteration) of showing log. | |
| """ | |
| self.cfg = cfg | |
| self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_)) | |
| self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq)) | |
| if reward_model is not None: | |
| self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model)) | |
| else: | |
| self._reward_estimator = None | |
| def __call__(self, ctx: "OnlineRLContext") -> None: | |
| """ | |
| Output of ctx: | |
| - train_output (:obj:`Deque`): The training output in deque. | |
| """ | |
| train_output_queue = [] | |
| for _ in range(self.cfg.policy.learn.update_per_collect): | |
| self._fetcher(ctx) | |
| if ctx.train_data is None: | |
| break | |
| if self._reward_estimator: | |
| self._reward_estimator(ctx) | |
| self._trainer(ctx) | |
| train_output_queue.append(ctx.train_output) | |
| ctx.train_output = train_output_queue | |
| class HERLearner: | |
| """ | |
| Overview: | |
| The class of the learner with the Hindsight Experience Replay (HER). \ | |
| Use the `__call__` method to execute the data featching and training \ | |
| process. | |
| """ | |
| def __init__( | |
| self, | |
| cfg: EasyDict, | |
| policy, | |
| buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], | |
| her_reward_model, | |
| ) -> None: | |
| """ | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Config. | |
| - policy (:obj:`Policy`): The policy to be trained. | |
| - buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training. | |
| - her_reward_model (:obj:`HerRewardModel`): HER reward model. | |
| """ | |
| self.cfg = cfg | |
| self._fetcher = task.wrap(her_data_enhancer(cfg, buffer_, her_reward_model)) | |
| self._trainer = task.wrap(trainer(cfg, policy)) | |
| def __call__(self, ctx: "OnlineRLContext") -> None: | |
| """ | |
| Output of ctx: | |
| - train_output (:obj:`Deque`): The deque of training output. | |
| """ | |
| train_output_queue = [] | |
| for _ in range(self.cfg.policy.learn.update_per_collect): | |
| self._fetcher(ctx) | |
| if ctx.train_data is None: | |
| break | |
| self._trainer(ctx) | |
| train_output_queue.append(ctx.train_output) | |
| ctx.train_output = train_output_queue | |